In [None]:
import torch
from transformers import AutoTokenizer, TextIteratorStreamer
from threading import Thread
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import numpy as np
from omegaconf import OmegaConf

from llm_unlearning.models.models import load_model_and_tokenizer
from llm_unlearning.unlearning_datasets.tofu import TofuDataset
from llm_unlearning.evals.utils import probability, rouge_score, extract_question_tokens, extract_answer_tokens
from llm_unlearning.evals.tofu_evals import Probability, Rouge

def load_model_and_tokenizer_wrapper(model_path):
    print(f"Loading model from: {model_path}")
    config = OmegaConf.create({"path": model_path, "tokenizer_path": model_path, "fp16": True})
    model, tokenizer = load_model_and_tokenizer(config)
    model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
    return model, tokenizer

def load_tofu_dataset(tokenizer):
    config = OmegaConf.create({
        "split": "full",
        "max_length": 512,
        "question_key": "question",
        "answer_key": "answer",
        "question_start_tag": "Question: ",
        "question_end_tag": "\nAnswer: ",
        "answer_tag": ""
    })
    return TofuDataset(tokenizer, config)

def stream_generate_text(model, tokenizer, input_ids, attention_mask, max_new_tokens, temperature, top_p, top_k):
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=False, skip_special_tokens=True)

    generation_kwargs = dict(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        do_sample=True if temperature > 0.0 else False,
        pad_token_id=tokenizer.eos_token_id,
        streamer=streamer,
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    return streamer

def safe_rouge_eval(model, tokenizer, input_ids, attention_mask, labels, question_length):
    try:
        rouge_eval = Rouge(max_length=512)
        pad_token_id = tokenizer.pad_token_id

        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=512,
            pad_token_id=pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=True,
            do_sample=False,
        )

        extracted_outputs = extract_answer_tokens(outputs, question_length, pad_token_id)
        extracted_labels = extract_answer_tokens(labels, question_length, pad_token_id)

        decoded_outputs = tokenizer.batch_decode(extracted_outputs, skip_special_tokens=True)
        decoded_labels = tokenizer.batch_decode(extracted_labels, skip_special_tokens=True)

        # Strip away "Answer: " prefix
        decoded_outputs = [output[8:] if output.startswith("Answer: ") else output for output in decoded_outputs]
        decoded_labels = [label[8:] if label.startswith("Answer: ") else label for label in decoded_labels]

        rouge_score_value = rouge_score(decoded_outputs, decoded_labels, 'rougeL')

        # print("Decoded outputs:", decoded_outputs)
        # print("Decoded labels:", decoded_labels)

        # Ensure rouge_score_value is a list
        if not isinstance(rouge_score_value, list):
            rouge_score_value = [rouge_score_value]

        return torch.tensor(rouge_score_value, device=model.device)
    except Exception as e:
        print(f"Error in ROUGE evaluation: {str(e)}")
        return torch.tensor(0.0)

def interact_with_model(model, tokenizer, dataset):
    question_dropdown = widgets.Dropdown(
        options=[(item[dataset.config.question_key], i) for i, item in enumerate(dataset.data)],
        description="Question:"
    )
    max_new_tokens_slider = widgets.IntSlider(min=1, max=200, value=100, description="Max New Tokens:")
    temperature_slider = widgets.FloatSlider(min=0.0, max=10.0, value=1.0, description="Temperature:")
    top_p_slider = widgets.FloatSlider(min=0.0, max=1.0, value=1.0, description="Top-p:")
    top_k_slider = widgets.IntSlider(min=0, max=1000, value=1000, description="Top-k:")
    generate_button = widgets.Button(description="Generate")
    output_area = widgets.Output()

    def on_button_click(b):
        with output_area:
            clear_output(wait=True)
            question_idx = question_dropdown.index
            item = dataset[question_idx]

            item = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in item.items()}

            # Ensure all tensors have batch dimension
            for key in ['input_ids', 'attention_mask', 'labels']:
                if item[key].dim() == 1:
                    item[key] = item[key].unsqueeze(0)

            if isinstance(item['question_length'], torch.Tensor) and item['question_length'].dim() == 0:
                item['question_length'] = item['question_length'].unsqueeze(0)
            elif isinstance(item['question_length'], int):
                item['question_length'] = torch.tensor([item['question_length']], device=model.device)

            # Extract question tokens
            input_ids, attention_mask = extract_question_tokens(item, tokenizer.pad_token_id)
            question_length = item['question_length']

            print("Generating text...")
            streamer = stream_generate_text(model, tokenizer, input_ids, attention_mask,
                                            max_new_tokens_slider.value,
                                            temperature_slider.value,
                                            top_p_slider.value,
                                            top_k_slider.value)

            generated_text = ""
            for new_text in streamer:
                generated_text += new_text
                clear_output(wait=True)
                print(generated_text)

            original_text = tokenizer.decode(item['input_ids'][0], skip_special_tokens=True)
            question = original_text.split('\nAnswer:')[0].replace('Question: ', '')
            original_answer = original_text.split('\nAnswer:')[1].strip()
            generated_answer = generated_text.split('\nAnswer:')[1].strip() if '\nAnswer:' in generated_text else generated_text

            # Evaluation
            rouge_scores = safe_rouge_eval(model, tokenizer, input_ids, attention_mask, item['input_ids'], question_length)

            # print(f"\nQuestion:\n{question}\n")
            # print(f"Generated answer:\n{generated_answer}\n")
            print(f"Ground truth answer:\n{original_answer}\n")
            print(f"ROUGE-L Score: {rouge_scores.item():.4f}")


    generate_button.on_click(on_button_click)

    display(question_dropdown, max_new_tokens_slider, temperature_slider, top_p_slider, top_k_slider, generate_button, output_area)

In [None]:
print("Loading model and tokenizer...")
model, tokenizer = load_model_and_tokenizer_wrapper("/nfs/homedirs/gudm/development/new/results/finetune/retain10/checkpoint-60")

print("Loading TOFU dataset...")
dataset = load_tofu_dataset(tokenizer)

In [None]:
interact_with_model(model, tokenizer, dataset)