In [1]:
import torch

from transformers import AutoModelForCausalLM

from data.q_and_a.train_and_eval import TrainAndEval
from data.q_and_a.eval_with_answers import EvalWithAnswers

from models_.building.pubmed_tokenizer import  load_query_tokenizer as load_tokenizer

from data.pubmed.from_json import FromJsonDataset
from data.pubmed.contents import ContentsDataset


from storage.faiss_ import FaissStorage

from rag.tokenization.llama import build_tokenizer_function
from rag.quering import build_querier

from q_and_a.forward import build_forwarder
from q_and_a.prompts import prompt
from q_and_a.picking.from_logits import build_from_logits
from q_and_a.eval import evaluate

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Loading data: augmentation and question and answer

In [3]:
train = TrainAndEval("../../data/pubmed_QA_train.json")
evaluationData = TrainAndEval("../../data/pubmed_QA_eval.json")
evaluateWithAnswers = EvalWithAnswers(evaluationData)

augmented_data = FromJsonDataset(json_file="../../data/pubmed_500K.json")
augmented_data = ContentsDataset(augmented_data)

In [4]:
print(evaluateWithAnswers[0])

{'id': 'pubmed23n0012_5208', 'excerpt': 'Temporal changes in medial basal hypothalamic LH-RH correlated with plasma LH during the rat estrous cycle and following electrochemical stimulation of the medial preoptic area in pentobarbital-treated proestrous rats. In the present studies we have simultaneously measured changes in medial basal hypothalamic (MBH) leutenizing hormone-releasing hormone (LH-RH) and in plasma LH by radioimmunoassay in female rats at various hours during the 4-day estrous cycle and under experimental conditions known to alter pituitary LH secretion. In groups of rats decapitated at 12.00 h and 15.00 h on estrus and diestrus, plasma LH remained at basal levels (5-8 ng/ml) and MBH-LH-RH concentrations showed average steady state concentrations of 2231 +/- 205 pg/mg. On the day of proestrus hourly measurements of MBH-LH-RH between 12.00 h and 21.00 h suggested rhythmic rises and falls in the decapeptide concomitant with rises and falls in plasma LH. In a second group 

# Building the RAG system

In [5]:
storage = FaissStorage(
    dimension=800,
)

storage.load("../../outputs/store/pubmed_500K.index")

tokenizer = load_tokenizer()
tokenizer_fn = build_tokenizer_function(tokenizer)

querier = build_querier(storage, augmented_data, tokenizer_fn)

In [6]:
tokenizer

BertTokenizerFast(name_or_path='ncbi/MedCPT-Query-Encoder', vocab_size=30522, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

# Building question and answer system

In [7]:
MODEL_NAME = "meta-llama/Llama-3.2-1B"

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
model = model.to(device)

model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [8]:
question = train[0]["question"]
question

'What did the study reveal about the role of external calcium concentration in the action potential and contraction recovery time of crayfish muscle fibers?'

In [9]:
answer = train[0]["statement"]
answer

'The study investigated how changes in external calcium concentration affect the action potential and contraction recovery time in crayfish muscle fibers, revealing that calcium entry through TTS membranes is crucial for excitation-contraction coupling.'

In [10]:
options = train[0]["distractors"]
# append the answer to the options
options.append(answer)
options

['The study found that increasing external calcium concentration had no effect on the action potential or contraction recovery time in crayfish muscle fibers.',
 'The research concluded that magnesium ions play a more significant role than calcium in the action potential and contraction recovery of crayfish muscle fibers.',
 'The investigation revealed that external calcium concentration only affects the resting potential of crayfish muscle fibers, not the action potential or contraction recovery.',
 'The study investigated how changes in external calcium concentration affect the action potential and contraction recovery time in crayfish muscle fibers, revealing that calcium entry through TTS membranes is crucial for excitation-contraction coupling.']

In [11]:
augmented_data = querier(question, 5)
augmented_data

([171539365888.0,
  224323747840.0,
  226616344576.0,
  232986050560.0,
  234069245952.0],
 ["3,3'-Diiodothyronine production, a major pathway of peripheral iodothyronine metabolism in man. 3,3'-Diiodothyronine (3,3'-T(2)) has been detected in human serum and in thyroglobulin. However, no quantitative assessment of its clearance rate (CR), production rate (PR), or of the importance of extrathyroidal sources of 3,3'-T(2) relative to direct thyroidal secretion is yet available. This study examines these parameters in seven euthyroid subjects, and in eight athyreotic subjects (H) eumetabolic due to thyroxine therapy (HT(4)) (n = 5) or triiodothyronine replacement (HT(3)) (n = 3). A highly specific radioimmunoassay for the measurement of 3,3'-T(2) in whole serum was developed. Serum 3,3'-T(2) concentrations were (mean +/- SD) 6.0+/-1.0 ng/100 ml in 13 normal subjects, 9.0+/-4.6 ng/100 ml in 25 hyperthyroid patients, and 2.7+/-1.1 ng/100 ml in 17 hypothyroid patients. The values in each of 

In [12]:
forward = build_forwarder(
    model,
    tokenizer,
    querier,
    k_augmentations=5,
    prompt_builder=prompt,
    device=device,
)

forward_and_get_last_logit = lambda question, options: forward(
    question,
    options=options,
).logits[0][-1]

result = forward_and_get_last_logit(
    question,
    options=options,
)

Prompt: You are an expert in multiple-choice questions. Your task is to select the best answer from the given options based on the provided context. 

Context:
3,3'-Diiodothyronine production, a major pathway of peripheral iodothyronine metabolism in man. 3,3'-Diiodothyronine (3,3'-T(2)) has been detected in human serum and in thyroglobulin. However, no quantitative assessment of its clearance rate (CR), production rate (PR), or of the importance of extrathyroidal sources of 3,3'-T(2) relative to direct thyroidal secretion is yet available. This study examines these parameters in seven euthyroid subjects, and in eight athyreotic subjects (H) eumetabolic due to thyroxine therapy (HT(4)) (n = 5) or triiodothyronine replacement (HT(3)) (n = 3). A highly specific radioimmunoassay for the measurement of 3,3'-T(2) in whole serum was developed. Serum 3,3'-T(2) concentrations were (mean +/- SD) 6.0+/-1.0 ng/100 ml in 13 normal subjects, 9.0+/-4.6 ng/100 ml in 25 hyperthyroid patients, and 2.7+

In [13]:
result

tensor([13.4062, 10.6328, 12.8828,  ..., -0.0641, -0.0647, -0.0652],
       device='cuda:0', dtype=torch.float16)

In [14]:
possible_answers = [" A", " B", " C", "D"]

In [15]:
picker = build_from_logits(
    tokenizer,
    options=possible_answers,
)

selected_option = picker(result)

In [16]:
selected_option

0

# Lets evaluate the model 🔥

In [17]:
accuracy = evaluate(
    forward_fn=forward_and_get_last_logit,
    picker_fn=picker,
    eval_dataset=evaluateWithAnswers,
)

print(f"Accuracy: {accuracy:.2f}")

Prompt: You are an expert in multiple-choice questions. Your task is to select the best answer from the given options based on the provided context. 

Context:
3,3'-Diiodothyronine production, a major pathway of peripheral iodothyronine metabolism in man. 3,3'-Diiodothyronine (3,3'-T(2)) has been detected in human serum and in thyroglobulin. However, no quantitative assessment of its clearance rate (CR), production rate (PR), or of the importance of extrathyroidal sources of 3,3'-T(2) relative to direct thyroidal secretion is yet available. This study examines these parameters in seven euthyroid subjects, and in eight athyreotic subjects (H) eumetabolic due to thyroxine therapy (HT(4)) (n = 5) or triiodothyronine replacement (HT(3)) (n = 3). A highly specific radioimmunoassay for the measurement of 3,3'-T(2) in whole serum was developed. Serum 3,3'-T(2) concentrations were (mean +/- SD) 6.0+/-1.0 ng/100 ml in 13 normal subjects, 9.0+/-4.6 ng/100 ml in 25 hyperthyroid patients, and 2.7+

KeyboardInterrupt: 