In [None]:
import torch
from transformers import AutoTokenizer, PreTrainedModel
from typing import List, Tuple, Dict
from tqdm import tqdm
import json
from llm_unlearning.models.models import load_model_and_tokenizer
from llm_unlearning.unlearning_datasets.tofu import TofuDataset
from omegaconf import OmegaConf
from llm_unlearning.evals.utils import extract_question_tokens, extract_answer_tokens

def load_models_and_tokenizer(target_path: str, reference_path: str) -> Tuple[PreTrainedModel, PreTrainedModel, AutoTokenizer]:
    print(f"Loading target model from: {target_path}")
    config_target = OmegaConf.create({"path": target_path, "tokenizer_path": "microsoft/phi-1_5", "fp16": True})
    target_model, tokenizer = load_model_and_tokenizer(config_target)
    target_model = target_model.to('cuda' if torch.cuda.is_available() else 'cpu')

    print(f"Loading reference model from: {reference_path}")
    config_reference = OmegaConf.create({"path": reference_path, "tokenizer_path": "microsoft/phi-1_5", "fp16": True})
    reference_model, _ = load_model_and_tokenizer(config_reference)
    reference_model = reference_model.to('cuda' if torch.cuda.is_available() else 'cpu')

    return target_model, reference_model, tokenizer

def load_tofu_dataset(tokenizer: AutoTokenizer) -> TofuDataset:
    config = OmegaConf.create({
        "split": "full",
        "max_length": 512,
        "question_key": "question",
        "answer_key": "answer",
        "question_start_tag": "Question: ",
        "question_end_tag": "\nAnswer: ",
        "answer_tag": ""
    })
    return TofuDataset(tokenizer, config)

def get_logits(logits: torch.Tensor, tokenizer: AutoTokenizer) -> List[Tuple[float, str]]:
    logit_map = {token: logit for logit, token in zip(logits, range(len(logits)))}
    return logit_map

def generate_and_compare(target_model: PreTrainedModel, reference_model: PreTrainedModel, 
                         tokenizer: AutoTokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor,
                         max_new_tokens: int = 50) -> Dict:
    device = next(target_model.parameters()).device
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    generated_tokens = []
    target_logits_history = []
    reference_logits_history = []

    for _ in tqdm(range(max_new_tokens)):
        with torch.no_grad():
            target_outputs = target_model(input_ids)
            reference_outputs = reference_model(input_ids)

        target_logits = target_outputs.logits[0, -1, :]
        reference_logits = reference_outputs.logits[0, -1, :]

        target_top_logits = get_logits(target_logits, tokenizer)
        reference_top_logits = get_logits(reference_logits, tokenizer)

        target_logits_history.append(target_top_logits)
        reference_logits_history.append(reference_top_logits)

        next_token = torch.argmax(target_logits).unsqueeze(0)
        generated_tokens.append(next_token.item())
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=device, dtype=torch.long)], dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)

    return {
        "generated_text": generated_text,
        "target_logits_history": target_logits_history,
        "reference_logits_history": reference_logits_history
    }


def evaluate_and_compare(target_model: PreTrainedModel, reference_model: PreTrainedModel, 
                         tokenizer: AutoTokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> Dict:
    device = next(target_model.parameters()).device
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    target_logits_history = []
    reference_logits_history = []

    with torch.no_grad():
        target_outputs = target_model(input_ids, attention_mask=attention_mask)
        reference_outputs = reference_model(input_ids, attention_mask=attention_mask)

    target_logits = target_outputs.logits
    reference_logits = reference_outputs.logits

    for i in range(input_ids.shape[1]):
        if input_ids[0, i] == tokenizer.pad_token_id:
            print("Found padding token at position", i)
            break
        target_top_logits = get_logits(target_logits[0, i, :], tokenizer)
        reference_top_logits = get_logits(reference_logits[0, i, :], tokenizer)

        target_logits_history.append(target_top_logits)
        reference_logits_history.append(reference_top_logits)


    return {
        "target_logits_history": target_logits_history,
        "reference_logits_history": reference_logits_history
    }

In [None]:
import torch
import torch.nn.functional as F


def get_top_n_different_tokens(target_logits_history, reference_logits_history, index, n=10, top_k=None, top_p=None):
    # Convert logit histories to tensors
    target_logits = torch.tensor(list(target_logits_history[index].values()))
    reference_logits = torch.tensor(list(reference_logits_history[index].values()))

    reference_probs = F.softmax(reference_logits, dim=-1)
    mask = torch.ones_like(reference_probs, dtype=torch.bool)

    if top_k is not None:
        top_k_indices = torch.topk(reference_probs, min(top_k, len(reference_probs))).indices
        mask.fill_(False)
        mask[top_k_indices] = True

    if top_p is not None:
        sorted_probs, sorted_indices = torch.sort(reference_probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
        sorted_indices_to_remove[0] = False
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        mask[indices_to_remove] = False

    masked_target_logits = target_logits[mask]
    masked_reference_logits = reference_logits[mask]

    logit_diff = torch.abs(masked_target_logits - masked_reference_logits)

    top_n_indices = torch.topk(logit_diff, min(n, logit_diff.size(0))).indices

    tokens = list(target_logits_history[index].keys())
    masked_tokens = [token for token, m in zip(tokens, mask) if m]

    result = [
        (masked_tokens[idx.item()], logit_diff[idx].item())
        for idx in top_n_indices
    ]

    return result


def get_top_n_tokens(logits_history, index, n=10):
    logits = torch.tensor(list(logits_history[index].values()))

    top_n_indices = torch.topk(logits, n).indices
    tokens = list(logits_history[index].keys())
    result = [
        (tokens[idx.item()], logits[idx].item())
        for idx in top_n_indices
    ]

    return result

In [None]:
target_model_path = "/nfs/homedirs/gudm/development/new/results/baseline/20240912_005849_npo_forget10_7/checkpoint-120"
reference_model_path = "locuslab/tofu_ft_phi-1.5"
# reference_model_path = "microsoft/phi-1_5"
reference_model_path= "/nfs/homedirs/gudm/development/new/results/finetune/retain90_10e/checkpoint-1120"

target_model, reference_model, tokenizer = load_models_and_tokenizer(target_model_path, reference_model_path)

In [None]:
dataset = load_tofu_dataset(tokenizer)

sample_idx = 3999
sample = dataset[sample_idx]

for key in sample:
    sample[key] = sample[key].unsqueeze(0)

question_ids, attention_mask = extract_question_tokens(sample, tokenizer.pad_token_id)
answer_ids = extract_answer_tokens(sample["input_ids"], sample["question_length"], tokenizer.pad_token_id)
original_text = tokenizer.decode(sample['input_ids'][0], skip_special_tokens=True)
question = original_text.split('\nAnswer:')[0].replace('Question: ', '')

# result = generate_and_compare(target_model, reference_model, tokenizer, question_ids, attention_mask)
# print(tokenizer.decode(sample['input_ids'][0], skip_special_tokens=True))
# print("Generated text: ", result['generated_text'])

result_evaluate = evaluate_and_compare(target_model, reference_model, tokenizer, sample["input_ids"], attention_mask)



In [None]:
length = len(result_evaluate["target_logits_history"])
for i in range(0, length - sample["question_length"].item()):
    index = i + sample["question_length"].item()
    result = result_evaluate

    print(f"\n{index}: {tokenizer.decode(sample['input_ids'][0, index].item(), skip_special_tokens=True)}[{tokenizer.decode(sample['input_ids'][0, index+1].item(), skip_special_tokens=True)}]")

    top_diff_tokens = get_top_n_different_tokens(result['target_logits_history'], result['reference_logits_history'], index=index, n=20, top_p=0.95)
    top_tokens_target = get_top_n_tokens(result['target_logits_history'], index=index, n=10)
    top_tokens_reference = get_top_n_tokens(result['reference_logits_history'], index=index, n=10)

    target_tokens = [f"{tokenizer.decode(token, skip_special_tokens=True)} ({logit:.2f})" for token, logit in top_tokens_target]
    reference_tokens = [f"{tokenizer.decode(token, skip_special_tokens=True)} ({logit:.2f})" for token, logit in top_tokens_reference]

    max_len = max(len(t) for t in target_tokens + reference_tokens)
    aligned_target = ', '.join(t.ljust(max_len) for t in target_tokens)
    aligned_reference = ', '.join(t.ljust(max_len) for t in reference_tokens)

    print(f"Top tokens in reference model: {aligned_reference}")
    print(f"Top tokens in target model:    {aligned_target}")
    print("Top different tokens:", ", ".join(f"{tokenizer.decode(token, skip_special_tokens=True)} ({diff:.4f} = {result['reference_logits_history'][index][token]:.4f} -> {result['target_logits_history'][index][token]:.4f})" for token, diff in top_diff_tokens))