In [None]:
# Import necessary libraries
from openprompt.plms import T5TokenizerWrapper
from datasets import load_from_disk
from openprompt.pipeline_base import PromptDataLoader
from transformers import T5ForConditionalGeneration, T5Tokenizer
from openprompt.prompts import ManualTemplate, MixedTemplate
from openprompt import PromptForClassification
from openprompt.data_utils import FewShotSampler
from random import shuffle
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
import torch
from openprompt.prompts import ManualVerbalizer
from openprompt.data_utils import InputExample
from tqdm import tqdm
import json

dataset_path = "/lustre/work/client/users/minhos/cache/datasets/p3_sciq_multiple_choice"
raw_dataset = load_from_disk(dataset_path)

dataset = {'validation':[]}

raw_dataset['validation'] = raw_dataset['validation'].select(range(1000))
for idx, data in enumerate(raw_dataset['validation']):
    # Extract necessary fields
    question = data["inputs_pretokenized"]  # The question + context
    choices = data["answer_choices"]       # List of answer choices
    correct_answer = data["targets_pretokenized"].strip()  # Correct answer text
    
    # Find the index of the correct answer
    correct_index = choices.index(correct_answer) if correct_answer in choices else -1
    if correct_index == -1:
        print(f"Correct answer not found in choices for index {idx}")
        continue
    formatted_choices = ",".join(choices)
    # Create an InputExample
    input_example = InputExample(
        text_a=question,
        guid=idx,
        label=correct_index,
        meta={"choices": formatted_choices}
    )

    dataset['validation'].append(input_example)

print(dataset['validation'][0])
print(type(dataset['validation'][0]))

# Load the T5 model
t5_path = "/lustre/work/client/users/minhos/models_for_supercomputer/t5-base"
model = T5ForConditionalGeneration.from_pretrained(t5_path)
tokenizer = T5Tokenizer.from_pretrained(t5_path)


# Setup Template for the evaluation
template = ManualTemplate(
    tokenizer=tokenizer,
    text='{"placeholder":"text_a"}{"mask"}',
)

# Logging setup
log_file = "qa_manual_template_t5_diff_data.json"
results = []

# Iterate dataset by generating a new verbalizer for each dataset
for data in dataset['validation']:

    def format_labels(choices):
    # Split the string into a list, strip whitespace from each choice
        return [choice.strip() for choice in choices.split(",")]
            
        
    def create_dynamic_verbalizer(choices, tokenizer):
        formatted_labels = format_labels(choices)
        # Tokenize each label word
    
        return ManualVerbalizer(
            tokenizer=tokenizer,
            num_classes=len(format_labels(choices)),
            label_words= [[label] for label in formatted_labels]a

        )

    choices = data.meta["choices"]
    formatted_labels = format_labels(choices)

    verbalizer = create_dynamic_verbalizer(choices, tokenizer)
    prompt_model = PromptForClassification(
        plm=model,
        template=template,
        verbalizer=verbalizer,
        freeze_plm=True, 
    )
    validation_dataloader = PromptDataLoader(
        dataset=[data],
        template=template,
        tokenizer=tokenizer,
        tokenizer_wrapper_class=T5TokenizerWrapper,
        decoder_max_length=50,max_seq_length=480,
        batch_size=1,shuffle=False, teacher_forcing=False, predict_eos_token=False,
        truncate_method="tail",
    )

    # With this set up, the prediction becomes <pad> </s> <ukn> or ''  
    def evaluate_single_example(prompt_model, dataloader, tokenizer):
        prompt_model.eval()
        with torch.no_grad():
            for inputs in dataloader:
                logits = prompt_model(inputs)
                for idx, label in enumerate(verbalizer.label_words):
                   print(f"Logit for '{label}': {logits[:, idx].item()}")
                generated_ids = torch.argmax(logits, dim=-1)  # Get token IDs
                generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
                predicted_class = generated_ids.item()
                return generated_ids, generated_text, predicted_class, data.label
    
    def evaluate_single_example(prompt_model, dataloader, tokenizer):
        prompt_model.eval()
        with torch.no_grad():
            for inputs in dataloader:
                logits = prompt_model(inputs)
                for idx, label in enumerate(verbalizer.label_words):
                    print(f"Logit for '{label}': {logits[:, idx].item()}")
                aggregated_logits = [
                    torch.max(torch.tensor([logits[:, tokenizer.convert_tokens_to_ids(token)] for token in label]))
                    for label in verbalizer.label_words
                ]
                print(f"Logit for class ({verbalizer.label_words[0]}): {aggregated_logits[0]}")
                generated_ids = torch.argmax(torch.tensor(aggregated_logits), dim=-1)  # Get token IDs
                predicted_class = generated_ids.item()
                generated_text = " ".join(verbalizer.label_words[predicted_class])
                return generated_ids, generated_text, predicted_class, data.label

    generated_ids, generated_text, predicted_class, label = evaluate_single_example(
        prompt_model, validation_dataloader, tokenizer
    )
    correct = predicted_class == label
    results.append({
        "index": idx, 
        "generated_text": generated_text,
        "predicted_class": predicted_class, 
        "true_class": label, 
        "correct": correct
    })

    
    print(f"Example {idx + 1}/{len(dataset['validation'])}: Generated 'id:{generated_ids}, {generated_text}', Predicted {predicted_class}, True {label} - "
          f"{'Correct' if correct else 'Incorrect'}" )

# Compute overall accuracy
accuracy = sum(r["correct"] for r in results) / len(results)
print(f"Validation Accuracy: {accuracy:.4f}")

# Save results to JSON
with open(log_file, "w") as f:
    json.dump(results, f, indent=4)