In [34]:
from rdflib import Graph, URIRef, RDFS, RDF, OWL, Literal, Namespace
import os
import json
import numpy as np

In [26]:
import faiss

In [25]:
from sentence_transformers import SentenceTransformer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import numpy as np
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

In [25]:
MOUSE_OWL_PATH = "./data/datasets/anatomy-dataset/mouse.owl"
HUMAN_OWL_PATH = "./data/datasets/anatomy-dataset/human.owl"
ALIGNMENT_RDF_PATH = "./data/datasets/anatomy-dataset/reference.rdf"
HUMAN_NAMESPACE = "http://human.owl#"
MOUSE_NAMESPACE = "http://mouse.owl#"

Getting ontology enriched terms from human.owl to be FAISS friendly

In [24]:
# Paths
ONTOLOGY_INDEX_PATH = "./data/models/ontology_index.faiss"
ONTOLOGY_TRACKER_PATH = "./data/models/ontology_id_tracker.json"
EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-large"

In [27]:
mouse_graph = Graph()
mouse_graph.parse(MOUSE_OWL_PATH)

<Graph identifier=N8ff3d966ec7147be8ce0f2ea00a8b7cc (<class 'rdflib.graph.Graph'>)>

In [28]:
human_graph = Graph()
human_graph.parse(HUMAN_OWL_PATH)

<Graph identifier=N72502864318f4c94904b45445d0fcae0 (<class 'rdflib.graph.Graph'>)>

In [29]:
alignment_graph = Graph()
alignment_graph.parse(ALIGNMENT_RDF_PATH)

<Graph identifier=N393057ed888f4e838964e48977b6da3f (<class 'rdflib.graph.Graph'>)>

In [30]:
def build_text_for_embedding(label, definition=None, synonyms=None, superclasses=None):
    parts = [f"Concept: {label}"]

    if synonyms:
        parts.append(f"Also known as: {', '.join(synonyms)}")

    if superclasses:
        parts.append(f"Part of: {', '.join(superclasses)}")

    if definition:
        parts.append(f"Defined as: {definition}")

    return ". ".join(parts)


In [31]:
def get_label(graph, uri):
    label = graph.value(uri, RDFS.label)
    return str(label) if isinstance(label, Literal) else None

def extract_related_uris(graph, subject, predicate):
    """Dereferences URIs linked by the predicate and returns their rdfs:label."""
    values = []
    for obj in graph.objects(subject, predicate):
        label = get_label(graph, obj)
        if label:
            values.append(label)
    return values

def extract_superclass_labels(graph, subject):
    """Get human-readable labels of direct superclasses."""
    super_labels = []
    for superclass in graph.objects(subject, RDFS.subClassOf):
        if isinstance(superclass, URIRef):
            label = get_label(graph, superclass)
            if label:
                super_labels.append(label)
        elif (superclass, RDF.type, OWL.Restriction) in graph:
            filler = graph.value(superclass, OWL.someValuesFrom)
            if isinstance(filler, URIRef):
                super_labels.append(str(filler).split("#")[-1])
    return super_labels

In [32]:
OBO = Namespace("http://www.geneontology.org/formats/oboInOwl#")

# Now re-run the enrichment step
terms = []
for s in human_graph.subjects(RDF.type, OWL.Class):
    label = get_label(human_graph, s)
    if not label:
        continue

    definition = extract_related_uris(human_graph, s, OBO.hasDefinition)
    synonyms = extract_related_uris(human_graph, s, OBO.hasRelatedSynonym)
    superclasses = extract_superclass_labels(human_graph, s)

    text_parts = [label]
    if definition:
        text_parts.append(definition[0])
    if synonyms:
        text_parts.append("Synonyms: " + "; ".join(synonyms))
    if superclasses:
        text_parts.append("Superclass: " + "; ".join(superclasses))

    enriched_text = build_text_for_embedding(
        label=label,
        definition=definition[0] if definition else None,
        synonyms=synonyms,
        superclasses=superclasses
    )

    terms.append({
        "uri": str(s),
        "label": label,
        "definition": definition[0] if definition else "",
        "synonyms": synonyms,
        "superclasses": superclasses,
        "text_for_embedding": enriched_text
    })

# Save enriched ontology terms to JSON
output_path = "./data/ontology/ontology_terms_enriched.json"
with open(output_path, "w", encoding="utf-8") as f:
    json.dump(terms, f, indent=2, ensure_ascii=False)

output_path, len(terms)

('./data/ontology/ontology_terms_enriched.json', 3298)

In [33]:
OBO = Namespace("http://www.geneontology.org/formats/oboInOwl#")

# Now re-run the enrichment step
terms = []
for s in mouse_graph.subjects(RDF.type, OWL.Class):
    label = get_label(mouse_graph, s)
    if not label:
        continue

    definition = extract_related_uris(mouse_graph, s, OBO.hasDefinition)
    synonyms = extract_related_uris(mouse_graph, s, OBO.hasRelatedSynonym)
    superclasses = extract_superclass_labels(mouse_graph, s)

    text_parts = [label]
    if definition:
        text_parts.append(definition[0])
    if synonyms:
        text_parts.append("Synonyms: " + "; ".join(synonyms))
    if superclasses:
        text_parts.append("Superclass: " + "; ".join(superclasses))

    enriched_text = build_text_for_embedding(
        label=label,
        definition=definition[0] if definition else None,
        synonyms=synonyms,
        superclasses=superclasses
    )

    terms.append({
        "uri": str(s),
        "label": label,
        "definition": definition[0] if definition else "",
        "synonyms": synonyms,
        "superclasses": superclasses,
        "text_for_embedding": enriched_text
    })

# Save enriched ontology terms to JSON
output_path = "./data/ontology/mouse_terms_enriched.json"
with open(output_path, "w", encoding="utf-8") as f:
    json.dump(terms, f, indent=2, ensure_ascii=False)

output_path, len(terms)

('./data/ontology/mouse_terms_enriched.json', 2737)

In [34]:
# Namespaces
ALIGN = Namespace("http://knowledgeweb.semanticweb.org/heterogeneity/alignment")

# Step 1: Extract gold URI mappings from reference.rdf
gold_mappings = {}
for cell in alignment_graph.subjects(RDF.type, ALIGN.Cell):
    mouse_uri = alignment_graph.value(cell, ALIGN.entity1)
    human_uri = alignment_graph.value(cell, ALIGN.entity2)
    if isinstance(mouse_uri, URIRef) and isinstance(human_uri, URIRef):
        gold_mappings[str(mouse_uri)] = str(human_uri)

In [35]:
testset = []
for s in mouse_graph.subjects(RDF.type, OWL.Class):
    label = mouse_graph.value(s, RDFS.label)
    comment = mouse_graph.value(s, RDFS.comment)
    if label:
        entry = {
            "uri": str(s),
            "label": str(label),
            "description": str(comment) if isinstance(comment, Literal) else "",
            "gold_uri": gold_mappings.get(str(s), "")
        }
        if entry["gold_uri"]:
            testset.append(entry)

In [36]:
testset_path = "./data/datasets/anatomy-dataset/mouse_testset.json"
with open(testset_path, "w", encoding="utf-8") as f:
    json.dump(testset, f, indent=2, ensure_ascii=False)

testset_path, len(testset)

('./data/datasets/anatomy-dataset/mouse_testset.json', 1497)

In [37]:
import json
with open("./data/ontology/ontology_terms_enriched.json", "r", encoding="utf-8") as f:
    human_terms = json.load(f)

In [38]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())

CUDA available: True
CUDA device count: 2


In [39]:
device = torch.device("cuda:1")

In [28]:
ONTOLOGY_JSON = "./data/ontology/ontology_terms_enriched.json"
FAISS_INDEX_PATH = "./data/models/ontology_index.faiss"
MOUSE_FAISS_INDEX_PATH = "./data/models/mouse_ontology_index.faiss"
ID_TRACKER_JSON = "./data/models/ontology_id_tracker.json"
MOUSE_ID_TRACKER_JSON = "./data/models/mouse_ontology_id_tracker.json"

In [41]:
def normalize_embeddings(embeddings):
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    return embeddings / norms

In [29]:
def ontology_indexing(ontology_json=ONTOLOGY_JSON, faiss_index_path=FAISS_INDEX_PATH, id_tracker_json=ID_TRACKER_JSON):

    with open(ontology_json, "r", encoding="utf-8") as f:
            ontology_terms = json.load(f)

    texts = []
    ids = []
    valid_terms = []

    for i, term in enumerate(ontology_terms):
        text = term.get("text_for_embedding")
        if not text:
            raise ValueError("Missing 'text_for_embedding'")
        texts.append(text)
        ids.append(abs(hash(term["uri"])) % (10**12))
        valid_terms.append(term)

    # Embedding
    model = SentenceTransformer(EMBEDDING_MODEL_NAME)
    embeddings = model.encode(texts, batch_size=16, show_progress_bar=True)
    embeddings = normalize_embeddings(np.array(embeddings))

    # FAISS indexing
    dimension = embeddings.shape[1]
    base_index = faiss.IndexFlatIP(dimension)
    index = faiss.IndexIDMap(base_index)
    index.add_with_ids(embeddings, np.array(ids))
    os.makedirs(os.path.dirname(faiss_index_path), exist_ok=True)
    faiss.write_index(index, faiss_index_path)

    # Save ID tracker
    id_map = {str(id_): term for id_, term in zip(ids, valid_terms)}
    with open(id_tracker_json, "w", encoding="utf-8") as f:
        json.dump(id_map, f, indent=4, ensure_ascii=False)

In [43]:
ontology_indexing()

Batches: 100%|██████████| 207/207 [00:14<00:00, 13.97it/s]


In [44]:
# Load model + index
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME, device="cpu")
ontology_index = faiss.read_index(ONTOLOGY_INDEX_PATH)
with open(ONTOLOGY_TRACKER_PATH, "r", encoding="utf-8") as f:
    ontology_id_map = {int(k): v for k, v in json.load(f).items()}

In [31]:
def normalize_embeddings(embeddings):
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    return embeddings / norms

def match_to_ontology(text, top_k=1):
    embedding = embedding_model.encode([text])
    embedding = normalize_embeddings(np.array(embedding)).astype(np.float32)
    D, I = ontology_index.search(embedding, top_k)
    matches = []
    for idx, score in zip(I[0], D[0]):
        if idx == -1:
            continue
        term = ontology_id_map.get(idx, {})
        matches.append({
            "uri": term.get("uri"),
            "label": term.get("label"),
            "score": float(score),
            "description": term.get("description")
        })
    return matches


In [46]:
with open("./data/datasets/anatomy-dataset/mouse_testset.json", "r", encoding="utf-8") as f:
    testset = json.load(f)

print("Total test examples:", len(testset))

Total test examples: 1497


In [47]:
# Re-execute helper functions after kernel reset

def normalize_label(label):
    """Lowercase and remove non-alphanumeric characters."""
    return re.sub(r'[^a-z0-9 ]+', '', label.lower().strip())

def build_lexical_index(human_terms, top_k=50):
    """
    Builds a TF-IDF index over human labels and prepares for quick filtering.

    Args:
        human_terms (List[Dict]): Each term should have a 'label' key.
        top_k (int): Number of top lexical candidates to return.

    Returns:
        lexical_filter_fn: A function that returns top_k lexical matches for a given mouse label.
    """
    labels = [normalize_label(t["label"]) for t in human_terms]
    vectorizer = TfidfVectorizer(analyzer='word', ngram_range=(1, 2)).fit(labels)
    label_matrix = vectorizer.transform(labels)

    def lexical_filter_fn(query_label):
        query_vec = vectorizer.transform([normalize_label(query_label)])
        sim = cosine_similarity(query_vec, label_matrix).flatten()
        top_indices = np.argsort(sim)[-top_k:][::-1]
        return [(i, sim[i]) for i in top_indices if sim[i] > 0]

    return lexical_filter_fn

def run_batch_faiss(query_embeddings, faiss_index, top_k=5):
    """
    Runs batch FAISS search.

    Args:
        query_embeddings (np.ndarray): Shape (n_queries, dim), should be normalized.
        faiss_index: FAISS index of normalized human embeddings.
        top_k (int): Number of top results to return.

    Returns:
        D, I: FAISS distances and indices
    """
    return faiss_index.search(query_embeddings, top_k)


In [48]:
results = []

# Prepare: build lexical filter
lexical_filter_fn = build_lexical_index(human_terms, top_k=50)

# Precompute embeddings for all human terms
human_embeddings = np.array([t["text_for_embedding"] for t in human_terms])  # already normalized
human_labels = [t["label"] for t in human_terms]
human_uris = [t["uri"] for t in human_terms]

for case in testset:
    query_label = case["label"]
    gold = case["gold_uri"]

    # 1. Lexical pre-filter
    filtered_indices = [idx for idx, _ in lexical_filter_fn(query_label)]
    if not filtered_indices:
        results.append({
            "query": query_label,
            "gold_uri": gold,
            "predicted_uri": None,
            "confidence": 0,
            "found_at_rank": None
        })
        continue

    # 2. Get filtered embeddings + metadata
    filtered_embeddings = human_embeddings[filtered_indices]
    filtered_uris = [human_uris[i] for i in filtered_indices]

    # 3. Encode query and normalize
    query_emb = embedding_model.encode([query_label])
    query_emb = normalize_embeddings(np.array(query_emb)).astype(np.float32)

    # 4. FAISS match only against filtered pool (manual since sub-index)
    similarities = (query_emb @ filtered_embeddings.T).flatten()
    top_k = np.argsort(similarities)[-5:][::-1]  # top-5 by similarity

    top_matches = [
        {"uri": filtered_uris[i], "score": float(similarities[i])}
        for i in top_k
    ]

    found_at = next((i for i, match in enumerate(top_matches) if match["uri"] == gold), None)

    results.append({
        "query": query_label,
        "gold_uri": gold,
        "predicted_uri": top_matches[0]["uri"] if top_matches else None,
        "confidence": top_matches[0]["score"] if top_matches else 0,
        "found_at_rank": found_at
    })

# Final metrics
at_1 = sum(1 for r in results if r["found_at_rank"] == 0) / len(results)
at_5 = sum(1 for r in results if r["found_at_rank"] is not None) / len(results)

print(f"mapping@1: {at_1:.2%}")
print(f"mapping@5: {at_5:.2%}")



UFuncTypeError: ufunc 'matmul' did not contain a loop with signature matching types (dtype('float32'), dtype('<U1121')) -> None

In [49]:
index = faiss.read_index(FAISS_INDEX_PATH)
model = SentenceTransformer(EMBEDDING_MODEL_NAME)

In [50]:
with open("./data/ontology/mouse_terms_enriched.json", "r", encoding="utf-8") as f:
    mouse_terms = json.load(f)

with open(ID_TRACKER_JSON, "r", encoding="utf-8") as f:
    human_map = json.load(f)
    human_map = {int(k): v for k, v in human_map.items()}

In [51]:
matches = []

for term in mouse_terms:
    text = term.get("text_for_embedding")
    if not text:
        continue

    emb = model.encode([text])
    emb = normalize_embeddings(np.array(emb).astype(np.float32))

    D, I = index.search(emb, 5)

    results = []
    for idx, score in zip(I[0], D[0]):
        if idx == -1:
            continue
        h_term = human_map.get(idx)
        results.append({
            "uri": h_term.get("uri"),
            "label": h_term.get("label"),
            "score": float(score)
        })

    matches.append({
        "mouse_uri": term["uri"],
        "mouse_label": term["label"],
        "top_match": results[0] if results else None,
        "top_k_matches": results
    })

In [52]:
output_path = "./data/datasets/anatomy-dataset/mouse_to_human_matches.json"
with open(output_path, "w", encoding="utf-8") as f:
    json.dump(matches, f, indent=2, ensure_ascii=False)

print(f"Matching complete. Results saved to {output_path}")

Matching complete. Results saved to ./data/datasets/anatomy-dataset/mouse_to_human_matches.json


In [1]:
from rapidfuzz.fuzz import ratio

def rerank_by_label_similarity(mouse_label, top_k_matches, weight_faiss=0.7, weight_label=0.3):
    """
    Re-ranks top_k_matches based on a combined score of FAISS embedding similarity and label similarity.

    Each match must contain: label, score
    """
    reranked = []
    for match in top_k_matches:
        candidate_label = match.get("label", "")
        label_sim = ratio(mouse_label, candidate_label) / 100  # normalize to [0,1]

        combined_score = weight_faiss * match["score"] + weight_label * label_sim

        reranked.append({
            **match,
            "label_similarity": label_sim,
            "combined_score": combined_score
        })

    return sorted(reranked, key=lambda m: -m["combined_score"])


In [35]:
ontology_indexing(ontology_json="./data/ontology/mouse_terms_enriched.json", faiss_index_path=MOUSE_FAISS_INDEX_PATH, id_tracker_json=MOUSE_ID_TRACKER_JSON)

Batches: 100%|██████████| 172/172 [00:04<00:00, 36.52it/s]


In [44]:
with open(MOUSE_ID_TRACKER_JSON, "r", encoding="utf-8") as f:
    mouse_index = faiss.read_index(MOUSE_FAISS_INDEX_PATH)
    mouse_id_map = {int(k): v for k, v in json.load(f).items()}

In [45]:
def is_bidirectional_match(mouse_uri, predicted_human_uri, human_label):
    """Returns True if human→mouse returns the original mouse_uri as top match"""
    # Use FAISS over mouse concepts to match human_label back
    model = SentenceTransformer(EMBEDDING_MODEL_NAME)
    emb = model.encode([human_label])
    emb = normalize_embeddings(np.array(emb).astype(np.float32))

    D, I = mouse_index.search(emb, 1)
    top_idx = I[0][0]
    if top_idx == -1:
        return False
    top_mouse_uri = mouse_id_map.get(top_idx, {}).get("uri")
    return top_mouse_uri == mouse_uri


In [47]:
# Load matcher results and testset (now that both files are re-uploaded)
with open("./data/datasets/anatomy-dataset/mouse_to_human_matches.json", "r", encoding="utf-8") as f:
    matches = json.load(f)

with open("./data/datasets/anatomy-dataset/mouse_testset.json", "r", encoding="utf-8") as f:
    testset = json.load(f)

THRESHOLD = 0.8
TP, FP, FN = 0, 0, 0

# Build a lookup from URI to gold URI
gold_lookup = {entry["uri"]: entry["gold_uri"] for entry in testset}

# Evaluate results
results = []
for m in matches:
    mouse_uri = m["mouse_uri"]
    mouse_label = m["mouse_label"]
    gold_uri = gold_lookup.get(mouse_uri)
    if not gold_uri:
        continue
    #top_k = m["top_k_matches"]
    reranked = rerank_by_label_similarity(mouse_label, m["top_k_matches"], 0.8, 0.2)
    top_match = reranked[0]

    if top_match["combined_score"] < THRESHOLD:
        FN += 1  # prediction too weak
        continue

    predicted_uri = top_match["uri"]
    is_correct = (predicted_uri == gold_uri)

    # if not is_bidirectional_match(mouse_uri, predicted_uri, top_match["label"]):
    #     FN += 1
    #     continue

    if is_correct:
        TP += 1
    else:
        FP += 1

    found_at = next((i for i, match in enumerate(reranked) if match["uri"] == gold_uri), None)

    results.append({
        "mouse_uri": mouse_uri,
        "gold_uri": gold_uri,
        "predicted_uri": reranked[0]["uri"] if reranked else None,
        "confidence": reranked[0]["score"] if reranked else 0,
        "found_at_rank": found_at
    })

# Compute metrics
# total = len(results)
# at_1 = sum(1 for r in results if r["found_at_rank"] == 0) / total
# at_5 = sum(1 for r in results if r["found_at_rank"] is not None) / total

# at_1, at_5, total

precision = TP / (TP + FP) if (TP + FP) > 0 else 0
recall    = TP / (TP + FN) if (TP + FN) > 0 else 0
f1        = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

print(f"Precision: {precision:.2%}")
print(f"Recall:    {recall:.2%}")
print(f"F1 Score:  {f1:.2%}")
print(f"Total evaluated: {TP + FP + FN}")
print(f"Total skipped by threshold: {FN}")


Precision: 73.57%
Recall:    94.36%
F1 Score:  82.68%
Total evaluated: 1497
Total skipped by threshold: 63
