In [None]:
import torch
from transformers import AutoTokenizer, TextIteratorStreamer
from threading import Thread
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import numpy as np
from omegaconf import OmegaConf
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

from llm_unlearning.models.models import load_model_and_tokenizer
from llm_unlearning.unlearning_datasets.tofu import TofuDataset
from llm_unlearning.evals.utils import probability, rouge_score, extract_question_tokens, extract_answer_tokens
from llm_unlearning.evals.tofu_evals import Probability, Rouge

def load_model_and_tokenizer_wrapper(model_path):
    print(f"Loading model from: {model_path}")
    config = OmegaConf.create({"path": model_path, "tokenizer_path": "microsoft/phi-1_5", "fp16": True})
    model, tokenizer = load_model_and_tokenizer(config)
    model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
    return model, tokenizer

def load_tofu_dataset(tokenizer):
    config = OmegaConf.create({
        "split": "full",
        "max_length": 512,
        "question_key": "question",
        "answer_key": "answer",
        "question_start_tag": "Question: ",
        "question_end_tag": "\nAnswer: ",
        "answer_tag": ""
    })
    return TofuDataset(tokenizer, config)

def stream_generate_text(model, tokenizer, input_ids, attention_mask, max_new_tokens, temperature, top_p, top_k):
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)

    generation_kwargs = dict(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        do_sample=True if temperature > 0.0 else False,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        streamer=streamer,
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    return streamer

def compute_rouge_score(model, tokenizer, input_ids, attention_mask, labels, question_length):
    try:
        rouge_eval = Rouge(max_length=512)
        pad_token_id = tokenizer.pad_token_id

        extracted_labels = extract_answer_tokens(labels, question_length, pad_token_id)
        decoded_labels = tokenizer.batch_decode(extracted_labels, skip_special_tokens=True)
        decoded_labels = [label[8:] if label.startswith("Answer: ") else label for label in decoded_labels]

        extracted_inputs = extract_answer_tokens(input_ids, question_length, pad_token_id)
        decoded_inputs = tokenizer.batch_decode(extracted_inputs, skip_special_tokens=True)
        decoded_inputs = [input[8:] if input.startswith("Answer: ") else input for input in decoded_inputs]

        rouge_score_value = rouge_score(decoded_inputs, decoded_labels, 'rougeL')

        if not isinstance(rouge_score_value, list):
            rouge_score_value = [rouge_score_value]

        return torch.tensor(rouge_score_value, device=model.device)
    except Exception as e:
        print(f"Error in ROUGE evaluation: {str(e)}")
        return torch.tensor(0.0)

def generate_n_samples(dataset, indices, n, file_prefix, model, tokenizer, batch_size, temperature, top_p, top_k, output_area):
    with output_area:
        clear_output(wait=True)
        print(f"\n\nGenerating text for {file_prefix} {n} samples...")
        
        subset = Subset(dataset, indices)
        dataloader = DataLoader(subset, batch_size=batch_size, shuffle=False)
        
        all_results = []
        
        for batch in tqdm(dataloader, desc="Processing batches"):
            batch = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            
            input_ids, attention_mask = extract_question_tokens(batch, tokenizer.pad_token_id)
            question_length = batch['question_length']
            
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=512,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                use_cache=True,
                do_sample=True if temperature > 0.0 else False,
                temperature=temperature,
                top_p=top_p,
                top_k=top_k,
            )
            
            rouge_scores = compute_rouge_score(model, tokenizer, outputs, attention_mask, batch['input_ids'], question_length)
            
            for i in range(len(input_ids)):
                original_text = tokenizer.decode(batch['input_ids'][i], skip_special_tokens=True)
                question = original_text.split('\nAnswer:')[0].replace('Question: ', '')
                original_answer = original_text.split('\nAnswer:')[1].strip()
                generated_text = tokenizer.decode(outputs[i], skip_special_tokens=True)
                generated_answer = generated_text.split('\nAnswer:')[1].strip() if '\nAnswer:' in generated_text else generated_text
                
                all_results.append({
                    'question': question,
                    'generated_answer': generated_answer,
                    'original_answer': original_answer,
                    'rouge_score': rouge_scores[i].item()
                })
        
        with open(f"./{file_prefix}_samples.txt", "w") as f:
            for i, result in enumerate(all_results):
                f.write(f"Sample {i + 1}:\n")
                f.write(f"Question:\n{result['question']}\n\n")
                f.write(f"Generated answer:\n{result['generated_answer']}\n\n")
                f.write(f"Ground truth answer:\n{result['original_answer']}\n\n")
                f.write(f"ROUGE-L Score: {result['rouge_score']:.4f}\n\n")
                f.write("-" * 50 + "\n\n")
        
        average_rouge = sum(result['rouge_score'] for result in all_results) / len(all_results)
        print(f"\n\nAverage ROUGE-L Score for {file_prefix} {n}: {average_rouge:.4f}")
        print(f"Results written to ./{file_prefix}_samples.txt")

def interact_with_model(model, tokenizer, dataset):
    question_dropdown = widgets.Dropdown(
        options=[(item[dataset.config.question_key], i) for i, item in enumerate(dataset.data)],
        description="Question:"
    )
    max_new_tokens_slider = widgets.IntSlider(min=1, max=200, value=100, description="Max New Tokens:")
    temperature_slider = widgets.FloatSlider(min=0.0, max=10.0, value=1.0, description="Temperature:")
    top_p_slider = widgets.FloatSlider(min=0.0, max=1.0, value=1.0, description="Top-p:")
    top_k_slider = widgets.IntSlider(min=0, max=1000, value=1000, description="Top-k:")
    generate_button = widgets.Button(description="Generate")

    n_samples_slider = widgets.IntSlider(min=1, max=len(dataset), value=10, description="N Samples:")
    batch_size_slider = widgets.IntSlider(min=1, max=32, value=4, description="Batch Size:")
    generate_top_n_button = widgets.Button(description="Generate Top N")
    generate_bottom_n_button = widgets.Button(description="Generate Bottom N")

    output_area = widgets.Output()

    def on_button_click(b):
        with output_area:
            clear_output(wait=True)
            question_idx = question_dropdown.index
            item = dataset[question_idx]

            item = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in item.items()}

            for key in ['input_ids', 'attention_mask', 'labels']:
                if item[key].dim() == 1:
                    item[key] = item[key].unsqueeze(0)

            if isinstance(item['question_length'], torch.Tensor) and item['question_length'].dim() == 0:
                item['question_length'] = item['question_length'].unsqueeze(0)
            elif isinstance(item['question_length'], int):
                item['question_length'] = torch.tensor([item['question_length']], device=model.device)

            input_ids, attention_mask = extract_question_tokens(item, tokenizer.pad_token_id)
            question_length = item['question_length']

            print("Generating text...")
            streamer = stream_generate_text(model, tokenizer, input_ids, attention_mask,
                                            max_new_tokens_slider.value,
                                            temperature_slider.value,
                                            top_p_slider.value,
                                            top_k_slider.value)

            generated_text = ""
            for new_text in streamer:
                generated_text += new_text
                clear_output(wait=True)
                print(generated_text)

            original_text = tokenizer.decode(item['input_ids'][0], skip_special_tokens=True)
            question = original_text.split('\nAnswer:')[0].replace('Question: ', '')
            original_answer = original_text.split('\nAnswer:')[1].strip()
            generated_answer = generated_text.split('\nAnswer:')[1].strip() if '\nAnswer:' in generated_text else generated_text

            rouge_scores = compute_rouge_score(model, tokenizer, input_ids, attention_mask, item['input_ids'], question_length)

            print(f"Ground truth answer:\n{original_answer}\n")
            print(f"ROUGE-L Score: {rouge_scores.item():.4f}")

    def on_generate_top_n_click(b):
        indices = list(range(len(dataset)))
        top_n_indices = indices[:n_samples_slider.value]
        generate_n_samples(dataset, top_n_indices, n_samples_slider.value, "top", model, tokenizer, 
                           batch_size_slider.value, temperature_slider.value, top_p_slider.value, top_k_slider.value, output_area)

    def on_generate_bottom_n_click(b):
        indices = list(range(len(dataset)))
        bottom_n_indices = indices[-n_samples_slider.value:]
        generate_n_samples(dataset, bottom_n_indices, n_samples_slider.value, "bottom", model, tokenizer, 
                           batch_size_slider.value, temperature_slider.value, top_p_slider.value, top_k_slider.value, output_area)

    generate_button.on_click(on_button_click)
    generate_top_n_button.on_click(on_generate_top_n_click)
    generate_bottom_n_button.on_click(on_generate_bottom_n_click)

    display(question_dropdown, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider,
            generate_button, n_samples_slider, batch_size_slider, generate_top_n_button, generate_bottom_n_button, output_area)


In [None]:
print("Loading model and tokenizer...")
model, tokenizer = load_model_and_tokenizer_wrapper("/nfs/homedirs/gudm/development/new/results/finetune/retain10/checkpoint-60")

print("Loading TOFU dataset...")
dataset = load_tofu_dataset(tokenizer)

In [None]:
interact_with_model(model, tokenizer, dataset)