In [None]:
import torch
from pathlib import Path
from data_utils import (
    load_json,
    xs_of_all_questions,
    merge_entity_mentions,
    set_seed
)
from llm_utils_improvements import (
    MistralLLM,
    prompt_template_no_rag,
    prompt_template_with_axioms,
    permute_mcqa_options,
    SubsumptionEntitySelector
)
from retrievers import (
   HiTRetriever,
)
from math_functools import (
   entity_subsumption
)
from tqdm import tqdm

set_seed(42)

######
# DATA PREP
######

mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))

######
# DATASET PARTITION
# *(add all five datasets the list if you wish to run this script agaisnt the entire benchmark)*
# select from: ['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']
# ----
#   allowable_datasets=['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']
# ----
######

allowable_datasets=['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']
mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets)
mirage_questions_with_entity_mentions = merge_entity_mentions(mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets)

######
# MISTRAL LLM VERSION !! IMPORTANT !!
# (currently supports Mistral-7B-Instruct-v0.1, v0.3 & BioMistral)
# select from:
#   * "mistralai/Mistral-7B-Instruct-v0.3"
#   * "BioMistral/BioMistral-7B"
######

mistral_llm_version = "BioMistral/BioMistral-7B"

######
# SCRIPT STARTS
######

mistral_lm = MistralLLM(mistral_llm_version)

mistral_lm.load_tokenizer(
  use_fast=True
).load_model(
  device_map="auto",
  torch_dtype="auto",
  low_cpu_mem_usage=True
).register_generation_config(
  do_sample=False,
  num_beams=1,
  pad_token_id=mistral_lm._tokenizer.pad_token_id,
  eos_token_id=mistral_lm._tokenizer.eos_token_id
)

mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag", prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag", prompt_template_with_axioms)

print("Loaded!")

Processing medqa ... 


100%|██████████| 1273/1273 [00:00<00:00, 17236.22it/s]


Processing medmcqa ... 


100%|██████████| 4183/4183 [00:02<00:00, 1750.62it/s]


Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1456.03it/s]


Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 1450.63it/s]


Processing mmlu ... 


100%|██████████| 1089/1089 [00:00<00:00, 1113.14it/s]


Loaded!


In [2]:
embeddings_dir = "./embeddings"
common_map = Path("./embeddings/axiom-mappings.json")
common_verbalisations = Path("./embeddings/axiom-verbalisations.json")
hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

hit_ret_snomed_25_w_ent_sub = HiTRetriever(
  embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
  meta_map_fp=common_map,
  verbalisations_fp=common_verbalisations,
  model_fp=hit_SNOMED25_model_path,
  score_fn=entity_subsumption
)

In [3]:
dataset_name = "pubmedqa"
mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
mirage_questions_with_entity_mentions = merge_entity_mentions(mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name])

Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1575.27it/s]


In [4]:
snomed_concept_information: dict = load_json(Path("./data/snomed_axioms.json"))
entity_selector = SubsumptionEntitySelector(hit_ret_snomed_25_w_ent_sub)

In [5]:
example_question_index = 225

example_question = mirage_questions_with_entity_mentions[example_question_index]

print("Question Dict: ", example_question)
print("Question Text:", example_question['question'])
print("Question Options:", example_question['options'])
print("Question Answer:", example_question['answer'])

Question Dict:  {'question': 'Does head positioning influence anterior chamber depth in pseudoexfoliation syndrome?', 'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'}, 'answer': 'A', 'PMID': [10877371], 'entities': [{'entity_literal': 'head', 'start_position': 5, 'end_position': 9, 'entity_type': 'PATHOLOGICAL_FORMATION'}, {'entity_literal': 'anterior chamber', 'start_position': 32, 'end_position': 48, 'entity_type': 'MULTI_TISSUE_STRUCTURE'}, {'entity_literal': 'influence anterior chamber depth', 'start_position': 22, 'end_position': 54, 'entity_type': 'HEAD'}]}
Question Text: Does head positioning influence anterior chamber depth in pseudoexfoliation syndrome?
Question Options: {'A': 'yes', 'B': 'no', 'C': 'maybe'}
Question Answer: A


In [None]:
from copy import deepcopy

prepared_questions = []

for question_index, question in enumerate(tqdm(mirage_questions_with_entity_mentions)):

  current_clean_question = deepcopy(mirage_questions_with_entity_mentions[question_index])

  entity_selector.encode_and_rank_candidates(current_clean_question['entities'])
  entities_for_rag = entity_selector.get_top_candidates(top_k=1)
  entity_mention_iris_for_rag = []
  for mention in entities_for_rag:
    entity_mention_iris_for_rag.append(mention[1])

  # obtain axiom verbalisations (or produce `concept cards`) for each IRI for prompt enrichment

  axiom_verbalisations = []

  for iri in entity_mention_iris_for_rag:
    if iri == "owl:Thing": # Thing is a Thing, not sure how this ended up in the IRI field..?
      continue
    label: str = snomed_concept_information[iri]['label']
    subclass_axioms: list[str] = snomed_concept_information[iri]['verbalization']['subclass_of']
    equiv_axioms: list[str] = snomed_concept_information[iri]['verbalization']['equivalent_to']
    for idx, axiom in enumerate(subclass_axioms):
      axiom_verbalisations.append(f"{label}: {subclass_axioms[idx]}")
    for idx, axiom in enumerate(equiv_axioms):
      axiom_verbalisations.append(f"{label}: {equiv_axioms[idx]}")

  # bind the axioms verbalisations to the question object for prompt injection

  current_clean_question['axioms'] = axiom_verbalisations

  prepared_questions.append(current_clean_question)

  if question_index > 10:
    print("STOPPING EARLY!")
    break

  parent_emb_t = torch.Tensor(parent_emd)
  2%|▏         | 11/500 [00:13<10:16,  1.26s/it]

STOPPING EARLY!





In [7]:
prepared_questions

[{'question': 'Is anorectal endosonography valuable in dyschesia?',
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'A',
  'PMID': [12377809],
  'entities': [{'entity_literal': 'anorectal',
    'start_position': 3,
    'end_position': 12,
    'entity_type': 'PATHOLOGICAL_FORMATION'},
   {'entity_literal': 'anorectal endosonography valuable in dyschesia',
    'start_position': 3,
    'end_position': 49,
    'entity_type': 'HEAD'}],
  'axioms': ['Dose form intended site: is a type of Qualifier value']},
 {'question': 'Is there a connection between sublingual varices and hypertension?',
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'A',
  'PMID': [26163474],
  'entities': [{'entity_literal': 'sublingual',
    'start_position': 30,
    'end_position': 40,
    'entity_type': 'ORGAN'},
   {'entity_literal': 'a connection',
    'start_position': 9,
    'end_position': 21,
    'entity_type': 'HEAD'}],
  'axioms': ['Sublingual: is a type of Sub-location']},
 {'

In [8]:
print(prepared_questions[0]['axioms'])

['Dose form intended site: is a type of Qualifier value']


In [9]:
new_text = mistral_lm.generate_mc_letter('mirage_mcqa_axiom_rag', prepared_questions[0])
print(new_text)

A


In [10]:
for q_idx, question in enumerate(prepared_questions):
  ans = mistral_lm.generate_mc_letter('mirage_mcqa_axiom_rag', question)
  print(f"Answer to question {q_idx}: {ans}")
  print("-" * 20)
  print(f"The answer is: {question['answer']}")
  if len(question['axioms']) > 0:
    print(f"(included axiom: {question['axioms'][0]})")
  else:
    print(f"(NO AXIOMS)")
  print("-" * 20)
  

Answer to question 0: A
--------------------
The answer is: A
(included axiom: Dose form intended site: is a type of Qualifier value)
--------------------
Answer to question 1: A
--------------------
The answer is: A
(included axiom: Sublingual: is a type of Sub-location)
--------------------
Answer to question 2: A
--------------------
The answer is: A
(NO AXIOMS)
--------------------
Answer to question 3: A
--------------------
The answer is: A
(included axiom: Patient: is a type of Person in healthcare environment)
--------------------
Answer to question 4: A
--------------------
The answer is: A
(NO AXIOMS)
--------------------
Answer to question 5: A
--------------------
The answer is: A
(included axiom: Intracranial: is a type of Intra-location)
--------------------
Answer to question 6: A
--------------------
The answer is: A
(included axiom: Woman: is a type of Adult)
--------------------
Answer to question 7: A
--------------------
The answer is: A
(NO AXIOMS)
----------------

In [11]:
import random
from copy import deepcopy

# for question_i in mirage_questions_with_entity_mentions:
#   print(question_i['answer'])

def jumble_dict_list(dict_list: list[dict]) -> list[dict]:
  shuffled_list = deepcopy(dict_list[:])
  random.shuffle(shuffled_list)
  return shuffled_list

print("-" * 72)
print("JUMBLING...")
print("-" * 72)

jumbled_questions = jumble_dict_list(mirage_questions_with_entity_mentions)

# for question_i in jumbled_questions:
#   print(question_i['answer'])

------------------------------------------------------------------------
JUMBLING...
------------------------------------------------------------------------


In [12]:
jumbled_questions

[{'question': 'Are the long-term results of the transanal pull-through equal to those of the transabdominal pull-through?',
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'B',
  'PMID': [17208539],
  'entities': [{'entity_literal': 'the long-term results of the transanal pull-through equal to those of the transabdominal pull-through',
    'start_position': 4,
    'end_position': 105,
    'entity_type': 'HEAD'}]},
 {'question': "Midwives' competence: is it affected by working in a rural location?",
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'C',
  'PMID': [17691856],
  'entities': [{'entity_literal': "Midwives' competence:",
    'start_position': 0,
    'end_position': 21,
    'entity_type': 'HEAD'}]},
 {'question': 'Is hypoalbuminemia an independent prognostic factor in patients with gastric cancer?',
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'B',
  'PMID': [20602101],
  'entities': [{'entity_literal': 'patients',
    'start_p

In [13]:
from copy import deepcopy

permute_options = True
shuffle_options = True
jumbled_prepared_questions = []

for question_index, question in enumerate(tqdm(jumbled_questions)):

  current_clean_question = deepcopy(jumbled_questions[question_index])

  entity_selector.encode_and_rank_candidates(current_clean_question['entities'])
  entities_for_rag = entity_selector.get_top_candidates(top_k=1)
  entity_mention_iris_for_rag = []
  for mention in entities_for_rag:
    entity_mention_iris_for_rag.append(mention[1])

  # obtain axiom verbalisations (or produce `concept cards`) for each IRI for prompt enrichment

  axiom_verbalisations = []

  for iri in entity_mention_iris_for_rag:
    if iri == "owl:Thing": # Thing is a Thing, not sure how this ended up in the IRI field..?
      continue
    label: str = snomed_concept_information[iri]['label']
    subclass_axioms: list[str] = snomed_concept_information[iri]['verbalization']['subclass_of']
    equiv_axioms: list[str] = snomed_concept_information[iri]['verbalization']['equivalent_to']
    for idx, axiom in enumerate(subclass_axioms):
      axiom_verbalisations.append(f"{label}: {subclass_axioms[idx]}")
    for idx, axiom in enumerate(equiv_axioms):
      axiom_verbalisations.append(f"{label}: {equiv_axioms[idx]}")

  # bind the axioms verbalisations to the question object for prompt injection

  current_clean_question['axioms'] = axiom_verbalisations

  if permute_options:
    randomised_options, new_answer_key = permute_mcqa_options(question['options'], question['answer'])
    del current_clean_question['options']
    del current_clean_question['answer']
    current_clean_question['options'] = randomised_options
    current_clean_question['answer'] = new_answer_key

  if shuffle_options:
    question['shuffle'] = True

  jumbled_prepared_questions.append(current_clean_question)

  if question_index > 10:
    print("STOPPING EARLY!")
    break

  2%|▏         | 11/500 [00:09<06:44,  1.21it/s]

STOPPING EARLY!





In [14]:
for q_idx, question in enumerate(jumbled_prepared_questions):
  ans = mistral_lm.generate_mc_letter('mirage_mcqa_axiom_rag', question)
  print(f"Answer to question {q_idx}: {ans}")
  print("-" * 20)
  print(f"The answer is: {question['answer']}")
  if len(question['axioms']) > 0:
    print(f"(included axiom: {question['axioms'][0]})")
  else:
    print(f"(NO AXIOMS)")
  print("-" * 20)
  

Answer to question 0: C
--------------------
The answer is: B
(included axiom: Procedure: is a type of SNOMED CT Concept)
--------------------
Answer to question 1: A
--------------------
The answer is: C
(included axiom: Staging and scales: is a type of SNOMED CT Concept)
--------------------
Answer to question 2: B
--------------------
The answer is: C
(included axiom: Patient: is a type of Person in healthcare environment)
--------------------
Answer to question 3: C
--------------------
The answer is: C
(included axiom: Neutrophil inclusion: is a type of Cytoplasmic inclusion)
--------------------
Answer to question 4: B
--------------------
The answer is: B
(included axiom: Food: is a type of Edible substance)
--------------------
Answer to question 5: A
--------------------
The answer is: A
(included axiom: Disease: is a type of Clinical finding)
--------------------
Answer to question 6: C
--------------------
The answer is: A
(NO AXIOMS)
--------------------
Answer to question 

In [15]:
jumbled_questions

[{'question': 'Are the long-term results of the transanal pull-through equal to those of the transabdominal pull-through?',
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'B',
  'PMID': [17208539],
  'entities': [{'entity_literal': 'the long-term results of the transanal pull-through equal to those of the transabdominal pull-through',
    'start_position': 4,
    'end_position': 105,
    'entity_type': 'HEAD'}],
  'shuffle': True},
 {'question': "Midwives' competence: is it affected by working in a rural location?",
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'C',
  'PMID': [17691856],
  'entities': [{'entity_literal': "Midwives' competence:",
    'start_position': 0,
    'end_position': 21,
    'entity_type': 'HEAD'}],
  'shuffle': True},
 {'question': 'Is hypoalbuminemia an independent prognostic factor in patients with gastric cancer?',
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'B',
  'PMID': [20602101],
  'entities': [{'ent

In [16]:
jumbled_prepared_questions

[{'question': 'Are the long-term results of the transanal pull-through equal to those of the transabdominal pull-through?',
  'PMID': [17208539],
  'entities': [{'entity_literal': 'the long-term results of the transanal pull-through equal to those of the transabdominal pull-through',
    'start_position': 4,
    'end_position': 105,
    'entity_type': 'HEAD'}],
  'axioms': ['Procedure: is a type of SNOMED CT Concept'],
  'options': {'A': 'maybe', 'B': 'no', 'C': 'yes'},
  'answer': 'B'},
 {'question': "Midwives' competence: is it affected by working in a rural location?",
  'PMID': [17691856],
  'entities': [{'entity_literal': "Midwives' competence:",
    'start_position': 0,
    'end_position': 21,
    'entity_type': 'HEAD'}],
  'axioms': ['Staging and scales: is a type of SNOMED CT Concept'],
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'C'},
 {'question': 'Is hypoalbuminemia an independent prognostic factor in patients with gastric cancer?',
  'PMID': [20602101],


In [17]:
for q_idx, question in enumerate(jumbled_prepared_questions):
  ans = mistral_lm.generate_mc_letter('mirage_mcqa_axiom_rag', question)
  print(f"Answer to question {q_idx}: {ans}")
  print(f"Correct Answer: {jumbled_prepared_questions[q_idx]['answer']}")
  print("-" * 20)
  print("\n")
  # print("-" * 20)
  # print(f"The answer is: {question['answer']}")
  # if len(question['axioms']) > 0:
  #   print(f"(included axiom: {question['axioms'][0]})")
  # else:
  #   print(f"(NO AXIOMS)")
  # print("-" * 20)
  

Answer to question 0: C
Correct Answer: B
--------------------


Answer to question 1: A
Correct Answer: C
--------------------


Answer to question 2: B
Correct Answer: C
--------------------


Answer to question 3: C
Correct Answer: C
--------------------


Answer to question 4: B
Correct Answer: B
--------------------


Answer to question 5: A
Correct Answer: A
--------------------


Answer to question 6: C
Correct Answer: A
--------------------


Answer to question 7: A
Correct Answer: A
--------------------


Answer to question 8: A
Correct Answer: C
--------------------


Answer to question 9: C
Correct Answer: A
--------------------


Answer to question 10: C
Correct Answer: C
--------------------


Answer to question 11: A
Correct Answer: A
--------------------




In [18]:
correct = 0
incorrect = 0

In [None]:
import random

def shuffle_dict(d: dict, *, seed: int | None = None) -> dict:
  rng = random.Random(seed) if seed is not None else random
  keys = list(d.keys())
  rng.shuffle(keys)
  return {k: d[k] for k in keys}

# Example usage
data = {
    "A": "apple",
    "B": "banana",
    "C": "cherry",
    "D": "date"
}

shuffled = shuffle_dict(data)
print(shuffled)  # Keys will appear in random order
print(shuffled.keys())

{'A': 'apple', 'C': 'cherry', 'B': 'banana', 'D': 'date'}
dict_keys(['A', 'C', 'B', 'D'])


In [None]:
from llm_utils_improvements import prompt_template_with_axioms

idx = 7
another_example_question = jumbled_prepared_questions[idx]
print(another_example_question)
print("\n\n")
print(prompt_template_with_axioms(another_example_question['question'], another_example_question['options'], another_example_question['axioms'], shuffle=True))

{'question': 'Can magnetic resonance imaging accurately predict concordant pain provocation during provocative disc injection?', 'PMID': [19430778], 'entities': [{'entity_literal': 'magnetic resonance imaging', 'start_position': 4, 'end_position': 30, 'entity_type': 'HEAD'}], 'axioms': ['MRI: defined as Procedure with method: Magnetic resonance imaging.'], 'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'}, 'answer': 'A'}



You are a helpful medical expert, your task is to answer a multi-choice medical question.
Return only the letter of the best answer.

Helpful context:
MRI: defined as Procedure with method: Magnetic resonance imaging.

Here is the question:
Can magnetic resonance imaging accurately predict concordant pain provocation during provocative disc injection?

Here are the potential choices:
A. yes
C. maybe
B. no

Answer (letter only): 


# SETUP

In [1]:
import torch
from pathlib import Path
from data_utils import load_json, xs_of_all_questions, merge_entity_mentions
from llm_utils_improvements import (
    MistralLLM,
    prompt_template_no_rag, prompt_template_with_axioms,
    chat_prompt_template_no_rag, chat_prompt_template_with_axioms,
    BaseEntitySelector,
    SubsumptionEntitySelector,
    ApproximateNearestNeighbourEntitySelector,
    SimilarityEntitySelector
)
from retrievers import HiTRetriever, SBERTRetriever
from math_functools import entity_subsumption, batch_poincare_dist_with_adaptive_curv_k, batch_cosine_similarity
from tqdm import tqdm

In [2]:
#### Config (quick A/B toggles)
USE_RAG = True
USE_CHAT_TEMPLATES = True          # << switch this to False to A/B against the old path
SHUFFLE_DISPLAY_ORDER = True       # shuffles how choices are displayed (letters stay tied to the same text)
FORCE_DTYPE_AUTO = True            # set True if bf16 isn’t supported

In [None]:
# Data
mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))

allowable_datasets = ['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']
mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets)
mirage_questions_with_entity_mentions = merge_entity_mentions(
    mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets
)

# Model
mistral_llm_version = "BioMistral/BioMistral-7B"
mistral_lm = MistralLLM(mistral_llm_version)

dtype = "auto" if FORCE_DTYPE_AUTO else torch.bfloat16
mistral_lm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype=dtype,
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_lm._tokenizer.pad_token_id,
    eos_token_id=mistral_lm._tokenizer.eos_token_id
)

# Register BOTH string and chat templates so we can A/B with a flag
mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag", prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag", prompt_template_with_axioms)
mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

print("Loaded!")
print(f"LLM: {mistral_llm_version} | RAG: {USE_RAG} | Chat templates: {USE_CHAT_TEMPLATES}")

Processing medqa ... 


100%|██████████| 1273/1273 [00:00<00:00, 17109.35it/s]


Processing medmcqa ... 


100%|██████████| 4183/4183 [00:04<00:00, 941.38it/s] 


Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 704.26it/s]


Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 883.30it/s] 


Processing mmlu ... 


100%|██████████| 1089/1089 [00:01<00:00, 790.03it/s] 


Loaded!
LLM: BioMistral/BioMistral-7B | RAG: True | Chat templates: True


In [None]:
# HiT retriever (for RAG)
embeddings_dir = "./embeddings"
common_map = Path("./embeddings/axiom-mappings.json")
common_verbalisations = Path("./embeddings/axiom-verbalisations.json")
hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

hit_ret_snomed_25_w_ent_sub = HiTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=entity_subsumption
)

hit_ret_snomed_25_w_hyp_dist = HiTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

sbert_ret_cosine_sim = SBERTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_str="all-MiniLM-L12-v2",
    score_fn=batch_cosine_similarity
)

In [5]:
import re

snomed_concept_information: dict = load_json(Path("./data/snomed_axioms.json"))

COMMON_STOPWORDS = {
  "the","a","an","and","or","of","to","in","on","for","with","without","from",
  "by","at","as","is","are","was","were","be","been","being","that","which",
  "this","these","those","it","its","their","there","then","than","such"
}

SNOMED_CONCEPTS_NEAR_TOP = {
  "clinical finding","finding","disease","disorder","procedure","event", "body structure", "organism",
  "substance", "situation with explicit context","qualifier value","morphologic abnormality"
}

def lexical_repr_without_stopwords(s: str) -> set[str]:
  """set difference between the input string and 'common stopwords'"""
  return set(re.findall(r"[a-z]+", s.lower())) - COMMON_STOPWORDS

def lexical_overlap_score(input: str, reference_set: set[str]) -> int:
  """set (lexical) intersection between input_string and the reference vocab (typically, the question)"""
  return len(lexical_repr_without_stopwords(input) & reference_set)

def is_broad_concept(rdfs_label: str) -> bool:
  """basic lexical operation to check whether an rdfs_label is a high level SNOMED CT concept"""
  return (rdfs_label.split(" (", 1)[0]).lower() in SNOMED_CONCEPTS_NEAR_TOP

def select_axioms_for_prompt(axioms: list[str], question_text: str, *, min_overlap: int = 1) -> list[str]:
  """Return axioms that exceed the lexical overlap of `min_overlap` with the `question_text`"""
  clean_question_text = lexical_repr_without_stopwords(question_text)
  retained_axioms = []
  for this_axiom in axioms:
    lex_score = lexical_overlap_score(this_axiom, clean_question_text)
    if lex_score >= min_overlap:
      retained_axioms.append(this_axiom)
  return retained_axioms

def get_axiom_verbalisations(iris: list[str], *, max_verbalisations: int = 3, question_text: str, min_overlap: int = 1) -> list[str]:
  returnable_verbalisations = []
  for iri in iris: # skip any 'owl:Thing'(s)
    if iri == "owl:Thing": # Though these shouldn't be in IRI
      continue
    snomed_concept = snomed_concept_information[iri]
    label = snomed_concept['label'] # rdfs_label
    if is_broad_concept(label):
      continue
    verbalisations = snomed_concept['verbalization']
    subclass_axioms = verbalisations['subclass_of'][:max_verbalisations]
    equiv_axioms = verbalisations['equivalent_to'][:max_verbalisations]
    # this is pretty horrible, TODO: make this a little less, difficult to read ...
    complete_subclass_verbalisations = [f"{label} {verb_str}" for verb_str in subclass_axioms] if subclass_axioms is not [] else []
    complete_equiv_verbalisations = [f"{label} {verb_str}" for verb_str in equiv_axioms] if equiv_axioms is not [] else []
    returnable_verbalisations = [*complete_subclass_verbalisations, *complete_equiv_verbalisations]
  return select_axioms_for_prompt(
    returnable_verbalisations,
    question_text,
    min_overlap=min_overlap
  )

def per_question_seed(dataset: str, q: dict, idx: int) -> int:
    qid = q.get("id", idx)
    return (hash((dataset, str(qid))) & 0xFFFFFFFF)

# MANUAL

In [6]:
def count_predictions(incorrect_questions):
  selected_options = set()
  predictions = []
  for response in incorrect_questions:
    selected_options.add(response['pred'])
    predictions.append(response['pred'])
  counts = {}
  for selected_option in selected_options:
    counts[selected_option] = 0
  for prediction in predictions:
    counts[prediction] += 1
  return counts

In [7]:
def count_actual_answers(incorrect_questions):
  answer_set = set()
  answers = []
  for response in incorrect_questions:
    answer_set.add(response['question']['answer'])
    answers.append(response['question']['answer'])
  counts = {}
  for possible_answer in answer_set:
    counts[possible_answer] = 0
  for answer in answers:
    counts[answer] += 1
  return counts

In [8]:
#### Config (quick A/B toggles)
USE_RAG = False
USE_CHAT_TEMPLATES = True          # << switch this to False to A/B against the old path
SHUFFLE_DISPLAY_ORDER = True       # shuffles how choices are displayed (letters stay tied to the same text)
FORCE_DTYPE_AUTO = True            # set True if bf16 isn’t supported

In [None]:
USE_RAG = False
SHUFFLE_DISPLAY_ORDER = True
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

datasets_within_mirage_to_process = ['pubmedqa']

for dataset_name in datasets_within_mirage_to_process:
    
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
      mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)
    template_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"

    for question_index, question in enumerate(tqdm(mirage_questions_with_entity_mentions)):

      question['axioms'] = []

      if USE_RAG:
        entity_selector.encode_and_rank_candidates(question['entities'], ret_k=RET_K, append_k_per_entity=APPEND_K)
        top_candidate = entity_selector._all_mention_results[0] if len(entity_selector._all_mention_results) > 0 else []
        if top_candidate:
          rank, iri, score, label = top_candidate
          axioms = get_axiom_verbalisations(
            [iri],
            max_verbalisations=5,
            question_text=question['question'],
            min_overlap=2
          )
          question['axioms'] = axioms if axioms else []
        # end: if
      # end: if

      shuffle_seed = per_question_seed(dataset_name, question, question_index)
      
      question['shuffle'] = SHUFFLE_DISPLAY_ORDER
      question['seed'] = shuffle_seed

      response_letter = mistral_lm.generate_mcqa_letter(
        "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat",
        question,
        max_new_tokens=1
      )

      if response_letter == question['answer']:
        correct_questions += 1
      else:
        incorrect_questions += 1

    total = len(mirage_questions_with_entity_mentions)
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")


Processing pubmedqa ...
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1364.56it/s]
100%|██████████| 500/500 [02:03<00:00,  4.05it/s]

Total correct:   196
Total incorrect: 304
Accuracy:        39.2%


Incorrect Answer Distribution: {}
Actual Answer Distribution: {}





In [None]:
USE_RAG = False
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

datasets_within_mirage_to_process = ['bioasq']

for dataset_name in datasets_within_mirage_to_process:
    
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
      mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)
    template_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"

    for question_index, question in enumerate(tqdm(mirage_questions_with_entity_mentions)):

      question['axioms'] = []

      if USE_RAG:
        entity_selector.encode_and_rank_candidates(question['entities'], ret_k=RET_K, append_k_per_entity=APPEND_K)
        top_candidate = entity_selector._all_mention_results[0] if len(entity_selector._all_mention_results) > 0 else []
        if top_candidate:
          rank, iri, score, label = top_candidate
          axioms = get_axiom_verbalisations(
            [iri],
            max_verbalisations=5,
            question_text=question['question'],
            min_overlap=2
          )
          question['axioms'] = axioms if axioms else []
        # end: if
      # end: if

      shuffle_seed = per_question_seed(dataset_name, question, question_index)
      
      question['shuffle'] = SHUFFLE_DISPLAY_ORDER
      question['seed'] = shuffle_seed

      response_letter = mistral_lm.generate_mcqa_letter(
        "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat",
        question,
        max_new_tokens=1
      )

      if response_letter == question['answer']:
        correct_questions += 1
      else:
        incorrect_questions += 1

    total = len(mirage_questions_with_entity_mentions)
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")


Processing bioasq ...
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 1261.72it/s]
100%|██████████| 618/618 [02:25<00:00,  4.23it/s]

Total correct:   350
Total incorrect: 268
Accuracy:        56.63%







In [None]:
USE_RAG = False
SHUFFLE_DISPLAY_ORDER = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

datasets_within_mirage_to_process = ['mmlu']

for dataset_name in datasets_within_mirage_to_process:
    
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
      mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)
    template_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"

    for question_index, question in enumerate(tqdm(mirage_questions_with_entity_mentions)):

      question['axioms'] = []

      if USE_RAG:
        entity_selector.encode_and_rank_candidates(question['entities'], ret_k=RET_K, append_k_per_entity=APPEND_K)
        top_candidate = entity_selector._all_mention_results[0] if len(entity_selector._all_mention_results) > 0 else []
        if top_candidate:
          rank, iri, score, label = top_candidate
          axioms = get_axiom_verbalisations(
            [iri],
            max_verbalisations=5,
            question_text=question['question'],
            min_overlap=2
          )
          question['axioms'] = axioms if axioms else []
        # end: if
      # end: if

      shuffle_seed = per_question_seed(dataset_name, question, question_index)
      
      question['shuffle'] = SHUFFLE_DISPLAY_ORDER
      question['seed'] = shuffle_seed

      response_letter = mistral_lm.generate_mcqa_letter(
        "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat",
        question,
        max_new_tokens=1
      )

      if response_letter == question['answer']:
        correct_questions += 1
      else:
        incorrect_questions += 1

    total = len(mirage_questions_with_entity_mentions)
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")


Processing mmlu ...
Processing mmlu ... 


100%|██████████| 1089/1089 [00:02<00:00, 540.03it/s]
100%|██████████| 1089/1089 [04:28<00:00,  4.05it/s]

Total correct:   560
Total incorrect: 529
Accuracy:        51.42%







In [None]:
import torch
from pathlib import Path
from data_utils import load_json, xs_of_all_questions, merge_entity_mentions
from llm_utils_improvements import (
    MistralLLM,
    prompt_template_no_rag, prompt_template_with_axioms,
    chat_prompt_template_no_rag, chat_prompt_template_with_axioms,
    SubsumptionEntitySelector,
    ApproximateNearestNeighbourEntitySelector,
    SimilarityEntitySelector
)
from retrievers import HiTRetriever, SBERTRetriever
from math_functools import entity_subsumption, batch_poincare_dist_with_adaptive_curv_k, batch_cosine_similarity
from tqdm import tqdm

USE_RAG = True
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1       

# HiT retriever (for RAG)
embeddings_dir = "./embeddings"
common_map = Path("./embeddings/axiom-mappings.json")
common_verbalisations = Path("./embeddings/axiom-verbalisations.json")
hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

sbert_ret_cosine_sim = SBERTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=batch_cosine_similarity
)

datasets_within_mirage_to_process = ['pubmedqa']

for dataset_name in datasets_within_mirage_to_process:
    
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
      mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)
    template_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"

    for question_index, question in enumerate(tqdm(mirage_questions_with_entity_mentions)):

      question['axioms'] = []

      if USE_RAG:
        entity_selector.encode_and_rank_candidates(question['entities'], ret_k=RET_K, append_k_per_entity=APPEND_K)
        top_candidate = entity_selector._all_mention_results[0] if len(entity_selector._all_mention_results) > 0 else []
        if top_candidate:
          rank, iri, score, label = top_candidate
          axioms = get_axiom_verbalisations(
            [iri],
            max_verbalisations=5,
            question_text=question['question'],
            min_overlap=2
          )
          question['axioms'] = axioms if axioms else []
        # end: if
      # end: if

      shuffle_seed = per_question_seed(dataset_name, question, question_index)
      
      question['shuffle'] = SHUFFLE_DISPLAY_ORDER
      question['seed'] = shuffle_seed

      response_letter = mistral_lm.generate_mcqa_letter(
        "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat",
        question,
        max_new_tokens=1
      )

      if response_letter == question['answer']:
        correct_questions += 1
      else:
        incorrect_questions += 1

    total = len(mirage_questions_with_entity_mentions)
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")


Processing pubmedqa ...
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1020.61it/s]
100%|██████████| 500/500 [02:14<00:00,  3.73it/s]

Total correct:   217
Total incorrect: 283
Accuracy:        43.4%


Incorrect Answer Distribution: {}
Actual Answer Distribution: {}





In [15]:
USE_RAG = True
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

datasets_within_mirage_to_process = ['bioasq']

for dataset_name in datasets_within_mirage_to_process:
    
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
      mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)
    template_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"

    for question_index, question in enumerate(tqdm(mirage_questions_with_entity_mentions)):

      question['axioms'] = []

      if USE_RAG:
        entity_selector.encode_and_rank_candidates(question['entities'], ret_k=RET_K, append_k_per_entity=APPEND_K)
        top_candidate = entity_selector._all_mention_results[0] if len(entity_selector._all_mention_results) > 0 else []
        if top_candidate:
          rank, iri, score, label = top_candidate
          axioms = get_axiom_verbalisations(
            [iri],
            max_verbalisations=5,
            question_text=question['question'],
            min_overlap=2
          )
          question['axioms'] = axioms if axioms else []
        # end: if
      # end: if

      shuffle_seed = per_question_seed(dataset_name, question, question_index)
      
      question['shuffle'] = SHUFFLE_DISPLAY_ORDER
      question['seed'] = shuffle_seed

      response_letter = mistral_lm.generate_mcqa_letter(
        "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat",
        question,
        max_new_tokens=1
      )

      if response_letter == question['answer']:
        correct_questions += 1
      else:
        incorrect_questions += 1

    total = len(mirage_questions_with_entity_mentions)
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")


Processing bioasq ...
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 665.72it/s] 
100%|██████████| 618/618 [02:03<00:00,  4.98it/s]

Total correct:   336
Total incorrect: 282
Accuracy:        54.37%







In [16]:
USE_RAG = True
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

datasets_within_mirage_to_process = ['mmlu']

for dataset_name in datasets_within_mirage_to_process:
    
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
      mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)
    template_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"

    for question_index, question in enumerate(tqdm(mirage_questions_with_entity_mentions)):

      question['axioms'] = []

      if USE_RAG:
        entity_selector.encode_and_rank_candidates(question['entities'], ret_k=RET_K, append_k_per_entity=APPEND_K)
        top_candidate = entity_selector._all_mention_results[0] if len(entity_selector._all_mention_results) > 0 else []
        if top_candidate:
          rank, iri, score, label = top_candidate
          axioms = get_axiom_verbalisations(
            [iri],
            max_verbalisations=5,
            question_text=question['question'],
            min_overlap=2
          )
          question['axioms'] = axioms if axioms else []
        # end: if
      # end: if

      shuffle_seed = per_question_seed(dataset_name, question, question_index)
      
      question['shuffle'] = SHUFFLE_DISPLAY_ORDER
      question['seed'] = shuffle_seed

      response_letter = mistral_lm.generate_mcqa_letter(
        "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat",
        question,
        max_new_tokens=1
      )

      if response_letter == question['answer']:
        correct_questions += 1
      else:
        incorrect_questions += 1

    total = len(mirage_questions_with_entity_mentions)
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")


Processing mmlu ...
Processing mmlu ... 


100%|██████████| 1089/1089 [00:00<00:00, 1148.55it/s]
100%|██████████| 1089/1089 [05:21<00:00,  3.39it/s]

Total correct:   558
Total incorrect: 531
Accuracy:        51.24%







In [None]:
# # 66 -> 68, 56 -> 56, 50 -> 50

# question_with_axioms = 0
# for q in mirage_questions_with_entity_mentions:
#     print(q)

In [1]:
import torch
from pathlib import Path
from data_utils import load_json, xs_of_all_questions, merge_entity_mentions
from llm_utils_improvements import (
    MistralLLM,
    prompt_template_no_rag, prompt_template_with_axioms,
    chat_prompt_template_no_rag, chat_prompt_template_with_axioms,
    BaseEntitySelector,
    SubsumptionEntitySelector,
    ApproximateNearestNeighbourEntitySelector,
    SimilarityEntitySelector
)
from retrievers import HiTRetriever, SBERTRetriever
from math_functools import entity_subsumption, batch_poincare_dist_with_adaptive_curv_k, batch_cosine_similarity
from tqdm import tqdm

#### Config (quick A/B toggles)
USE_RAG = True
USE_CHAT_TEMPLATES = True          # << switch this to False to A/B against the old path
SHUFFLE_DISPLAY_ORDER = True       # shuffles how choices are displayed (letters stay tied to the same text)
FORCE_DTYPE_AUTO = True            # set True if bf16 isn’t supported

# Data
mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))

allowable_datasets = ['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']
mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets)
mirage_questions_with_entity_mentions = merge_entity_mentions(
    mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets
)

# Model
mistral_llm_version = "mistralai/Mistral-7B-Instruct-v0.3"
mistral_lm = MistralLLM(mistral_llm_version)

dtype = "auto" if FORCE_DTYPE_AUTO else torch.bfloat16
mistral_lm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype=dtype,
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_lm._tokenizer.pad_token_id,
    eos_token_id=mistral_lm._tokenizer.eos_token_id
)

# Register BOTH string and chat templates so we can A/B with a flag
mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag", prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag", prompt_template_with_axioms)
mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

print("Loaded!")
print(f"LLM: {mistral_llm_version} | RAG: {USE_RAG} | Chat templates: {USE_CHAT_TEMPLATES}")

# HiT retriever (for RAG)
embeddings_dir = "./embeddings"
common_map = Path("./embeddings/axiom-mappings.json")
common_verbalisations = Path("./embeddings/axiom-verbalisations.json")
hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

hit_ret_snomed_25_w_ent_sub = HiTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=entity_subsumption
)

hit_ret_snomed_25_w_hyp_dist = HiTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

sbert_ret_cosine_sim = SBERTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=batch_cosine_similarity
)

import re

snomed_concept_information: dict = load_json(Path("./data/snomed_axioms.json"))

COMMON_STOPWORDS = {
  "the","a","an","and","or","of","to","in","on","for","with","without","from",
  "by","at","as","is","are","was","were","be","been","being","that","which",
  "this","these","those","it","its","their","there","then","than","such"
}

SNOMED_CONCEPTS_NEAR_TOP = {
  "clinical finding","finding","disease","disorder","procedure","event", "body structure", "organism",
  "substance", "situation with explicit context","qualifier value","morphologic abnormality"
}

def lexical_repr_without_stopwords(s: str) -> set[str]:
  """set difference between the input string and 'common stopwords'"""
  return set(re.findall(r"[a-z]+", s.lower())) - COMMON_STOPWORDS

def lexical_overlap_score(input: str, reference_set: set[str]) -> int:
  """set (lexical) intersection between input_string and the reference vocab (typically, the question)"""
  return len(lexical_repr_without_stopwords(input) & reference_set)

def is_broad_concept(rdfs_label: str) -> bool:
  """basic lexical operation to check whether an rdfs_label is a high level SNOMED CT concept"""
  return (rdfs_label.split(" (", 1)[0]).lower() in SNOMED_CONCEPTS_NEAR_TOP

def select_axioms_for_prompt(axioms: list[str], question_text: str, *, min_overlap: int = 1) -> list[str]:
  """Return axioms that exceed the lexical overlap of `min_overlap` with the `question_text`"""
  clean_question_text = lexical_repr_without_stopwords(question_text)
  retained_axioms = []
  for this_axiom in axioms:
    lex_score = lexical_overlap_score(this_axiom, clean_question_text)
    if lex_score >= min_overlap:
      retained_axioms.append(this_axiom)
  return retained_axioms

def get_axiom_verbalisations(iris: list[str], *, max_verbalisations: int = 3, question_text: str, min_overlap: int = 1) -> list[str]:
  returnable_verbalisations = []
  for iri in iris: # skip any 'owl:Thing'(s)
    if iri == "owl:Thing": # Though these shouldn't be in IRI
      continue
    snomed_concept = snomed_concept_information[iri]
    label = snomed_concept['label'] # rdfs_label
    if is_broad_concept(label):
      continue
    verbalisations = snomed_concept['verbalization']
    subclass_axioms = verbalisations['subclass_of'][:max_verbalisations]
    equiv_axioms = verbalisations['equivalent_to'][:max_verbalisations]
    # this is pretty horrible, TODO: make this a little less, difficult to read ...
    complete_subclass_verbalisations = [f"{label} {verb_str}" for verb_str in subclass_axioms] if subclass_axioms is not [] else []
    complete_equiv_verbalisations = [f"{label} {verb_str}" for verb_str in equiv_axioms] if equiv_axioms is not [] else []
    returnable_verbalisations = [*complete_subclass_verbalisations, *complete_equiv_verbalisations]
  return select_axioms_for_prompt(
    returnable_verbalisations,
    question_text,
    min_overlap=min_overlap
  )

def per_question_seed(dataset: str, q: dict, idx: int) -> int:
    qid = q.get("id", idx)
    return (hash((dataset, str(qid))) & 0xFFFFFFFF)

def count_predictions(incorrect_questions):
  selected_options = set()
  predictions = []
  for response in incorrect_questions:
    selected_options.add(response['pred'])
    predictions.append(response['pred'])
  counts = {}
  for selected_option in selected_options:
    counts[selected_option] = 0
  for prediction in predictions:
    counts[prediction] += 1
  return counts

def count_actual_answers(incorrect_questions):
  answer_set = set()
  answers = []
  for response in incorrect_questions:
    answer_set.add(response['question']['answer'])
    answers.append(response['question']['answer'])
  counts = {}
  for possible_answer in answer_set:
    counts[possible_answer] = 0
  for answer in answers:
    counts[answer] += 1
  return counts

Processing medqa ... 


100%|██████████| 1273/1273 [00:00<00:00, 17179.48it/s]


Processing medmcqa ... 


100%|██████████| 4183/4183 [00:02<00:00, 1459.86it/s]


Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1533.22it/s]


Processing bioasq ... 


100%|██████████| 618/618 [00:01<00:00, 590.47it/s] 


Processing mmlu ... 


100%|██████████| 1089/1089 [00:01<00:00, 706.82it/s]


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded!
LLM: mistralai/Mistral-7B-Instruct-v0.3 | RAG: True | Chat templates: True


In [2]:
USE_RAG = False
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['pubmedqa']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
            
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing pubmedqa ...
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 672.27it/s]
100%|██████████| 100/100 [00:25<00:00,  3.97it/s]

Total correct:   25
Total incorrect: 75
Accuracy:        25.0%


Incorrect Answer Distribution: {'C': 68, 'B': 7}
Actual Answer Distribution: {'A': 75}





In [3]:
USE_RAG = False
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['bioasq']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
                
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing bioasq ...
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 820.19it/s] 
100%|██████████| 100/100 [00:29<00:00,  3.45it/s]

Total correct:   66
Total incorrect: 34
Accuracy:        66.0%


Incorrect Answer Distribution: {'A': 13, 'B': 21}
Actual Answer Distribution: {'A': 21, 'B': 13}





In [4]:
USE_RAG = False
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['mmlu']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
            
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing mmlu ...
Processing mmlu ... 


100%|██████████| 1089/1089 [00:01<00:00, 768.31it/s]
100%|██████████| 100/100 [00:32<00:00,  3.12it/s]

Total correct:   59
Total incorrect: 41
Accuracy:        59.0%


Incorrect Answer Distribution: {'D': 11, 'A': 12, 'C': 11, 'B': 7}
Actual Answer Distribution: {'D': 16, 'A': 5, 'C': 4, 'B': 16}





In [5]:
USE_RAG = True
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['pubmedqa']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
                
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing pubmedqa ...
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1268.37it/s]
100%|██████████| 100/100 [00:31<00:00,  3.14it/s]

Total correct:   43
Total incorrect: 57
Accuracy:        43.0%


Incorrect Answer Distribution: {'C': 32, 'B': 25}
Actual Answer Distribution: {'A': 57}





In [6]:
USE_RAG = True
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['bioasq']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
                
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing bioasq ...
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 716.66it/s] 
100%|██████████| 100/100 [00:29<00:00,  3.42it/s]

Total correct:   62
Total incorrect: 38
Accuracy:        62.0%


Incorrect Answer Distribution: {'A': 14, 'B': 24}
Actual Answer Distribution: {'A': 24, 'B': 14}





In [7]:
USE_RAG = True
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['mmlu']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
                
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing mmlu ...
Processing mmlu ... 


100%|██████████| 1089/1089 [00:01<00:00, 662.72it/s]
100%|██████████| 100/100 [00:27<00:00,  3.57it/s]

Total correct:   57
Total incorrect: 43
Accuracy:        57.0%


Incorrect Answer Distribution: {'D': 11, 'A': 12, 'C': 12, 'B': 8}
Actual Answer Distribution: {'D': 17, 'A': 4, 'C': 5, 'B': 17}





In [1]:
import torch
from pathlib import Path
from data_utils import load_json, xs_of_all_questions, merge_entity_mentions
from llm_utils_improvements import (
    MistralLLM,
    prompt_template_no_rag, prompt_template_with_axioms,
    chat_prompt_template_no_rag, chat_prompt_template_with_axioms,
    BaseEntitySelector,
    SubsumptionEntitySelector,
    ApproximateNearestNeighbourEntitySelector,
    SimilarityEntitySelector
)
from retrievers import HiTRetriever, SBERTRetriever
from math_functools import entity_subsumption, batch_poincare_dist_with_adaptive_curv_k, batch_cosine_similarity
from tqdm import tqdm

#### Config (quick A/B toggles)
USE_RAG = True
USE_CHAT_TEMPLATES = True          # << switch this to False to A/B against the old path
SHUFFLE_DISPLAY_ORDER = True       # shuffles how choices are displayed (letters stay tied to the same text)
FORCE_DTYPE_AUTO = True            # set True if bf16 isn’t supported

# Data
mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))

allowable_datasets = ['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']
mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets)
mirage_questions_with_entity_mentions = merge_entity_mentions(
    mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets
)

# Model
mistral_llm_version = "mistralai/Mistral-7B-Instruct-v0.1"
mistral_lm = MistralLLM(mistral_llm_version)

dtype = "auto" if FORCE_DTYPE_AUTO else torch.bfloat16
mistral_lm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype=dtype,
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_lm._tokenizer.pad_token_id,
    eos_token_id=mistral_lm._tokenizer.eos_token_id
)

# Register BOTH string and chat templates so we can A/B with a flag
mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag", prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag", prompt_template_with_axioms)
mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

print("Loaded!")
print(f"LLM: {mistral_llm_version} | RAG: {USE_RAG} | Chat templates: {USE_CHAT_TEMPLATES}")

# HiT retriever (for RAG)
embeddings_dir = "./embeddings"
common_map = Path("./embeddings/axiom-mappings.json")
common_verbalisations = Path("./embeddings/axiom-verbalisations.json")
hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

hit_ret_snomed_25_w_ent_sub = HiTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=entity_subsumption
)

hit_ret_snomed_25_w_hyp_dist = HiTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

sbert_ret_cosine_sim = SBERTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=batch_cosine_similarity
)

import re

snomed_concept_information: dict = load_json(Path("./data/snomed_axioms.json"))

COMMON_STOPWORDS = {
  "the","a","an","and","or","of","to","in","on","for","with","without","from",
  "by","at","as","is","are","was","were","be","been","being","that","which",
  "this","these","those","it","its","their","there","then","than","such"
}

SNOMED_CONCEPTS_NEAR_TOP = {
  "clinical finding","finding","disease","disorder","procedure","event", "body structure", "organism",
  "substance", "situation with explicit context","qualifier value","morphologic abnormality"
}

def lexical_repr_without_stopwords(s: str) -> set[str]:
  """set difference between the input string and 'common stopwords'"""
  return set(re.findall(r"[a-z]+", s.lower())) - COMMON_STOPWORDS

def lexical_overlap_score(input: str, reference_set: set[str]) -> int:
  """set (lexical) intersection between input_string and the reference vocab (typically, the question)"""
  return len(lexical_repr_without_stopwords(input) & reference_set)

def is_broad_concept(rdfs_label: str) -> bool:
  """basic lexical operation to check whether an rdfs_label is a high level SNOMED CT concept"""
  return (rdfs_label.split(" (", 1)[0]).lower() in SNOMED_CONCEPTS_NEAR_TOP

def select_axioms_for_prompt(axioms: list[str], question_text: str, *, min_overlap: int = 1) -> list[str]:
  """Return axioms that exceed the lexical overlap of `min_overlap` with the `question_text`"""
  clean_question_text = lexical_repr_without_stopwords(question_text)
  retained_axioms = []
  for this_axiom in axioms:
    lex_score = lexical_overlap_score(this_axiom, clean_question_text)
    if lex_score >= min_overlap:
      retained_axioms.append(this_axiom)
  return retained_axioms

def get_axiom_verbalisations(iris: list[str], *, max_verbalisations: int = 3, question_text: str, min_overlap: int = 1) -> list[str]:
  returnable_verbalisations = []
  for iri in iris: # skip any 'owl:Thing'(s)
    if iri == "owl:Thing": # Though these shouldn't be in IRI
      continue
    snomed_concept = snomed_concept_information[iri]
    label = snomed_concept['label'] # rdfs_label
    if is_broad_concept(label):
      continue
    verbalisations = snomed_concept['verbalization']
    subclass_axioms = verbalisations['subclass_of'][:max_verbalisations]
    equiv_axioms = verbalisations['equivalent_to'][:max_verbalisations]
    # this is pretty horrible, TODO: make this a little less, difficult to read ...
    complete_subclass_verbalisations = [f"{label} {verb_str}" for verb_str in subclass_axioms] if subclass_axioms is not [] else []
    complete_equiv_verbalisations = [f"{label} {verb_str}" for verb_str in equiv_axioms] if equiv_axioms is not [] else []
    returnable_verbalisations = [*complete_subclass_verbalisations, *complete_equiv_verbalisations]
  return select_axioms_for_prompt(
    returnable_verbalisations,
    question_text,
    min_overlap=min_overlap
  )

def per_question_seed(dataset: str, q: dict, idx: int) -> int:
    qid = q.get("id", idx)
    return (hash((dataset, str(qid))) & 0xFFFFFFFF)

def count_predictions(incorrect_questions):
  selected_options = set()
  predictions = []
  for response in incorrect_questions:
    selected_options.add(response['pred'])
    predictions.append(response['pred'])
  counts = {}
  for selected_option in selected_options:
    counts[selected_option] = 0
  for prediction in predictions:
    counts[prediction] += 1
  return counts

def count_actual_answers(incorrect_questions):
  answer_set = set()
  answers = []
  for response in incorrect_questions:
    answer_set.add(response['question']['answer'])
    answers.append(response['question']['answer'])
  counts = {}
  for possible_answer in answer_set:
    counts[possible_answer] = 0
  for answer in answers:
    counts[answer] += 1
  return counts

Processing medqa ... 


100%|██████████| 1273/1273 [00:00<00:00, 17377.19it/s]


Processing medmcqa ... 


100%|██████████| 4183/4183 [00:02<00:00, 2044.08it/s]


Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1534.70it/s]


Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 1380.26it/s]


Processing mmlu ... 


100%|██████████| 1089/1089 [00:00<00:00, 1251.73it/s]


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded!
LLM: mistralai/Mistral-7B-Instruct-v0.1 | RAG: True | Chat templates: True


In [2]:
USE_RAG = False
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['pubmedqa']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
                
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing pubmedqa ...
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1496.63it/s]
100%|██████████| 100/100 [00:22<00:00,  4.49it/s]

Total correct:   79
Total incorrect: 21
Accuracy:        79.0%


Incorrect Answer Distribution: {'B': 19, 'C': 2}
Actual Answer Distribution: {'A': 21}





In [3]:
USE_RAG = False
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['bioasq']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
                
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing bioasq ...
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 1111.87it/s]
100%|██████████| 100/100 [00:27<00:00,  3.58it/s]

Total correct:   62
Total incorrect: 38
Accuracy:        62.0%


Incorrect Answer Distribution: {'B': 12, 'A': 26}
Actual Answer Distribution: {'B': 26, 'A': 12}





In [4]:
USE_RAG = False
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['mmlu']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )

            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing mmlu ...
Processing mmlu ... 


100%|██████████| 1089/1089 [00:01<00:00, 758.64it/s]
100%|██████████| 100/100 [00:31<00:00,  3.17it/s]

Total correct:   50
Total incorrect: 50
Accuracy:        50.0%


Incorrect Answer Distribution: {'B': 13, 'A': 14, 'D': 18, 'C': 5}
Actual Answer Distribution: {'B': 20, 'C': 7, 'D': 14, 'A': 9}





In [5]:
USE_RAG = True
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['pubmedqa']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
                
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing pubmedqa ...
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1277.59it/s]
100%|██████████| 100/100 [00:27<00:00,  3.59it/s]

Total correct:   68
Total incorrect: 32
Accuracy:        68.0%


Incorrect Answer Distribution: {'B': 25, 'C': 7}
Actual Answer Distribution: {'A': 32}





In [6]:
USE_RAG = True
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['bioasq']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )
                
            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing bioasq ...
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 683.41it/s] 
100%|██████████| 100/100 [00:22<00:00,  4.54it/s]

Total correct:   60
Total incorrect: 40
Accuracy:        60.0%


Incorrect Answer Distribution: {'B': 17, 'A': 23}
Actual Answer Distribution: {'B': 23, 'A': 17}





In [7]:
USE_RAG = True
USE_CHAT_TEMPLATES = True     
SHUFFLE_DISPLAY_ORDER = True  
FORCE_DTYPE_AUTO = True
RET_K = 100
APPEND_K = 10
TOP_K = 1

# Main
datasets_within_mirage_to_process = ['mmlu']

for dataset_name in datasets_within_mirage_to_process:
    print(f"\nProcessing {dataset_name} ...")

    mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
    mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets=[dataset_name])
    biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
    head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))
    mirage_questions_with_entity_mentions = merge_entity_mentions(
        mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets=[dataset_name]
    )

    correct_questions = 0
    incorrect_questions = 0
    incorrect_question_response = []

    entity_selector = SimilarityEntitySelector(sbert_ret_cosine_sim)

    # choose template keys once
    if USE_CHAT_TEMPLATES:
        tmpl_key = "mirage_mcqa_axiom_rag_chat" if USE_RAG else "mirage_mcqa_no_rag_chat"
        generate_fn = mistral_lm.chat_generate_mc_letter
    else:
        tmpl_key = "mirage_mcqa_axiom_rag" if USE_RAG else "mirage_mcqa_no_rag"
        generate_fn = mistral_lm.generate_mc_letter
        

    for idx, q in enumerate(tqdm(mirage_questions_with_entity_mentions[125:225])):
        
        def _top_candidate_and_score(sel: BaseEntitySelector):
          if not getattr(sel, "_all_mention_results", None):
            return None
          return sel._all_mention_results[0]
        
        if USE_RAG:
            entity_selector.encode_and_rank_candidates(q.get('entities', []), ret_k=RET_K, append_k_per_entity=APPEND_K)
            top = _top_candidate_and_score(entity_selector)
            use_context = False
            iris = []
            if top:
              rank, iri, score, label = top
              if score >= 0.25:
                use_context = True
              if use_context:
                iris = [iri]
            axioms = []
            if iris:
                axioms = get_axiom_verbalisations(
                    iris,
                    max_verbalisations=5,
                    question_text=q.get("question",""),
                    min_overlap=2
                )

            q['axioms'] = axioms if axioms else []
        else:
            q['axioms'] = []

        # Deterministic display-order shuffle seed (does NOT change letters, only line order)
        shuffle_seed = per_question_seed(dataset_name, q, idx)
        q['shuffle'] = SHUFFLE_DISPLAY_ORDER
        q['seed'] = shuffle_seed

        # --- Generate ---
        # Both string and chat templates accept extra kwargs, so passing the whole question dict is fine.
        try:
            response_letter = generate_fn(tmpl_key, q, max_new_tokens=2)
        except Exception as e:
            # be defensive: treat parsing failures as wrong
            print(f"HIT EXCEPTION: {e}")
            response_letter = "?"
        
        # --- Score ---
        gold = q['answer']
        if response_letter == gold:
            correct_questions += 1
        else:
            incorrect_questions += 1
            incorrect_question_response.append({"question": q, "pred": response_letter})

    # total = len(mirage_questions_with_entity_mentions)
    total = 100
    acc = round((correct_questions / total) * 100, 2)
    print(f"Total correct:   {correct_questions}")
    print(f"Total incorrect: {incorrect_questions}")
    print(f"Accuracy:        {acc}%\n\n")
    print(f"Incorrect Answer Distribution: {count_predictions(incorrect_question_response)}")
    print(f"Actual Answer Distribution: {count_actual_answers(incorrect_question_response)}")


Processing mmlu ...
Processing mmlu ... 


100%|██████████| 1089/1089 [00:00<00:00, 1273.17it/s]
100%|██████████| 100/100 [00:28<00:00,  3.51it/s]

Total correct:   52
Total incorrect: 48
Accuracy:        52.0%


Incorrect Answer Distribution: {'B': 10, 'A': 11, 'D': 20, 'C': 7}
Actual Answer Distribution: {'B': 20, 'C': 6, 'D': 13, 'A': 9}





In [None]:
import torch
from pathlib import Path
from data_utils import load_json, xs_of_all_questions, merge_entity_mentions
from llm_utils_improvements import (
    MistralLLM,
    prompt_template_no_rag, prompt_template_with_axioms,
    chat_prompt_template_no_rag, chat_prompt_template_with_axioms,
    BaseEntitySelector,
    SubsumptionEntitySelector,
    ApproximateNearestNeighbourEntitySelector,
    SimilarityEntitySelector
)
from retrievers import HiTRetriever, SBERTRetriever
from math_functools import entity_subsumption, batch_poincare_dist_with_adaptive_curv_k, batch_cosine_similarity
from tqdm import tqdm

#### Config (quick A/B toggles)
USE_RAG = True
USE_CHAT_TEMPLATES = True          # << switch this to False to A/B against the old path
SHUFFLE_DISPLAY_ORDER = True       # shuffles how choices are displayed (letters stay tied to the same text)
FORCE_DTYPE_AUTO = True            # set True if bf16 isn’t supported

# Data
mirage_benchmark = load_json(Path("./data/MIRAGE/benchmark.json"))
biomedical_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"))
head_entity_mentions = load_json(Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json"))

allowable_datasets = ['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']
mirage_questions = xs_of_all_questions(mirage_benchmark, allowable_datasets)
mirage_questions_with_entity_mentions = merge_entity_mentions(
    mirage_benchmark, biomedical_entity_mentions, head_entity_mentions, allowable_datasets
)

# Model
mistral_llm_version = "mistralai/Mistral-7B-Instruct-v0.1"
mistral_lm = MistralLLM(mistral_llm_version)

dtype = "auto" if FORCE_DTYPE_AUTO else torch.bfloat16

mistral_lm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_lm._tokenizer.pad_token_id,
    eos_token_id=mistral_lm._tokenizer.eos_token_id
)

# Register BOTH string and chat templates so we can A/B with a flag
mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag", prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag", prompt_template_with_axioms)
mistral_lm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_lm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

print("Loaded!")
print(f"LLM: {mistral_llm_version} | RAG: {USE_RAG} | Chat templates: {USE_CHAT_TEMPLATES}")

# HiT retriever (for RAG)
embeddings_dir = "./embeddings"
common_map = Path("./embeddings/axiom-mappings.json")
common_verbalisations = Path("./embeddings/axiom-verbalisations.json")
hit_SNOMED25_model_path = Path('./models/HiT-mixed-SNOMED-25/final')

hit_ret_snomed_25_w_ent_sub = HiTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=entity_subsumption
)

hit_ret_snomed_25_w_hyp_dist = HiTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/hit-snomed-25-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=batch_poincare_dist_with_adaptive_curv_k
)

sbert_ret_cosine_sim = SBERTRetriever(
    embeddings_fp=Path(f"{embeddings_dir}/sbert-plm-embeddings.npy"),
    meta_map_fp=common_map,
    verbalisations_fp=common_verbalisations,
    model_fp=hit_SNOMED25_model_path,
    score_fn=batch_cosine_similarity
)

import re

snomed_concept_information: dict = load_json(Path("./data/snomed_axioms.json"))

COMMON_STOPWORDS = {
  "the","a","an","and","or","of","to","in","on","for","with","without","from",
  "by","at","as","is","are","was","were","be","been","being","that","which",
  "this","these","those","it","its","their","there","then","than","such"
}

SNOMED_CONCEPTS_NEAR_TOP = {
  "clinical finding","finding","disease","disorder","procedure","event", "body structure", "organism",
  "substance", "situation with explicit context","qualifier value","morphologic abnormality"
}

def lexical_repr_without_stopwords(s: str) -> set[str]:
  """set difference between the input string and 'common stopwords'"""
  return set(re.findall(r"[a-z]+", s.lower())) - COMMON_STOPWORDS

def lexical_overlap_score(input: str, reference_set: set[str]) -> int:
  """set (lexical) intersection between input_string and the reference vocab (typically, the question)"""
  return len(lexical_repr_without_stopwords(input) & reference_set)

def is_broad_concept(rdfs_label: str) -> bool:
  """basic lexical operation to check whether an rdfs_label is a high level SNOMED CT concept"""
  return (rdfs_label.split(" (", 1)[0]).lower() in SNOMED_CONCEPTS_NEAR_TOP

def select_axioms_for_prompt(axioms: list[str], question_text: str, *, min_overlap: int = 1) -> list[str]:
  """Return axioms that exceed the lexical overlap of `min_overlap` with the `question_text`"""
  clean_question_text = lexical_repr_without_stopwords(question_text)
  retained_axioms = []
  for this_axiom in axioms:
    lex_score = lexical_overlap_score(this_axiom, clean_question_text)
    if lex_score >= min_overlap:
      retained_axioms.append(this_axiom)
  return retained_axioms

def get_axiom_verbalisations(iris: list[str], *, max_verbalisations: int = 3, question_text: str, min_overlap: int = 1) -> list[str]:
  returnable_verbalisations = []
  for iri in iris: # skip any 'owl:Thing'(s)
    if iri == "owl:Thing": # Though these shouldn't be in IRI
      continue
    snomed_concept = snomed_concept_information[iri]
    label = snomed_concept['label'] # rdfs_label
    if is_broad_concept(label):
      continue
    verbalisations = snomed_concept['verbalization']
    subclass_axioms = verbalisations['subclass_of'][:max_verbalisations]
    equiv_axioms = verbalisations['equivalent_to'][:max_verbalisations]
    # this is pretty horrible, TODO: make this a little less, difficult to read ...
    complete_subclass_verbalisations = [f"{label} {verb_str}" for verb_str in subclass_axioms] if subclass_axioms is not [] else []
    complete_equiv_verbalisations = [f"{label} {verb_str}" for verb_str in equiv_axioms] if equiv_axioms is not [] else []
    returnable_verbalisations = [*complete_subclass_verbalisations, *complete_equiv_verbalisations]
  return select_axioms_for_prompt(
    returnable_verbalisations,
    question_text,
    min_overlap=min_overlap
  )

def per_question_seed(dataset: str, question_id: dict, idx: int) -> int:
    question_id = q.get("id", idx)
    return (hash((dataset, str(question_id))) & 0xFFFFFFFF)

def count_predictions(incorrect_questions):
  selected_options = set()
  predictions = []
  for response in incorrect_questions:
    selected_options.add(response['pred'])
    predictions.append(response['pred'])
  counts = {}
  for selected_option in selected_options:
    counts[selected_option] = 0
  for prediction in predictions:
    counts[prediction] += 1
  return counts

def count_actual_answers(incorrect_questions):
  answer_set = set()
  answers = []
  for response in incorrect_questions:
    answer_set.add(response['question']['answer'])
    answers.append(response['question']['answer'])
  counts = {}
  for possible_answer in answer_set:
    counts[possible_answer] = 0
  for answer in answers:
    counts[answer] += 1
  return counts

In [9]:
def count_predictions(incorrect_questions):
  selected_options = set()
  predictions = []
  for response in incorrect_questions:
    selected_options.add(response['predicition'])
    predictions.append(response['predicition'])
  counts = {}
  for selected_option in selected_options:
    counts[selected_option] = 0
  for prediction in predictions:
    counts[prediction] += 1
  return counts

def count_actual_answers(incorrect_questions):
  answer_set = set()
  answers = []
  for response in incorrect_questions:
    answer_set.add(response['question']['answer'])
    answers.append(response['question']['answer'])
  counts = {}
  for possible_answer in answer_set:
    counts[possible_answer] = 0
  for answer in answers:
    counts[answer] += 1
  return counts

In [1]:
import torch
from pathlib import Path
from data_utils import (
  load_json,
  merge_entity_mentions,
)
from llm_utils_improvements import (
  MistralLLM,
  BaseEntitySelector,
  SimilarityEntitySelector,
  ApproximateNearestNeighbourEntitySelector,
  SubsumptionEntitySelector,
  permute_mcqa_options,
  chat_prompt_template_no_rag,
  chat_prompt_template_with_axioms
)
from retrievers import (
  BaseRetriever,
  SBERTRetriever,
  HiTRetriever,
  OnTRetriever
)
from math_functools import (
  batch_cosine_similarity,
  batch_poincare_dist_with_adaptive_curv_k,
  entity_subsumption,
  concept_subsumption
)
from copy import copy, deepcopy
from tqdm import tqdm
import numpy as np
import random
import re

snomed_concept_information: dict = load_json(Path("./data/snomed_axioms.json"))

COMMON_STOPWORDS = {
  "the","a","an","and","or","of","to","in","on","for","with","without","from",
  "by","at","as","is","are","was","were","be","been","being","that","which",
  "this","these","those","it","its","their","there","then","than","such"
}

SNOMED_CONCEPTS_NEAR_TOP = {
  "clinical finding","finding","disease","disorder","procedure","event", "body structure", "organism",
  "substance", "situation with explicit context","qualifier value","morphologic abnormality"
}

def lexical_repr_without_stopwords(s: str) -> set[str]:
  """set difference between the input string and 'common stopwords'"""
  return set(re.findall(r"[a-z]+", s.lower())) - COMMON_STOPWORDS

def lexical_overlap_score(input: str, reference_set: set[str]) -> int:
  """set (lexical) intersection between input_string and the reference vocab (typically, the question)"""
  return len(lexical_repr_without_stopwords(input) & reference_set)

def is_broad_concept(rdfs_label: str) -> bool:
  """basic lexical operation to check whether an rdfs_label is a high level SNOMED CT concept"""
  return (rdfs_label.split(" (", 1)[0]).lower() in SNOMED_CONCEPTS_NEAR_TOP

def select_axioms_for_prompt(axioms: list[str], question_text: str, *, min_overlap: int = 1) -> list[str]:
  """Return axioms that exceed the lexical overlap of `min_overlap` with the `question_text`"""
  clean_question_text = lexical_repr_without_stopwords(question_text)
  retained_axioms = []
  for this_axiom in axioms:
    lex_score = lexical_overlap_score(this_axiom, clean_question_text)
    if lex_score >= min_overlap:
      retained_axioms.append(this_axiom)
  return retained_axioms

def get_axiom_verbalisations(iris: list[str], *, max_verbalisations: int = 3, question_text: str, min_overlap: int = 1) -> list[str]:
  """fetch axiom verbalisations from a file that has been pre-generated (takes too long at run-time)"""
  returnable_verbalisations = []
  for iri in iris: # skip any 'owl:Thing'(s)
    if iri == "owl:Thing": # Though these shouldn't be in IRI
      continue
    snomed_concept = snomed_concept_information[iri]
    label = snomed_concept['label'] # rdfs_label
    if is_broad_concept(label):
      continue
    verbalisations = snomed_concept['verbalization']
    subclass_axioms = verbalisations['subclass_of'][:max_verbalisations]
    equiv_axioms = verbalisations['equivalent_to'][:max_verbalisations]
    # this is pretty horrible, TODO: make this a little less, difficult to read ...
    complete_subclass_verbalisations = [f"{label} {verb_str}" for verb_str in subclass_axioms] if subclass_axioms is not [] else []
    complete_equiv_verbalisations = [f"{label} {verb_str}" for verb_str in equiv_axioms] if equiv_axioms is not [] else []
    returnable_verbalisations = [*complete_subclass_verbalisations, *complete_equiv_verbalisations]
  return select_axioms_for_prompt(
    returnable_verbalisations,
    question_text,
    min_overlap=min_overlap
  )

def question_seed(dataset: str, q: dict, idx: int) -> int:
    qid = q.get("id", idx)
    return (hash((dataset, str(qid))) & 0xFFFFFFFF)

def count_predictions(incorrect_questions):
  selected_options = set()
  predictions = []
  for response in incorrect_questions:
    selected_options.add(response['predicition'])
    predictions.append(response['predicition'])
  counts = {}
  for selected_option in selected_options:
    counts[selected_option] = 0
  for prediction in predictions:
    counts[prediction] += 1
  return counts

def count_actual_answers(incorrect_questions):
  answer_set = set()
  answers = []
  for response in incorrect_questions:
    answer_set.add(response['question']['answer'])
    answers.append(response['question']['answer'])
  counts = {}
  for possible_answer in answer_set:
    counts[possible_answer] = 0
  for answer in answers:
    counts[answer] += 1
  return counts


class QATestHarness:
    
  _use_rag: bool
  _shuffle_question_options: bool
  _permute_question_options: bool
  _retrieval_k: int
  _append_k: int
  _top_k: int
  _benchmark_data: dict
  _biomedical_entity_mentions: dict
  _head_entity_mentions: dict

  _allowable_datasets = ['medqa', 'medmcqa', 'pubmedqa', 'bioasq', 'mmlu']

  _retriever: BaseRetriever
  _entity_selector: BaseEntitySelector

  _correct_questions: list[dict]
  _incorrect_questions: list[dict]

  _shuffle_seed: int

  _llm: MistralLLM

  def __init__(self, benchmark_data_fp: Path, biomedical_mentions_fp: Path, head_mentions_fp: Path):
    self._benchmark_data = load_json(benchmark_data_fp)
    self._biomedical_entity_mentions = load_json(biomedical_mentions_fp)
    self._head_entity_mentions = load_json(head_mentions_fp)
    self._correct_questions = []
    self._incorrect_questions = []
  
  @classmethod
  def set_random_seed(cls, seed_value: int = 42):
    """apply a random seed value to python/np/torch random utils"""
    random.seed(seed_value) # python
    np.random.seed(seed_value) # numpy
    if torch.cuda.is_available():
      torch.cuda.manual_seed(seed_value) # single GPU, fails silently
      torch.cuda.manual_seed_all(seed_value) # multi GPU, fails silently
  
  def set_use_rag(self, use_rag: bool):
    """set to true if tests should run with prompt enrichment (fetching axiom verbalisations)"""
    self._use_rag = use_rag
    return self
  
  def register_retriever(self, retriever: BaseRetriever):
    """register a retriever to inject axiom verbalisations into the prompt/context during question answering (required `if use_rag == true`)"""
    if not self._use_rag:
      raise ValueError("You can only register a retriever when `use_rag` is set to True.")
    # else:
    self._retriever = retriever
    return self

  def register_entity_selector(self, selector: BaseEntitySelector):
    """register an entity selector cls to employ a selection criteria during entity and axiom retrieval (required `if use_rag == true`)"""
    if not self._use_rag:
      raise ValueError("You can only register an entity selector when `use_rag` is set to True.")
    # else:
    self._entity_selector = selector
    return self
  
  def register_llm(self, llm: MistralLLM):
    """register a specific LLM for when this harness is run"""
    self._llm = llm
    return self
  
  def set_retrieval_k(self, k: int):
    """the k_threhold for the retriever (cut-off applied during retrieval of each mention)"""
    self._retrieval_k = k
    return self
  
  def set_append_k(self, k: int):
    """the number of entities to select during retrieval (for each entity mention)"""
    self._append_k = k
    return self
  
  def set_top_k(self, k: int):
    """the number of entities to select from the entire pool of retrieved entities (ranked by score)"""
    self._top_k = k
    return self
  
  def set_shuffle_question_options(self, shuffle: bool):
    """set to true if the order of the (option, answer) should be randomised during question presentation"""
    self._shuffle_question_options = shuffle
    return self

  def set_permute_question_options(self, permute: bool):
    """set to true if the arrangement of option -> answer mappings should undergo permutation prior to question presentation"""
    self._permute_question_options = permute
    return self
  
  def run(self, dataset_name: str):
    """run the test harness across a specific dataset from the provided benchmark"""
    if dataset_name not in self._allowable_datasets:
      raise ValueError(f"You're trying to run the test harness agaisnt: {dataset_name}, which is not registered as an allowable dataset.")
    print(f"Using RAG: {self._use_rag}")
    questions_with_entity_mentions = merge_entity_mentions(
      self._benchmark_data, self._biomedical_entity_mentions, self._head_entity_mentions, allowable_datasets=[dataset_name]
    )
    # ensure questions are cleared
    self._correct_questions = []
    self._incorrect_questions = []
    print(f"Processing {dataset_name} ...")
   
    for question_index, question in enumerate(tqdm(questions_with_entity_mentions)):

      # options for shuffling the arrangement of the answers
      question['shuffle'] = self._shuffle_question_options
      question['seed'] = question_seed(dataset_name, question, question_index)
      
      question['axioms'] = []

      if self._permute_question_options: # remaps the option -> answer mapping (implemented after observed LM bias towards selecting `A`)
        permutated_options, new_answer_key = permute_mcqa_options(question['options'], question['answer'])
        question['original_options'] = copy(question['options'])
        question['original_answer'] = copy(question['answer'])
        del question['options']
        del question['answer']
        question['options'] = permutated_options
        question['answer'] = new_answer_key

      if self._use_rag:
        self._entity_selector.encode_and_rank_candidates(question['entities'], ret_k=self._retrieval_k, append_k_per_entity=self._append_k) # type: ignore
        top_candidate = self._entity_selector._all_mention_results[0] if len(self._entity_selector._all_mention_results) > 0 else []
        if top_candidate:
          rank, iri, score, label = top_candidate
          # filters out some noise \w lexical gating (verb_str \cap question_str) >= min_overlap
          # required for adding large numbers of verbalisations into the prompt context
          axioms = get_axiom_verbalisations(
            [iri],
            max_verbalisations=5,
            question_text=question['question'],
            min_overlap=2
          )
          question['axioms'] = axioms if axioms else []
        # end: if
      # end: if

      response_letter = self._llm.generate_mcqa_letter(
        "mirage_mcqa_axiom_rag_chat" if self._use_rag else "mirage_mcqa_no_rag_chat",
        question,
        max_new_tokens=2
      )

      question['predication'] = response_letter

      if response_letter == question['answer']:
        self._correct_questions.append(question)
      else:
        self._incorrect_questions.append(question)

    total = len(questions_with_entity_mentions)
    acc = round((len(self._correct_questions) / total) * 100, 2)
    print(f"Total correct:   {len(self._correct_questions)}")
    print(f"Total incorrect: {len(self._incorrect_questions)}")
    print(f"Accuracy:        {acc}%\n\n")

  def run_multiple(self, datasets: list[str]):
    """run tests for a subset of the datasets associated with the provided benchmark"""
    for dataset in datasets:
      self.run(dataset)
      

In [2]:
# ------------------------------------
# LLM options:
# ------------------------------------
# "mistralai/Mistral-7B-Instruct-v0.1"
# "mistralai/Mistral-7B-Instruct-v0.3"
# "BioMistral/BioMistral-7B"
# ------------------------------------

LLM_MODEL_ID = "BioMistral/BioMistral-7B"
SEED = 42

# instanciate a retriever
sbert_ret = SBERTRetriever(
  embeddings_fp=Path(f"./embeddings/sbert-plm-embeddings.npy"),
  meta_map_fp=Path("./embeddings/axiom-mappings.json"),
  verbalisations_fp=Path("./embeddings/axiom-verbalisations.json"),
  model_str="all-MiniLM-L12-v2",
  score_fn=batch_cosine_similarity
)

# and an entity selector
sbert_entity_selector = SimilarityEntitySelector(sbert_ret)

# and an LLM
mistral_llm = MistralLLM(LLM_MODEL_ID)

mistral_llm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_llm._tokenizer.pad_token_id,
    eos_token_id=mistral_llm._tokenizer.eos_token_id
)
mistral_llm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_llm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

# ideally, we would load from config (TODO: load cfgNode \w yacs or hydra)
tests = QATestHarness(
  Path("./data/MIRAGE/benchmark.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json")
).set_shuffle_question_options(True).set_permute_question_options(
  True
).set_retrieval_k(100).set_append_k(10).set_top_k(1).set_use_rag(True).register_retriever(
  sbert_ret
).register_entity_selector(
  sbert_entity_selector
).register_llm(
  mistral_llm
)

# quick test for reproducability:

QATestHarness.set_random_seed(SEED)

tests.set_use_rag(False)
tests.run_multiple(['pubmedqa', 'bioasq', 'mmlu'])

QATestHarness.set_random_seed(SEED)

tests.set_use_rag(False)
tests.run_multiple(['pubmedqa', 'bioasq', 'mmlu'])

# quick test to compare results to:

QATestHarness.set_random_seed(SEED)

tests.set_use_rag(True)
tests.run_multiple(['pubmedqa', 'bioasq', 'mmlu'])

QATestHarness.set_random_seed(SEED)

tests.set_use_rag(True)
tests.run_multiple(['pubmedqa', 'bioasq', 'mmlu'])

Using RAG: False
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1541.08it/s]


Processing pubmedqa ...


100%|██████████| 500/500 [02:25<00:00,  3.43it/s]


Total correct:   229
Total incorrect: 271
Accuracy:        45.8%


Using RAG: False
Processing bioasq ... 


100%|██████████| 618/618 [00:01<00:00, 527.06it/s]


Processing bioasq ...


100%|██████████| 618/618 [03:16<00:00,  3.14it/s]


Total correct:   375
Total incorrect: 243
Accuracy:        60.68%


Using RAG: False
Processing mmlu ... 


100%|██████████| 1089/1089 [00:02<00:00, 515.23it/s]


Processing mmlu ...


100%|██████████| 1089/1089 [05:38<00:00,  3.22it/s]


Total correct:   540
Total incorrect: 549
Accuracy:        49.59%


Using RAG: False
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 618.95it/s] 


Processing pubmedqa ...


100%|██████████| 500/500 [02:17<00:00,  3.64it/s]


Total correct:   221
Total incorrect: 279
Accuracy:        44.2%


Using RAG: False
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 643.22it/s]


Processing bioasq ...


100%|██████████| 618/618 [03:01<00:00,  3.40it/s]


Total correct:   361
Total incorrect: 257
Accuracy:        58.41%


Using RAG: False
Processing mmlu ... 


100%|██████████| 1089/1089 [00:01<00:00, 601.30it/s]


Processing mmlu ...


100%|██████████| 1089/1089 [05:24<00:00,  3.36it/s]


Total correct:   540
Total incorrect: 549
Accuracy:        49.59%


Using RAG: True
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1564.68it/s]


Processing pubmedqa ...


100%|██████████| 500/500 [02:48<00:00,  2.97it/s]


Total correct:   230
Total incorrect: 270
Accuracy:        46.0%


Using RAG: True
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 1419.38it/s]


Processing bioasq ...


100%|██████████| 618/618 [03:33<00:00,  2.89it/s]


Total correct:   336
Total incorrect: 282
Accuracy:        54.37%


Using RAG: True
Processing mmlu ... 


100%|██████████| 1089/1089 [00:01<00:00, 590.33it/s]


Processing mmlu ...


100%|██████████| 1089/1089 [06:39<00:00,  2.72it/s]


Total correct:   537
Total incorrect: 552
Accuracy:        49.31%


Using RAG: True
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1595.54it/s]


Processing pubmedqa ...


100%|██████████| 500/500 [02:31<00:00,  3.30it/s]


Total correct:   231
Total incorrect: 269
Accuracy:        46.2%


Using RAG: True
Processing bioasq ... 


100%|██████████| 618/618 [00:00<00:00, 659.97it/s] 


Processing bioasq ...


100%|██████████| 618/618 [03:17<00:00,  3.13it/s]


Total correct:   352
Total incorrect: 266
Accuracy:        56.96%


Using RAG: True
Processing mmlu ... 


100%|██████████| 1089/1089 [00:01<00:00, 912.36it/s]


Processing mmlu ...


100%|██████████| 1089/1089 [07:08<00:00,  2.54it/s]

Total correct:   549
Total incorrect: 540
Accuracy:        50.41%







In [3]:
tests._correct_questions

[{'question': 'Which of the following best describes the structure that collects urine in the body?',
  'entities': [{'entity_literal': 'urine',
    'start_position': 66,
    'end_position': 71,
    'entity_type': 'ORGANISM_SUBSTANCE'},
   {'entity_literal': 'Which of the following best',
    'start_position': 0,
    'end_position': 27,
    'entity_type': 'HEAD'}],
  'shuffle': True,
  'seed': 1260436531,
  'axioms': [],
  'original_options': {'A': 'Bladder',
   'B': 'Kidney',
   'C': 'Ureter',
   'D': 'Urethra'},
  'original_answer': 'A',
  'predication': 'A',
  'options': {'A': 'Bladder', 'B': 'Urethra', 'C': 'Kidney', 'D': 'Ureter'},
  'answer': 'A'},
 {'question': 'Which of the following describes the cluster of blood capillaries found in each nephron in the kidney?',
  'entities': [{'entity_literal': 'blood capillaries',
    'start_position': 48,
    'end_position': 65,
    'entity_type': 'MULTI_TISSUE_STRUCTURE'},
   {'entity_literal': 'nephron',
    'start_position': 80,
    'en

In [None]:
# BASE (No RAG) # 1

# MedQA:    497  /  1273  =  0.3904
# MedMCQA:  1582 /  4183  =  0.3782
# PubMedQA: 218  /  500   =  0.4360
# BioASQ:   336  /  618   =  0.5437
# MMLU:     560  /  1089  =  0.5142


# SBERT

# MedQA:    497 / 1273


In [None]:
# "mistralai/Mistral-7B-Instruct-v0.3"
# "BioMistral/BioMistral-7B"

LLM_MODEL_ID = "BioMistral/BioMistral-7B"

sbert_ret = SBERTRetriever(
  embeddings_fp=Path(f"./embeddings/sbert-plm-embeddings.npy"),
  meta_map_fp=Path("./embeddings/axiom-mappings.json"),
  verbalisations_fp=Path("./embeddings/axiom-verbalisations.json"),
  model_str="all-MiniLM-L12-v2",
  score_fn=batch_cosine_similarity
)
sbert_entity_selector = SimilarityEntitySelector(sbert_ret)
mistral_llm = MistralLLM(LLM_MODEL_ID)

mistral_llm.load_tokenizer(use_fast=True).load_model(
    device_map="auto",
    torch_dtype="auto",
    low_cpu_mem_usage=True
).register_generation_config(
    do_sample=False,
    num_beams=1,
    pad_token_id=mistral_llm._tokenizer.pad_token_id,
    eos_token_id=mistral_llm._tokenizer.eos_token_id
)

mistral_llm.register_prompt_template_fn("mirage_mcqa_no_rag_chat", chat_prompt_template_no_rag)
mistral_llm.register_prompt_template_fn("mirage_mcqa_axiom_rag_chat", chat_prompt_template_with_axioms)

tests = QATestHarness(
  Path("./data/MIRAGE/benchmark.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-BIOMED-bionlp13cg.json"), 
  Path("./data/MIRAGE/benchmark-questions-entities-HEAD.json")
).set_shuffle_question_options(True).set_permute_question_options(
  True
).set_retrieval_k(100).set_append_k(10).set_top_k(1).set_use_rag(True).register_retriever(
  sbert_ret
).register_entity_selector(
  sbert_entity_selector
).register_llm(
  mistral_llm
)

SEED = 42

QATestHarness.set_random_seed(SEED)
tests.run("pubmedqa")

Using RAG: True
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1605.99it/s]


Processing pubmedqa ...


100%|██████████| 500/500 [02:44<00:00,  3.05it/s]

Total correct:   221
Total incorrect: 279
Accuracy:        44.2%







In [21]:
tests.run("pubmedqa")

Using RAG: True
Processing pubmedqa ... 


100%|██████████| 500/500 [00:00<00:00, 1420.48it/s]


Processing pubmedqa ...


100%|██████████| 500/500 [02:11<00:00,  3.80it/s]

Total correct:   277
Total incorrect: 223
Accuracy:        55.4%







In [7]:
tests._incorrect_questions

[{'question': 'Transesophageal echocardiographic assessment of left ventricular function in brain-dead patients: are marginally acceptable hearts suitable for transplantation?',
  'options': {'A': 'yes', 'B': 'no', 'C': 'maybe'},
  'answer': 'A',
  'PMID': [8910148],
  'entities': [{'entity_literal': 'left ventricular',
    'start_position': 48,
    'end_position': 64,
    'entity_type': 'MULTI_TISSUE_STRUCTURE'},
   {'entity_literal': 'brain-dead patients',
    'start_position': 77,
    'end_position': 96,
    'entity_type': 'CANCER'},
   {'entity_literal': 'hearts',
    'start_position': 124,
    'end_position': 130,
    'entity_type': 'ORGANISM'},
   {'entity_literal': 'Transesophageal echocardiographic assessment of left ventricular function in brain-dead patients:',
    'start_position': 0,
    'end_position': 97,
    'entity_type': 'HEAD'}],
  'axioms': [],
  'shuffle': True,
  'seed': 42,
  'predication': 'C'},
 {'question': 'Does strategy training reduce age-related deficits in

In [14]:
def count_predictions(incorrect_questions):
  selected_options = set()
  predictions = []
  for response in incorrect_questions:
    selected_options.add(response['predication'])
    predictions.append(response['predication'])
  counts = {}
  for selected_option in selected_options:
    counts[selected_option] = 0
  for prediction in predictions:
    counts[prediction] += 1
  return counts

def count_actual_answers(incorrect_questions):
  answer_set = set()
  answers = []
  for response in incorrect_questions:
    answer_set.add(response['question']['answer'])
    answers.append(response['question']['answer'])
  counts = {}
  for possible_answer in answer_set:
    counts[possible_answer] = 0
  for answer in answers:
    counts[answer] += 1
  return counts