In [8]:
from __future__ import annotations
from typing import Union, NamedTuple, Any, Sequence, Callable, override, overload

from multiprocessing import Pool, cpu_count

from abc import ABC, abstractmethod
from collections.abc import Generator, MutableMapping, Mapping, Sequence

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import auc as sk_auc
from scipy.sparse import coo_matrix
from rank_bm25 import BM25Okapi

from sentence_transformers import SentenceTransformer

from argparse import ArgumentParser, Namespace
from pathlib import Path
from functools import reduce

from pydantic import validate_call
from tqdm import tqdm

import math
import numpy as np
import statistics

import json
import pickle
import logging
import sys
import copy
import re

import torch

from hierarchy_transformers import HierarchyTransformer
from OnT.OnT import OntologyTransformer

In [9]:
VERBOSE = False
! pip install json2html

import json2html
import latextable
from latextable import texttable

[0m

## Basic Utils

*Source: [data_utils.py](../src/hroov/utils/data_utils.py)*

Small helper functions used for tokenisation and stripping higher level concept tags (e.g. '$Hear\ Disease\ (Condition) \rightarrow Heart\ Disease$') for leakage prevention.

TODO: update description.

In [7]:
_regex_parens = re.compile(r"\s*\([^)]*\)") # for parentheses removal (prevents leakage)

def strip_parens(s: str) -> str:
    return _regex_parens.sub("", s)

def load_json(file_path: Path) -> dict[str, str]:
    with file_path.open('r', encoding='utf-8') as fp:
        return json.load(fp)

def save_json(file_path: Path, payload: dict | list, encoding: str = "utf-8", indentation: int = 4) -> None:
    with open(file_path, "w", encoding=encoding) as fp:
        json.dump(payload, fp, indent=indentation)

def load_concepts_to_list(concepts_file_path: Path) -> list[str]:
    return list(load_json(concepts_file_path).values())

def naive_tokenise(seq: str) -> list[str]:
    return seq.lower().split()

def parallel_tokenise(seq_list: list[str], workers: int, chunksize: int = 25000) -> list[list[str]]:
    with Pool(workers) as pool:
        return list(pool.map(naive_tokenise, seq_list, chunksize=chunksize))
    
def produce_candidate_ids_for_embs(embeddings_ds):
    return np.arange(len(embeddings_ds))

## Indexing Utils

*Source: [retrievers.py](../src/hroov/utils/retrievers.py)*

TODO: add descripton here.

In [57]:
# weighted TF-IDF at run-time
def aggregate_posting_scores(query_weights, inverted):
    scores = {}
    for term, weight in query_weights.items():
        if term not in inverted:
            continue
        for doc_id, tfidf_score in inverted[term]:
            scores[doc_id] = scores.get(doc_id, 0.0) + weight * tfidf_score
    return scores

# sort TF-IDF result set
def topk(scores: dict[int, float], k: int = 10):
    return sorted(scores.items(), key=lambda x: x[1], reverse=True)[:k]

def build_tf_idf_index(axiom_list: list[str], tfidf_dest: str, *args, **kwargs):
    vectoriser = TfidfVectorizer(**kwargs)
    doc_term_matrix = vectoriser.fit_transform(axiom_list)
    vocab = vectoriser.get_feature_names_out()
    # prep for storing to disk: create empty postings struct
    inverted_index: dict[str, list[tuple[int, float]]] = {term: [] for term in vocab}
    # see: https://matteding.github.io/2019/04/25/sparse-matrices/
    coo = coo_matrix(doc_term_matrix)
    # populate the inverted index
    for row, col, score in zip(coo.row, coo.col, coo.data):
        inverted_index[str(vocab[col])].append((int(row), float(score)))
    # order: desc
    for postings in inverted_index.values():
        postings.sort(key=lambda x: x[1], reverse=True)
    # save to disk
    with open(tfidf_dest, "wb") as fp:
        pickle.dump(
        {
            "vectorizer": vectoriser,
            "postings": postings, # type: ignore
            "verbalisations": axiom_list
        },
        fp,
        protocol=pickle.HIGHEST_PROTOCOL,
    )
    return vectoriser, inverted_index

<style>
*{
    line-height: 24px;
}
</style>

## Math Functions

*Source: [math_functools.py](../src/hroov/utils/math_functools.py)*

TODO: add description.

**L2 distance:**

$$ L_2 = \| v - u \|_2 \equiv \sqrt{\sum_{i=1}^d (v_i - u_i)^2} $$

**Inner product:**

$$\sum_{i=1}^{d} u_i \cdot v_i $$

**Cosine Similarity:**

$$ sim(u,v) = \operatorname{cos}(\theta) = \frac{u \cdot v }{\| u \|_2 \| v \|_2} $$

**Geodesic Distance (\w adaptive/sectional curvature $\kappa$):**

$$
d_{\kappa}(u,v) = \frac{1}{\sqrt{\kappa}} 
\cdot 
\operatorname{arcosh} 
\Biggl( 1 + \frac{2\kappa \|u - v\|^{2}}{
  \bigl( 1 - \kappa \|u\|^{2} \bigr) \cdot \bigl( 1 - \kappa \|v\|^{2} \bigr)
}
\Biggr),
\qquad
\|u\|,\|v\|<\frac{1}{\sqrt{\kappa}}
$$

In [10]:
def batch_euclidian_l2_distance(u: np.ndarray, vs: np.ndarray) -> np.ndarray:
    return np.linalg.norm(u - vs, axis=1)

def l2_norm(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=np.float32)
    return np.sqrt(np.sum(x**2))

def batch_l2_norm(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=np.float32)
    return np.asarray(np.sqrt(np.sum(x**2, axis=1)))

def inner_product(p_u: np.ndarray, p_v: np.ndarray) -> np.ndarray:
    u = np.asarray(p_u, dtype=np.float32)
    v = np.asarray(p_v, dtype=np.float32)
    return np.inner(u, v)

def batch_inner_product(p_u: np.ndarray, p_vs: np.ndarray) -> np.ndarray:
    u = np.asarray(p_u, dtype=np.float32).ravel()
    vs = np.asarray(p_vs, dtype=np.float32)
    return vs.dot(u)

def cosine_similarity(u, v, normalised=True):
    u = np.asarray(u, dtype=np.float32)
    v = np.asarray(v, dtype=np.float32)
    return np.inner(u, v) if normalised else np.inner(u, v) / (l2_norm(u) * l2_norm(v))

def batch_cosine_similarity(p_u, p_vs, normalised=True):
    u  = np.asarray(p_u,  dtype=np.float32)
    vs = np.asarray(p_vs, dtype=np.float32)
    return batch_inner_product(u, vs) if normalised else batch_inner_product(u, vs) / (l2_norm(u) * batch_l2_norm(vs))

<style>
*{
    line-height: 24px;
}
</style>

## Hyperbolic Space

A Riemannian manifold $\mathcal{M}$, of dimension $d$ is defined as a smooth (differentiable) manifold equipped with a Riemannian metric tensor $g$, such that the manifold is represented by $(\mathcal{M}, g)$. For any point $x \in \mathcal{M}$, there exists a local neighbourhood whose geometry resembles Euclidean geometry. Hyperbolic space $\mathbb{H}^n$ is a Riemannian manifold with a constant sectional curvature $-\kappa$, which can be represented in the Poincaré ball model whose points lie within the open ball, given by:
 
$$
B_{\kappa}^n = \{\ x \in \mathbb{R}^n : \|x\| < r\ \}, \qquad r=\frac{1}{\sqrt{\kappa}},
$$

where $r$ is the radius of the ball. The Poincaré metric $g_{\kappa}$ induces the hyperbolic distance function $d_{\kappa}$ between any two points $x,y \in B^n_{\kappa}$, applied for scoring in §\ref{sec:task-definition-and-methodology}, is given by:


$$
d_{\kappa}(x,y) = \frac{1}{\sqrt{\kappa}} 
\cdot 
\operatorname{arcosh} 
\Biggl( 1 + \frac{2\kappa \|x - y\|^{2}}{
  \bigl( 1 - \kappa \|x\|^{2} \bigr) \cdot \bigl( 1 - \kappa \|y\|^{2} \bigr)
}
\Biggr),
\qquad
\|x\|,\|y\|<\frac{1}{\sqrt{\kappa}}.
$$

As $\|x\|$ and $\|y\|$ approach the boundary of the ball (norm $\to \frac{1}{\sqrt{\kappa}}$), distances diverge even if the Euclidean norm difference $\|x-y\|$ is not itself significant, meaning that points situated near the boundary can represent more specific concepts (since their hyperbolic separation becomes large). This is in contrast to points situated toward the center, that represent more generic concepts.

In [11]:
def batch_poincare_distance_with_curv_k(u: np.ndarray, vs: np.ndarray, k: np.float64 | np.float32) -> np.float64 | np.float32:
    u_norm_sqd = np.sum(u**2)
    vs_norms_sqd = np.sum(vs**2, axis=1)
    l2_dist_sqd = np.sum((u - vs)**2, axis=1)
    offset = 1e-7 # tiny-offset: guard agaisnt division by zero & floating point arithmatic inaccuracies
    arg = 1 + ((2 * k * l2_dist_sqd) / ((1 - (k * u_norm_sqd + offset)) * (1 - (k * vs_norms_sqd + offset)))) # acosh
    arg = np.maximum(1.0, arg) # bounds check: domain of acosh is bound to [1, \inf)
    acosh_scaling = np.float64(1) / np.float64(np.sqrt(k)) # scaling factor: k
    return (acosh_scaling * np.arccosh(arg, dtype=np.float64)) # 1 / sqrt(k) * acosh(arg)

def batch_poincare_dist_with_adaptive_curv_k(u: np.ndarray, vs:np.ndarray, model: HierarchyTransformer | OntologyTransformer, **kwargs):
    if isinstance(model, HierarchyTransformer):
        k = np.float64(model.get_circum_poincareball(model.embed_dim).c)
    elif isinstance(model, OntologyTransformer):
        hierarchy_model = model.hit_model
        hierarchy_poincare_ball = hierarchy_model.get_circum_poincareball(hierarchy_model.embed_dim)
        k = np.float64(hierarchy_poincare_ball.c)    
    else:
        raise Exception("Hyperbolic distance should only be only calculated in B^n or H^n")
    return np.asarray(batch_poincare_distance_with_curv_k(u, vs, k))

<style>
*{
    line-height: 24px;
}
</style>

## Subsumption Score

TODO

In [12]:
def identity(x):
    return x

def subsumption_score_hit(hit_transformer: HierarchyTransformer, child_emb: np.ndarray | torch.Tensor, parent_emd: np.ndarray | torch.Tensor, centri_weight: float = 1.0):
    child_emb_t = torch.Tensor(child_emb)
    parent_emb_t = torch.Tensor(parent_emd)
    dists = hit_transformer.manifold.dist(child_emb_t, parent_emb_t)
    child_norms = hit_transformer.manifold.dist0(child_emb_t)
    parent_norms = hit_transformer.manifold.dist0(parent_emb_t)
    return -(dists + centri_weight * (parent_norms - child_norms))

def subsumption_score_ont(ontology_transformer: OntologyTransformer, child_emb: np.ndarray | torch.Tensor, parent_emb: np.ndarray | torch.Tensor, weight_lambda: float = 1.0):
    child_emb_t = torch.Tensor(child_emb)
    parent_emb_t = torch.Tensor(parent_emb)
    return ontology_transformer.score_hierarchy(child_emb_t, parent_emb_t, weight_lambda)

def entity_subsumption(u: np.ndarray, vs: np.ndarray, model: HierarchyTransformer, *, weight: float = 0.35):
    return np.asarray(subsumption_score_hit(model, u, vs, centri_weight=weight))

def concept_subsumption(u: np.ndarray, vs: np.ndarray, model: OntologyTransformer, *, weight: float = 0.35, **kwargs):
    return np.asarray(subsumption_score_ont(model, u, vs, weight_lambda=weight))

<style>
*{
    line-height: 24px;
}
</style>

## Data Mapping & Management

*Source: [query_utils.py](../src/hroov/utils/query_utils.py)*

TODO

In [13]:
def make_signature(obj):
  if isinstance(obj, Mapping):
    return tuple((key, make_signature(obj[key])) for key in sorted(obj))
  elif isinstance(obj, Sequence) and not isinstance(obj, (str, bytes, bytearray)):
    return tuple(make_signature(item) for item in obj)
  # else:
  return obj

def unique_unhashable_object_list(obj_xs: list[dict]) -> list[dict]:
    object_signatures = set()
    unique_obj_list = []
    for obj in obj_xs:
      signature = make_signature(sorted(obj.items()))
      if signature not in object_signatures:
        object_signatures.add(signature)
        unique_obj_list.append(obj)
    return unique_obj_list

def obj_max_depth(x: int, obj: Any, key: str = "depth") -> int:
  return x if x > obj[key] else obj[key]

def dcg_exp_relevancy_at_pos(relevancy: int, rank_position: int) -> float:
  if relevancy <= 0:
    return float(0.0)
  numerator = (2**relevancy) - 1
  denominator = math.log2(rank_position + 1)
  return float(numerator / denominator)

def dcg_linear_relevancy_at_pos(relevancy: int, rank_position: int) -> float:
  if relevancy <= 0:
    return float(0.0)
  numerator = relevancy
  denominator = math.log2(rank_position + 1)
  return float(numerator / denominator)

def accumulate(a, b, key='dcg'):
  return a + b[key]

#### Data Mapping: Query Models

TODO

**Query Base Model**

In [14]:
class Query:
  
  _query_obj_repr: dict
  _query_string: str
  _target: dict[str, Union[str, int]]
  _entity_mention: dict[str, Union[str, int]]

  def __init__(self, query_obj_repr: dict):
    self._query_obj_repr = query_obj_repr
    self._query_string = query_obj_repr['entity_mention']['entity_literal']
    self._target = query_obj_repr['target_entity']
    self._entity_mention = query_obj_repr['entity_mention']

  def get_query_string(self) -> str:
    return self._query_string
  
  def get_target_iri(self) -> str:
    return str(self._target['iri'])
  
  def get_target_label(self) -> str:
    return str(self._target['rdfs:label'])
  
  def get_target_depth(self) -> int:
    return int(self._target['depth'])
  
  def get_target(self) -> dict:
    return self._target

**Equivalent Query Model**

In [15]:
class EquivQuery(Query):
  
  _equiv_class_expression: Union[str, None]

  def __init__(self, query_obj_repr: dict):
    super().__init__(query_obj_repr)

**Subsumption Query Model**

In [16]:
class SubsumptionQuery(Query):
  
  _parents: list
  _ancestors: list

  def __init__(self, query_obj_repr: dict):
    super().__init__(query_obj_repr)
    self._set_parents()
    self._set_ancestors()

  def _set_parents(self):
    if self._query_obj_repr['parent_entities'] and len(self._query_obj_repr['parent_entities']) > 0:
      self._parents = self._query_obj_repr['parent_entities']
    else:
      self._parents = []

  def _set_ancestors(self):
    if self._query_obj_repr['ancestors'] and len(self._query_obj_repr['ancestors']) > 0:
      self._ancestors = self._query_obj_repr['ancestors']
    else:
      self._ancestors = []

  def get_parents(self) -> list:
    return self._parents
  
  def get_ancestors(self) -> list:
    return self._ancestors
  
  def get_all_subsumptive_targets(self) -> list:
    return [self._target, *self._parents, *self._ancestors]
  
  def get_unique_subsumptive_targets(self) -> list:
    return unique_unhashable_object_list(
      self.get_all_subsumptive_targets()
    )

  def get_sorted_subsumptive_targets(self, key="depth", reverse=False, depth_cutoff=3) -> list:
    xs = self.get_all_subsumptive_targets()
    xs.sort(key=lambda x: x[key], reverse=reverse)
    return xs[:depth_cutoff]
  
  def get_unique_sorted_subsumptive_targets(self, key="depth", reverse=False, depth_cutoff=3) -> list:
    return unique_unhashable_object_list(
      self.get_sorted_subsumptive_targets(key=key, reverse=reverse, depth_cutoff=depth_cutoff)
    )
  
  def get_targets_with_dcg(self, type="exp", depth_cutoff=3, **kwargs) -> tuple[float, list[dict]]:
    # get targets (target, parents, ancestors) ordered in ascending via depth
    targets_asc_depth = self.get_unique_sorted_subsumptive_targets(key="depth", depth_cutoff=depth_cutoff)
    # increase depth by 1 (offsetting zero-based index)
    targets_w_offset = [
      {**x, "depth": x["depth"] + 1}
      for x in targets_asc_depth
    ]
    # find the max depth (to calculate relevancy): (max_depth - depth_at_pos_k) + zero_based_offset
    # which we refer to as: relevancy := ascent height + zero_based_offset
    max_target_depth = reduce(
      obj_max_depth, 
      targets_w_offset, 
      0
    )
    # calculate relevance for each target (node/parent/ancestor) 
    targets_with_rel = [
      {**x, "relevance": (max_target_depth - x["depth"]) + 1}
      for x in targets_w_offset
    ]
    # ensure targets are sorted by relevance
    targets_with_rel.sort(key=lambda x: x['relevance'], reverse=True)
    # calculate dcg:
    if type == "linear":
      targets_with_dcg = [
        {**x, "dcg": dcg_linear_relevancy_at_pos(x['relevance'], rank)}
        for rank, x in enumerate(targets_with_rel, start=1)
      ]
    else:
      targets_with_dcg = [
        {**x, "dcg": dcg_exp_relevancy_at_pos(x['relevance'], rank)}
        for rank, x in enumerate(targets_with_rel, start=1)
      ]
    iDCG = reduce(accumulate, targets_with_dcg, 0)
    self._idcg = iDCG
    self._targets_with_dcg = targets_with_dcg
    return iDCG, targets_with_dcg

  def get_ideal_dcg(self, type="exp"):
    if self._idcg:
      return self._idcg
    iDCG, targets = self.get_targets_with_dcg()
    return iDCG

**QueryObjectMapping**

In [17]:
class QueryObjectMapping:

  _loaded: bool
  _data_file_path: Path
  _data: list
  _equiv_queries: list
  _subsumpt_queries: list

  @validate_call
  def __init__(self, json_data_fp: Path):
    self._loaded = False
    self._data_file_path = json_data_fp
    self._load()
    self._map()

  def _load(self) -> None:
    with self._data_file_path.open('r', encoding='utf-8') as fp:
      self._data = json.load(fp)
    self._loaded = True

  @validate_call
  def load_from_path(self, json_data_fp: Path) -> None:
    # overwrite an existing file path
    self._data_file_path = json_data_fp
    self._load()

  # equivalence_retrieval: bool = True, subsumption_retrieval: bool = False
  def _map(self) -> None:
    equiv_queries = []
    subsumpt_queries = []
    for query_obj_repr in self._data:
      # if the query obj within the data contains an equiv class
      if len(query_obj_repr['equivalent_classes']) > 0:
        # we're dealing with an equiv query
        equiv_queries.append(EquivQuery(query_obj_repr))
      else:
        subsumpt_queries.append(SubsumptionQuery(query_obj_repr))
    self._equiv_queries = equiv_queries
    self._subsumpt_queries = subsumpt_queries

  def get_queries(self) -> tuple[list, list]:
    return (self._equiv_queries, self._subsumpt_queries)
  
  def get_subsumpt_queries_with_no_transformations(self):
    tmp_subsumpt_queries = copy.deepcopy(self._subsumpt_queries)
    result_queries = []
    for query in tmp_subsumpt_queries:
      if query._entity_mention['transformed_entity_literal_for_type_alignment'] == "":
        result_queries.append(query)
    return result_queries
  
  def get_subsumpt_queries_with_transformations_only(self):
    tmp_subsumpt_queries = copy.deepcopy(self._subsumpt_queries)
    result_queries = []
    for query in tmp_subsumpt_queries:
      if query._entity_mention['transformed_entity_literal_for_type_alignment'] != "":
        result_queries.append(query)
    return result_queries

**QueryResult**

In [18]:
class QueryResult(NamedTuple):
  rank: int
  iri: str
  score: float
  verbalisation: str

<style>
*{
    line-height: 24px;
}
</style>

## Retrievers

*Source: [retrievers.py](../src/hroov/utils/retrievers.py), also see: [gpu_retrievers.py](../src/hroov/utils/gpu_retrievers.py) for GPU accelerated retrieval.*

TODO

**BaseRetriever**

In [19]:
class BaseRetriever(ABC):
    
  _verbalisations: list
  _meta_map: list

  @validate_call
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path):
    with open(verbalisations_fp, 'r', encoding='utf-8') as fp:
      self._verbalisations = json.load(fp)
    with open(meta_map_fp, 'r', encoding='utf-8') as fp:
       self._meta_map = json.load(fp)

  @abstractmethod
  def retrieve(self, query_string: str, *, top_k: int = 10, **kwargs) -> list[QueryResult]:
    pass

**BaseModelRetriever**

In [20]:
class BaseModelRetriever(BaseRetriever):
   
  _embeddings: np.ndarray
  _candidate_indicies: np.ndarray
  _model: Union[SentenceTransformer, HierarchyTransformer, OntologyTransformer]
  _score_fn: Callable

  @overload
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path): ...

  @overload
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path, *, score_fn: Callable | None = None): ...
    
  @overload
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path, *, score_fn: Callable | None = None, model_fp: Path | None = None): ...
    
  @overload
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path, *, score_fn: Callable | None = None, model_fp: Path | None = None, model_str: str | None = None): ...

  @validate_call
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path, *, score_fn: Callable | None = None, model_fp: Path | None = None, model_str: str | None = None):
    super().__init__(verbalisations_fp, meta_map_fp)
    self._embeddings = np.load(embeddings_fp, mmap_mode="r")
    self._candidate_indicies = np.arange(len(self._embeddings))
    if score_fn:
      self.register_score_function(score_fn)
    if model_fp:
      try:
        self.register_local_model(model_fp.expanduser().resolve())
      except FileNotFoundError:
        self.register_model(str(model_fp))
    elif (not model_fp and model_str):
      self.register_model(model_str)

  def register_score_function(self, score_fn: Callable):
    self._score_fn = score_fn

  @override
  def retrieve(self, query_string: str, *, top_k: int | None = None, reverse_candidate_scores=False, **kwargs) -> list[QueryResult]:
    """
    TODO: 1. add docstring explaining why **kwargs is accepted and pass through to _score_fn
          2. add explaination of parameters
          3. types (args/return)
    """
    query_embedding = self._embed(query_string)
    scored_embeddings = self._score_fn(query_embedding, self._embeddings, **kwargs)
    if reverse_candidate_scores and top_k is not None:
      top_k_indicies = self._candidate_indicies[np.flip(np.argsort(scored_embeddings))[:top_k]]
    elif not reverse_candidate_scores and top_k is not None:
      top_k_indicies = self._candidate_indicies[np.argsort(scored_embeddings)[:top_k]]
    elif reverse_candidate_scores and top_k is None:
      top_k_indicies = self._candidate_indicies[np.flip(np.argsort(scored_embeddings))]
    elif not reverse_candidate_scores and top_k is None:
      top_k_indicies = self._candidate_indicies[np.argsort(scored_embeddings)]
    else:
      raise KeyError("Valid arguments for reverse_candidate_scores and top_k must be set.")
    results = []
    for rank, candidate_index in enumerate(top_k_indicies):
      candidate_score = scored_embeddings[candidate_index]
      candidate_meta_map = self._meta_map[candidate_index]
      candidate_verbalisation = candidate_meta_map['verbalisation']
      candidate_iri = candidate_meta_map['iri']
      results.append((rank, candidate_iri, candidate_score, candidate_verbalisation))
    return results

  @abstractmethod
  def register_model(self, model: str) -> None:
    pass

  @abstractmethod
  def register_local_model(self, model_fp: Path) -> None:
    pass

  @abstractmethod
  def _embed(self, query_string: str) -> np.ndarray:
    pass


**HiTRetriever**

In [21]:
class HiTRetriever(BaseModelRetriever):

  @override
  def register_model(self, model: str) -> None:
    self._model = HierarchyTransformer.from_pretrained(model)

  @override
  def register_local_model(self, model_fp: Path) -> None:
    self._model = HierarchyTransformer.from_pretrained(str(model_fp.expanduser().resolve()))  

  @override
  def _embed(self, query_string: str) -> np.ndarray:
    return (self._model.encode(
      [query_string]
    ).astype("float32"))[0] # type: ignore

**OnTRetriever**

In [22]:
class OnTRetriever(BaseModelRetriever):

  @override
  def register_model(self, model: str) -> None:
    self._model = OntologyTransformer.load(model)

  @override
  def register_local_model(self, model_fp: Path) -> None:
    self._model = OntologyTransformer.load(str(model_fp.expanduser().resolve()))

  @override
  def _embed(self, query_string: str) -> np.ndarray:
    return (self._model.encode_concept(
      [query_string]
    ).astype("float32"))[0] # type: ignore

**SBERTRetriever**

TODO

In [23]:
class SBERTRetriever(BaseModelRetriever):

  @override
  def register_model(self, model: str) -> None:
    self._model = SentenceTransformer.load(model)

  @override
  def register_local_model(self, model_fp: Path) -> None:
    self._model = SentenceTransformer.load(str(model_fp.expanduser().resolve()))

  @override
  def _embed(self, query_string: str) -> np.ndarray:
    return (self._model.encode(
      [query_string]
    ).astype("float32"))[0] # type: ignore

**BM25Retriever**

TODO

In [26]:
class BM25Retriever(BaseRetriever):
    
  _bm25: BM25Okapi

  @validate_call
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, k1: float = 1.3, b: float = 0.7):
    super().__init__(verbalisations_fp, meta_map_fp)
    self._tokenised_verbalisations = parallel_tokenise(self._verbalisations, workers=4)
    self._bm25 = BM25Okapi(self._tokenised_verbalisations, k1=k1, b=b)

  @classmethod
  def build_from_index(cls, index_fp: Path | str):
    pass

  def save_index(self, index_fp: Path | str):
    if isinstance(index_fp, Path):
      index_fp = str(index_fp.expanduser().resolve())
    with open(index_fp, "wb") as fp:
      pickle.dump({
        "index": self._bm25,
        "verbalisations": self._verbalisations,
        "meta_map": self._meta_map
      }, fp, protocol=pickle.HIGHEST_PROTOCOL)
    print("Saved BM25 index to disk.")
    
  def load_index(self, index_fp: Path | str):
    if isinstance(index_fp, Path):
      index_fp = str(index_fp.expanduser().resolve())
    with open(index_fp, "rb") as fp:
      bm25_bin = pickle.load(fp)
    self._bm25 = bm25_bin['index']
    self._verbalisations = bm25_bin['verbalisations']
    self._meta_map = bm25_bin['meta_map']

  def retrieve(self, query_string: str, *, top_k: int | None = None, **kwargs) -> list[QueryResult]:
    tokens = naive_tokenise(query_string)
    scores = self._bm25.get_scores(tokens)
    if top_k is not None:
      top_idx = np.argsort(scores)[::-1][:top_k]
    else:
      top_idx = np.argsort(scores)[::-1]
    results = []
    for rank, idx in enumerate(top_idx):
      iri = self._meta_map[idx]['iri']
      verbalisation = self._verbalisations[idx]
      results.append(
        QueryResult(
          rank=rank,
          iri=iri,
          score=float(scores[idx]),
          verbalisation=verbalisation
        )
      )
    return results

In [25]:
class TFIDFRetriever(BaseRetriever):
  
  # TODO: clean up and implement `save` and `load` methods, similar to BM25.

  _vectorizer: TfidfVectorizer
  _inverted_index: dict[str, list[tuple[int, float]]]
  _tfidf_matrix: "scipy.sparse.csr_matrix"
  _tokenizer: Callable[[str], Sequence[str]] | None

  @validate_call
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, *,
    lowercase: bool = True, stop_words: str | None = "english",
    ngram_range: tuple[int, int] = (1, 1),
    tokenizer: Callable[[str], Sequence[str]] | None = None,
    max_features: int | None = None,
  ) -> None:
    super().__init__(verbalisations_fp, meta_map_fp)
    self._vectorizer = TfidfVectorizer(
      stop_words="english",
      use_idf=True,
      smooth_idf=True,
      # norm="l2"
      norm=None
    )
    doc_term_matrix = self._vectorizer.fit_transform(self._verbalisations)
    vocab = self._vectorizer.get_feature_names_out()
    inverted_index: dict[str, list[tuple[int, float]]] = {term: [] for term in vocab}
    coo = coo_matrix(doc_term_matrix)

    for row, col, score in zip(coo.row, coo.col, coo.data):
      inverted_index[str(vocab[col])].append((int(row), float(score)))
    for postings in inverted_index.values():
      postings.sort(key=lambda x: x[1], reverse=True)

    self._inverted_index = inverted_index

  def retrieve(self, query_string: str, *, top_k: int | None = None, **kwargs) -> list[QueryResult]:
    query_vec = self._vectorizer.transform([query_string])
    vocab = self._vectorizer.get_feature_names_out()        
    q_weights = {
      vocab[col]: float(val)
      for col, val in zip(query_vec.indices, query_vec.data) # type: ignore
      if val > 0.0
    }
    tfidf_scores = aggregate_posting_scores(q_weights, self._inverted_index)
    if top_k:
      tfidf_top = topk(tfidf_scores, top_k)
    else:
      tfidf_top = topk(tfidf_scores, len(tfidf_scores))
    results: list[QueryResult] = []
    for rank, (doc_id, score) in enumerate(tfidf_top, 1):
      iri = self._meta_map[doc_id]['iri']
      verbalisation = self._meta_map[doc_id]['verbalisation']
      results.append(
        QueryResult(
          rank=rank,
          iri=iri,
          score=float(score),
          verbalisation=verbalisation,
        )
      )
    return results

<style>
*{
    line-height: 24px;
}
</style>

## Evaluation Metrics

Most evaluation metrics are actually defined within the context of the experiments themselves; i.e. they're not pre-defined as functions prior to implementation. However, some intuition for nDCG is provided below.

Other evaluation metrics include:

Hit rate / h@k, MRR, MR, Median Rank, PR-AUC, mAP and Recall@k.

TODO: update.
TODO: convert nDCG implementation details from python comments to something a little more palletable.

In [27]:
# nDCG NOTES:

# We don't want to directly modify the depth (as that might be confusing)
# So.. we're going to create a new set of subsumptive targets with a relevancy score
# where the relevancy := ascent_height + 1 \forall t \in T
# where the ascent_height := max(depth) - depth_at_t
# -> relevancy := (max(depth) - depth_at_t) + 1 \forall t \in T
#
# Example (POLYHIERARCHICAL ONTOLOGY):
#
# Say, you've got a query_string with a target entity, two parent entities and seven ancestors 
# (exclusive of SNOMED CT CONCEPT & owl:Thing):
# 
#                              ---------------------------------------------------------------------------------
#                  ENTITY_TYPE | DEPTH | ASCENT_HEIGHT | ASCENT_HEIGHT + 1 (REL) | 2^{r} - 1 |  = val  |  DCG  |
#                              |-------|---------------|-------------------------|-----------|---------|-------|
#      X      <- TARGET_ENTITY |   0   |       4       |            5            |  2^5 - 1  |    31   |  31.0 |  
#     / \                      |       |               |                         |           |         |       |
#    X   X          <- PARENTS |   1   |       3       |            4            |  2^4 - 1  |    15   |  9.46 |
#   /   / \                    |       |               |                         |           |         |       |
#  X   X   X      <- ANCESTORS |   2   |       2       |            3            |  2^3 - 1  |    7    |  7.50 |
#  |   |    \                  |       |               |                         |           |         |       |
#  |   X     X    <- ANCESTORS |   3   |       1       |            2            |  2^2 - 1  |    3    |  3.57 |
#  |   |    / \                |       |               |                         |           |         |       |
#  |   |   X   X  <- ANCESTORS |   4   |       0       |            1            |  2^1 - 1  |    1    |  1.61 |
#  |   |   |   |               ---------------------------------------------------------------------------------
#  -------------------------                     
#  | TOP SNOMED CT CONCEPT |
#  -------------------------
#  |       owl:Thing       |
#  -------------------------------------------------------------------------------------------------------
#  |              IDEAL ORDERING                   |                  EXAMPLE RESULTS                    |
#  -------------------------------------------------------------------------------------------------------
#  |  RANK      ENTITY NAME          REL    DCG    |  RANK      ENTITY               REL     DCG         |
#  -------------------------------------------------------------------------------------------------------
#  |  1         TARGET_ENTITY        5      31.0   |  1         ANCESTOR_@_3_1|2     2       3.00        |
#  |  2         PARENT_1|2           4      9.46   |  2         ANCESTOR_@_2_1|2|3   3       4.42        |
#  |  3         PARENT_2|2           4      7.50   |  3         TARGET_ENTITY        5       15.5        |
#  |  4         ANCESTOR_@_2_1|2|3   3      3.01   |  4         PARENT_1|2           4       6.46        |
#  |  5         ANCESTOR_@_2_1|2|3   3      2.71   |  5         ANCESTOR_@_2_1|2|3   3       2.71        |
#  |  6         ANCESTOR_@_2_1|2|3   3      2.49   |  6         ANCESTOR_@_2_1|2|3   3       2.49        |
#  |  7         ANCESTOR_@_3_1|2     2      1.00   |  7         ANCESTOR_@_3_1|2     2       1.00        |
#  |  8         ANCESTOR_@_3_1|2     2      0.95   |  8         NOT-RELEVANT-RESULT  0       0.00        |
#  |  9         ANCESTOR_@_4_1|2     1      0.30   |  9         ANCESTOR_@_4_1|2     1       0.30        |
#  |  10        ANCESTOR_@_4_1|2     1      0.29   |  10        NOT-RELEVANT-RESULT  0       0.00        |
# --------------------------------------------------------------------------------------------------------
#  |     iDCG = \sum_i^{\|T\|}t_{dcg} = 58.72      |  DCG@10 = \sum_i^{\|Q_{results}\|}q_dcg = 35.88     |
# --------------------------------------------------------------------------------------------------------
#
#     nDCG@10 =  DCG@10        35.88
#               --------   =   -----   =  0.611
#                iDCG@10       58.72
#
#
#  * depth is measured from the target.
#  * Quick Recap: we're taking an OOV phrase/string (from a set of entity mentions on a QA dataset)
#    -> assigning a target SNOMED entity to that entity mention, s.t. SNOMED entity ~= entity mention
#    -> for the entity mention, we traverse the ontology from the target, up the hierarchy, until we reach the top
#    -> as we traverse the structure, we record the depth, rdfs:label, pref:label alt:labels and IRI of each node
#    -> that allows us to construct a graph (that kind of looks like a tree, as its a fragment of an ontology, which is
#       largely a taxonomy, but it is polyhierarchical, so it ends up being a graph)
#    -> as we get further away from the target, the concepts get more general/broad, so assume/consider the relevancy
#       decreases monotonically as a function of the depth (we implement two relevancy scores for DCG:
#           
#           (1) Relevancy \w exponential decay:  \frac{2^{rel} - 1}{log_2(rank+1)}
#           (2) Relevancy \w linear decay: \frac{rel}{ln(rank+1)}
#
#    -> We opt to use exponential decay, as it more suitably approximates distance in hyperbolic space, though, it is noted
#       that the result is normalised anyway...

In [28]:
from functools import reduce
from typing import Any
import math

def add(a, b, key='dcg'):
  return a + b[key]

def obj_max_depth(x: int, obj: Any, key: str = "depth") -> int:
  return x if x > obj[key] else obj[key]

def dcg_exp_relevancy_at_pos(relevancy: int, rank_position: int) -> float:
  if relevancy <= 0:
    return float(0.0)
  numerator = (2**relevancy) - 1
  denominator = math.log2(rank_position + 1)
  return float(numerator / denominator)

def compute_ndcg_at_k(results: list[tuple[int, str, float, str]], targets_with_dcg_exp: list[dict], k: int = 20) -> float:
  relevance_map = {target['iri']: target['relevance'] for target in targets_with_dcg_exp}
  dcg = 0.0
  for rank, (idx, iri, score, label) in enumerate(results[:k], start=1):
    rel = relevance_map.get(iri, 0)
    dcg += dcg_exp_relevancy_at_pos(rel, rank)
  ideal_dcg = sum(target['dcg'] for target in targets_with_dcg_exp[:k])
  if ideal_dcg == 0:
    return 0.0
  
  return dcg / ideal_dcg

<style>
*{
    line-height: 24px;
}
</style>

## Embedding & Retrieval 

#### Precompute Embeddings

1. Process the entity lexicon (or $\mathcal{EL}$-normalised concept strings) and extract a verbalisation list.
2. Encode each concepts' textual representation into each models native embedding space.
3. Save the embeddings to disk, along with a map between $emb \leftrightarrow concept_{text\ repr}$.

#### Retrieval & Ranking

1. Load the embeddings + their mappings.
2. Accept a query string $q$.
3. Compute a score for $q$ using a retrieval method: $\{ TFIDF, BM25, SBERT, HiT, OnT \}$ using the pre-computed embeddings.
4. `argsort` the embs according to their scores
5. Return the sorted list.

**Embedding:**

In [29]:
# prep for embeding:

print("Preparing data for indexing/encoding...")

data_dir = "../data"
embeddings_dir = "../embeddings"

if not (Path(data_dir).exists()):
  print("[WARNING] No data directory exists. The notebook will fail. Review the README.md, or the docs dir.")

# if an embeddings dir has not yet been created, create one    
Path(embeddings_dir).expanduser().resolve().mkdir(parents=True, exist_ok=True)

 # generated during SNOMED CT processing
entity_lexicon_fp = Path(f"{data_dir}/preprocessed_entity_lexicon.json")

# list of the verbalisations (label text, or deeponto verbs)
verbalisation_list_fp = Path(f"{embeddings_dir}/verbalisations.json")
# each index of the entity_map points to a tuple: (index, label, verbalisation, iri)
entity_map_fp = Path(f"{embeddings_dir}/entity_map.json")
# compiles a list of the above mappings (handy when it comes to argsort)
entity_mappings_list_fp = Path(f"{embeddings_dir}/entity_mappings.json")

entity_lexicon = load_json(entity_lexicon_fp)
iris = entity_lexicon.keys()

entity_map = {}
entity_verbalisation_list = []
list_of_entity_mappings = []

for entity_idx, entity_iri in enumerate(tqdm(iris)):
    entity_map[str(entity_idx)] = {
        "mapping_id": str(entity_idx),
        "label": entity_lexicon[entity_iri].get('name'), # type: ignore
        "verbalisation": strip_parens(str(entity_lexicon[entity_iri].get('name'))).lower(), # type: ignore
        "iri": entity_iri
    }
    entity_verbalisation_list.append(strip_parens(str(entity_lexicon[entity_iri].get('name'))).lower()) # type: ignore
    list_of_entity_mappings.append(entity_map[str(entity_idx)])

save_json(verbalisation_list_fp, entity_verbalisation_list)
save_json(entity_map_fp, entity_map)
save_json(entity_mappings_list_fp, list_of_entity_mappings)

print("Complete!")

Preparing data for indexing/encoding...


100%|██████████| 375724/375724 [00:01<00:00, 344608.53it/s]


Complete!


In [30]:
embs_already_exist = True

In [31]:
if not embs_already_exist:

  sbert_plm_hf_string = "all-MiniLM-L12-v2"
  sbert_plm_encoder = SentenceTransformer.load(sbert_plm_hf_string)
  sbert_plm_embeddings = sbert_plm_encoder.encode(
      entity_verbalisation_list,
      batch_size=128,
      show_progress_bar=True,
      normalize_embeddings=True
  ).astype("float32")
  np.save(f"{embeddings_dir}/sbert-plm-embeddings.npy", sbert_plm_embeddings)

In [32]:
if not embs_already_exist:

  hit_snomed_25_model_fp = '../models/snomed_models/HiT-mixed-SNOMED-25/final'
  hit_snomed_25_encoder = HierarchyTransformer.from_pretrained(hit_snomed_25_model_fp)
  hit_snomed_25_embeddings = hit_snomed_25_encoder.encode(
      entity_verbalisation_list,
      batch_size=128,
      show_progress_bar=True
  ).astype("float32")
  np.save(f"{embeddings_dir}/hit-snomed-25-embeddings.npy", hit_snomed_25_embeddings)

In [33]:
if not embs_already_exist:

  ont_galen_23_pred_model_fp = "../models/models/prediction/OnTr-all-MiniLM-L12-v2-GALEN"
  ont_galen_23_pred_encoder = OntologyTransformer.load(ont_galen_23_pred_model_fp)
  ont_galen_23_pred_embeddings = ont_galen_23_pred_encoder.encode_concept(
      entity_verbalisation_list,
      batch_size=128,
      show_progress_bar=True
  ).astype("float32")
  np.save(f"{embeddings_dir}/ont-galen-23-pred-embeddings.npy", ont_galen_23_pred_embeddings)

In [34]:
if not embs_already_exist:

  ont_anatomy_23_pred_model_fp = "../models/models/prediction/OnTr-all-MiniLM-L12-v2-ANATOMY"
  ont_anatomy_23_pred_encoder = OntologyTransformer.load(ont_anatomy_23_pred_model_fp)
  ont_anatomy_23_pred_embeddings = ont_anatomy_23_pred_encoder.encode_concept(
      entity_verbalisation_list,
      batch_size=128,
      show_progress_bar=True
  ).astype("float32")
  np.save(f"{embeddings_dir}/ont-anatomy-23-pred-embeddings.npy", ont_anatomy_23_pred_embeddings)

In [35]:
if not embs_already_exist:

  ont_gene_ontology_23_pred_model_fp = "../models/models/prediction/OnTr-all-MiniLM-L12-v2-GO"
  ont_gene_ontology_23_pred_encoder = OntologyTransformer.load(ont_gene_ontology_23_pred_model_fp)
  ont_gene_ontology_23_pred_embeddings = ont_gene_ontology_23_pred_encoder.encode_concept(
      entity_verbalisation_list,
      batch_size=128,
      show_progress_bar=True
  ).astype("float32")
  np.save(f"{embeddings_dir}/ont-go-23-pred-embeddings.npy", ont_gene_ontology_23_pred_embeddings)

In [36]:
if not embs_already_exist:

  ont_snomed_LG_model_fp = '../models/snomed_models/OnT-LG'
  ont_snomed_LG_encoder = OntologyTransformer.load(ont_snomed_LG_model_fp)
  ont_snomed_LG_embeddings = ont_snomed_LG_encoder.encode_concept(
      entity_verbalisation_list,
      batch_size=128,
      show_progress_bar=True
  ).astype("float32")
  np.save(f"{embeddings_dir}/ont-snomed-LG-embeddings.npy", ont_snomed_LG_embeddings)

In [37]:
if not embs_already_exist:

  ont_snomed_minified_32_model_fp = '../models/snomed_models/OnTr-m-32'
  ont_snomed_m_32_encoder = OntologyTransformer.load(ont_snomed_minified_32_model_fp)
  ont_snomed_minified_32_embeddings = ont_snomed_m_32_encoder.encode_concept(
      entity_verbalisation_list,
      batch_size=128,
      show_progress_bar=True
  ).astype("float32")
  np.save(f"{embeddings_dir}/ont-snomed-minified-32-embeddings.npy", ont_snomed_minified_32_embeddings)

In [38]:
if not embs_already_exist:

  ont_snomed_minified_model_fp = '../models/snomed_models/OnTr-minified-64'
  ont_snomed_encoder = OntologyTransformer.load(ont_snomed_minified_model_fp)
  ont_snomed_minified_embeddings = ont_snomed_encoder.encode_concept(
      entity_verbalisation_list,
      batch_size=128,
      show_progress_bar=True
  ).astype("float32")
  np.save(f"{embeddings_dir}/ont-snomed-minified-embeddings.npy", ont_snomed_minified_embeddings)

In [39]:
if not embs_already_exist:
    
  ont_snomed_minified_128_model_fp = '../models/snomed_models/OnTr-m-128'
  ont_snomed_m_128_encoder = OntologyTransformer.load(ont_snomed_minified_128_model_fp)
  ont_snomed_minified_128_embeddings = ont_snomed_m_128_encoder.encode_concept(
      entity_verbalisation_list,
      batch_size=128,
      show_progress_bar=True
  ).astype("float32")
  np.save(f"{embeddings_dir}/ont-snomed-minified-128-embeddings.npy", ont_snomed_minified_128_embeddings)

**Retrieval**

*Load the pre-computed embeddings.*

In [40]:
sbert_plm_embs = np.load(f"{embeddings_dir}/sbert-plm-embeddings.npy", mmap_mode="r")
hit_snomed_25_embs = np.load(f"{embeddings_dir}/hit-snomed-25-embeddings.npy", mmap_mode="r") # HiT FULL
ont_galen_23_pred_embs = np.load(f"{embeddings_dir}/ont-galen-23-pred-embeddings.npy", mmap_mode="r")
ont_anatomy_23_pred_embs = np.load(f"{embeddings_dir}/ont-anatomy-23-pred-embeddings.npy", mmap_mode="r")
ont_gene_ontology_23_pred_embs = np.load(f"{embeddings_dir}/ont-go-23-pred-embeddings.npy", mmap_mode="r")
ont_snomed_LG_embs = np.load(f"{embeddings_dir}/ont-snomed-LG-embeddings.npy", mmap_mode="r") # SNOMED FULL
ont_minified_32_embs = np.load(f"{embeddings_dir}/ont-snomed-minified-32-embeddings.npy", mmap_mode="r")
ont_minified_embs = np.load(f"{embeddings_dir}/ont-snomed-minified-embeddings.npy", mmap_mode="r") # M-64
ont_minified_128_embs = np.load(f"{embeddings_dir}/ont-snomed-minified-128-embeddings.npy", mmap_mode="r")

In [41]:
embeddings_dir = "../embeddings"
common_map = Path(f"{embeddings_dir}/entity_mappings.json")
common_verbalisations = Path(f"{embeddings_dir}/verbalisations.json")

**Retrievers: Lexical Methods**

In [42]:
tfidf_ret = TFIDFRetriever(common_verbalisations, common_map)
bm25_ret = BM25Retriever(common_verbalisations, common_map, k1=1.3, b=0.7)

**Retrievers: SBERT**

In [43]:
sbert_plm_hf_string = "all-MiniLM-L12-v2"

sbert_ret_plm_w_cosine_sim = SBERTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_str="all-MiniLM-L12-v2",
  score_fn=batch_cosine_similarity
)

sbert_ret_plm_w_euclid_dist = SBERTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_str="all-MiniLM-L12-v2",
  score_fn=batch_euclidian_l2_distance
)

**Retrievers: HiT**

In [46]:
# Hierarchy Transformer-based Retriever (HiT Full)

hit_snomed_25_model_fp = '../models/snomed_models/HiT-mixed-SNOMED-25/final'
hit_SNOMED25_model_path = Path(hit_snomed_25_model_fp)

hit_ret_snomed_25_w_hyp_dist = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=hit_SNOMED25_model_path,
  score_fn=batch_poincare_dist_with_adaptive_curv_k
)

hit_ret_snomed_25_w_ent_sub = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=hit_SNOMED25_model_path,
  score_fn=entity_subsumption
)

You are trying to use a model that was created with Sentence Transformers version 5.0.0, but you're currently using version 4.1.0. This might cause unexpected behavior or errors. In that case, try to update to the latest version.


**Retrievers OnT:**

#### OnT Encoders (ANATOMY, GALEN, GO)

In [47]:
ont_anatomy_23_pred_model_fp = "../models/models/prediction/OnTr-all-MiniLM-L12-v2-ANATOMY"
ont_anatonmy_pred_model_path = Path(ont_anatomy_23_pred_model_fp)

ont_ret_anatomy_pred_w_hyp_dist = OnTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/ont-anatomy-23-pred-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=ont_anatonmy_pred_model_path,
  score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ont_ret_anatomy_pred_w_con_sub = OnTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/ont-anatomy-23-pred-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=ont_anatonmy_pred_model_path,
  score_fn=concept_subsumption
)

In [48]:
ont_galen_23_pred_model_fp = "../models/models/prediction/OnTr-all-MiniLM-L12-v2-GALEN"
ont_galen_pred_model_path = Path(ont_galen_23_pred_model_fp)

ont_ret_galen_pred_w_hyp_dist = OnTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/ont-galen-23-pred-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=ont_galen_pred_model_path,
  score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ont_ret_galen_pred_w_con_sub = OnTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/ont-galen-23-pred-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=ont_galen_pred_model_path,
  score_fn=concept_subsumption
)

In [49]:
ont_gene_ontology_23_pred_model_fp = "../models/models/prediction/OnTr-all-MiniLM-L12-v2-GO"
ont_go_pred_model_path = Path(ont_gene_ontology_23_pred_model_fp)

ont_ret_go_pred_w_hyp_dist = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-go-23-pred-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_go_pred_model_path,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ont_ret_go_pred_w_con_sub = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-go-23-pred-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_go_pred_model_path,
    score_fn=concept_subsumption
)

#### OnT SNOMED-CT Tuned Models

In [50]:
ont_snomed_LG_model_fp = "../models/snomed_models/OnT-LG"
ont_snomed_LG_model_path = Path(ont_snomed_LG_model_fp)

ont_snomed_LG_w_hyp_dist = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-LG-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_snomed_LG_model_path,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ont_snomed_LG_w_con_sub = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-LG-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_snomed_LG_model_path,
    score_fn=concept_subsumption
)

In [51]:
ontr_snomed_minified_32_model_fp = '../models/snomed_models/OnTr-m-32'
ontr_snomed_minified_32_model_path = Path(ontr_snomed_minified_32_model_fp)

ontr_ret_snomed_minified_32_w_hyp_dist = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-32-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ontr_snomed_minified_32_model_path,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ontr_ret_snomed_minified_32_w_con_sub = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-32-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ontr_snomed_minified_32_model_path,
    score_fn=concept_subsumption
)



In [52]:
ontr_snomed_minified_model_fp = '../models/snomed_models/OnTr-minified-64'
ontr_snomed_minified_model_fp = Path(ontr_snomed_minified_model_fp)

ontr_ret_snomed_minified_w_hyp_dist = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ontr_snomed_minified_model_fp,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ontr_ret_snomed_minified_w_con_sub = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ontr_snomed_minified_model_fp,
    score_fn=concept_subsumption
)



In [53]:
ontr_snomed_minified_128_model_fp = '../models/snomed_models/OnTr-m-128'
ontr_snomed_minified_128_model_fp = Path(ontr_snomed_minified_128_model_fp)

ontr_ret_snomed_minified_128_w_hyp_dist = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-128-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ontr_snomed_minified_128_model_fp,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ontr_ret_snomed_minified_128_w_con_sub = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-128-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ontr_snomed_minified_128_model_fp,
    score_fn=concept_subsumption
)



# Experiments

### Configuration & Data

TODO

In [54]:
# load query objects via QueryObjectMapping:

data_query_mapping = QueryObjectMapping(Path("../data/eval_dataset_50.json"))

equiv_queries, subsumpt_queries = data_query_mapping.get_queries()

global_cutoff_depth = 5 # global cutoff parameter for multi-target retrieval

# copy the original data (as we re-use later)

oov_single_target_queries = copy.deepcopy(subsumpt_queries)
for q in oov_single_target_queries:
    q._ancestors = []
    q._parents = []

oov_match_all = copy.deepcopy(subsumpt_queries)

In [None]:
# set up the 'models dict' ready for experimental runs (single target):

models_dict_single_target = {
  # BASELINES
  "BoW TFIDF": tfidf_ret,
  "BoW BM25": bm25_ret,
  # BASELINE CONTEXTUAL EMBEDDINGS
  "SBERT cos-sim": sbert_ret_plm_w_cosine_sim,
  # HiT (Full)
  "HiT SNO-25(F)": hit_ret_snomed_25_w_hyp_dist,
  # OnT Transfer Models
  "OnT GALEN(P)": ont_ret_galen_pred_w_hyp_dist,
  "OnT ANATOMY(P)": ont_ret_anatomy_pred_w_hyp_dist,
  "OnT GO(P)": ont_ret_go_pred_w_hyp_dist,
  # OnT SNOMED Models (Full, batch_size=64, Mini, batch_size=[32,64,128])
  "OnT SNO-25(FULL-LG)": ont_snomed_LG_w_hyp_dist,
  "OnT SNO-25(M-32)": ontr_ret_snomed_minified_32_w_hyp_dist,
  "OnT SNO-25(M-64)": ontr_ret_snomed_minified_w_hyp_dist,
  "OnT SNO-25(M-128)": ontr_ret_snomed_minified_128_w_hyp_dist,
} 

In [None]:
# OOV QUERIES (SINGLE TARGET) [50]

# PREP TABLE START #
import latextable
from latextable import texttable
experiment_one_table = texttable.Texttable()
experiment_one_table.set_deco(texttable.Texttable.HEADER)
experiment_one_table.set_precision(2)
experiment_one_table.set_cols_dtype(['t', 't', 'f', 'f', 'f', 'f', 'f', 'f', 'f'])
experiment_one_table.set_cols_align(["l", "l", "c", "c", "c", "c", "c", "c", "c"])
experiment_one_table.header(["Model", "Variant", "MRR", "H@1", "H@3", "H@5", "Med", "MR", "R@100"])
# END-PREP TABLE #

ks      = [1, 3, 5, 100, len(entity_verbalisation_list)]
MAX_K   = max(ks)

all_results = {}

for model_name, model in models_dict_single_target.items():
    
    # init accumulators
    results = {
      "MRR": 0.0, # Mean Reciprical Rank
      **{f"H@{k}": 0.0 for k in ks}, # Hits@k
      **{f"P@{k}": 0.0 for k in ks}, # Precision@k
      **{f"R@{k}": 0.0 for k in ks}, # Recall@k
      **{f"F1@{k}": 0.0 for k in ks}, # F1@k
      "MR": 0.0 # Mean Rank
    }
    # PR-AUC, Median Rank & Coverage are calculated during the test procedure
    hit_count = 0 # for coverage
    all_ranks = [] # for median rank

    for q_idx, query in enumerate(oov_single_target_queries):
        
        qstr = query.get_query_string()
        gold_iri = query.get_target_iri()

        ranked_results = [] # empty lists (are unlikely to exist) but are treated as full misses
        
        # TODO: replace with match (?) - i.e. switch
        if isinstance(model, HiTRetriever):
          if model._score_fn == entity_subsumption:
            ranked_results = model.retrieve(qstr, top_k=None, reverse_candidate_scores=True, model=model._model, weight=0.0)
          elif model._score_fn == batch_poincare_dist_with_adaptive_curv_k: 
            ranked_results = model.retrieve(qstr, top_k=None, reverse_candidate_scores=False, model=model._model)
        #
        elif isinstance(model, OnTRetriever):
          if model._score_fn == concept_subsumption:
            ranked_results = model.retrieve(qstr, top_k=None, reverse_candidate_scores=True, model=model._model, weight=0.0)
          elif model._score_fn == batch_poincare_dist_with_adaptive_curv_k: 
            ranked_results = model.retrieve(qstr, top_k=None, reverse_candidate_scores=False, model=model._model)
        #
        elif isinstance(model, SBERTRetriever):
          if model._score_fn == batch_cosine_similarity:
            ranked_results = model.retrieve(qstr, top_k=None, reverse_candidate_scores=True)
          else:
            ranked_results = model.retrieve(qstr, top_k=None, reverse_candidate_scores=False)
        #
        elif isinstance(model, BM25Retriever):
          ranked_results = model.retrieve(qstr, top_k=None)
        #
        elif isinstance(model, TFIDFRetriever):
          ranked_results = model.retrieve(qstr, top_k=None)
        #
        elif isinstance(model, MixedModelRetriever):
           ranked_results = model.retrieve(qstr, top_k=None)
        #
        elif isinstance(model, CustomMixedModelRetriever):
           ranked_results = model.retrieve(qstr, top_k=MAX_K)
        #
        else:
           raise ValueError("No appropriate retriever has been set.")

        retrieved_iris = [iri for (_, iri, _, _) in ranked_results] # type: ignore

        # MRR & MeanRank
        rank_pos = None
        for rank_idx, iri in enumerate(retrieved_iris, start=1):
            if iri == gold_iri:
                rank_pos = rank_idx
                results["MRR"] += 1.0 / rank_idx
                results["MR"] += rank_idx
                break
        
        # for calculating coverage
        if rank_pos is not None:
           hit_count += 1

        # include a penalty to appropriately offset the MR
        # rather than artifically inflating the performance
        # by simply dropping queries that do not contain 
        if rank_pos is None:
            results["MR"] += MAX_K + 1 # penalty: rank := MAX_K + 1

        for k in ks:
            hit = 1 if (rank_pos is not None and rank_pos <= k) else 0
            results[f"H@{k}"] += hit # Hits@K
            p_at_k = hit / k # Precision@K
            results[f"P@{k}"] += p_at_k
            r_at_k = 1 if (rank_pos is not None and rank_pos <= k) else 0
            results[f"R@{k}"] += r_at_k
            if (p_at_k + r_at_k) > 0:
               results[f"F1@{k}"] += 2 * (p_at_k * r_at_k) / (p_at_k + r_at_k)

        final_rank = rank_pos if rank_pos is not None else MAX_K + 1
        all_ranks.append(final_rank)

    # normalise over queries & compute coverage
    N = len(oov_single_target_queries)
    normalized = {metric: value / N for metric, value in results.items()}
    normalized['Cov'] = (hit_count / N) # calculate the coverage of this model
    normalized['Med'] = statistics.median(all_ranks) # median rank
    # area under precision-recall curve (trapezodial rule)
    recall_at_k_xs    = [normalized[f"R@{k}"] for k in ks]
    # check for monotonic recall
    if any(r2 < r1 for r1, r2 in zip(recall_at_k_xs, recall_at_k_xs[1:])):
        raise ValueError(f"Recall must be non-decreasing for PR-AUC")
    precision_at_k_xs = [normalized[f"P@{k}"] for k in ks]
    normalized["AUC"] = float(sk_auc(recall_at_k_xs, precision_at_k_xs))

    print(f"Model: {model_name}")
    print(f"  MRR:    {normalized['MRR']:.2f}")
    for k in [1, 3, 5]:
        print(f"  H@{k}:    {normalized[f'H@{k}']:.2f}")
    print(f"  Med:    {normalized['Med']:.1f}")
    print(f"  MR:     {normalized['MR']:.1f}")
    print(f"  R@100:  {normalized['R@100']}")
    print("-"*60)
    
    model_metric_string = model_name.split()
    experiment_one_table.add_row([model_metric_string[0], model_metric_string[1], 
                                  normalized['MRR'], 
                                  normalized['H@1'], normalized['H@3'], normalized['H@5'], 
                                  normalized['Med'], normalized['MR'], normalized['R@100']])

    all_results[model_name] = normalized

Path('../logs').mkdir(parents=True, exist_ok=True)
output_file = '../logs/oov_entity_mentions_single_target_ANN_50_queries.json'
with open(output_file, 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"All results saved to {output_file}")

print(f"Printing table: \n\n")

print(experiment_one_table.draw())

print("\n\n Printing LaTeX: \n\n")

print(latextable.draw_latex(
    table=experiment_one_table, 
    caption="Single target retrieval performance of OOV entity mentions measured across multiple models (50 Queries)", 
    use_booktabs=True, position="H", caption_above=True, caption_short="Single target performance of OOV mentions", 
    label="tab:single-target-oov"
  )
)



Model: BoW TFIDF
  MRR:    0.26
  H@1:    0.16
  H@3:    0.30
  H@5:    0.34
  Med:    18.5
  MR:     90308.1
  R@100:  0.62
------------------------------------------------------------
Model: BoW BM25
  MRR:    0.30
  H@1:    0.20
  H@3:    0.36
  H@5:    0.44
  Med:    11.5
  MR:     44793.7
  R@100:  0.6
------------------------------------------------------------
Model: SBERT cos-sim
  MRR:    0.51
  H@1:    0.34
  H@3:    0.66
  H@5:    0.70
  Med:    2.0
  MR:     27.3
  R@100:  0.92
------------------------------------------------------------
Model: HiT SNO-25(F)
  MRR:    0.39
  H@1:    0.28
  H@3:    0.46
  H@5:    0.50
  Med:    5.5
  MR:     170.7
  R@100:  0.78
------------------------------------------------------------
Model: OnTr GALEN(P)
  MRR:    0.54
  H@1:    0.34
  H@3:    0.70
  H@5:    0.80
  Med:    2.0
  MR:     11.2
  R@100:  0.96
------------------------------------------------------------
Model: OnTr ANATOMY(P)
  MRR:    0.62
  H@1:    0.48
  H@3:    0.70
  H

## Multi-target Experiments

In [61]:
models_dict_multi_target = {
  # baselines
  "BoW Lexical TFIDF": tfidf_ret,
  "BoW Lexical BM25": bm25_ret,
  # baseline contextual
  "SBERT SNOMED25 cos-sim": sbert_ret_plm_w_cosine_sim,
  "SBERT SNOMED25 d_l2": sbert_ret_plm_w_euclid_dist,
  # HiT models
  "HiT SNOMED25(F) d_k": hit_ret_snomed_25_w_hyp_dist,
  "HiT SNOMED25(F) s_e": hit_ret_snomed_25_w_ent_sub,
  # OnT transferability (prediction)
  "OnTr GALEN(P) d_k": ont_ret_galen_pred_w_hyp_dist,
  "OnTr GALEN(P) s_c": ont_ret_galen_pred_w_con_sub,
  "OnTr ANATOMY(P) d_k": ont_ret_anatomy_pred_w_hyp_dist,
  "OnTr ANATOMY(P) s_c": ont_ret_anatomy_pred_w_con_sub,
  "OnTr GO(P) d_k": ont_ret_go_pred_w_hyp_dist,
  "OnTr GO(P) s_c": ont_ret_go_pred_w_con_sub,
  # OnT SNOMED models
  "OnTr SNO-25(FULL-LG) d_k": ont_snomed_LG_w_hyp_dist,
  "OnTr SNO-25(FULL-LG) s_c": ont_snomed_LG_w_con_sub,
  "OnTr SNO-25(M-32) d_k": ontr_ret_snomed_minified_32_w_hyp_dist,
  "OnTr SNO-25(M-32) s_c": ontr_ret_snomed_minified_32_w_con_sub,
  "OnTr SNO-25(M-64) d_k": ontr_ret_snomed_minified_w_hyp_dist,
  "OnTr SNO-25(M-64) s_c": ontr_ret_snomed_minified_w_con_sub,
  "OnTr SNO-25(M-128) d_k": ontr_ret_snomed_minified_128_w_hyp_dist,
  "OnTr SNO-25(M-128) s_c": ontr_ret_snomed_minified_128_w_con_sub,
}

In [None]:
def average_precision_binary(rels: Iterable[int], total_relevant: int | None = None) -> float:
    rels = [1 if r > 0 else 0 for r in rels]
    if total_relevant is None:
        # AP over the full ranking (or AP@K using seen rels as denominator)
        total_relevant = sum(rels)
    if total_relevant == 0:
        return 0.0
    hits = 0
    cum_prec = 0.0
    for k, r in enumerate(rels, start=1):
        if r:
            hits += 1
            cum_prec += hits / k
    return cum_prec / total_relevant


def pr_points_from_binary(rels: Iterable[int], total_relevant: int | None = None) -> tuple[np.ndarray, np.ndarray]:
    rels = np.asarray([1 if r > 0 else 0 for r in rels], dtype=int)
    if total_relevant is None:
        total_relevant = int(rels.sum())  # fallback
    if total_relevant == 0:
        return np.array([]), np.array([])
    hits_cum = np.cumsum(rels)
    hit_mask = rels == 1
    ranks = np.nonzero(hit_mask)[0] + 1   # 1-based ranks of hits
    precisions = hits_cum[hit_mask] / ranks.astype(float)
    recalls    = hits_cum[hit_mask] / float(total_relevant)
    return recalls, precisions


def interpolate_precision(recall: np.ndarray, precision: np.ndarray, recall_grid: np.ndarray) -> np.ndarray:
  if recall.size == 0:
    return np.zeros_like(recall_grid, dtype=float)
  # sort recall, ensuring monotonicity
  order = np.argsort(recall)
  r = recall[order] # recall in asc (smallest -> largest)
  p = precision[order] # precision @ i : \forall i \in r(ecall)
  # non-increasing precision (cumulative maxima) from right to left
  p_right_max = p.copy()
  # i <- arg (position) \in prec, reversed in desc
  for i in range(len(p_right_max) - 2, -1, -1):
    if p_right_max[i] < p_right_max[i + 1]:
        p_right_max[i] = p_right_max[i + 1]
  # ^ yields (recall, max(precision)) @ k : \forall k \in \{r_0, r_1, \cdot r_n\} \leftarrow \text{recall}
  # i.e. p, r : p -> max(p), r -> r \forall r \in R (r*)
  interp = np.zeros_like(recall_grid, dtype=float)
  for i, rg in enumerate(recall_grid):
      # find first index where recall >= rg
      idx = np.searchsorted(r, rg, side="left")
      if idx < len(p_right_max):
          interp[i] = p_right_max[idx]
      else:
          interp[i] = 0.0
  return interp


def macro_pr_curve(all_query_rels: list[tuple[Iterable[int], int]], recall_points: int = 101) -> tuple[np.ndarray, np.ndarray]:
    recall_grid = np.linspace(0.0, 1.0, recall_points)
    acc = np.zeros_like(recall_grid, dtype=float)
    Q = len(all_query_rels)
    if Q == 0:
        return recall_grid, acc
    for rels, total_relevant in all_query_rels:
        r, p = pr_points_from_binary(rels, total_relevant=total_relevant)
        acc += interpolate_precision(r, p, recall_grid)
    return recall_grid, acc / Q

In [None]:
# OOV (TARGET + ANCESTORS) [weighted subsumption retrieval, d=5, \lambda = 0.35] [50]

# PREP TABLE START #
experiment_three_table = texttable.Texttable()
experiment_three_table.set_deco(texttable.Texttable.HEADER)
experiment_three_table.set_precision(2)
experiment_three_table.set_cols_dtype(['t', 't', 't', 'f', 'f', 'f', 'f', 'f'])
experiment_three_table.set_cols_align(["l", "l", "l", "c", "c", "c", "c", "c"])
experiment_three_table.header(["Model", "Variant", "Metric", "mAP", "MRR*", "nDCG@10", "PR-AUC", "R@100"])
# END-PREP TABLE #

ks      = [1, 3, 5, 10, 100, len(entity_verbalisation_list)]
MAX_K   = max(ks)

all_results = {}
macro_avg_PR_AUC_data = {}

for model_name, model in models_dict_multi_target.items():
    
    # init accumulators
    results = {
      "MRR": 0.0, # Mean Reciprical Rank
      "MAP": 0.0, # Mean Average Precision
      **{f"Hits@{k}": 0.0 for k in ks},
      **{f"P@{k}": 0.0 for k in ks}, # Precision@k
      **{f"R@{k}": 0.0 for k in ks}, # Recall@k
      **{f"F1@{k}": 0.0 for k in ks}, # F1@k
      **{f"nDCG@{k}": 0.0 for k in ks}, # normalised Discounted Cumlative Gain @ k
      "MR": 0.0, # Mean Rank
      "aRP": 0.0  # R-Precision
    }
    # AUC-PR, Median Rank & Coverage are calculated during the test procedure
    hit_count = 0 # for coverage
    total_possible_hits = 0 # for coverage := hit_count / total_possible_hits .. essentially: recall@k, when k = MAX_K
    all_ranks = [] # for median rank
    # @depricationWarning : AUC-PR, previous implementation was rough approximation
    per_query_rels_for_PR = []

    for q_idx, query in enumerate(oov_match_all):
        
        qstr = query.get_query_string()
        gold_targets = query.get_unique_sorted_subsumptive_targets(key="depth", reverse=False, depth_cutoff=global_cutoff_depth) # [*parents, *ancestors]
        g_target_iris = set([x["iri"] for x in gold_targets])
        num_targets = len(g_target_iris)
        total_possible_hits += num_targets
        average_precision = 0.0
        hit_count_this_query = 0
        hit_count_lt_or_eq_num_targets = 0

        ranked_results: list[QueryResult] = [] # empty lists (are unlikely to exist) but are treated as full misses
        
        # TODO: replace with match (?) - i.e. switch
        if isinstance(model, HiTRetriever):
          if model._score_fn == entity_subsumption:
            ranked_results = model.retrieve(qstr, top_k=MAX_K, reverse_candidate_scores=True, model=model._model, weight=0.35)
          elif model._score_fn == batch_poincare_dist_with_adaptive_curv_k: 
            ranked_results = model.retrieve(qstr, top_k=MAX_K, reverse_candidate_scores=False, model=model._model)
        #
        elif isinstance(model, OnTRetriever):
          if model._score_fn == concept_subsumption:
            ranked_results = model.retrieve(qstr, top_k=MAX_K, reverse_candidate_scores=True, model=model._model, weight=0.35)
          elif model._score_fn == batch_poincare_dist_with_adaptive_curv_k: 
            ranked_results = model.retrieve(qstr, top_k=MAX_K, reverse_candidate_scores=False, model=model._model)
        #
        elif isinstance(model, SBERTRetriever):
          if model._score_fn == batch_cosine_similarity:
            ranked_results = model.retrieve(qstr, top_k=MAX_K, reverse_candidate_scores=True)
          else:
            ranked_results = model.retrieve(qstr, top_k=MAX_K, reverse_candidate_scores=False)
        #
        elif isinstance(model, BM25Retriever):
          ranked_results = model.retrieve(qstr, top_k=MAX_K)
        #
        elif isinstance(model, TFIDFRetriever):
          ranked_results = model.retrieve(qstr, top_k=MAX_K)
        #
        elif isinstance(model, MixedModelRetriever):
          ranked_results = model.retrieve(qstr, top_k=MAX_K)
        #
        elif isinstance(model, CustomMixedModelRetriever):
          ranked_results = model.retrieve(qstr, top_k=MAX_K)
        #
        else:
           raise ValueError("No appropriate retriever has been set.")

        retrieved_iris = [iri for (_, iri, _, _) in ranked_results] # type: ignore

        # (macro) PR-AUC
        rel_binary = []
        for rank_idx, iri in enumerate(retrieved_iris, start=1):
            if iri in g_target_iris:
              rel_binary.append(1)
            else:
              rel_binary.append(0)
        per_query_rels_for_PR.append((rel_binary, num_targets))

        # MRR & Mean Rank (on the first hit)
        rank_pos = None
        for rank_idx, iri in enumerate(retrieved_iris, start=1):
            if iri in g_target_iris:
                rank_pos = rank_idx
                results["MRR"] += 1.0 / rank_idx
                results["MR"] += rank_idx
                break
        
        # Average Precision (this query), for use in calculating mAP
        for rank_idx, iri in enumerate(retrieved_iris, start=1):
           if iri in g_target_iris:
              hit_count += 1
              hit_count_this_query += 1
              average_precision += hit_count_this_query / rank_idx
        average_precision /= num_targets
        results["MAP"] += average_precision

        # R-Precision (this query)
        for rank_idx, iri in enumerate(retrieved_iris, start=1):
           if iri in g_target_iris:
              hit_count_lt_or_eq_num_targets += 1
           if rank_idx == num_targets: # then we need to calculate the precision @ this index
              results["aRP"] += hit_count_lt_or_eq_num_targets / num_targets
              break

        # include a penalty to appropriately offset the MR
        # rather than artifically inflating the performance
        # by simply dropping queries that do not contain 
        # (unlikely in this case)
        if rank_pos is None:
            results["MR"] += MAX_K + 1 # penalty: rank := MAX_K + 1

        for k in ks:
            hit = 1 if (rank_pos is not None and rank_pos <= k) else 0
            results[f"Hits@{k}"] += hit
            top_k_results = set(retrieved_iris[:k])
            total_hits_at_k = len(g_target_iris.intersection(top_k_results))
            p_at_k = total_hits_at_k / k # Precision@K
            results[f"P@{k}"] += p_at_k
            r_at_k = total_hits_at_k / num_targets
            results[f"R@{k}"] += r_at_k
            if (p_at_k + r_at_k) > 0:
               results[f"F1@{k}"] += 2 * (p_at_k * r_at_k) / (p_at_k + r_at_k)
            iDCG, targets_with_dcg = query.get_targets_with_dcg(type="exp", depth_cutoff=global_cutoff_depth)
            results[f"nDCG@{k}"] += compute_ndcg_at_k(ranked_results, targets_with_dcg, k) # type: ignore

        final_rank = rank_pos if rank_pos is not None else MAX_K + 1
        all_ranks.append(final_rank)

    # (macro) PR-AUC
    R_grid, P_macro = macro_pr_curve(per_query_rels_for_PR, recall_points=101)
    macro_pr_auc = float(np.trapezoid(P_macro, R_grid))

    # normalise over queries & compute coverage
    N = len(oov_match_all)
    normalized = {metric: value / N for metric, value in results.items()}
    normalized['Cov'] = (hit_count / total_possible_hits) # calculate the coverage of this model
    normalized['Med'] = statistics.median(all_ranks) # median rank
    # area under precision-recall curve (trapezodial rule)
    recall_at_k_xs    = [normalized[f"R@{k}"] for k in ks]
    # check for monotonic recall
    if any(r2 < r1 for r1, r2 in zip(recall_at_k_xs, recall_at_k_xs[1:])):
        raise ValueError(f"Recall must be non-decreasing for PR-AUC")
    precision_at_k_xs = [normalized[f"P@{k}"] for k in ks]
    normalized["AUC"] = float(sk_auc(recall_at_k_xs, precision_at_k_xs))
    normalized["MacroPR_AUC"] = macro_pr_auc

    print(f"Model: {model_name}")
    print(f" mAP: \t  {normalized['MAP']:.2f}") # Mean Average Precision
    print(f" MRR*:   {normalized['MRR']:.2f}") # MRR at first hit ranks
    print(f" nDCG@10: {normalized['nDCG@10']:.2f}") # nDCG@10
    print(f" PR-AUC:  {normalized['AUC']:.2f}") # area under precision-recall curve
    print(f" mPR-AUC: {normalized['MacroPR_AUC']:.2f}") # PR-AUC (macro averaged)
    print(f" R@100:   {normalized['R@100']:.2f}") # Recall@100
    print("-"*60)

    all_results[model_name] = normalized

    model_metric_string = model_name.split()

    experiment_three_table.add_row([
      model_metric_string[0],
      model_metric_string[1],
      model_metric_string[2],
      normalized['MAP'], 
      normalized['MRR'], 
      normalized['nDCG@10'], 
      normalized['MacroPR_AUC'], # normalized['AUC'],
      normalized['R@100']
    ])

    macro_avg_PR_AUC_data[model_name] = {
      "recall": R_grid.tolist(),
      "precision": P_macro.tolist()
    }

Path('../logs').mkdir(parents=True, exist_ok=True)
output_file = '../logs/oov_entity_mentions_multi_relevant_targets_weight_w035_50_queries.json'
with open(output_file, 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"All results saved to {output_file}")

output_macro_pr_auc_file = '../logs/oov_entity_mentions_multi_target_WEIGHTED_w035_50_q__PR_AUC_POINTS_4_PLOT.json'
with open(output_macro_pr_auc_file, 'w') as f:
    json.dump(macro_avg_PR_AUC_data, f, indent=2)

print(f"Macro PR AUC plot data dumped to {output_macro_pr_auc_file}")

print(f"Printing table: \n\n")

print(experiment_three_table.draw())

print("\n\n Printing LaTeX: \n\n")

print(latextable.draw_latex(
    table=experiment_three_table, 
    caption="Performance of fetching multiple relevant entities using OOV mentions with lambda=0.35 (50 Queries)",
    use_booktabs=True, position="H", caption_above=True, caption_short="Multi target performance of OOV mentions, lambda=0.35",
    label="tab:multi-target-oov-weighted"
  )
)

Model: BoW Lexical TFIDF
 mAP: 	  0.08
 MRR*:   0.29
 nDCG@10: 0.22
 PR-AUC:  0.01
 mPR-AUC: 0.08
 R@100:   0.20
------------------------------------------------------------
Model: BoW Lexical BM25
 mAP: 	  0.08
 MRR*:   0.32
 nDCG@10: 0.24
 PR-AUC:  0.02
 mPR-AUC: 0.08
 R@100:   0.19
------------------------------------------------------------
Model: SBERT SNOMED25 cos-sim
 mAP: 	  0.15
 MRR*:   0.55
 nDCG@10: 0.41
 PR-AUC:  0.05
 mPR-AUC: 0.15
 R@100:   0.37
------------------------------------------------------------
Model: SBERT SNOMED25 d_l2
 mAP: 	  0.15
 MRR*:   0.55
 nDCG@10: 0.41
 PR-AUC:  0.05
 mPR-AUC: 0.15
 R@100:   0.37
------------------------------------------------------------
Model: HiT SNOMED25(F) d_k
 mAP: 	  0.14
 MRR*:   0.44
 nDCG@10: 0.32
 PR-AUC:  0.05
 mPR-AUC: 0.15
 R@100:   0.45
------------------------------------------------------------
Model: HiT SNOMED25(F) s_e
 mAP: 	  0.15
 MRR*:   0.39
 nDCG@10: 0.28
 PR-AUC:  0.04
 mPR-AUC: 0.15
 R@100:   0.38
-------