## Import libraries

In [None]:
import transformers
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
import pandas as pd
import numpy as np
import json
from huggingface_hub import login
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm 

# Log in using your token
login("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")

models2name = {
               'llama1b': "meta-llama/Llama-3.2-1B-Instruct",
               'llama8b': "meta-llama/Llama-3.1-8B-Instruct", 
               'llama70b': "meta-llama/Llama-3.1-70B-Instruct",
               'romanAI': 'giulioderasmo/RomanAI'
            } 


## load test data

In [4]:
test_ds = pd.read_csv('./test_ds.csv')
test_ds.head()

Unnamed: 0,clean_text,label,#tokens
0,Art. 52 È accertato nella somma di lire 4.695....,normattiva_dump,360
1,Art. 21 ((PROVVEDIMENTO ABROGATO DAL D.LGS. 2 ...,normattiva_dump,35
2,Regione: Calabria Circoscrizione: Calabria |==...,normattiva_dump,335
3,SENATO DELLA REPUBBLICA Legislatura 18 Resocon...,sommcomm,6925
4,Art. 62 Livelli essenziali delle prestazioni 1...,normattiva_dump,254


## load model and tokenizer

In [10]:
## cache dir cosi teniamo tutto li!
READER_MODEL_NAME = models2name['llama8b']

load_in_4bit = None
load_in_8bit = True
if load_in_4bit:
    #  4-bit quantization
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
elif load_in_8bit:
    #  8-bit quantization
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
    )
else: 
    print('No quantization config provided, loading model in full precision')
    quantization_config = None

model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME, 
                                            quantization_config=quantization_config,
                                            cache_dir='/home/fselab3/Documents/giuder/cache_dir',
                                            device_map="auto")

model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME, 
                                          padding_side = 'right',)

Loading checkpoint shards: 100%|██████████| 4/4 [00:09<00:00,  2.46s/it]


In [13]:
tokenizer.pad_token_id

## perplexity func

In [20]:
def calculate_perplexity(model, tokenizer, test_dataset, prompt_length=20, max_seq_length=2048):
    model.eval()

    nlls = []
    # Process each sample individually instead of concatenating
    for i, sample in enumerate(tqdm(test_dataset["clean_text"])):
        # Tokenize sample and truncate to max_seq_length
        encoding = tokenizer(sample, return_tensors="pt", truncation=True, max_length=max_seq_length)
        input_ids = encoding.input_ids.to(model.device)

        # Skip samples that are too short
        if input_ids.size(1) <= prompt_length + 1:  # Need at least one token to predict
            continue

        # Split into prompt and target
        prompt_ids = input_ids[:, :prompt_length]
        target_ids = input_ids.clone()
        # Set tokens in the prompt section to -100 (ignore in loss calculation)
        target_ids[:, :prompt_length] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs.loss

        # Accumulate loss * number of valid tokens
        nlls.append(neg_log_likelihood)

    ppl = torch.exp(torch.stack(nlls).mean())

    return ppl


## try out

In [None]:
PPL = calculate_perplexity(model, tokenizer, test_ds.head(1000), prompt_length=20, max_seq_length=1024)

 43%|████▎     | 427/1000 [01:51<02:28,  3.86it/s]

In [None]:
print(f'PPL for the {READER_MODEL_NAME} is {PPL:.3f}')

PPL for the meta-llama/Llama-3.1-8B-Instruct is 4.141
