In [1]:
import torch
from pathlib import Path

from transformers import AutoModelForSequenceClassification

from data.q_and_a.train_and_eval import TrainAndEval
from data.q_and_a.eval_with_answers import EvalWithAnswers

from models_.building.llama_tokenizer import  load_tokenizer

from data.pubmed.from_json import FromJsonDataset
from data.pubmed.contents import ContentsDataset

from storage.faiss_ import FaissStorage

from rag.tokenization.llama import build_tokenizer_function
from rag.quering import build_querier
import os
from q_and_a.forward import build_enhanced_forwarder
from q_and_a.prompts import prompt
from q_and_a.picking.from_logits import build_from_logits
from q_and_a.eval import evaluate
from q_and_a.forward import build_forwarder

In [2]:
train = TrainAndEval("../../data/pubmed_QA_train.json")
evaluationData = TrainAndEval("../../data/pubmed_QA_eval.json")
evaluateWithAnswers = EvalWithAnswers(evaluationData)

augmented_data = FromJsonDataset(json_file="../../data/pubmed_500K.json")
augmented_data = ContentsDataset(augmented_data)

In [3]:
storage = FaissStorage(
    dimension=800,
)

storage.load("../../outputs/store/pubmed_500K.index")

tokenizer = load_tokenizer()
tokenizer_fn = build_tokenizer_function(tokenizer)

querier = build_querier(storage, augmented_data, tokenizer_fn)
storage = FaissStorage(
    dimension=800,
)

In [12]:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint_path = Path("/home/ubuntu/pytorch_training/10_rag/notebooks/train/checkpoints/checkpoint-500")

In [13]:
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_path,
    load_in_8bit=True,
    pad_token_id=tokenizer.pad_token_id,
    local_files_only=True,
    num_labels=4,
)

model

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-3.2-1B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LlamaForSequenceClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048, padding_idx=128001)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): lora.Linear8bitLt(
            (base_layer): Linear8bitLt(in_features=2048, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): Linear8bitLt(in_features=2048, out_features=512, bias=False)
          (v_proj): lora.Linear8bitLt(
            (b

In [14]:
model.eval()

LlamaForSequenceClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048, padding_idx=128001)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): lora.Linear8bitLt(
            (base_layer): Linear8bitLt(in_features=2048, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.1, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=8, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=8, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): Linear8bitLt(in_features=2048, out_features=512, bias=False)
          (v_proj): lora.Linear8bitLt(
            (b

In [18]:
forward = build_forwarder(
    model,
    tokenizer,
    querier,
    k_augmentations=1,
    prompt_builder=prompt,
    device=device,
)

forward_and_get_arg_max = lambda question, options: forward(
    question,
    options=options,
)

In [19]:
def pick_from_classifier(out):
    return torch.argmax(out.logits[0])

In [20]:
accuracy = evaluate(
    forward_fn=forward_and_get_arg_max,
    picker_fn=pick_from_classifier,
    eval_dataset=evaluateWithAnswers,
)

print(f"Accuracy: {accuracy:.2f}")

Right answer: 2, picked: 0
Accuracy at 0: 0.00
Right answer: 3, picked: 0
Right answer: 1, picked: 1
Right answer: 0, picked: 1
Right answer: 1, picked: 1
Right answer: 0, picked: 1
Right answer: 3, picked: 0
Right answer: 2, picked: 1
Right answer: 2, picked: 1
Right answer: 1, picked: 0
Right answer: 0, picked: 0
Accuracy at 100: 0.21
Right answer: 0, picked: 1
Right answer: 3, picked: 0
Right answer: 1, picked: 0
Right answer: 1, picked: 0
Right answer: 2, picked: 0
Right answer: 2, picked: 1
Right answer: 3, picked: 0
Right answer: 3, picked: 0
Right answer: 2, picked: 0
Right answer: 2, picked: 0
Accuracy at 200: 0.19
Right answer: 0, picked: 1
Right answer: 0, picked: 1
Right answer: 2, picked: 0
Right answer: 2, picked: 0
Right answer: 2, picked: 1
Right answer: 0, picked: 0
Right answer: 3, picked: 0
Right answer: 3, picked: 1
Right answer: 0, picked: 0
Right answer: 3, picked: 0
Accuracy at 300: 0.18
Right answer: 2, picked: 1
Right answer: 1, picked: 1
Right answer: 3, picked

KeyboardInterrupt: 