In [1]:
import transformers
import torch
from pathlib import Path
import srsly
import random
from tqdm.auto import tqdm

transformers.utils.logging.set_verbosity_error()

In [3]:
inputs = list(srsly.read_jsonl("testing-data/cri.jsonl"))

In [2]:
model = transformers.AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct').to('cuda')
tokenizer = transformers.AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct', model_max_length=1024)
model.eval();

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [36]:
def tokenize(chunk, question, answer):
    prompt = (
        "<|begin_of_text|><|start_header_id|>system<|end_header_id|>"
        "\nYou are an excellent student who has just read the following excerpt."
        " The teacher will ask you a question. You will answer accurately."
        f"\n\nExcerpt:\n\n{chunk}<|eot_id|>"
        "\n<|start_header_id|>user<|end_header_id|>"
        f"{question}<|eot_id|>"
        "\n<|start_header_id|>assistant<|end_header_id|>"
    )

    # Inputs
    # Construct inputs separately, so we know where the prompt ends.
    prompt_encoding = tokenizer(prompt, return_tensors="pt", return_length=True)
    answer_encoding = tokenizer(answer, return_tensors="pt")
    input_ids = torch.cat((prompt_encoding.input_ids, answer_encoding.input_ids), 1).to('cuda')

    # Targets
    target_ids = input_ids.clone()
    prompt_length = prompt_encoding.length.item()
    # Setting targets to -100 will ignore them when calculating loss
    # Do this for all tokens before the answer.
    target_ids[:, :prompt_length] = -100

    return input_ids, target_ids

In [None]:
def get_loss(input_ids, target_ids):
    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

    # loss is calculated using CrossEntropyLoss which averages over valid labels
    return outputs.loss

outputs = []
for cri_input in tqdm(inputs):
    input_ids, target_ids = tokenize(cri_input["Chunk"], cri_input["Question"], cri_input["answer"])
    loss = get_loss(input_ids, target_ids)
    # print(f'Score: {cri_input["score"]} | Loss: {loss.item()}')
    cri_input["Loss"] = loss.item()
    outputs.append(cri_input)

  0%|          | 0/2540 [00:00<?, ?it/s]