In [1]:
from retrievers import (
  BaseRetriever,
  SBERTRetriever,
  HiTRetriever,
  OnTRetriever
)
from pathlib import Path
from math_functools import (
  batch_cosine_similarity,
  batch_poincare_dist_with_adaptive_curv_k,
  entity_subsumption,
  concept_subsumption
)
from llm_utils_improvements import (
  MistralLLM,
  BaseEntitySelector,
  SimilarityEntitySelector,
  ApproximateNearestNeighbourEntitySelector,
  SubsumptionEntitySelector,
  chat_prompt_template_no_rag,
  chat_prompt_template_with_axioms
)
from harness_utils import (
  QATestHarness
)

# ------------------------------------
# LLM options:
# ------------------------------------
# "mistralai/Mistral-7B-Instruct-v0.1"
# "mistralai/Mistral-7B-Instruct-v0.3"
# "BioMistral/BioMistral-7B"
# ------------------------------------

LLM_MODEL_ID = "BioMistral/BioMistral-7B"
SEED = 42

# instanciate a retriever
sbert_ret = SBERTRetriever(
  embeddings_fp=Path(f"./embeddings/sbert-plm-embeddings.npy"),
  meta_map_fp=Path("./embeddings/axiom-mappings.json"),
  verbalisations_fp=Path("./embeddings/axiom-verbalisations.json"),
  model_str="all-MiniLM-L12-v2",
  score_fn=batch_cosine_similarity
)

# and an entity selector
sbert_entity_selector = SimilarityEntitySelector(sbert_ret)

# and an LLM
mistral_llm = MistralLLM(LLM_MODEL_ID)

mistral_llm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_llm._tokenizer.pad_token_id,
    eos_token_id=mistral_llm._tokenizer.eos_token_id
)
mistral_llm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_llm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

# ideally, we would load from config (TODO: load cfgNode \w yacs or hydra)
tests = QATestHarness(
  Path("./data/MIRAGE/benchmark.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json")
).set_shuffle_question_options(True).set_permute_question_options(
  True
).set_retrieval_k(100).set_append_k(10).set_top_k(1).set_use_rag(True).register_retriever(
  sbert_ret
).register_entity_selector(
  sbert_entity_selector
).register_llm(
  mistral_llm
)

# quick tests:

QATestHarness.set_random_seed(SEED)

tests.set_use_rag(False)
tests.run_multiple(['pubmedqa'])

tests.set_use_rag(False)
tests.run_multiple(['pubmedqa'])

# quick test to compare results to:

tests.set_use_rag(True)
tests.run_multiple(['pubmedqa'])

tests.set_use_rag(True)
tests.run_multiple(['pubmedqa'])

Using RAG: False
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1533.76it/s]


Processing pubmedqa ...


100%|██████████| 500/500 [02:26<00:00,  3.41it/s]


Total correct:   204
Total incorrect: 296
Accuracy:        40.8%


Using RAG: False
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 618.76it/s] 


Processing pubmedqa ...


100%|██████████| 500/500 [02:25<00:00,  3.43it/s]


Total correct:   216
Total incorrect: 284
Accuracy:        43.2%


Using RAG: True
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1530.64it/s]


Processing pubmedqa ...


100%|██████████| 500/500 [02:44<00:00,  3.04it/s]


Total correct:   226
Total incorrect: 274
Accuracy:        45.2%


Using RAG: True
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 809.01it/s] 


Processing pubmedqa ...


100%|██████████| 500/500 [02:35<00:00,  3.22it/s]

Total correct:   212
Total incorrect: 288
Accuracy:        42.4%







In [1]:
from retrievers import (
  BaseRetriever,
  SBERTRetriever,
  HiTRetriever,
  OnTRetriever
)
from pathlib import Path
from math_functools import (
  batch_cosine_similarity,
  batch_poincare_dist_with_adaptive_curv_k,
  batch_poincare_dist_with_adaptive_curv_k_torch,
  entity_subsumption,
  concept_subsumption
)
from llm_utils import (
  MistralLLM,
  BaseEntitySelector,
  SimilarityEntitySelector,
  ApproximateNearestNeighbourEntitySelector,
  SubsumptionEntitySelector,
  chat_prompt_template_no_rag,
  chat_prompt_template_with_axioms
)
from harness_utils import (
  QATestHarness
)

##########
# GLOBALS
##########

# ------------------------------------
# LLM options:
# ------------------------------------
# "mistralai/Mistral-7B-Instruct-v0.1"
# "mistralai/Mistral-7B-Instruct-v0.3"
# "BioMistral/BioMistral-7B"
# ------------------------------------

LLM_MODEL_ID = "BioMistral/BioMistral-7B"
SEED = 42

##################################################
# BOOTSTRAP: encoder, retriever & entity selector
##################################################

common_map = Path("./embeddings/axiom-mappings.json") # *entity mappings
common_verbalisations = Path("./embeddings/axiom-verbalisations.json") # rdfs:label(s) & verbs
embeddings_dir = "./embeddings" # dir for embeddings

# fine-tuned embedding model, for embedding entity mentions
retriever_model_fp = hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

# accepts an entity mention &
# 1. produces an embedding
# 2. measures the `score_fn` agaisnt existing embeddings
# 3. returns a ranked list of entities as tuples: (rank, iri, score, verbalisation)
hit_retriever = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=retriever_model_fp,
  score_fn=batch_poincare_dist_with_adaptive_curv_k_torch
)

# provides a ranking for a pool of entities (drawn from multiple multiple mentions for the same question)
entity_selector = ApproximateNearestNeighbourEntitySelector(hit_retriever)

##################
#  BOOTSTRAP: LLM
##################

# initialises a LLM & exposes methods for 
# RAG \w axiom verbalisation-based prompt enrichment
mistral_llm = MistralLLM(LLM_MODEL_ID)

mistral_llm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_llm._tokenizer.pad_token_id,
    eos_token_id=mistral_llm._tokenizer.eos_token_id
)
mistral_llm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_llm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

# ideally, we would load from config (TODO: load cfgNode \w yacs or hydra)
tests = QATestHarness(
  Path("./data/MIRAGE/benchmark.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json")
).set_shuffle_question_options(True).set_permute_question_options(
  True
).set_retrieval_k(100).set_append_k(10).set_top_k(1).set_use_rag(True).register_retriever(
  hit_retriever
).register_entity_selector(
  entity_selector # type: ignore
).register_llm(
  mistral_llm # type: ignore
)

# single test run:

QATestHarness.set_random_seed(SEED)

tests.set_use_rag(True)
tests.run_multiple(['pubmedqa', 'bioasq', 'mmlu', 'medqa', 'medmcqa'])

Using RAG: True
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1030.32it/s]


Processing pubmedqa ...


  vs_t = torch.as_tensor(vs_chunk, device=device, dtype=dtype)
100%|██████████| 500/500 [03:07<00:00,  2.66it/s]


Total correct: 232
Total incorrect: 268
Accuracy: 46.4% 

Using RAG: True
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 1458.95it/s]


Processing bioasq ...


100%|██████████| 618/618 [03:48<00:00,  2.70it/s]


Total correct: 366
Total incorrect: 252
Accuracy: 59.22% 

Using RAG: True
Processing mmlu ... 


100%|██████████| 1089/1089 [00:00<00:00, 1268.71it/s]


Processing mmlu ...


100%|██████████| 1089/1089 [11:21<00:00,  1.60it/s]


Total correct: 544
Total incorrect: 545
Accuracy: 49.95% 

Using RAG: True
Processing medqa ... 


100%|██████████| 1273/1273 [00:00<00:00, 17134.00it/s]


Processing medqa ...


100%|██████████| 1273/1273 [30:43<00:00,  1.45s/it]


Total correct: 515
Total incorrect: 758
Accuracy: 40.46% 

Using RAG: True
Processing medmcqa ... 


100%|██████████| 4183/4183 [00:02<00:00, 1600.22it/s]


Processing medmcqa ...


100%|██████████| 4183/4183 [25:58<00:00,  2.68it/s]

Total correct: 1485
Total incorrect: 2698
Accuracy: 35.5% 






In [None]:
from retrievers import (
  BaseRetriever,
  SBERTRetriever,
  HiTRetriever,
  OnTRetriever
)
from pathlib import Path
from math_functools import (
  batch_cosine_similarity,
  batch_poincare_dist_with_adaptive_curv_k,
  batch_poincare_dist_with_adaptive_curv_k_torch,
  efficient_batch_poincare_distance_with_curv_k,
  entity_subsumption,
  concept_subsumption
)
from llm_utils import (
  MistralLLM,
  BaseEntitySelector,
  SimilarityEntitySelector,
  ApproximateNearestNeighbourEntitySelector,
  SubsumptionEntitySelector,
  chat_prompt_template_no_rag,
  chat_prompt_template_with_axioms
)
from harness_utils import (
  QATestHarness
)

##########
# GLOBALS
##########

# ------------------------------------
# LLM options:
# ------------------------------------
# "mistralai/Mistral-7B-Instruct-v0.1"
# "mistralai/Mistral-7B-Instruct-v0.3"
# "BioMistral/BioMistral-7B"
# ------------------------------------

LLM_MODEL_ID = "BioMistral/BioMistral-7B"
SEED = 42

##################################################
# BOOTSTRAP: encoder, retriever & entity selector
##################################################

common_map = Path("./embeddings/axiom-mappings.json") # *entity mappings
common_verbalisations = Path("./embeddings/axiom-verbalisations.json") # rdfs:label(s) & verbs
embeddings_dir = "./embeddings" # dir for embeddings

# fine-tuned embedding model, for embedding entity mentions
retriever_model_fp = hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

# accepts an entity mention &
# 1. produces an embedding
# 2. measures the `score_fn` agaisnt existing embeddings
# 3. returns a ranked list of entities as tuples: (rank, iri, score, verbalisation)
hit_retriever = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=retriever_model_fp,
  score_fn=efficient_batch_poincare_distance_with_curv_k
)

# provides a ranking for a pool of entities (drawn from multiple multiple mentions for the same question)
entity_selector = ApproximateNearestNeighbourEntitySelector(hit_retriever)

##################
#  BOOTSTRAP: LLM
##################

# initialises a LLM & exposes methods for 
# RAG \w axiom verbalisation-based prompt enrichment
mistral_llm = MistralLLM(LLM_MODEL_ID)

mistral_llm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_llm._tokenizer.pad_token_id,
    eos_token_id=mistral_llm._tokenizer.eos_token_id
)
mistral_llm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_llm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

# ideally, we would load from config (TODO: load cfgNode \w yacs or hydra)
tests = QATestHarness(
  Path("./data/MIRAGE/benchmark.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json")
).set_shuffle_question_options(True).set_permute_question_options(
  True
).set_retrieval_k(100).set_append_k(10).set_top_k(1).set_use_rag(True).register_retriever(
  hit_retriever
).register_entity_selector(
  entity_selector # type: ignore
).register_llm(
  mistral_llm # type: ignore
)

# single test run:

QATestHarness.set_random_seed(SEED)

tests.set_use_rag(True)
tests.run_multiple(['pubmedqa', 'bioasq', 'mmlu', 'medqa', 'medmcqa'])

In [None]:
from retrievers import (
  BaseRetriever,
  SBERTRetriever,
  HiTRetriever,
  OnTRetriever
)
from pathlib import Path
from math_functools import (
  batch_cosine_similarity,
  batch_poincare_dist_with_adaptive_curv_k,
  batch_poincare_dist_with_adaptive_curv_k_torch,
  entity_subsumption,
  concept_subsumption
)
from llm_utils import (
  MistralLLM,
  BaseEntitySelector,
  SimilarityEntitySelector,
  ApproximateNearestNeighbourEntitySelector,
  SubsumptionEntitySelector,
  chat_prompt_template_no_rag,
  chat_prompt_template_with_axioms
)
from harness_utils import (
  QATestHarness
)

##########
# GLOBALS
##########

# ------------------------------------
# LLM options:
# ------------------------------------
# "mistralai/Mistral-7B-Instruct-v0.1"
# "mistralai/Mistral-7B-Instruct-v0.3"
# "BioMistral/BioMistral-7B"
# ------------------------------------

LLM_MODEL_ID = "BioMistral/BioMistral-7B"
SEED = 42

##################################################
# BOOTSTRAP: encoder, retriever & entity selector
##################################################

common_map = Path("./embeddings/axiom-mappings.json") # *entity mappings
common_verbalisations = Path("./embeddings/axiom-verbalisations.json") # rdfs:label(s) & verbs
embeddings_dir = "./embeddings" # dir for embeddings

# fine-tuned embedding model, for embedding entity mentions
retriever_model_fp = hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

# accepts an entity mention &
# 1. produces an embedding
# 2. measures the `score_fn` agaisnt existing embeddings
# 3. returns a ranked list of entities as tuples: (rank, iri, score, verbalisation)
hit_retriever = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=retriever_model_fp,
  score_fn=batch_poincare_dist_with_adaptive_curv_k_torch
)

# provides a ranking for a pool of entities (drawn from multiple multiple mentions for the same question)
entity_selector = ApproximateNearestNeighbourEntitySelector(hit_retriever)

##################
#  BOOTSTRAP: LLM
##################

# initialises a LLM & exposes methods for 
# RAG \w axiom verbalisation-based prompt enrichment
mistral_llm = MistralLLM(LLM_MODEL_ID)

mistral_llm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_llm._tokenizer.pad_token_id,
    eos_token_id=mistral_llm._tokenizer.eos_token_id
)
mistral_llm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_llm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

# ideally, we would load from config (TODO: load cfgNode \w yacs or hydra)
tests = QATestHarness(
  Path("./data/MIRAGE/benchmark.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json")
).set_shuffle_question_options(True).set_permute_question_options(
  True
).set_retrieval_k(100).set_append_k(10).set_top_k(1).set_use_rag(True).register_retriever(
  hit_retriever
).register_entity_selector(
  entity_selector # type: ignore
).register_llm(
  mistral_llm # type: ignore
)

# single test run:

QATestHarness.set_random_seed(SEED)

tests.set_use_rag(True)
tests.run_multiple(['pubmedqa', 'bioasq', 'mmlu', 'medqa', 'medmcqa'])

In [None]:
from retrievers import (
  BaseRetriever,
  SBERTRetriever,
  HiTRetriever,
  OnTRetriever
)
from pathlib import Path
from math_functools import (
  batch_cosine_similarity,
  batch_poincare_dist_with_adaptive_curv_k,
  batch_poincare_dist_with_adaptive_curv_k_torch,
  entity_subsumption,
  concept_subsumption
)
from llm_utils import (
  MistralLLM,
  BaseEntitySelector,
  SimilarityEntitySelector,
  ApproximateNearestNeighbourEntitySelector,
  SubsumptionEntitySelector,
  chat_prompt_template_no_rag,
  chat_prompt_template_with_axioms
)
from harness_utils import (
  QATestHarness
)
import torch

##########
# GLOBALS
##########

# ------------------------------------
# LLM options:
# ------------------------------------
# "mistralai/Mistral-7B-Instruct-v0.1"
# "mistralai/Mistral-7B-Instruct-v0.3"
# "BioMistral/BioMistral-7B"
# ------------------------------------

LLM_MODEL_ID = "BioMistral/BioMistral-7B"
SEED = 42

##################################################
# BOOTSTRAP: encoder, retriever & entity selector
##################################################

common_map = Path("./embeddings/axiom-mappings.json") # *entity mappings
common_verbalisations = Path("./embeddings/axiom-verbalisations.json") # rdfs:label(s) & verbs
embeddings_dir = "./embeddings" # dir for embeddings

retriever_model_fp = hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')
# self,
#         verbalisations_fp: Path,
#         meta_map_fp: Path,
#         embeddings_fp: Path,
#         *,
#         score_fn: Callable | None = None,
#         model_fp: Path | None = None,
#         model_str: str | None = None,
#         backend: Backend = "numpy",            # <— control CPU vs GPU
#         resident: bool = True,                 # <— copy once & keep on device
#         eps: float = 1e-7
# Poincaré (HiT/OnT) on GPU Torch, resident bank:
hit_retriever = HiTRetriever(
    verbalisations_fp=common_verbalisations,
    meta_map_fp=common_map,
    embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
    model_fp=retriever_model_fp,
    backend="cupy",
    resident=True
)

hit_retriever._set_curvature_from_model()

# hit_retriever.register_local_model(retriever_model_fp)
# hit_retriever._set_curvature_from_model()
# print(hit_retriever._curv_k)

# results = ret.retrieve("what is ...", top_k=10)

# # SBERT cosine on CPU NumPy:
# sret = SBERTRetriever(
#     Path("verbalisations.json"),
#     Path("meta_map.json"),
#     Path("sbert_embeddings.npy"),
#     backend="numpy",
# )
# sret.use_cosine()
# results2 = sret.retrieve("a query", top_k=20, reverse_candidate_scores=True)  # if you want highest cosine

# fine-tuned embedding model, for embedding entity mentions
# retriever_model_fp = hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

# accepts an entity mention &
# 1. produces an embedding
# 2. measures the `score_fn` agaisnt existing embeddings
# 3. returns a ranked list of entities as tuples: (rank, iri, score, verbalisation)
# hit_retriever = HiTRetriever(
#   embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
#   meta_map_fp=common_map,
#   verbalisations_fp=common_verbalisations,
#   model_fp=retriever_model_fp,
#   score_fn=batch_poincare_dist_with_adaptive_curv_k_torch
# )

# provides a ranking for a pool of entities (drawn from multiple multiple mentions for the same question)
entity_selector = ApproximateNearestNeighbourEntitySelector(hit_retriever)

##################
#  BOOTSTRAP: LLM
##################

# initialises a LLM & exposes methods for 
# RAG \w axiom verbalisation-based prompt enrichment
mistral_llm = MistralLLM(LLM_MODEL_ID)

mistral_llm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_llm._tokenizer.pad_token_id,
    eos_token_id=mistral_llm._tokenizer.eos_token_id
)
mistral_llm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_llm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

# ideally, we would load from config (TODO: load cfgNode \w yacs or hydra)
tests = QATestHarness(
  Path("./data/MIRAGE/benchmark.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json")
).set_shuffle_question_options(True).set_permute_question_options(
  True
).set_retrieval_k(100).set_append_k(10).set_top_k(1).set_use_rag(True).register_retriever(
  hit_retriever
).register_entity_selector(
  entity_selector # type: ignore
).register_llm(
  mistral_llm # type: ignore
)

# single test run:

QATestHarness.set_random_seed(SEED)

tests.set_use_rag(True)
tests.run_multiple(['pubmedqa', 'bioasq', 'mmlu', 'medqa', 'medmcqa'])

Using RAG: True
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1411.10it/s]


Processing pubmedqa ...


  0%|          | 0/500 [00:00<?, ?it/s]


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()