In [1]:
import torch

from transformers import AutoModelForCausalLM

from data.q_and_a.test_questions import TestQuestions
from data.q_and_a.eval_with_answers import EvalWithAnswers

from models_.building.llama_tokenizer import  load_tokenizer

from data.pubmed.from_json import FromJsonDataset
from data.pubmed.contents import ContentsDataset

from storage.faiss_ import FaissStorage

from rag.tokenization.llama import build_tokenizer_function
from rag.quering import build_querier

from q_and_a.forward import build_enhanced_forwarder
from q_and_a.prompts import prompt
from q_and_a.picking.from_logits import build_from_logits
from q_and_a.predict import predict

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Loading data: augmentation and question and answer

In [3]:
test_data = TestQuestions("../../data/pubmed_QA_test_questions.json")

augmented_data = FromJsonDataset(json_file="../../data/pubmed_500K.json")
augmented_data = ContentsDataset(augmented_data)

In [4]:
print(test_data[0])

{'id': 26, 'question': 'What were the findings regarding alpha-1-antitrypsin deficiency of genotype PiZ in an autopsy series of 238 individuals?', 'option': ['Alpha-1-antitrypsin deficiency of genotype PiZ was found in 30 cases, with no significant association with pulmonary emphysema.', 'In an autopsy series of 238 individuals, alpha-1-antitrypsin deficiency of genotype PiZ was identified in 15 cases, with a higher prevalence of pulmonary emphysema among heterozygous individuals.', 'In the autopsy series, alpha-1-antitrypsin deficiency of genotype PiZ was identified in 5 cases, all of whom were homozygous.', 'The study found that alpha-1-antitrypsin deficiency of genotype PiZ was present in 20 cases, with a higher prevalence of liver disease among heterozygous individuals.']}


# Building the RAG system

In [5]:
storage = FaissStorage(
    dimension=800,
)

storage.load("../../outputs/store/pubmed_500K.index")

tokenizer = load_tokenizer()
tokenizer_fn = build_tokenizer_function(tokenizer)

querier = build_querier(storage, augmented_data, tokenizer_fn)

In [6]:
tokenizer

PreTrainedTokenizerFast(name_or_path='meta-llama/Llama-3.2-1B', vocab_size=128000, model_max_length=131072, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|begin_of_text|>', 'eos_token': '<|end_of_text|>', 'pad_token': '<|end_of_text|>'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	128000: AddedToken("<|begin_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128001: AddedToken("<|end_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128002: AddedToken("<|reserved_special_token_0|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128003: AddedToken("<|reserved_special_token_1|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128004: AddedToken("<|finetune_right_pad_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128005: AddedToken("<|re

# Building question and answer system

In [7]:
MODEL_NAME = "meta-llama/Llama-3.2-1B"

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
model = model.to(device)

model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [8]:
from typing import Any

import torch

from typing import Callable, List, Tuple
import torch
import torch.nn.functional as F
ForwardType = Callable[[str, List[str]], Any]

def enhanced_forward(
        llm,
        tokenizer,
        augmenter: Callable[[str, int], Tuple[List[int], List[str]]],
        k_augmentations: int,
        prompt_builder: Callable[[str, List[str], List[str]], str],
        question: str,
        options: List[str],
        device: str,
        num_iterations: int = 3,
):
    """
    Performs multiple forward passes, appending the most probable token each time,
    and returns average probabilities across all tokens.

    Args:
        llm: The language model to use for generating the response.
        tokenizer: The tokenizer associated with the language model.
        augmenter (Callable): A function that takes a query string and returns the first k_augmentations in a tuple of
            distances and items.
        k_augmentations (int): The number of augmentations to generate.
        prompt_builder (Callable): A function that builds the prompt for the language model.
            It takes the augmented information, question and options, and returns a formatted string.
        question (str): The question to ask.
        options (List[str]): The list of options to choose from.
        device (str): The device to run the model on ('cpu' or 'cuda').
        num_iterations (int): Number of forward passes to perform (default: 3).

    Returns:
        Tuple containing:
            - List of generated tokens
            - Average probabilities tensor across all tokens
    """
    model.eval()
    # Get augmented items
    _, items = augmenter(question, k_augmentations)

    # Generate the initial prompt
    prompt = prompt_builder(question, options, items)
    input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True).to(device)

    generated_token_ids =  input_ids.clone()
    probss = []
    generated_tokens = []

    with torch.no_grad():
        for _ in range(num_iterations):
            # Prepare model inputs
            # For most autoregressive models, only input_ids are strictly necessary for the forward pass
            # if attention_mask is not explicitly handled or modified in the loop.
            # However, it's good practice to pass it if the model uses it.
            # As we append tokens, the attention_mask also needs to be extended.
            #attention_mask = torch.ones_like(generated_token_ids).to(device)

            # Get model outputs
            outputs = model(generated_token_ids,)

            # Get the logits for the last token in the sequence
            # outputs.logits is typically of shape (batch_size, sequence_length, vocab_size)
            next_token_logits = outputs.logits[:, -1, :] # Get logits for the very last token

            # Apply softmax to get probabilities (optional, as argmax works on logits directly)
            probs = torch.softmax(next_token_logits, dim=-1)
            probss.append(probs)
            # Get the predicted token ID (greedy decoding)
            next_token_id = torch.argmax(probs, dim=-1).unsqueeze(-1)

            # Append the predicted token ID to the generated sequence
            generated_token_ids = torch.cat([generated_token_ids, next_token_id], dim=1)

            # Check if the generated token is an end-of-sequence (EOS) token
            if tokenizer.eos_token_id is not None and next_token_id.item() == tokenizer.eos_token_id:
                break

            # Append the generated token ID to the list
            generated_tokens.append(next_token_id.item())

    # Concatenate all probabilities
    probs = torch.cat(probss, dim=0)
    # Average the probabilities across all iterations
    avg_probs = torch.mean(probs, dim=0)

    return generated_tokens, avg_probs

def build_enhanced_forwarder(
    llm,
    tokenizer,
    augmenter: Callable[[str, int], Tuple[List[int], List[str]]],
    k_augmentations: int,
    prompt_builder: Callable[[str, List[str], List[str]], str],
    num_iterations: int,
    device: str,
) -> Callable[[str, List[str]], Tuple[List[int], torch.Tensor]]:
    """
    Builds an enhanced forward function that can be used to generate responses from the language model.

    Returns:
        Callable: A function that takes a question and a list of options and returns the generated response.
    """
    def forward_fn(question: str, options: List[str]) -> Tuple[List[int], torch.Tensor]:
        return enhanced_forward(
            llm=llm,
            tokenizer=tokenizer,
            augmenter=augmenter,
            k_augmentations=k_augmentations,
            prompt_builder=prompt_builder,
            question=question,
            options=options,
            num_iterations=num_iterations,
            device=device,
        )

    return forward_fn

In [9]:
def prompt(
    question: str,
    options: List[str],
    augmented_items: List[str],
) -> str:
    context = "\n".join(augmented_items)

    options_str = "\n".join(
        [f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]
    )

    """
    Generates a prompt for the language model based on the question, options, and augmented items.

    Args:
        question (str): The question to ask.
        options (List[str]): The list of options to choose from.
        augmented_items (List[str]): The augmented items to include in the prompt.

    Returns:
        str: The formatted prompt string.
    """
    prompt = f"""You are an expert in multiple-choice questions. Your task is to select the best answer from the given options based on the provided context.
Context :  {context}

Question: {question}

Options:
{options_str}

Between A, B, C and D the best option is"""
    #print(prompt)
    return prompt

In [10]:
forward = build_enhanced_forwarder(
    model,
    tokenizer,
    querier,
    k_augmentations=1,
    prompt_builder=prompt,
    num_iterations=10,
    device=device,
)

def forward_and_get_last_logit(
    question,
    options,
):
    tokens, logits =  forward(
        question,
        options=options,
    )

    return logits

In [11]:
possible_answers = [" A", " B", " C", " D"]

In [12]:
picker = build_from_logits(
    tokenizer,
    options=possible_answers,
)

# Lets evaluate the model 🔥

In [13]:
responses = predict(
    forward_fn=forward_and_get_last_logit,
    picker_fn=picker,
    eval_dataset=test_data,
)

responses

Processed 10.0%
Processed 5.0%
Processed 3.3333333333333335%
Processed 2.5%
Processed 2.0%
Processed 1.6666666666666667%
Processed 1.4285714285714286%
Processed 1.25%
Processed 1.1111111111111112%


[(0, 3),
 (1, 0),
 (2, 2),
 (3, 3),
 (4, 3),
 (5, 3),
 (6, 2),
 (7, 2),
 (8, 3),
 (9, 0),
 (10, 3),
 (11, 2),
 (12, 2),
 (13, 3),
 (14, 2),
 (15, 3),
 (16, 3),
 (17, 3),
 (18, 3),
 (19, 2),
 (20, 1),
 (21, 3),
 (22, 2),
 (23, 3),
 (24, 3),
 (25, 3),
 (26, 0),
 (27, 3),
 (28, 2),
 (29, 3),
 (30, 3),
 (31, 2),
 (32, 3),
 (33, 2),
 (34, 2),
 (35, 0),
 (36, 3),
 (37, 0),
 (38, 2),
 (39, 3),
 (40, 3),
 (41, 0),
 (42, 0),
 (43, 3),
 (44, 0),
 (45, 0),
 (46, 2),
 (47, 3),
 (48, 2),
 (49, 3),
 (50, 3),
 (51, 3),
 (52, 3),
 (53, 0),
 (54, 3),
 (55, 2),
 (56, 3),
 (57, 2),
 (58, 0),
 (59, 2),
 (60, 0),
 (61, 0),
 (62, 3),
 (63, 0),
 (64, 0),
 (65, 2),
 (66, 2),
 (67, 0),
 (68, 0),
 (69, 0),
 (70, 0),
 (71, 3),
 (72, 2),
 (73, 2),
 (74, 3),
 (75, 0),
 (76, 3),
 (77, 0),
 (78, 3),
 (79, 2),
 (80, 3),
 (81, 3),
 (82, 3),
 (83, 3),
 (84, 2),
 (85, 3),
 (86, 3),
 (87, 0),
 (88, 3),
 (89, 0),
 (90, 2),
 (91, 3),
 (92, 3),
 (93, 3),
 (94, 3),
 (95, 2),
 (96, 3),
 (97, 0),
 (98, 3),
 (99, 0),
 (100, 0),

In [17]:
import pandas as pd

In [23]:
responses_with_ids = []

for i in range(len(responses)):
    responses_with_ids.append((test_data[i]["id"], responses[i][1]))

In [24]:
responses_with_ids

[(26, 3),
 (29, 0),
 (37, 2),
 (70, 3),
 (109, 3),
 (182, 3),
 (234, 2),
 (274, 2),
 (276, 3),
 (320, 0),
 (417, 3),
 (481, 2),
 (505, 2),
 (507, 3),
 (530, 2),
 (565, 3),
 (570, 3),
 (581, 3),
 (596, 3),
 (610, 2),
 (641, 1),
 (651, 3),
 (691, 2),
 (721, 3),
 (733, 3),
 (847, 3),
 (888, 0),
 (909, 3),
 (914, 2),
 (926, 3),
 (938, 3),
 (1007, 2),
 (1013, 3),
 (1014, 2),
 (1041, 2),
 (1050, 0),
 (1076, 3),
 (1077, 0),
 (1109, 2),
 (1126, 3),
 (1142, 3),
 (1158, 0),
 (1160, 0),
 (1248, 3),
 (1332, 0),
 (1333, 0),
 (1416, 2),
 (1483, 3),
 (1491, 2),
 (1588, 3),
 (1613, 3),
 (1668, 3),
 (1687, 3),
 (1713, 0),
 (1752, 3),
 (1769, 2),
 (1784, 3),
 (1811, 2),
 (1856, 0),
 (1887, 2),
 (2050, 0),
 (2094, 0),
 (2101, 3),
 (2118, 0),
 (2129, 0),
 (2185, 2),
 (2199, 2),
 (2232, 0),
 (2239, 0),
 (2243, 0),
 (2261, 0),
 (2267, 3),
 (2319, 2),
 (2320, 2),
 (2321, 3),
 (2364, 0),
 (2371, 3),
 (2443, 0),
 (2453, 3),
 (2454, 2),
 (2472, 3),
 (2497, 3),
 (2520, 3),
 (2586, 3),
 (2620, 2),
 (2668, 3),
 (2

In [25]:
dataset = pd.DataFrame(responses_with_ids, columns=["ID", "answer"])

In [26]:
dataset.to_csv("predictions.csv", index=False)

{'id': 26216,
 'question': 'What were the findings of the study on eighteen patients with congenital nystagmus regarding the types of waveforms and the functionality of the smooth pursuit and saccadic systems?',
 'option': ['The study found that all patients exhibited only jerk waveforms and that the smooth pursuit system was non-functional.',
  'A study of eighteen patients with congenital nystagmus revealed complex waveforms beyond jerk and pendular types, indicating that while the smooth pursuit system functions, the fast component of jerk nystagmus acts as a corrective saccadic movement, and patients can perform voluntary saccades normally.',
  'Patients with congenital nystagmus were unable to perform any voluntary saccades, regardless of the waveform type.',
  'The research concluded that the fast component of jerk nystagmus is unrelated to saccadic movements and that patients showed no complex waveforms.']}