In [None]:
import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

from llm_unlearning.models.models import load_model_and_tokenizer

import llm_unlearning.unlearning_datasets.wmdp
from llm_unlearning.unlearning_datasets.wmdp import WikiTextDataset

model_path = "locuslab/tofu_ft_phi-1.5"

config = OmegaConf.create({"path": model_path, "tokenizer_path": "microsoft/phi-1_5", "fp16": True})
model, tokenizer = load_model_and_tokenizer(config)
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()

In [None]:
import importlib
importlib.reload(llm_unlearning.unlearning_datasets.wmdp)

num_samples = 10000

dummy_config = OmegaConf.create({
    "max_length": 512
})

dataset = WikiTextDataset(tokenizer, dummy_config, full_context_mode=True, num_samples=num_samples)

dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=WikiTextDataset.collate_fn)

def compute_loss_for_tokens(logits, labels, token_positions):
    losses = []
    for position in token_positions:
        loss = torch.nn.functional.cross_entropy(logits[:, position, :], labels[:, position])
        losses.append(loss.item())
    return losses

In [None]:
icl_scores = []
fiftyth_token_losses = []
five_hundredth_token_losses = []

with torch.no_grad():
    for i, batch in enumerate(tqdm(dataloader, total=num_samples)):
        if i >= num_samples:
            break

        inputs = {k: v.to(model.device) for k, v in batch.items()}

        outputs = model(**inputs)
        logits = outputs.logits
        labels = inputs['labels']

        losses = compute_loss_for_tokens(logits, labels, [49, 499])  # 0-based indexing

        icl_score = losses[1] - losses[0]  # 500th token loss - 50th token loss
        is_nan = icl_score != icl_score
        if is_nan: continue

        fiftyth_token_losses.append(losses[0])
        five_hundredth_token_losses.append(losses[1])

average_fifty = sum(fiftyth_token_losses) / len(fiftyth_token_losses)
icl_scores = [five_hundredth - average_fifty for five_hundredth in five_hundredth_token_losses]
average_icl_score = sum(icl_scores) / len(icl_scores)

print(f"Average In-Context Learning Score: {average_icl_score:.4f}")

plt.figure(figsize=(12, 6))
plt.hist(icl_scores, bins=30, edgecolor='black')
plt.title("Distribution of In-Context Learning Scores")
plt.xlabel("ICL Score")
plt.ylabel("Frequency")
plt.axvline(average_icl_score, color='r', linestyle='dashed', linewidth=2, label=f'Mean: {average_icl_score:.4f}')
plt.legend()
plt.show()