# Imports

import pandas as pd
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    Trainer, TrainingArguments,
    DataCollatorForSeq2Seq,
    BitsAndBytesConfig
)
import wandb

# Configs

In [None]:
model_path = "../models/phi_pubmed_pretrained_attempt_1/final"

data_path = "../data/pubmed_baseline/"
test_data_path = data_path + "pubmed_test.csv"

max_length = 512
batch_size = 32

In [None]:
wandb.init(project="pubmed-pretrain-evaluation", name="attempt_2")

# Dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

In [None]:
def tokenize_dataset(tokenizer, data_df):
    dataset = Dataset.from_pandas(data_df)
    def tokenize(example):
        text = f"<s>#{example['title']}\n{example['abstract']}</s>"
        return tokenizer(text, truncation=True, padding="max_length", max_length=max_len, return_attention_mask=True)
    dataset = dataset.map(tokenize, batched=False)
    return dataset

In [None]:
test_df = pd.read_csv(test_data_path)

test_set = tokenize_dataset(tokenizer, test_df)

# Test set evaluation

In [None]:
training_args = TrainingArguments(
    output_dir="./eval_output",
    per_device_eval_batch_size=batch_size,
    do_eval=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True)
)

In [None]:
# === Evaluate perplexity ===
eval_result = trainer.evaluate(tokenized_test)
loss = eval_result["eval_loss"]
perplexity = torch.exp(torch.tensor(loss))

print(f"\n✅ Evaluation Metrics:")
print(f"Eval Loss     : {loss:.4f}")
print(f"Eval Perplexity: {perplexity:.2f}")

In [None]:
wandb.log({"eval_loss": loss, "eval_perplexity": perplexity.item()})

# Inference

In [None]:
samples = tokenized_test.select(range(5))  # First 5 examples
input_ids = torch.tensor(samples["input_ids"]).to(model.device)
attention_mask = torch.tensor(samples["attention_mask"]).to(model.device)

generated_ids = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    max_new_tokens=128,
    do_sample=False
)

generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

# Log predictions to W&B
wandb_table = wandb.Table(columns=["Title", "Actual Abstract", "Generated Text"])
for i, gen in enumerate(generated_texts):
    title = samples[i]["title"]
    actual = samples[i]["abstract"]
    print(f"\nTitle: {title}\n---\nActual Abstract: {actual}\n---\nGenerated: {gen}\n")
    if use_wandb:
        wandb_table.add_data(title, actual, gen)

if use_wandb:
    wandb.log({"generated_examples": wandb_table})