In [2]:
import torch

from transformers import AutoModelForCausalLM

from data.q_and_a.train_and_eval import TrainAndEval

from models_.building.pubmed_tokenizer import  load_query_tokenizer as load_tokenizer

from data.pubmed.from_json import FromJsonDataset
from data.pubmed.contents import ContentsDataset

from storage.faiss_ import FaissStorage

from rag.tokenization.llama import build_tokenizer_function
from rag.quering import build_querier

from q_and_a.forward import forward
from q_and_a.prompts import prompt

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Loading data: augmentation and question and answer

In [4]:
train = TrainAndEval("../../data/pubmed_QA_train.json")

augmented_data = FromJsonDataset(json_file="../../data/pubmed_500K.json")
augmented_data = ContentsDataset(augmented_data)

# Building the RAG system

In [5]:
storage = FaissStorage(
    dimension=800,
)

storage.load("../../outputs/store/pubmed_500K.index")

tokenizer = load_tokenizer()
tokenizer_fn = build_tokenizer_function(tokenizer)

querier = build_querier(storage, augmented_data, tokenizer_fn)

In [6]:
tokenizer

BertTokenizerFast(name_or_path='ncbi/MedCPT-Query-Encoder', vocab_size=30522, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

# Building question and answer system

In [7]:
MODEL_NAME = "meta-llama/Llama-3.2-1B"

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
model = model.to(device)

model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [8]:
question = train[0]["question"]
question

'What did the study reveal about the role of external calcium concentration in the action potential and contraction recovery time of crayfish muscle fibers?'

In [9]:
answer = train[0]["statement"]
answer

'The study investigated how changes in external calcium concentration affect the action potential and contraction recovery time in crayfish muscle fibers, revealing that calcium entry through TTS membranes is crucial for excitation-contraction coupling.'

In [10]:
options = train[0]["distractors"]
# append the answer to the options
options.append(answer)
options

['The study found that increasing external calcium concentration had no effect on the action potential or contraction recovery time in crayfish muscle fibers.',
 'The research concluded that magnesium ions play a more significant role than calcium in the action potential and contraction recovery of crayfish muscle fibers.',
 'The investigation revealed that external calcium concentration only affects the resting potential of crayfish muscle fibers, not the action potential or contraction recovery.',
 'The study investigated how changes in external calcium concentration affect the action potential and contraction recovery time in crayfish muscle fibers, revealing that calcium entry through TTS membranes is crucial for excitation-contraction coupling.']

In [11]:
augmented_data = querier(question, 5)
augmented_data

([171539365888.0,
  224323829760.0,
  226616164352.0,
  232985919488.0,
  234069327872.0],
 ["3,3'-Diiodothyronine production, a major pathway of peripheral iodothyronine metabolism in man. 3,3'-Diiodothyronine (3,3'-T(2)) has been detected in human serum and in thyroglobulin. However, no quantitative assessment of its clearance rate (CR), production rate (PR), or of the importance of extrathyroidal sources of 3,3'-T(2) relative to direct thyroidal secretion is yet available. This study examines these parameters in seven euthyroid subjects, and in eight athyreotic subjects (H) eumetabolic due to thyroxine therapy (HT(4)) (n = 5) or triiodothyronine replacement (HT(3)) (n = 3). A highly specific radioimmunoassay for the measurement of 3,3'-T(2) in whole serum was developed. Serum 3,3'-T(2) concentrations were (mean +/- SD) 6.0+/-1.0 ng/100 ml in 13 normal subjects, 9.0+/-4.6 ng/100 ml in 25 hyperthyroid patients, and 2.7+/-1.1 ng/100 ml in 17 hypothyroid patients. The values in each of 

In [12]:
result = forward(
    model,
    tokenizer,
    querier,
    k_augmentations=5,
    prompt_builder=prompt,
    question=question,
    options=options,
    device=device,
)

Augmented items: ["3,3'-Diiodothyronine production, a major pathway of peripheral iodothyronine metabolism in man. 3,3'-Diiodothyronine (3,3'-T(2)) has been detected in human serum and in thyroglobulin. However, no quantitative assessment of its clearance rate (CR), production rate (PR), or of the importance of extrathyroidal sources of 3,3'-T(2) relative to direct thyroidal secretion is yet available. This study examines these parameters in seven euthyroid subjects, and in eight athyreotic subjects (H) eumetabolic due to thyroxine therapy (HT(4)) (n = 5) or triiodothyronine replacement (HT(3)) (n = 3). A highly specific radioimmunoassay for the measurement of 3,3'-T(2) in whole serum was developed. Serum 3,3'-T(2) concentrations were (mean +/- SD) 6.0+/-1.0 ng/100 ml in 13 normal subjects, 9.0+/-4.6 ng/100 ml in 25 hyperthyroid patients, and 2.7+/-1.1 ng/100 ml in 17 hypothyroid patients. The values in each of the latter two groups were significantly different from normal. 3,3'-T(2) w

OutOfMemoryError: CUDA out of memory. Tried to allocate 18.00 MiB. GPU 0 has a total capacity of 5.65 GiB of which 14.00 MiB is free. Including non-PyTorch memory, this process has 5.61 GiB memory in use. Of the allocated memory 5.40 GiB is allocated by PyTorch, and 111.82 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [11]:
result

CausalLMOutputWithPast(loss=None, logits=tensor([[[ 5.0732,  5.0327,  7.6645,  ..., -4.8181, -4.8183, -4.8187],
         [12.0421, 13.8072, 10.4849,  ...,  1.7583,  1.7579,  1.7576],
         [12.6344, 15.7775, 12.3718,  ...,  1.2299,  1.2302,  1.2298],
         ...,
         [11.7180, 12.2432, 11.7466,  ...,  0.9857,  0.9847,  0.9842],
         [ 7.9333,  6.4972,  9.0693,  ..., -0.0789, -0.0798, -0.0802],
         [12.9428, 10.5113, 12.3660,  ...,  0.2332,  0.2325,  0.2319]]],
       grad_fn=<UnsafeViewBackward0>), past_key_values=<transformers.cache_utils.DynamicCache object at 0x7ff7283b3eb0>, hidden_states=None, attentions=None)