# Computing the perplexities over the test set

In [1]:
import os
from helper import init_ipynb
envfound = init_ipynb()

DIR = os.environ["DIR_PATH"] if envfound else None
DEVICE = os.environ["DEVICE"] if envfound else None
API_KEY = os.environ["API_KEY"] if envfound else None
PLATFORM = os.environ["OS_TYPE"] if envfound else None

if(PLATFORM == "Darwin"):
    os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import LlamaForCausalLM, AutoTokenizer
import torch
from torch.nn import CrossEntropyLoss
from datasets import load_from_disk
from evaluate import load
from torch.nn.functional import one_hot
from typing import Dict
from tqdm.notebook import tqdm



# ppl = load("perplexity", module_type="metric")
testds = load_from_disk("docs/pmc_patiens_fil_test.hf")

MODEL_MAX_LENGTH = 1024

def compute_loss(model: LlamaForCausalLM,
                 input_texts: Dict[str, str],
                 tokenizer: AutoTokenizer) -> torch.Tensor:
    """
        Override of loss computation.

        args : 
            - model (AutoModelForCausalLM) :
            - inputs :

        returns :
            loss value, torch tensor with grad_fn
    """
    VOCAB_SIZE = len(tokenizer)
    inputs = tokenizer(input_texts, max_length=MODEL_MAX_LENGTH, truncation=True, padding="max_length", return_tensors="pt").to(DEVICE) 
    predictions = model(**inputs).logits ## model run and extract logits
    loss = CrossEntropyLoss()(predictions.float(),
                              one_hot(
                                  inputs["input_ids"],
                                  num_classes=VOCAB_SIZE
                            ).float()
                    ).cpu() ## Loss computation, comparing the logits and the one hot distrib
    del input_texts
    del predictions
    return loss

In [3]:
def ppl(mod):
    model = LlamaForCausalLM.from_pretrained(mod).to(DEVICE)
    tokenizer = AutoTokenizer.from_pretrained(mod)
    tokenizer.pad_token = tokenizer.eos_token

    losses = []
    for sample in tqdm(testds["text"]):
        losses.append(
            compute_loss(model, [sample], tokenizer).detach()
        )

    del model
    del tokenizer
    return 2**(sum(losses) / len(losses)).cpu().item()

In [4]:
ppl_meditron_7b = ppl("epfl-llm/meditron-7b")

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  0%|          | 0/826 [00:00<?, ?it/s]

In [5]:
print("Loss for meditron 7b :", ppl_meditron_7b)

Loss for meditron 7b : 1.3506844489229484


In [6]:
ppl_ll3 = ppl("meta-llama/Meta-Llama-3-8B")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  0%|          | 0/826 [00:00<?, ?it/s]

In [15]:
print("Loss for LLaMA 3 :", ppl_ll3)

Loss for LLaMA 3 : 1.0753576460522802


In [None]:
ppl_epitron_M7B_PMCo_e1 = ppl("cryptoni/epitron_baseline_M7B_PMCo_e1")

In [11]:
print("Loss for EPITRON.M7B.PMCo.E1 :", ppl_epitron_M7B_PMCo_e1)

Loss for EPITRON.M7B.PMCo.E1 : 1.0552800471506047


In [11]:
ppl_epitron_M7B_PMCo_e5 = ppl("cryptoni/epitron_baseline_PMCo_M7B")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  0%|          | 0/826 [00:00<?, ?it/s]

In [12]:
print("Loss for EPITRON.M7B.PMCo.E5 :", ppl_epitron_M7B_PMCo_e5)

Loss for EPITRON.M7B.PMCo.E5 : 1.0552412232446187


In [7]:
ppl_epitron_LL3_8B_PMCo_e1 = ppl("cryptoni/epitron_LL3_8B_PMCo_e1")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  0%|          | 0/826 [00:00<?, ?it/s]

In [10]:
print("Loss for EPITRON.LL3.8B.PMCo.E1 :", ppl_epitron_LL3_8B_PMCo_e1)

Loss for EPITRON.LL3.8B.PMCo.E1 : 1.0159650442256611


In [13]:
ppl_epitron_M7B_PMCo_e3 = ppl("cryptoni/epitron_baseline_PMCo_M7B_e3")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  0%|          | 0/826 [00:00<?, ?it/s]

In [14]:
print("Loss for EPITRON.M7B.PMCo.E3 :", ppl_epitron_M7B_PMCo_e3)

Loss for EPITRON.M7B.PMCo.E3 : 1.0552418390534275


|model|ppl|dataset|
|:------:|:---------:|:-------:|
|Meditron-7B|1.35|pmc test|
|LLaMA-3-8B|1.075|pmc test|
|epitron.LL3.8B.PMCo.e1|1.01|pmc test|
|epitron.M7B.PMCo.e1|1.06|pmc test|
|epitron.M7B.PMCo.e3|1.055|pmc test|
|epitron.M7B.PMCo.e5|1.055|pmc test|
