In [1]:
import torch

from transformers import AutoModelForCausalLM, BitsAndBytesConfig

from data.q_and_a.train_and_eval import TrainAndEval
from data.q_and_a.eval_with_answers import EvalWithAnswers

from models_.building.llama_tokenizer import  load_tokenizer

from data.pubmed.from_json import FromJsonDataset
from data.pubmed.contents import ContentsDataset

from storage.faiss_ import FaissStorage

from rag.tokenization.llama import build_tokenizer_function
from rag.quering import build_querier

from q_and_a.forward import build_enhanced_forwarder
from q_and_a.prompts import prompt
from q_and_a.picking.from_logits import build_from_logits
from q_and_a.eval import evaluate

In [2]:
import sys
sys.path.append("/home/ubuntu/pytorch_training/10_rag/src")

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

# Loading data: augmentation and question and answer

In [4]:
train = TrainAndEval("../../data/pubmed_QA_train.json")
evaluationData = TrainAndEval("../../data/pubmed_QA_eval.json")
evaluateWithAnswers = EvalWithAnswers(evaluationData)

augmented_data = FromJsonDataset(json_file="../../data/pubmed_500K.json")
augmented_data = ContentsDataset(augmented_data)

In [5]:
print(evaluateWithAnswers[0])

{'id': 'pubmed23n0012_5208', 'excerpt': 'Temporal changes in medial basal hypothalamic LH-RH correlated with plasma LH during the rat estrous cycle and following electrochemical stimulation of the medial preoptic area in pentobarbital-treated proestrous rats. In the present studies we have simultaneously measured changes in medial basal hypothalamic (MBH) leutenizing hormone-releasing hormone (LH-RH) and in plasma LH by radioimmunoassay in female rats at various hours during the 4-day estrous cycle and under experimental conditions known to alter pituitary LH secretion. In groups of rats decapitated at 12.00 h and 15.00 h on estrus and diestrus, plasma LH remained at basal levels (5-8 ng/ml) and MBH-LH-RH concentrations showed average steady state concentrations of 2231 +/- 205 pg/mg. On the day of proestrus hourly measurements of MBH-LH-RH between 12.00 h and 21.00 h suggested rhythmic rises and falls in the decapeptide concomitant with rises and falls in plasma LH. In a second group 

# Building the RAG system

In [6]:
storage = FaissStorage(
    dimension=800,
)

storage.load("../../outputs/store/pubmed_500K.index")

tokenizer = load_tokenizer()
tokenizer_fn = build_tokenizer_function(tokenizer)

querier = build_querier(storage, augmented_data, tokenizer_fn)

# Building question and answer system

In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig

# Load the PEFT config to find the base model
peft_model_path = "../train/trainer-combined-logits-3000/trainer"
peft_config = PeftConfig.from_pretrained(peft_model_path)

# Load the base model (this must match the model used during fine-tuning)
base_model = AutoModelForCausalLM.from_pretrained(
    peft_config.base_model_name_or_path,
    #load_in_8bit=True,
)

# Load the PEFT model
model = PeftModel.from_pretrained(base_model, "../train/trainer-combined-logits-full-hd/trainer")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("../train/trainer-combined-logits-full-hd/tokenizer")

In [8]:
base_model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=512, bias=False)
            (lora_dropout): ModuleDict(
              (default): D

In [9]:
question = train[0]["question"]
question

'What did the study reveal about the role of external calcium concentration in the action potential and contraction recovery time of crayfish muscle fibers?'

In [10]:
answer = train[0]["statement"]
answer

'The study investigated how changes in external calcium concentration affect the action potential and contraction recovery time in crayfish muscle fibers, revealing that calcium entry through TTS membranes is crucial for excitation-contraction coupling.'

In [11]:
options = train[0]["distractors"]
# append the answer to the options
options.append(answer)
options

['The study found that increasing external calcium concentration had no effect on the action potential or contraction recovery time in crayfish muscle fibers.',
 'The research concluded that magnesium ions play a more significant role than calcium in the action potential and contraction recovery of crayfish muscle fibers.',
 'The investigation revealed that external calcium concentration only affects the resting potential of crayfish muscle fibers, not the action potential or contraction recovery.',
 'The study investigated how changes in external calcium concentration affect the action potential and contraction recovery time in crayfish muscle fibers, revealing that calcium entry through TTS membranes is crucial for excitation-contraction coupling.']

In [12]:
augmented_data = querier(question, 5)
augmented_data

([171539398656.0,
  224323829760.0,
  226616180736.0,
  232985919488.0,
  234069327872.0],
 ["3,3'-Diiodothyronine production, a major pathway of peripheral iodothyronine metabolism in man. 3,3'-Diiodothyronine (3,3'-T(2)) has been detected in human serum and in thyroglobulin. However, no quantitative assessment of its clearance rate (CR), production rate (PR), or of the importance of extrathyroidal sources of 3,3'-T(2) relative to direct thyroidal secretion is yet available. This study examines these parameters in seven euthyroid subjects, and in eight athyreotic subjects (H) eumetabolic due to thyroxine therapy (HT(4)) (n = 5) or triiodothyronine replacement (HT(3)) (n = 3). A highly specific radioimmunoassay for the measurement of 3,3'-T(2) in whole serum was developed. Serum 3,3'-T(2) concentrations were (mean +/- SD) 6.0+/-1.0 ng/100 ml in 13 normal subjects, 9.0+/-4.6 ng/100 ml in 25 hyperthyroid patients, and 2.7+/-1.1 ng/100 ml in 17 hypothyroid patients. The values in each of 

In [13]:
from typing import Any

import torch

from typing import Callable, List, Tuple
import torch
import torch.nn.functional as F
ForwardType = Callable[[str, List[str]], Any]

def enhanced_forward(
        llm,
        tokenizer,
        augmenter: Callable[[str, int], Tuple[List[int], List[str]]],
        k_augmentations: int,
        prompt_builder: Callable[[str, List[str], List[str]], str],
        question: str,
        options: List[str],
        device: str,
        num_iterations: int = 1,
):
    """
    Performs multiple forward passes, appending the most probable token each time,
    and returns average probabilities across all tokens.

    Args:
        llm: The language model to use for generating the response.
        tokenizer: The tokenizer associated with the language model.
        augmenter (Callable): A function that takes a query string and returns the first k_augmentations in a tuple of
            distances and items.
        k_augmentations (int): The number of augmentations to generate.
        prompt_builder (Callable): A function that builds the prompt for the language model.
            It takes the augmented information, question and options, and returns a formatted string.
        question (str): The question to ask.
        options (List[str]): The list of options to choose from.
        device (str): The device to run the model on ('cpu' or 'cuda').
        num_iterations (int): Number of forward passes to perform (default: 3).

    Returns:
        Tuple containing:
            - List of generated tokens
            - Average probabilities tensor across all tokens
    """
    llm.eval()
    # Get augmented items
    _, items = augmenter(question, k_augmentations)

    # Generate the initial prompt
    prompt = prompt_builder(question, options, items)
    input_ids = tokenizer.encode(prompt, return_tensors="pt", truncation=True).to(device)

    generated_token_ids =  input_ids.clone()
    probss = []
    generated_tokens = []

    with torch.no_grad():
        for i in range(num_iterations):
            # Prepare model inputs
            # For most autoregressive models, only input_ids are strictly necessary for the forward pass
            # if attention_mask is not explicitly handled or modified in the loop.
            # However, it's good practice to pass it if the model uses it.
            # As we append tokens, the attention_mask also needs to be extended.
            #attention_mask = torch.ones_like(generated_token_ids).to(device)

            # Get model outputs
            with torch.no_grad():
                outputs = llm(generated_token_ids)

            # Get the logits for the last token in the sequence
            # outputs.logits is typically of shape (batch_size, sequence_length, vocab_size)
            next_token_logits = outputs.logits[:, -1, :] # Get logits for the very last token

            # Apply softmax to get probabilities (optional, as argmax works on logits directly)
            probs = torch.softmax(next_token_logits, dim=-1)
            probss.append(probs)
            # Get the predicted token ID (greedy decoding)
            next_token_id = torch.argmax(probs, dim=-1).unsqueeze(-1)

            # Append the predicted token ID to the generated sequence
            generated_token_ids = torch.cat([generated_token_ids, next_token_id], dim=1)

            # Check if the generated token is an end-of-sequence (EOS) token
            if tokenizer.eos_token_id is not None and next_token_id.item() == tokenizer.eos_token_id:
                break

            # Append the generated token ID to the list
            generated_tokens.append(next_token_id.item())

    # Concatenate all probabilities
    probs = torch.cat(probss, dim=0)
    # Average the probabilities across all iterations
    avg_probs = torch.mean(probs, dim=0)

    return generated_tokens, avg_probs

def build_enhanced_forwarder(
    llm,
    tokenizer,
    augmenter: Callable[[str, int], Tuple[List[int], List[str]]],
    k_augmentations: int,
    prompt_builder: Callable[[str, List[str], List[str]], str],
    num_iterations: int,
    device: str,
) -> Callable[[str, List[str]], Tuple[List[int], torch.Tensor]]:
    """
    Builds an enhanced forward function that can be used to generate responses from the language model.

    Returns:
        Callable: A function that takes a question and a list of options and returns the generated response.
    """
    def forward_fn(question: str, options: List[str]) -> Tuple[List[int], torch.Tensor]:
        return enhanced_forward(
            llm=llm,
            tokenizer=tokenizer,
            augmenter=augmenter,
            k_augmentations=k_augmentations,
            prompt_builder=prompt_builder,
            question=question,
            options=options,
            num_iterations=num_iterations,
            device=device,
        )

    return forward_fn

In [14]:
def prompt(
    question: str,
    options: List[str],
    augmented_items: List[str],
) -> str:
    context = "\n".join(augmented_items)

    options_str = "\n".join(
        [f"{chr(65 + i)}. {option}" for i, option in enumerate(options)]
    )

    """
    Generates a prompt for the language model based on the question, options, and augmented items.

    Args:
        question (str): The question to ask.
        options (List[str]): The list of options to choose from.
        augmented_items (List[str]): The augmented items to include in the prompt.

    Returns:
        str: The formatted prompt string.
    """
    prompt = f"""You are an expert in multiple-choice questions. Your task is to select the best answer from the given options based on the provided context.
Context: {context}

Question: {question}

Options:
{options_str}

Between A, B, C and D the best option is the letter"""
    return prompt

In [15]:
forward = build_enhanced_forwarder(
    base_model,
    tokenizer,
    querier,
    k_augmentations=10,
    prompt_builder=prompt,
    num_iterations=1,
    device=device,
)

def forward_and_get_last_logit(
    question,
    options,
):
    tokens, logits =  forward(
        question,
        options=options,
    )

    return logits

result = forward_and_get_last_logit(
    question,
    options=options,
)

In [16]:
result

tensor([9.8596e-09, 7.8229e-08, 9.1263e-10,  ..., 9.3180e-14, 9.3064e-14,
        9.3120e-14], device='cuda:0')

In [17]:
possible_answers = [" A", " B", " C", " D"]

In [18]:
picker = build_from_logits(
    tokenizer,
    options=possible_answers,
)

selected_option = picker(result)

Options: [362, 426, 356, 423]


In [19]:
selected_option

3

# Lets evaluate the model 🔥

In [20]:
accuracy = evaluate(
    forward_fn=forward_and_get_last_logit,
    picker_fn=picker,
    eval_dataset=evaluateWithAnswers,
)

print(f"Accuracy: {accuracy:.2f}")

Right answer: 0, picked: 0
Accuracy at 0: 1.00
Right answer: 3, picked: 3
Right answer: 1, picked: 1
Right answer: 1, picked: 1
Right answer: 0, picked: 0
Right answer: 3, picked: 3
Right answer: 0, picked: 0
Right answer: 1, picked: 1
Right answer: 1, picked: 1
Right answer: 0, picked: 0
Right answer: 0, picked: 0
Accuracy at 100: 0.94
Right answer: 0, picked: 0
Right answer: 3, picked: 3
Right answer: 3, picked: 3


KeyboardInterrupt: 

In [21]:
from q_and_a.predict import predict
from data.q_and_a.test_questions import TestQuestions

In [22]:
test_data = TestQuestions("../../data/pubmed_QA_test_questions.json")

In [23]:
responses = predict(
    forward_fn=forward_and_get_last_logit,
    picker_fn=picker,
    eval_dataset=test_data,
)

responses

Processed 0.1%
Processed 0.2%
Processed 0.3%
Processed 0.4%
Processed 0.5%
Processed 0.6%
Processed 0.7%
Processed 0.8%
Processed 0.9%


[(0, 1),
 (1, 3),
 (2, 2),
 (3, 1),
 (4, 2),
 (5, 1),
 (6, 2),
 (7, 2),
 (8, 1),
 (9, 0),
 (10, 3),
 (11, 3),
 (12, 2),
 (13, 3),
 (14, 3),
 (15, 1),
 (16, 3),
 (17, 2),
 (18, 0),
 (19, 2),
 (20, 1),
 (21, 3),
 (22, 2),
 (23, 1),
 (24, 1),
 (25, 1),
 (26, 3),
 (27, 0),
 (28, 2),
 (29, 2),
 (30, 0),
 (31, 3),
 (32, 3),
 (33, 2),
 (34, 2),
 (35, 0),
 (36, 1),
 (37, 0),
 (38, 2),
 (39, 3),
 (40, 0),
 (41, 0),
 (42, 0),
 (43, 1),
 (44, 0),
 (45, 0),
 (46, 2),
 (47, 1),
 (48, 0),
 (49, 1),
 (50, 3),
 (51, 3),
 (52, 0),
 (53, 1),
 (54, 3),
 (55, 2),
 (56, 3),
 (57, 3),
 (58, 1),
 (59, 2),
 (60, 0),
 (61, 0),
 (62, 3),
 (63, 1),
 (64, 2),
 (65, 2),
 (66, 1),
 (67, 0),
 (68, 0),
 (69, 0),
 (70, 1),
 (71, 2),
 (72, 3),
 (73, 2),
 (74, 3),
 (75, 3),
 (76, 0),
 (77, 1),
 (78, 2),
 (79, 2),
 (80, 3),
 (81, 0),
 (82, 0),
 (83, 3),
 (84, 1),
 (85, 3),
 (86, 0),
 (87, 0),
 (88, 1),
 (89, 0),
 (90, 2),
 (91, 1),
 (92, 0),
 (93, 3),
 (94, 0),
 (95, 2),
 (96, 3),
 (97, 0),
 (98, 3),
 (99, 0),
 (100, 0),

In [24]:
import pandas as pd
responses_with_ids = []

for i in range(len(responses)):
    responses_with_ids.append((test_data[i]["id"], responses[i][1]))

In [25]:
dataset = pd.DataFrame(responses_with_ids, columns=["ID", "answer"])
dataset.head()

Unnamed: 0,ID,answer
0,26,1
1,29,3
2,37,2
3,70,1
4,109,2


In [26]:
dataset.to_csv("predictions.csv", index=False)