In [1]:
from transformers import LlamaForCausalLM, LlamaTokenizer

device = "cpu"

tokenizer = LlamaTokenizer.from_pretrained("/data/imoradi/hugging13B")
model = LlamaForCausalLM.from_pretrained("/data/imoradi/hugging13B")

Setting ds_accelerator to cuda (auto detect)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [30]:
from datasets import load_dataset

text = "Ubi est Iraius?"
encodings = tokenizer(text, return_tensors="pt")

In [31]:
import torch
from tqdm import tqdm

max_length = 512
stride = 512
seq_len = encodings.input_ids.size(1)

nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = outputs.loss

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break

ppl = torch.exp(torch.stack(nlls).mean())


  0%|                                                     | 0/1 [00:03<?, ?it/s]


In [32]:
print(ppl)

tensor(1382.6110)


In [4]:
from evaluate import load
import torch
inputtext = ["Ubi est Iraius?", "Italia in Roma est.", "jakcnsadundnf", "fisharecool", "The cat is a feline", "how are you"]
perplexity = load("perplexity", module_type="metric")
results = perplexity.compute(predictions=inputtext, model_id='/data/imoradi/hugging13B', device="cpu")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Using pad_token, but it is not set yet.


  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
print(results)

{'perplexities': [1122.2489013671875, 1417.0660400390625, 2081.21875, 1153.887451171875, 36.140384674072266, 136.12982177734375], 'mean_perplexity': 991.1152248382568}
