# Init'ing Encoders, etc (temporary)

In [15]:
from __future__ import annotations

from pathlib import Path
from pydantic import validate_call
from rank_bm25 import BM25Okapi
from multiprocessing import Pool, cpu_count
from sklearn.feature_extraction.text import TfidfVectorizer
from sentence_transformers import SentenceTransformer
from argparse import ArgumentParser, Namespace
from scipy.sparse import coo_matrix
from hierarchy_transformers import HierarchyTransformer
from OnT.OnT import OntologyTransformer

from collections.abc import Generator, MutableMapping, Mapping, Sequence
from typing import Union, NamedTuple, override, Any
from functools import reduce
import math

from typing import override, Callable, Union, NamedTuple
from abc import ABC, abstractmethod
import numpy as np

from sklearn.metrics import auc as sk_auc

import statistics
import copy

import numpy as np
import json
import pickle
import logging

import torch

from rank_bm25 import BM25Okapi

from tqdm import tqdm
import re # RegEx

from typing import overload

VERBOSE = False

_regex_parens = re.compile(r"\s*\([^)]*\)") # for parentheses removal (prevents leakage)

def strip_parens(s: str) -> str:
    return _regex_parens.sub("", s)

def load_json(file_path: Path) -> dict[str, str]:
    with file_path.open('r', encoding='utf-8') as fp:
        return json.load(fp)

def save_json(file_path: Path, payload: dict | list, encoding: str = "utf-8", indentation: int = 4) -> None:
    with open(file_path, "w", encoding=encoding) as fp:
        json.dump(payload, fp, indent=indentation)

def load_concepts_to_list(concepts_file_path: Path) -> list[str]:
    return list(load_json(concepts_file_path).values())

def naive_tokenise(seq: str) -> list[str]:
    return seq.lower().split()

# ~800k concepts, assuming 16 cores, 32 threads; 
# 800k / 32 = 25k (note: subprocesses != threads)
def parallel_tokenise(seq_list: list[str], workers: int, chunksize: int = 25000) -> list[list[str]]:
    with Pool(workers) as pool:
        return list(pool.map(naive_tokenise, seq_list, chunksize=chunksize))


def batch_euclidian_l2_distance(u: np.ndarray, vs: np.ndarray) -> np.ndarray:
    return np.linalg.norm(u - vs, axis=1)


def l2_norm(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=np.float32)
    return np.sqrt(np.sum(x**2))


def batch_l2_norm(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x, dtype=np.float32)
    return np.asarray(np.sqrt(np.sum(x**2, axis=1)))


def inner_product(p_u: np.ndarray, p_v: np.ndarray) -> np.ndarray:
    u = np.asarray(p_u, dtype=np.float32)
    v = np.asarray(p_v, dtype=np.float32)
    return np.inner(u, v)


def batch_inner_product(p_u: np.ndarray, p_vs: np.ndarray) -> np.ndarray:
    u = np.asarray(p_u, dtype=np.float32).ravel()
    vs = np.asarray(p_vs, dtype=np.float32)
    return vs.dot(u)


def cosine_similarity(u, v, normalised=True):
    u = np.asarray(u, dtype=np.float32)
    v = np.asarray(v, dtype=np.float32)
    return np.inner(u, v) if normalised else np.inner(u, v) / (l2_norm(u) * l2_norm(v))


def batch_cosine_similarity(p_u, p_vs, normalised=True):
    u  = np.asarray(p_u,  dtype=np.float32)
    vs = np.asarray(p_vs, dtype=np.float32)
    return batch_inner_product(u, vs) if normalised else batch_inner_product(u, vs) / (l2_norm(u) * batch_l2_norm(vs))


def batch_poincare_distance_with_curv_k(u: np.ndarray, vs: np.ndarray, k: np.float64 | np.float32) -> np.float64 | np.float32:
    u_norm_sqd = np.sum(u**2)
    vs_norms_sqd = np.sum(vs**2, axis=1)
    l2_dist_sqd = np.sum((u - vs)**2, axis=1)
    offset = 1e-7 # tiny-offset: guard agaisnt division by zero & floating point arithmatic inaccuracies
    arg = 1 + ((2 * k * l2_dist_sqd) / ((1 - (k * u_norm_sqd + offset)) * (1 - (k * vs_norms_sqd + offset)))) # acosh
    arg = np.maximum(1.0, arg) # bounds check: domain of acosh is bound to [1, \inf)
    acosh_scaling = np.float64(1) / np.float64(np.sqrt(k)) # scaling factor: k
    return (acosh_scaling * np.arccosh(arg, dtype=np.float64)) # 1 / sqrt(k) * acosh(arg)


def batch_poincare_dist_with_adaptive_curv_k(u: np.ndarray, vs:np.ndarray, model: HierarchyTransformer | OntologyTransformer, **kwargs):
    if isinstance(model, HierarchyTransformer):
        k = np.float64(model.get_circum_poincareball(model.embed_dim).c)
    elif isinstance(model, OntologyTransformer):
        hierarchy_model = model.hit_model
        hierarchy_poincare_ball = hierarchy_model.get_circum_poincareball(hierarchy_model.embed_dim)
        k = np.float64(hierarchy_poincare_ball.c)    
    else:
        raise Exception("Hyperbolic distance should only be only calculated in B^n or H^n")
    return np.asarray(batch_poincare_distance_with_curv_k(u, vs, k))

# additional metrics

def identity(x):
    return x

def subsumption_score_hit(hit_transformer: HierarchyTransformer, child_emb: np.ndarray | torch.Tensor, parent_emd: np.ndarray | torch.Tensor, centri_weight: float = 1.0):
    child_emb_t = torch.Tensor(child_emb)
    parent_emb_t = torch.Tensor(parent_emd)
    dists = hit_transformer.manifold.dist(child_emb_t, parent_emb_t)
    child_norms = hit_transformer.manifold.dist0(child_emb_t)
    parent_norms = hit_transformer.manifold.dist0(parent_emb_t)
    return -(dists + centri_weight * (parent_norms - child_norms))

def subsumption_score_ont(ontology_transformer: OntologyTransformer, child_emb: np.ndarray | torch.Tensor, parent_emb: np.ndarray | torch.Tensor, weight_lambda: float = 1.0):
    child_emb_t = torch.Tensor(child_emb)
    parent_emb_t = torch.Tensor(parent_emb)
    return ontology_transformer.score_hierarchy(child_emb_t, parent_emb_t, weight_lambda)


def entity_subsumption(u: np.ndarray, vs: np.ndarray, model: HierarchyTransformer, *, weight: float = 0.4):
    return np.asarray(subsumption_score_hit(model, u, vs, centri_weight=weight))


def concept_subsumption(u: np.ndarray, vs: np.ndarray, model: OntologyTransformer, *, weight: float = 0.4, **kwargs):
    return np.asarray(subsumption_score_ont(model, u, vs, weight_lambda=weight))

# data mapping: utils

def make_signature(obj):
  if isinstance(obj, Mapping):
    return tuple((key, make_signature(obj[key])) for key in sorted(obj))
  elif isinstance(obj, Sequence) and not isinstance(obj, (str, bytes, bytearray)):
    return tuple(make_signature(item) for item in obj)
  # else:
  return obj

def unique_unhashable_object_list(obj_xs: list[dict]) -> list[dict]:
    object_signatures = set()
    unique_obj_list = []
    for obj in obj_xs:
      signature = make_signature(sorted(obj.items()))
      if signature not in object_signatures:
        object_signatures.add(signature)
        unique_obj_list.append(obj)
    return unique_obj_list

def obj_max_depth(x: int, obj: Any, key: str = "depth") -> int:
  return x if x > obj[key] else obj[key]

def dcg_exp_relevancy_at_pos(relevancy: int, rank_position: int) -> float:
  if relevancy <= 0:
    return float(0.0)
  numerator = (2**relevancy) - 1
  denominator = math.log2(rank_position + 1)
  return float(numerator / denominator)

def dcg_linear_relevancy_at_pos(relevancy: int, rank_position: int) -> float:
  if relevancy <= 0:
    return float(0.0)
  numerator = relevancy
  denominator = math.log2(rank_position + 1)
  return float(numerator / denominator)

def accumulate(a, b, key='dcg'):
  return a + b[key]

class Query:
  
  _query_obj_repr: dict
  _query_string: str
  _target: dict[str, Union[str, int]]
  _entity_mention: dict[str, Union[str, int]]

  def __init__(self, query_obj_repr: dict):
    self._query_obj_repr = query_obj_repr
    self._query_string = query_obj_repr['entity_mention']['entity_literal']
    self._target = query_obj_repr['target_entity']
    self._entity_mention = query_obj_repr['entity_mention']

  def get_query_string(self) -> str:
    return self._query_string
  
  def get_target_iri(self) -> str:
    return str(self._target['iri'])
  
  def get_target_label(self) -> str:
    return str(self._target['rdfs:label'])
  
  def get_target_depth(self) -> int:
    return int(self._target['depth'])
  
  def get_target(self) -> dict:
    return self._target


# weighted TF-IDF at run-time
def aggregate_posting_scores(query_weights, inverted):
    scores = {}
    for term, weight in query_weights.items():
        if term not in inverted:
            continue
        for doc_id, tfidf_score in inverted[term]:
            scores[doc_id] = scores.get(doc_id, 0.0) + weight * tfidf_score
    return scores


# sort TF-IDF result set
def topk(scores: dict[int, float], k: int = 10):
    return sorted(scores.items(), key=lambda x: x[1], reverse=True)[:k]


def build_tf_idf_index(axiom_list: list[str], tfidf_dest: str, *args, **kwargs):
    vectoriser = TfidfVectorizer(**kwargs)
    doc_term_matrix = vectoriser.fit_transform(axiom_list)
    vocab = vectoriser.get_feature_names_out()
    # prep for storing to disk: create empty postings struct
    inverted_index: dict[str, list[tuple[int, float]]] = {term: [] for term in vocab}
    # see: https://matteding.github.io/2019/04/25/sparse-matrices/
    coo = coo_matrix(doc_term_matrix)
    # populate the inverted index
    for row, col, score in zip(coo.row, coo.col, coo.data):
        inverted_index[str(vocab[col])].append((int(row), float(score)))
    # order: desc
    for postings in inverted_index.values():
        postings.sort(key=lambda x: x[1], reverse=True)
    # save to disk
    with open(tfidf_dest, "wb") as fp:
        pickle.dump(
        {
            "vectorizer": vectoriser,
            "postings": postings, # type: ignore
            "verbalisations": axiom_list
        },
        fp,
        protocol=pickle.HIGHEST_PROTOCOL,
    )
    return vectoriser, inverted_index


def build_bm_25_index(concept_list: list[str], bm_25_dest: str = "bm25-index.pkl", **kwargs):
    tokenised_concepts = parallel_tokenise(concept_list, **kwargs)
    bm25 = BM25Okapi(tokenised_concepts)
    with open(bm_25_dest, "wb") as fp:
        pickle.dump({
            "bm25": bm25,
            "verbalisations": concept_list,
        }, fp, protocol=pickle.HIGHEST_PROTOCOL
    )

class QueryResult(NamedTuple):
  rank: int
  iri: str
  score: float
  verbalisation: str

class EquivQuery(Query):
  
  _equiv_class_expression: Union[str, None]

  def __init__(self, query_obj_repr: dict):
    super().__init__(query_obj_repr)

class SubsumptionQuery(Query):
  
  _parents: list
  _ancestors: list

  def __init__(self, query_obj_repr: dict):
    super().__init__(query_obj_repr)
    self._set_parents()
    self._set_ancestors()

  def _set_parents(self):
    if self._query_obj_repr['parent_entities'] and len(self._query_obj_repr['parent_entities']) > 0:
      self._parents = self._query_obj_repr['parent_entities']
    else:
      self._parents = []

  def _set_ancestors(self):
    if self._query_obj_repr['ancestors'] and len(self._query_obj_repr['ancestors']) > 0:
      self._ancestors = self._query_obj_repr['ancestors']
    else:
      self._ancestors = []

  def get_parents(self) -> list:
    return self._parents
  
  def get_ancestors(self) -> list:
    return self._ancestors
  
  def get_all_subsumptive_targets(self) -> list:
    return [self._target, *self._parents, *self._ancestors]
  
  def get_unique_subsumptive_targets(self) -> list:
    return unique_unhashable_object_list(
      self.get_all_subsumptive_targets()
    )

  def get_sorted_subsumptive_targets(self, key="depth", reverse=False, depth_cutoff=3) -> list:
    xs = self.get_all_subsumptive_targets()
    xs.sort(key=lambda x: x[key], reverse=reverse)
    return xs[:depth_cutoff]
  
  def get_unique_sorted_subsumptive_targets(self, key="depth", reverse=False, depth_cutoff=3) -> list:
    return unique_unhashable_object_list(
      self.get_sorted_subsumptive_targets(key=key, reverse=reverse, depth_cutoff=depth_cutoff)
    )
  
  def get_targets_with_dcg(self, type="exp", depth_cutoff=3, **kwargs) -> tuple[float, list[dict]]:
    # get targets (target, parents, ancestors) ordered in ascending via depth
    targets_asc_depth = self.get_unique_sorted_subsumptive_targets(key="depth", depth_cutoff=depth_cutoff)
    # increase depth by 1 (offsetting zero-based index)
    targets_w_offset = [
      {**x, "depth": x["depth"] + 1}
      for x in targets_asc_depth
    ]
    # find the max depth (to calculate relevancy): (max_depth - depth_at_pos_k) + zero_based_offset
    # which we refer to as: relevancy := ascent height + zero_based_offset
    max_target_depth = reduce(
      obj_max_depth, 
      targets_w_offset, 
      0
    )
    # calculate relevance for each target (node/parent/ancestor) 
    targets_with_rel = [
      {**x, "relevance": (max_target_depth - x["depth"]) + 1}
      for x in targets_w_offset
    ]
    # ensure targets are sorted by relevance
    targets_with_rel.sort(key=lambda x: x['relevance'], reverse=True)
    # calculate dcg:
    if type == "linear":
      targets_with_dcg = [
        {**x, "dcg": dcg_linear_relevancy_at_pos(x['relevance'], rank)}
        for rank, x in enumerate(targets_with_rel, start=1)
      ]
    else:
      targets_with_dcg = [
        {**x, "dcg": dcg_exp_relevancy_at_pos(x['relevance'], rank)}
        for rank, x in enumerate(targets_with_rel, start=1)
      ]
    iDCG = reduce(accumulate, targets_with_dcg, 0)
    self._idcg = iDCG
    self._targets_with_dcg = targets_with_dcg
    return iDCG, targets_with_dcg

  def get_ideal_dcg(self, type="exp"):
    if self._idcg:
      return self._idcg
    iDCG, targets = self.get_targets_with_dcg()
    return iDCG
  
class QueryObjectMapping:

  _loaded: bool
  _data_file_path: Path
  _data: list
  _equiv_queries: list
  _subsumpt_queries: list

  @validate_call
  def __init__(self, json_data_fp: Path):
    self._loaded = False
    self._data_file_path = json_data_fp
    self._load()
    self._map()

  def _load(self) -> None:
    with self._data_file_path.open('r', encoding='utf-8') as fp:
      self._data = json.load(fp)
    self._loaded = True

  @validate_call
  def load_from_path(self, json_data_fp: Path) -> None:
    # overwrite an existing file path
    self._data_file_path = json_data_fp
    self._load()

  # equivalence_retrieval: bool = True, subsumption_retrieval: bool = False
  def _map(self) -> None:
    equiv_queries = []
    subsumpt_queries = []
    for query_obj_repr in self._data:
      # if the query obj within the data contains an equiv class
      if len(query_obj_repr['equivalent_classes']) > 0:
        # we're dealing with an equiv query
        equiv_queries.append(EquivQuery(query_obj_repr))
      else:
        subsumpt_queries.append(SubsumptionQuery(query_obj_repr))
    self._equiv_queries = equiv_queries
    self._subsumpt_queries = subsumpt_queries

  def get_queries(self) -> tuple[list, list]:
    return (self._equiv_queries, self._subsumpt_queries)
  
  def get_subsumpt_queries_with_no_transformations(self):
    tmp_subsumpt_queries = copy.deepcopy(self._subsumpt_queries)
    result_queries = []
    for query in tmp_subsumpt_queries:
      if query._entity_mention['transformed_entity_literal_for_type_alignment'] == "":
        result_queries.append(query)
    return result_queries
  
  def get_subsumpt_queries_with_transformations_only(self):
    tmp_subsumpt_queries = copy.deepcopy(self._subsumpt_queries)
    result_queries = []
    for query in tmp_subsumpt_queries:
      if query._entity_mention['transformed_entity_literal_for_type_alignment'] != "":
        result_queries.append(query)
    return result_queries
  
class BaseRetriever(ABC):
    
  _verbalisations: list
  _meta_map: list

  @validate_call
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path):
    with open(verbalisations_fp, 'r', encoding='utf-8') as fp:
      self._verbalisations = json.load(fp)
    with open(meta_map_fp, 'r', encoding='utf-8') as fp:
       self._meta_map = json.load(fp)

  @abstractmethod
  def retrieve(self, query_string: str, *, top_k: int = 10, **kwargs) -> list[QueryResult]:
    pass

from typing import overload

class BaseModelRetriever(BaseRetriever):
   
  _embeddings: np.ndarray
  _candidate_indicies: np.ndarray
  _model: Union[SentenceTransformer, HierarchyTransformer, OntologyTransformer]
  _score_fn: Callable

  @overload
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path): ...

  @overload
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path, *, score_fn: Callable | None = None): ...
    
  @overload
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path, *, score_fn: Callable | None = None, model_fp: Path | None = None): ...
    
  @overload
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path, *, score_fn: Callable | None = None, model_fp: Path | None = None, model_str: str | None = None): ...

  @validate_call
  def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, embeddings_fp: Path, *, score_fn: Callable | None = None, model_fp: Path | None = None, model_str: str | None = None):
    super().__init__(verbalisations_fp, meta_map_fp)
    self._embeddings = np.load(embeddings_fp, mmap_mode="r")
    self._candidate_indicies = np.arange(len(self._embeddings))
    if score_fn:
      self.register_score_function(score_fn)
    if model_fp:
      try:
        self.register_local_model(model_fp.expanduser().resolve())
      except FileNotFoundError:
        self.register_model(str(model_fp))
    elif (not model_fp and model_str):
      self.register_model(model_str)

  def register_score_function(self, score_fn: Callable):
    self._score_fn = score_fn

  @override
  def retrieve(self, query_string: str, *, top_k: int | None = None, reverse_candidate_scores=False, **kwargs) -> list[QueryResult]:
    """
    TODO: 1. add docstring explaining why **kwargs is accepted and pass through to _score_fn
          2. add explaination of parameters
          3. types (args/return)
    """
    query_embedding = self._embed(query_string)
    scored_embeddings = self._score_fn(query_embedding, self._embeddings, **kwargs)
    if reverse_candidate_scores and top_k is not None:
      top_k_indicies = self._candidate_indicies[np.flip(np.argsort(scored_embeddings))[:top_k]]
    elif not reverse_candidate_scores and top_k is not None:
      top_k_indicies = self._candidate_indicies[np.argsort(scored_embeddings)[:top_k]]
    elif reverse_candidate_scores and top_k is None:
      top_k_indicies = self._candidate_indicies[np.flip(np.argsort(scored_embeddings))]
    elif not reverse_candidate_scores and top_k is None:
      top_k_indicies = self._candidate_indicies[np.argsort(scored_embeddings)]
    else:
      raise KeyError("Valid arguments for reverse_candidate_scores and top_k must be set.")
    results = []
    for rank, candidate_index in enumerate(top_k_indicies):
      candidate_score = scored_embeddings[candidate_index]
      candidate_meta_map = self._meta_map[candidate_index]
      candidate_verbalisation = candidate_meta_map['verbalisation']
      candidate_iri = candidate_meta_map['iri']
      results.append((rank, candidate_iri, candidate_score, candidate_verbalisation))
    return results

  @abstractmethod
  def register_model(self, model: str) -> None:
    pass

  @abstractmethod
  def register_local_model(self, model_fp: Path) -> None:
    pass

  @abstractmethod
  def _embed(self, query_string: str) -> np.ndarray:
    pass


class HiTRetriever(BaseModelRetriever):

  @override
  def register_model(self, model: str) -> None:
    self._model = HierarchyTransformer.from_pretrained(model)

  @override
  def register_local_model(self, model_fp: Path) -> None:
    self._model = HierarchyTransformer.from_pretrained(str(model_fp.expanduser().resolve()))  

  @override
  def _embed(self, query_string: str) -> np.ndarray:
    return (self._model.encode(
      [query_string]
    ).astype("float32"))[0] # type: ignore
  

class OnTRetriever(BaseModelRetriever):

  @override
  def register_model(self, model: str) -> None:
    self._model = OntologyTransformer.load(model)

  @override
  def register_local_model(self, model_fp: Path) -> None:
    self._model = OntologyTransformer.load(str(model_fp.expanduser().resolve()))

  @override
  def _embed(self, query_string: str) -> np.ndarray:
    return (self._model.encode_concept(
      [query_string]
    ).astype("float32"))[0] # type: ignore


class SBERTRetriever(BaseModelRetriever):

  @override
  def register_model(self, model: str) -> None:
    self._model = SentenceTransformer.load(model)

  @override
  def register_local_model(self, model_fp: Path) -> None:
    self._model = SentenceTransformer.load(str(model_fp.expanduser().resolve()))

  @override
  def _embed(self, query_string: str) -> np.ndarray:
    return (self._model.encode(
      [query_string]
    ).astype("float32"))[0] # type: ignore
  

def custom_mixed_product_distance(d_ont_32: np.ndarray, d_ont_128: np.ndarray, d_sbert: np.ndarray, 
                                  sigma: tuple[float, float, float] = (1.0, 1.0, 1.0),
                                  to_similarity: bool = True, kernel: str = "exp") -> np.ndarray:
    
    sigma_hit, sigma_ont, sigma_sbert = sigma
    d2 = (sigma_hit * d_ont_32)**2 + (sigma_ont * d_ont_128)**2 + (sigma_sbert * d_sbert)**2
    if kernel == "dist":
      return np.sqrt(d2)
    if kernel == "exp":
      return np.exp(-np.sqrt(d2)) # rbf
    if kernel in {"inv", "inverse"}:
        return 1.0 / (1.0 + d2) # inverse‑quad
    raise ValueError("no valid kernel given")


class CustomMixedModelRetriever(BaseRetriever):

    _ont_model_32: OntologyTransformer
    _ont_model_128: OntologyTransformer
    _sbert_model: SentenceTransformer

    _ont_embs_32:   np.ndarray
    _ont_embs_128:   np.ndarray
    _sbert_embs: np.ndarray

    _sigma: np.ndarray

    def __init__(self, verbalisations_fp: Path, meta_map_fp: Path, *,
        ont_model_32: OntologyTransformer, ont_32_embeddings_fp: Path,
        ont_model_128: OntologyTransformer, ont_128_embeddings_fp: Path,
        sbert_model: SentenceTransformer, sbert_embeddings_fp: Path,
        sigma: tuple[float, float, float] = (1.0, 1.0, 1.0),
        kernel: str = "exp") -> None:

        super().__init__(verbalisations_fp, meta_map_fp)

        self._ont_model_32 = ont_model_32
        self._ont_model_128 = ont_model_128
        self._sbert_model = sbert_model

        self._ont_embs_32 = np.load(ont_32_embeddings_fp, mmap_mode="r")
        self._ont_embs_128 = np.load(ont_128_embeddings_fp, mmap_mode="r")
        self._sbert_embs = np.load(sbert_embeddings_fp, mmap_mode="r")

        assert len(self._ont_embs_32) == len(self._ont_embs_128) == len(self._sbert_embs), \
            "all embedding files must contain the same number of rows"

        self._candidate_indices = np.arange(len(self._ont_embs_32))
        self._sigma  = np.asarray(sigma, dtype=np.float32)
        self._kernel = kernel

    def set_sigma(self, sigma: tuple[float, float, float]) -> None:
        self._sigma = np.asarray(sigma, dtype=np.float32)

    def get_sigma(self) -> tuple[float, float, float]:
        return tuple(float(x) for x in self._sigma) # type: ignore

    def retrieve(self, query_string: str, *, top_k: int | None = None, reverse_candidate_scores: bool = False, **kwargs) -> list[QueryResult]:

        q_ont_32 = self._ont_model_32.encode_concept([query_string])[0]
        q_ont_128 = self._ont_model_128.encode_concept([query_string])[0]
        q_sbert = self._sbert_model.encode([query_string], normalize_embeddings=True)[0]

        d_ont_32 = batch_poincare_dist_with_adaptive_curv_k(q_ont_32, self._ont_embs_32, self._ont_model_32)
        d_ont_128 = batch_poincare_dist_with_adaptive_curv_k(q_ont_128, self._ont_embs_128, self._ont_model_128)
        d_sbert = batch_euclidian_l2_distance(q_sbert, self._sbert_embs)

        scores = custom_mixed_product_distance(
            d_ont_32=d_ont_32,
            d_ont_128=d_ont_128,
            d_sbert=d_sbert,
            sigma=tuple(self._sigma),
            to_similarity=True,
            kernel=self._kernel,
        )

        if reverse_candidate_scores and top_k is not None:
            top_idx = np.argsort(scores)[:top_k]
        elif not reverse_candidate_scores and top_k is not None:
            top_idx = np.argsort(-scores)[:top_k]
        elif reverse_candidate_scores and top_k is None:
            top_idx = np.argsort(scores)
        elif not reverse_candidate_scores and top_k is None:
            top_idx = np.argsort(-scores)
        else:
            raise KeyError("Invalid Argument Exception.")

        results: list[QueryResult] = []
        for rank, idx in enumerate(top_idx):
            meta = self._meta_map[idx]
            results.append(
                QueryResult(
                    rank = rank,
                    iri = meta["iri"],
                    score = float(scores[idx]),
                    verbalisation = meta["verbalisation"],
                )
            )
        return results
    

from functools import reduce
from typing import Any
import math


def dcg_exp_relevancy_at_pos(relevancy: int, rank_position: int) -> float:
  if relevancy <= 0:
    return float(0.0)
  numerator = (2**relevancy) - 1
  denominator = math.log2(rank_position + 1)
  return float(numerator / denominator)


def add(a, b, key='dcg'):
  return a + b[key]


def compute_ndcg_at_k(results: list[tuple[int, str, float, str]], targets_with_dcg_exp: list[dict], k: int = 20) -> float:
  relevance_map = {target['iri']: target['relevance'] for target in targets_with_dcg_exp}
  dcg = 0.0
  for rank, (idx, iri, score, label) in enumerate(results[:k], start=1):
    rel = relevance_map.get(iri, 0)
    dcg += dcg_exp_relevancy_at_pos(rel, rank)
  ideal_dcg = sum(target['dcg'] for target in targets_with_dcg_exp[:k])
  if ideal_dcg == 0:
    return 0.0
  
  return dcg / ideal_dcg


data_dir = "./data"
entity_lexicon_fp = Path(f"{data_dir}/snomed_entity_lexicon_2025.json")
verbalisation_list_fp = Path(f"{data_dir}/verbalisations_2025.json")
entity_map_fp = Path(f"{data_dir}/entity_map_2025.json")
entity_mappings_list_fp = Path(f"{data_dir}/entity_mappings_2025.json")

entity_lexicon = load_json(entity_lexicon_fp)
iris = entity_lexicon.keys()
entity_map = {}
entity_verbalisation_list = []
list_of_entity_mappings = []

for entity_idx, entity_iri in enumerate(tqdm(iris)):
    entity_map[str(entity_idx)] = {
        "mapping_id": str(entity_idx),
        "label": entity_lexicon[entity_iri].get('name'), # type: ignore
        "verbalisation": strip_parens(str(entity_lexicon[entity_iri].get('name'))).lower(), # type: ignore
        "iri": entity_iri
    }
    entity_verbalisation_list.append(strip_parens(str(entity_lexicon[entity_iri].get('name'))).lower()) # type: ignore
    list_of_entity_mappings.append(entity_map[str(entity_idx)])

save_json(verbalisation_list_fp, entity_verbalisation_list)
save_json(entity_map_fp, entity_map)
save_json(entity_mappings_list_fp, list_of_entity_mappings)


sbert_plm_hf_string = "all-MiniLM-L12-v2"
sbert_plm_encoder = SentenceTransformer.load(sbert_plm_hf_string)

ont_anatomy_23_pred_model_fp = "models/prediction/OnTr-all-MiniLM-L12-v2-ANATOMY"
ont_anatomy_23_pred_encoder = OntologyTransformer.load(ont_anatomy_23_pred_model_fp)

ontr_snomed_25_uni_model_fp = './models/OnTr-snomed25-uni'
ontr_snomed_25_uni_encoder = OntologyTransformer.load(ontr_snomed_25_uni_model_fp)

ontr_snomed_minified_model_fp = './models/OnTr-minified-64'
ontr_snomed_encoder = OntologyTransformer.load(ontr_snomed_minified_model_fp)

embeddings_dir = "./embeddings"
save_json(Path(f"{embeddings_dir}/axiom-verbalisations.json"), entity_verbalisation_list)
save_json(Path(f"{embeddings_dir}/axiom-mappings.json"), list_of_entity_mappings)

common_map = Path("./embeddings/axiom-mappings.json")
common_verbalisations = Path("./embeddings/axiom-verbalisations.json")
embeddings_dir = "./embeddings"

sbert_plm_embs = np.load(f"{embeddings_dir}/sbert-plm-embeddings.npy", mmap_mode="r")
hit_snomed_23_embs = np.load(f"{embeddings_dir}/hit-snomed-23-embeddings.npy", mmap_mode="r")
hit_snomed_25_embs = np.load(f"{embeddings_dir}/hit-snomed-25-embeddings.npy", mmap_mode="r")
ont_anatomy_23_pred_embs = np.load(f"{embeddings_dir}/ont-anatomy-23-pred-embeddings.npy", mmap_mode="r")
ont_snomed_25_latest_embs = np.load(f"{embeddings_dir}/ont-snomed-25-latest-embeddings.npy", mmap_mode="r")
ont_minified_embs = np.load(f"{embeddings_dir}/ont-snomed-minified-embeddings.npy", mmap_mode="r")

## SBERT Model ##

sbert_ret_plm_w_cosine_sim = SBERTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_str="all-MiniLM-L12-v2",
  score_fn=batch_cosine_similarity
)

sbert_ret_plm_w_euclid_dist = SBERTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_str="all-MiniLM-L12-v2",
  score_fn=batch_euclidian_l2_distance
)

## HiT Models ##

hit_ret_snomed_23_w_hyp_dist = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-23-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_str="Hierarchy-Transformers/HiT-MiniLM-L12-SnomedCT",
  score_fn=batch_poincare_dist_with_adaptive_curv_k
)

hit_ret_snomed_23_w_ent_sub = hit_ret_snomed_23_w_hyp_dist = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-23-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_str="Hierarchy-Transformers/HiT-MiniLM-L12-SnomedCT",
  score_fn=entity_subsumption
)

hit_snomed_23_model = HierarchyTransformer.from_pretrained("Hierarchy-Transformers/HiT-MiniLM-L12-SnomedCT")

# HiT-SNOMED-25 #

hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

hit_ret_snomed_25_w_hyp_dist = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=hit_SNOMED25_model_path,
  score_fn=batch_poincare_dist_with_adaptive_curv_k
)

hit_ret_snomed_25_w_ent_sub = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=hit_SNOMED25_model_path,
  score_fn=entity_subsumption
)

hit_snomed_25_model = HierarchyTransformer.from_pretrained(str(hit_SNOMED25_model_path.expanduser().resolve()))

## ONT - ANATONMY PREDICTION

ont_anatonmy_pred_model_path = Path("models/prediction/OnTr-all-MiniLM-L12-v2-ANATOMY")

ont_ret_anatomy_pred_w_hyp_dist = OnTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/ont-anatomy-23-pred-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=ont_anatonmy_pred_model_path,
  score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ont_ret_anatomy_pred_w_con_sub = OnTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/ont-anatomy-23-pred-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=ont_anatonmy_pred_model_path,
  score_fn=concept_subsumption
)

ont_anatomy_model_pred = OntologyTransformer.load(str(ont_anatonmy_pred_model_path.expanduser().resolve()))



## OnTr - snomed_25 (Hui Trained)

ont_snomed_25_updated_model_path = Path("./models/OnTr-snomed25-uni")

ont_ret_snomed_25_updtd_w_hyp_dist = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-25-latest-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_snomed_25_updated_model_path,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ont_ret_snomed_25_updtd_w_con_sub = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-25-latest-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ont_snomed_25_updated_model_path,
    score_fn=concept_subsumption
)

ont_snomed_25_updtd_model = OntologyTransformer.load(str(ont_snomed_25_updated_model_path.expanduser().resolve()))

## OnTr snomed minified ##

ontr_snomed_minified_model_fp = Path('./models/OnTr-minified-64')

ontr_ret_snomed_minified_w_hyp_dist = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ontr_snomed_minified_model_fp,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

ontr_ret_snomed_minified_w_con_sub = OnTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/ont-snomed-minified-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=ontr_snomed_minified_model_fp,
    score_fn=concept_subsumption
)

ontr_minified_model = OntologyTransformer.load(str(ontr_snomed_minified_model_fp.expanduser().resolve()))

## Tripple Mini OnT-Mini + SBERT Product Manifold

ont_m_32_emb_fp    = Path(f"{embeddings_dir}/ont-snomed-minified-32-embeddings.npy")
ont_m_128_emb_fp   = Path(f"{embeddings_dir}/ont-snomed-minified-128-embeddings.npy")
sbert_emb_fp       = Path("./embeddings/sbert-plm-embeddings.npy")

product_ont_model_32  = OntologyTransformer.load('./models/OnTr-m-32')
product_ont_model_128 = OntologyTransformer.load('./models/OnTr-m-128')
product_sbert_model   = SentenceTransformer("all-MiniLM-L12-v2")

mixed_ret_mini = CustomMixedModelRetriever(
    verbalisations_fp = common_verbalisations,
    meta_map_fp = common_map,
    ont_model_32 = product_ont_model_32,
    ont_32_embeddings_fp = ont_m_32_emb_fp,
    ont_model_128 = product_ont_model_128,
    ont_128_embeddings_fp = ont_m_128_emb_fp,
    sbert_model = product_sbert_model,
    sbert_embeddings_fp = sbert_emb_fp,
    sigma = (1.0, 1.0, 0.35),
    kernel = "exp",
)

100%|██████████| 374536/374536 [00:03<00:00, 119117.74it/s]
  sbert_plm_encoder = SentenceTransformer.load(sbert_plm_hf_string)
  self._model = SentenceTransformer.load(model)


In [16]:
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, LogitsProcessorList
from pydantic import BaseModel
from typing import Any, Callable
from pathlib import Path
from logits_processor_zoo.transformers import (
  CiteFromPromptLogitsProcessor,
  MultipleChoiceLogitsProcessor,
)

from tqdm import tqdm
import torch
import json
import random

In [24]:
def load_json(file_path: Path) -> dict[str, str]:
  with file_path.open('r', encoding='utf-8') as fp:
    return json.load(fp)

def get_dataset(dataset_key: str, benchmark_data: dict):
  return benchmark_data[dataset_key]

def get_question_obj(question_id: str, dataset: dict):
  return dataset[question_id]

def get_dataset_question_mapping(dataset_key: str, benchmark_data: dict):
  data = get_dataset(dataset_key, benchmark_data)
  mapping = {}
  for question_index, question_id in enumerate(data):
    mapping[question_index] = question_id
  return mapping

def get_question_str(question_id: str, dataset: dict):
  return dataset[question_id]['question']

def get_question_opts(question_id: str, dataset: dict):
  return dataset[question_id]['options']

def get_question_ans(question_id: str, dataset: dict):
  return dataset[question_id]['answer']

def get_dataset_names(benchmark_data: dict):
  return list(benchmark_data.keys())

def get_question_count(dataset_name: str, benchmark_data: dict):
  return len(benchmark_data[dataset_name])

def get_random_question_sample(benchmark_data: dict, allowable_datasets: list[str] = ['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']):
  random_dataset_name = allowable_datasets[random.randint(0, len(allowable_datasets) - 1)]
  dataset_question_mapping = get_dataset_question_mapping(random_dataset_name, benchmark_data)
  # ^ provides a map from custom indicies used to access questions specific to each dataset
  dataset_questons_xs = benchmark_data[random_dataset_name]
  random_question_index = dataset_question_mapping[random.randint(0, len(dataset_question_mapping) - 1)]
  return dataset_questons_xs[random_question_index]

def xs_of_all_questions(benchmark_data: dict, allowable_datasets: list[str] = ['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']):
  xs = []
  for dataset_name in allowable_datasets:
    dataset_question_mappings = get_dataset_question_mapping(dataset_name, benchmark_data)
    question_list = benchmark_data[dataset_name]
    for itr, mapping_idx in dataset_question_mappings.items():
      xs.append(question_list[mapping_idx])
  return xs

def get_question_entity_mentions(entity_mention_data: dict, dataset: str, question_id: str):
  for question in entity_mention_data["questions"]:
    if question['source_dataset'] == dataset and question['question_id'] == question_id:
      return question['entities'] # warning: it is possible to be []

def merge_entity_mentions(benchmark_data: dict, biomed_entities: dict, head_entities: dict, allowable_datasets: list[str] = ['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']):
  xs = []
  for dataset_name in allowable_datasets:
    print(f"Processing {dataset_name} ... ")
    dataset_question_mappings = get_dataset_question_mapping(dataset_name, benchmark_data)
    question_list = benchmark_data[dataset_name]
    for itr, mapping_idx in tqdm(dataset_question_mappings.items()):
      question_entities_biomedical = get_question_entity_mentions(biomed_entities, dataset_name, mapping_idx)
      question_entities_head = get_question_entity_mentions(head_entities, dataset_name, mapping_idx)
      question_list[mapping_idx]['entities'] = []
      question_list[mapping_idx]['entities'].extend(question_entities_biomedical)
      question_list[mapping_idx]['entities'].extend(question_entities_head)
      xs.append(question_list[mapping_idx])
  return xs

In [25]:
mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))

def get_random_mirage_question():
  return get_random_question_sample(mirage_benchmark)

In [26]:
mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))

max_options = 0
option_keys = set()

for dataset in mirage_benchmark:
  print(f"Processing {dataset} with {len(mirage_benchmark[dataset])} questions ...")
  for dataset_question in mirage_benchmark[dataset]:
    num_options_in_question = len(mirage_benchmark[dataset][dataset_question]['options'])
    options = list(mirage_benchmark[dataset][dataset_question]['options'].keys())
    for option in options:
      option_keys.add(option)
    if num_options_in_question > max_options:
      if VERBOSE:
        print(f"New maximum number of options founds: {num_options_in_question}")
      max_options = num_options_in_question

print(f"\nFinished! The maximum number of options in any question is {max_options}")
print(f"The set of option keys consists of the following elements: {option_keys}")

Processing medqa with 1273 questions ...
Processing medmcqa with 4183 questions ...
Processing pubmedqa with 500 questions ...
Processing bioasq with 618 questions ...
Processing mmlu with 1089 questions ...

Finished! The maximum number of options in any question is 4
The set of option keys consists of the following elements: {'A', 'D', 'C', 'B'}


In [27]:
dataset_names = get_dataset_names(mirage_benchmark)

medqa_mapping = get_dataset_question_mapping("medqa", mirage_benchmark)
medmcqa_mapping = get_dataset_question_mapping("medmcqa", mirage_benchmark)
pubmedqa_mapping = get_dataset_question_mapping("pubmedqa", mirage_benchmark)
bioasq_mapping = get_dataset_question_mapping("bioasq", mirage_benchmark)
mmlu_mapping = get_dataset_question_mapping("mmlu", mirage_benchmark)

total_questions = 0
for name in dataset_names:
  print(f"{name} contains {get_question_count(name, mirage_benchmark)} questions.")
  total_questions += get_question_count(name, mirage_benchmark)
print(f"\nAll datasets contain a total of {total_questions} questions.")

if VERBOSE:
  random_index = random.randint(0, len(medqa_mapping) - 1)
  sample_random_medqa_mapping = medqa_mapping[random_index]
  print(random_index, " -> ", sample_random_medqa_mapping)
  print(dataset_names)

medqa contains 1273 questions.
medmcqa contains 4183 questions.
pubmedqa contains 500 questions.
bioasq contains 618 questions.
mmlu contains 1089 questions.

All datasets contain a total of 7663 questions.


In [28]:
mirage_questions = xs_of_all_questions(mirage_benchmark)

In [29]:
def foldr_xs_to_csv(xs: list[str]) -> str:
  """recursive foldr (right fold) for creating a single csv row"""
  if len(xs) == 1:
    return xs[0]
  return str(f"{xs[0]},{foldr_xs_to_csv(xs[1:])}")


def format_options(options: dict[str, str]) -> str:
  """produces a list of options for inclusion within prompts"""
  return "\n".join(f"{k}. {opt}" for k, opt in options.items())


def opt_letters(options: dict[str, str]) -> str:
  """produces a comma seperated list of a dicts keys, e.g. A,B,C,D"""
  return foldr_xs_to_csv(list(options.keys()))


def prompt_template_no_rag(question: str, options: dict[str, str], **kwargs) -> str:
  """produce a simple biomedical question answering (MC) template, discards additional kwargs (e.g. answer)"""
  return (
    "You are a helpful medical expert, and your task is to answer a multi-choice medical question. Your response will be used for research purposes only.\n"
    f"Return only the letter of the best answer ({opt_letters(options)}).\n\n"
    f"Here is the question:\n{question}\n\n"
    f"Here are the potential choices:\n{format_options(options)}\n\n"
    "Answer (letter only): "
  )


def prompt_template_with_axioms(question: str, options: dict[str, str], axioms: list[str], **kwargs) -> str:
  """produce a biomedical MCQA prompt for RAG, discards additional kwargs (e.g. answer)"""
  axiomatic_context = "\n".join(axiom for axiom in axioms)
  return (
    "You are a helpful medical expert, your task is to answer a multi-choice medical question.\n"
    f"Return only the letter of the best answer ({opt_letters(options)}).\n\n"
    f"Helpful context:\n{axiomatic_context}\n\n"
    f"Here is the question:\n{question}\n\n"
    f"Here are the potential choices:\n{format_options(options)}\n\n"
    "Answer (letter only): "
  )


# TODO: create a prompt template registry (singleton container)
PROMPT_TEMPLATES = {
  "mirage_mcqa_no_rag": prompt_template_no_rag,
  "mirage_mcqa_axiom_rag": prompt_template_with_axioms
}

In [30]:
# LLMs

class MistralLLM:

  _hf_id: str
  _model: Any
  _tokenizer: Any
  _callable_prompt_templates: dict[str, Callable]

  def __init__(self, hf_identifier: str, **kwargs):
    self._hf_id = hf_identifier
    self._model = None
    self._tokenizer = None
    self._callable_prompt_templates = {}
    
  def load_model(self, **kwargs):
    self._model = AutoModelForCausalLM.from_pretrained(self._hf_id, **kwargs)
    self._model.eval()
    return self
    
  def load_tokenizer(self, **kwargs):
    self._tokenizer = AutoTokenizer.from_pretrained(self._hf_id, **kwargs)
    if self._tokenizer.pad_token_id is None:
      self._tokenizer.pad_token = self._tokenizer.eos_token
    return self

  def register_generation_config(self, **kwargs):
    self._model.generation_config = GenerationConfig(**kwargs)
    return self

  def register_prompt_template_fn(self, callback_key: str, fn: Callable):
    self._callable_prompt_templates[callback_key] = fn
    return self

  @torch.inference_mode()
  def generate(self, prompt: str, **kwargs):
    inputs = self._tokenizer(
      prompt,
      return_tensors="pt", 
    ).to(self._model.device)
    # generate output
    out = self._model.generate(
      **inputs,
      **kwargs
    ) # decode & return
    return self._tokenizer.decode(out[0], skip_special_tokens=True)

  def generate_inject_template(self, template_key: str, template_args: dict, **kwargs):
    template_fn = self._callable_prompt_templates[template_key]
    prompt = template_fn(**template_args)
    return self.generate(prompt, **kwargs)
  
  def generate_constrain_logits(self, prompt: str, logits_processor_list: list | None = None, max_tokens: int = 1000, **kwargs):
    if logits_processor_list is None:
      logits_processor_list = []
    return self.generate(
      prompt,
      max_new_tokens = max_tokens,
      min_new_tokens = 1,
      logits_processor = LogitsProcessorList(logits_processor_list),
      **kwargs
    )
  
  # TODO: clean this up a little bit, fn should accept the delimiter
  def generate_inject_template_and_constrain_logits_for_mcqa(self, template_key: str, template_args: dict, **kwargs):
    mclp = MultipleChoiceLogitsProcessor(
      tokenizer=self._tokenizer,
      choices=list(template_args["options"].keys()),
      delimiter="."
    )
    template_fn = self._callable_prompt_templates[template_key]
    prompt = template_fn(**template_args)
    return self.generate_constrain_logits(prompt, [mclp], max_tokens=1, **kwargs)

  # TODO: fix this, simply grabbing the last character from the string *may* result in inaccuracies (works fine for now!)
  def generate_single_letter_for_mcqa(self, template_key: str, template_args: dict, **kwargs):
    response = self.generate_inject_template_and_constrain_logits_for_mcqa(template_key, template_args, **kwargs)
    return str(response)[len(str(response)) - 1:]



class BaseEntitySelector:

  _STOPWORDS = [
    "patient","pt","pts","person","people","individual","individuals",
    "male","female","man","woman","boy","girl","child","children","kid",
    "infant","newborn","neonate","adult","elderly","senior","parent","parents",
    "he","she","they","them","him","her","his","hers","their","theirs",
    "you","your","yours","we","our","ours","i","me","my","mine","one",
    "someone","anyone","everyone","nobody","somebody","this","that","these","those",
    "presented","presents","presenting","complains","complained","reports","reported",
    "history","h/o","c/o","since","for","with","had","found","noted","developed",
    "exhibits","demonstrates","shows","reveals","diagnosed","diagnosis","examination","exam",
    "on","during","before","after","prior",
    "which","what","when","where","why","how","whose","whom",
    "true","false","correct","incorrect","appropriate","best","most","least",
    "except","not","all","following","choose","select","mark","option","options",
    "both","none","either","neither","above","below",
    "feature","features","sign","signs","symptom","symptoms","finding","findings",
    "test","tests","result","results","value","values","level","levels","rate","ratio",
    "management","treatment","therapy","mechanism","complication","evaluation","investigation","investigations",
    "method","methods","technique","techniques","approach","approaches","procedure","procedures",
    "cause","causes","type","types","class","classes","category","categories","group","groups",
    "age","aged","old","year","years","yr","yrs","month","months","mo","mos",
    "week","weeks","day","days","hour","hours","hr","hrs","minute","minutes","min","mins",
    "always","never","usually","commonly","rarely","frequently","sometimes","generally","typically",
    "mainly","mostly","predominantly","severe","mild","moderate","acute","chronic","subacute","persistent","recurrent",
    "according","guidelines","classification","defined","definition","called","known","named","term","terminology",
    "left","right","bilateral","unilateral","anterior","posterior","medial","lateral","superior","inferior",
    "proximal","distal","upper","lower","central","peripheral",
    "hospital","clinic","ward","opd","er","icu","casualty",
    "region","area","part","portion","site","surface","margin","border","apex","base",
    "volume","pressure","temperature","saturation","score","grade","stage","index",
  ]
  _all_mention_results: list[QueryResult]
  _retriever: BaseRetriever

  def __init__(self, retriever: BaseRetriever):
    self._retriever = retriever

  def encode_and_rank_candidates(self, entities):
    pass

  def get_top_candidates(self, top_k=3):
    return self._all_mention_results[:top_k]


class SubsumptionEntitySelector(BaseEntitySelector):
  @override
  def encode_and_rank_candidates(self, entities):
    self._all_mention_results = []
    for mention in entities:
      if mention['entity_literal'] in self._STOPWORDS:
        continue
      # else:
      self._all_mention_results.append(
        self._retriever.retrieve(mention['entity_literal'], top_k=1, reverse_candidate_scores=True, weight=0.4, model=self._retriever._model)[0]
      )
    self._all_mention_results.sort(key=lambda x: x[2], reverse=True)


class ApproximateNearestNeighbourEntitySelector(BaseEntitySelector):
  @override
  def encode_and_rank_candidates(self, entities):
    self._all_mention_results = []
    for mention in entities:
      if mention['entity_literal'] in self._STOPWORDS:
        continue
      # else:
      self._all_mention_results.append(
        self._retriever.retrieve(mention['entity_literal'], top_k=1, reverse_candidate_scores=False, model=self._retriever._model)[0]
      )
    self._all_mention_results.sort(key=lambda x: x[2], reverse=False)


class SimilarityEntitySelector(BaseEntitySelector):
  @override
  def encode_and_rank_candidates(self, entities):
    self._all_mention_results = []
    for mention in entities:
      if mention['entity_literal'] in self._STOPWORDS:
        continue
      # else:
      self._all_mention_results.append(
        self._retriever.retrieve(mention['entity_literal'], top_k=1, reverse_candidate_scores=True)[0]
      )
    self._all_mention_results.sort(key=lambda x: x[2], reverse=True)


In [31]:
mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
mirage_questions = xs_of_all_questions(mirage_benchmark)
biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
mirage_questions_with_entity_mentions = merge_entity_mentions(mirage_benchmark, biomedical_entity_mentions, head_entity_mentions)

mistral_lm = MistralLLM("mistralai/Mistral-7B-Instruct-v0.3")

mistral_lm.load_tokenizer(
  use_fast=True
).load_model(
  device_map="auto",
  torch_dtype=torch.bfloat16,
  low_cpu_mem_usage=True
).register_generation_config(
  do_sample=False,
  num_beams=1,
  pad_token_id=mistral_lm._tokenizer.pad_token_id,
  eos_token_id=mistral_lm._tokenizer.eos_token_id
)

mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag", prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag", prompt_template_with_axioms)

print("Loaded!")

Processing medqa ... 


100%|██████████| 1273/1273 [00:00<00:00, 16618.11it/s]


Processing medmcqa ... 


100%|██████████| 4183/4183 [00:04<00:00, 1018.17it/s]


Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 724.08it/s]


Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 702.01it/s]


Processing mmlu ... 


100%|██████████| 1089/1089 [00:01<00:00, 650.71it/s]


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded!


In [110]:
random_question_sample = get_random_mirage_question()

snomed_concept_information: dict = load_json(Path("./data/snomed_axioms.json"))
entity_selector = ApproximateNearestNeighbourEntitySelector(ont_ret_snomed_25_updtd_w_hyp_dist)

entity_selector.encode_and_rank_candidates(random_question_sample['entities'])
entities_for_rag = entity_selector.get_top_candidates(top_k=5)
entity_mention_iris_for_rag = []
for mention in entities_for_rag:
  entity_mention_iris_for_rag.append(mention[1])

# obtain axiom verbalisations (or produce `concept cards`) for each IRI for prompt enrichment

axiom_verbalisations = []

for iri in entity_mention_iris_for_rag:
  label: str = snomed_concept_information[iri]['label']
  subclass_axioms: list[str] = snomed_concept_information[iri]['verbalization']['subclass_of']
  equiv_axioms: list[str] = snomed_concept_information[iri]['verbalization']['equivalent_to']
  for idx, axiom in enumerate(subclass_axioms):
    axiom_verbalisations.append(f"{label}: {subclass_axioms[idx]}")
  for idx, axiom in enumerate(equiv_axioms):
    axiom_verbalisations.append(f"{label}: {equiv_axioms[idx]}")

# bind the axioms verbalisations to the question object for prompt injection

random_question_sample['axioms'] = axiom_verbalisations

In [111]:
random_question_sample['question']

'A previously healthy 29-year-old man comes to the emergency department because of burning with urination for several days. He has also had pain in the right ankle for 3 days and pain and swelling in the left knee for 1 day. Two weeks ago, he had several days of fever and bloody diarrhea, for which he was treated with antibiotics. Examination shows a small left knee effusion and bilateral conjunctival injection. Which of the following is the most likely additional finding in this patient?'

In [112]:
response = mistral_lm.generate_inject_template_and_constrain_logits_for_mcqa("mirage_mcqa_no_rag", random_question_sample)
print(f"Language model response: \n\n{response}\n")
print("-" * 72)
print(f"The correct answer is: {random_question_sample['answer']}")

Language model response: 

You are a helpful medical expert, and your task is to answer a multi-choice medical question. Your response will be used for research purposes only.
Return only the letter of the best answer (A,B,C,D).

Here is the question:
A previously healthy 29-year-old man comes to the emergency department because of burning with urination for several days. He has also had pain in the right ankle for 3 days and pain and swelling in the left knee for 1 day. Two weeks ago, he had several days of fever and bloody diarrhea, for which he was treated with antibiotics. Examination shows a small left knee effusion and bilateral conjunctival injection. Which of the following is the most likely additional finding in this patient?

Here are the potential choices:
A. Circular erythematous rash with central clearing
B. Pain on passive extension of the fingers
C. Palpable mass in the right lower quadrant
D. Tenderness at the insertion of the Achilles tendon

Answer (letter only):  A



In [113]:
response = mistral_lm.generate_inject_template_and_constrain_logits_for_mcqa("mirage_mcqa_axiom_rag", random_question_sample)
print(f"Language model response: \n\n{response}\n")
print("-" * 72)
print(f"The correct answer is: {random_question_sample['answer']}")

Language model response: 

You are a helpful medical expert, your task is to answer a multi-choice medical question.
Return only the letter of the best answer (A,B,C,D).

Helpful context:
Examination: is a type of Additional values
Conjunctival concretion: defined as Disease with morphology: Focal calcification, site: Conjunctiva.
HEA: is a type of Anatomical site notations for tumor staging
Age more than 50 years, male: is a type of Age more than 50 years
Each: is a type of Additional dosage instructions

Here is the question:
A previously healthy 29-year-old man comes to the emergency department because of burning with urination for several days. He has also had pain in the right ankle for 3 days and pain and swelling in the left knee for 1 day. Two weeks ago, he had several days of fever and bloody diarrhea, for which he was treated with antibiotics. Examination shows a small left knee effusion and bilateral conjunctival injection. Which of the following is the most likely additiona

### Example(s) of Single Axiom Verbalisation Injection (k=1) Response:

*The correct answer is: B*

**Original prompt (\w response):**

    You are a helpful medical expert, and your task is to answer a multi-choice medical question.
    Return only the letter of the best answer (A,B,C,D).

    Here is the question:
    A young girl hospitalised with anorexia nervosa is on treatment, Even after taking adequate food according to the recommended diet plan for last 1 week, there is no gain in weight, what is the next step in management:

    Here are the potential choices:
    A. Increase fluid intake
    B. Observe patient for 2 hours after meal
    C. Increase the do se of anxiolytics
    D. Increase the caloric intake from 1500 kcal to 2000 kcal per day

    Answer (letter only):  D

**RAG-based prompt (\w response):**

    You are a helpful medical expert, your task is to answer a multi-choice medical question.
    Return only the letter of the best answer (A,B,C,D).

    Helpful context:
    Atypical anorexia nervosa is a type of Eating disorder

    Here is the question:
    A young girl hospitalised with anorexia nervosa is on treatment, Even after taking adequate food according to the recommended diet plan for last 1 week, there is no gain in weight, what is the next step in management:

    Here are the potential choices:
    A. Increase fluid intake
    B. Observe patient for 2 hours after meal
    C. Increase the do se of anxiolytics
    D. Increase the caloric intake from 1500 kcal to 2000 kcal per day

    Answer (letter only):  D



### Example(s) of Multi-Axiom Verbalisation Injection (k=5) Response:

*The correct answer is D*

**Original Prompt:**

    You are a helpful medical expert, and your task is to answer a multi-choice medical question. Your response will be used for research purposes only.
    Return only the letter of the best answer (A,B,C,D).

    Here is the question:
    A previously healthy 29-year-old man comes to the emergency department because of burning with urination for several days. He has also had pain in the right ankle for 3 days and pain and swelling in the left knee for 1 day. Two weeks ago, he had several days of fever and bloody diarrhea, for which he was treated with antibiotics. Examination shows a small left knee effusion and bilateral conjunctival injection. Which of the following is the most likely additional finding in this patient?

    Here are the potential choices:
    A. Circular erythematous rash with central clearing
    B. Pain on passive extension of the fingers
    C. Palpable mass in the right lower quadrant
    D. Tenderness at the insertion of the Achilles tendon

    Answer (letter only):  A

**RAG-based Prompt:**

    You are a helpful medical expert, your task is to answer a multi-choice medical question.
    Return only the letter of the best answer (A,B,C,D).

    Helpful context:
    Examination is a type of Additional values
    Conjunctival concretion defined as Disease with morphology Focal calcification, site Conjunctiva
    HEA is a type of Anatomical site notations for tumor staging
    Age more than 50 years, male is a type of Age more than 50 years
    Each is a type of Additional dosage instructions

    Here is the question:
    A previously healthy 29-year-old man comes to the emergency department because of burning with urination for several days. He has also had pain in the right ankle for 3 days and pain and swelling in the left knee for 1 day. Two weeks ago, he had several days of fever and bloody diarrhea, for which he was treated with antibiotics. Examination shows a small left knee effusion and bilateral conjunctival injection. Which of the following is the most likely additional finding in this patient?

    Here are the potential choices:
    A. Circular erythematous rash with central clearing
    B. Pain on passive extension of the fingers
    C. Palpable mass in the right lower quadrant
    D. Tenderness at the insertion of the Achilles tendon

    Answer (letter only):  A

### Example(s) of Single Axiom Verbalisation Injection (k=1) Response:

*The correct answer is: B*

**Original prompt (\w response):**

    You are a helpful medical expert, and your task is to answer a multi-choice medical question.
    Return only the letter of the best answer (A,B,C,D).

    Here is the question:
    A young girl hospitalised with anorexia nervosa is on treatment, Even after taking adequate food according to the recommended diet plan for last 1 week, there is no gain in weight, what is the next step in management:

    Here are the potential choices:
    A. Increase fluid intake
    B. Observe patient for 2 hours after meal
    C. Increase the do se of anxiolytics
    D. Increase the caloric intake from 1500 kcal to 2000 kcal per day

    Answer (letter only):  D

**RAG-based prompt (\w response):**

    You are a helpful medical expert, your task is to answer a multi-choice medical question.
    Return only the letter of the best answer (A,B,C,D).

    Helpful context:
    Atypical anorexia nervosa is a type of Eating disorder

    Here is the question:
    A young girl hospitalised with anorexia nervosa is on treatment, Even after taking adequate food according to the recommended diet plan for last 1 week, there is no gain in weight, what is the next step in management:

    Here are the potential choices:
    A. Increase fluid intake
    B. Observe patient for 2 hours after meal
    C. Increase the do se of anxiolytics
    D. Increase the caloric intake from 1500 kcal to 2000 kcal per day

    Answer (letter only):  D



### Example(s) of Multi-Axiom Verbalisation Injection (k=5) Response:

*The correct answer is D*

**Original Prompt:**

    You are a helpful medical expert, and your task is to answer a multi-choice medical question. Your response will be used for research purposes only.
    Return only the letter of the best answer (A,B,C,D).

    Here is the question:
    A previously healthy 29-year-old man comes to the emergency department because of burning with urination for several days. He has also had pain in the right ankle for 3 days and pain and swelling in the left knee for 1 day. Two weeks ago, he had several days of fever and bloody diarrhea, for which he was treated with antibiotics. Examination shows a small left knee effusion and bilateral conjunctival injection. Which of the following is the most likely additional finding in this patient?

    Here are the potential choices:
    A. Circular erythematous rash with central clearing
    B. Pain on passive extension of the fingers
    C. Palpable mass in the right lower quadrant
    D. Tenderness at the insertion of the Achilles tendon

    Answer (letter only):  A

**RAG-based Prompt:**

    You are a helpful medical expert, your task is to answer a multi-choice medical question.
    Return only the letter of the best answer (A,B,C,D).

    Helpful context:
    Examination is a type of Additional values
    Conjunctival concretion defined as Disease with morphology Focal calcification, site Conjunctiva
    HEA is a type of Anatomical site notations for tumor staging
    Age more than 50 years, male is a type of Age more than 50 years
    Each is a type of Additional dosage instructions

    Here is the question:
    A previously healthy 29-year-old man comes to the emergency department because of burning with urination for several days. He has also had pain in the right ankle for 3 days and pain and swelling in the left knee for 1 day. Two weeks ago, he had several days of fever and bloody diarrhea, for which he was treated with antibiotics. Examination shows a small left knee effusion and bilateral conjunctival injection. Which of the following is the most likely additional finding in this patient?

    Here are the potential choices:
    A. Circular erythematous rash with central clearing
    B. Pain on passive extension of the fingers
    C. Palpable mass in the right lower quadrant
    D. Tenderness at the insertion of the Achilles tendon

    Answer (letter only):  A

In [97]:
another_random_question_sample = get_random_mirage_question()

In [54]:
another_random_question_sample['question']

'A 10 days old neonate is posted for pyloric stenosis in surgery. The investigation report shows a serum calcium level of 6 mg/dL. What information would you like to know before you supplement calcium to this neonate –'

In [55]:
response = mistral_lm.generate_inject_template_and_constrain_logits_for_mcqa("mirage_mcqa_no_rag", another_random_question_sample)
print(f"Language model response: \n\n{response}\n")

Language model response: 

You are a helpful medical expert, and your task is to answer a multi-choice medical question. Your response will be used for research purposes only.
Return only the letter of the best answer (A,B,C,D).

Here is the question:
A 10 days old neonate is posted for pyloric stenosis in surgery. The investigation report shows a serum calcium level of 6 mg/dL. What information would you like to know before you supplement calcium to this neonate –

Here are the potential choices:
A. Blood glucose
B. Serum protein
C. Serum bilirubin
D. Oxygen saturation

Answer (letter only):  A



In [56]:
print("-" * 72)
print(f"The correct answer is: {random_question_sample['answer']}")

------------------------------------------------------------------------
The correct answer is: D


In [None]:
response = mistral_lm.generate_inject_template_and_constrain_logits_for_mcqa("mirage_mcqa_axiom_rag", random_question_sample)
print(f"Language model response: \n\n{response}\n")

TypeError: prompt_template_with_axioms() missing 1 required positional argument: 'axioms'