In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import evaluate
from evaluate import load
from tqdm import tqdm
import numpy as np
from torch.nn import CrossEntropyLoss
from evaluate import logging

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

###change to your model and tokenizer
# bnb_config = BitsAndBytesConfig(
#         load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True
#     )

# model_path = "/storage/ice1/6/8/dk305/Unlearning/SimPO/models/alpaca-7b-reproduced/unlearned/V2_1GPU_simnpo_grad_diff_1e-05_forget05_epoch5_batch1_accum4_beta2.5_gamma0.0_grad_diff_coeff1.0_reffine_tuned_evalsteps_per_epoch_seed1001_1/checkpoint-2750"
# model_path = "/home/hice1/dk305/scratch/Unlearning/SimPO/models/alpaca-7b-reproduced/unlearned/V3_1GPU_simnpo_grad_diff_1e-05_forget05_epoch10_batch1_accum4_beta2.5_gamma0.0_grad_diff_coeff1.0_reffine_tuned_evalsteps_per_epoch_seed1001_1"
model_path = "/home/hice1/dk305/scratch/Unlearning/SimPO/models/alpaca-7b-reproduced/unlearned/V3_1GPU_simnpo_grad_diff_1e-05_forget05_epoch10_batch1_accum4_beta2.5_gamma0.0_grad_diff_coeff1.0_reffine_tuned_evalsteps_per_epoch_seed1001_1"

model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_path)
model.to(device)
model.eval()

Loading checkpoint shards: 100%|██████████| 3/3 [04:22<00:00, 87.43s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32001, 4096, padding_idx=32000)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
  

In [3]:
lambada_dataset = load_dataset("lambada")

Generating train split: 100%|██████████| 2662/2662 [00:03<00:00, 744.73 examples/s]
Generating test split: 100%|██████████| 5153/5153 [00:00<00:00, 468437.73 examples/s]
Generating validation split: 100%|██████████| 4869/4869 [00:00<00:00, 510003.40 examples/s]


In [4]:
# Tokenize the entire text from the test dataset
encodings = tokenizer("\n\n".join(lambada_dataset['test']['text']), return_tensors="pt")

max_length = model.config.max_position_embeddings
stride = 512  # Sliding window stride
seq_len = encodings.input_ids.size(1)  # Total sequence length

nlls = []  # List to store negative log likelihoods
prev_end_loc = 0  # Initialize previous end location

# Iterate over the dataset in chunks using a sliding window
for begin_loc in tqdm(range(0, seq_len, stride), desc="Processing Dataset"):

    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc  # Calculate target length for the chunk

    # Slice the input tensor for the current chunk
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    
    # Set the target labels to -100 for the tokens we don't want to calculate loss over
    target_ids[:, :-trg_len] = -100

    # Compute the loss using the model (no gradient calculation)
    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        neg_log_likelihood = outputs.loss  # This is the loss for the current chunk

    # Append the loss to the list of NLLs
    nlls.append(neg_log_likelihood.item())

    prev_end_loc = end_loc  # Update previous end location

    if end_loc == seq_len:
        break

Processing Dataset: 100%|█████████▉| 900/904 [14:41<00:03,  1.02it/s]


In [7]:
# Calculate perplexity
mean_nll = torch.tensor(nlls).mean()  # Mean negative log likelihood
ppl = torch.exp(mean_nll)  # Perplexity is the exponent of the mean NLL

# Print the calculated perplexity
print(f"Perplexity for LAMBADA dataset: {ppl:.4f}")

Perplexity for LAMBADA dataset: 15.2909


In [8]:
# Load WikiText-2 dataset
validation_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation")
test_dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 542075.23 examples/s]
Generating train split: 100%|██████████| 36718/36718 [00:00<00:00, 841101.11 examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 537035.45 examples/s]


In [9]:
encodings = tokenizer("\n\n".join(test_dataset["text"]), return_tensors="pt")
max_length = model.config.max_position_embeddings
stride = 512
seq_len = encodings.input_ids.size(1)

nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = outputs.loss

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break

ppl = torch.exp(torch.stack(nlls).mean())

 99%|█████████▉| 663/667 [10:47<00:03,  1.02it/s]


In [10]:
print(f"Perplexity on WikiText Test: {ppl:.4f}")

Perplexity on WikiText Test: 5.6689
