In [None]:
!pip install absl-py rouge-score nltk

In [None]:
!python -m nltk.downloader punkt

# Imports

In [None]:
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TrainingArguments, Trainer,
    DataCollatorForLanguageModeling
)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import torch
import wandb
import numpy as np
import torch

from evaluation_metrics import compute_metrics_for_pretrain

# Configs

In [None]:
model_path = "../models/phi_pubmed_pretrained_attempt_5/final"

data_path = "../data/pubmed_baseline/"
test_data_path = data_path + "pubmed_test.csv"

model_id = "microsoft/Phi-3.5-mini-instruct"

max_len = 360
batch_size = 8

# Dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

In [None]:
def tokenize_dataset(tokenizer, data_df):
    dataset = Dataset.from_pandas(data_df)
    def tokenize(example):
        text = f"<s>{example['title']}\n{example['abstract']}</s>"
        return tokenizer(text, truncation=True, padding="max_length", max_length=max_len, return_attention_mask=True)
    dataset = dataset.map(tokenize, batched=False)
    return dataset

In [None]:
test_df = pd.read_csv(test_data_path)

test_set = tokenize_dataset(tokenizer, test_df.iloc[:2000, :])

# Model

In [None]:
# === Quantized model loading ===
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

In [None]:
model = PeftModel.from_pretrained(model, model_path)
model.eval() 

### Base model

In [None]:
# ONLY FOR base model evaluation
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    trust_remote_code=True
)
model.eval()

# Test set evaluation

In [None]:
wandb.init(project="pubmed-pretrain-evaluation", name="attempt_5")

In [None]:
training_args = TrainingArguments(
    output_dir="./eval_output",
    per_device_eval_batch_size=batch_size,
    do_eval=True,
    report_to="none",
    eval_accumulation_steps=2,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    compute_metrics=compute_metrics_for_pretrain
)

In [None]:
eval_result = trainer.evaluate(test_set)

In [None]:
# Print all results
print("\nEvaluation Metrics:")
print(f"Loss: {eval_result['eval_loss']:.4f}")
print(f"Perplexity: {torch.exp(torch.tensor(eval_result['eval_loss'])):.2f}")
print(f"BLEU: {eval_result['eval_bleu']:.4f}")
print(f"ROUGE-1: {eval_result['eval_rouge1']:.4f}")
print(f"ROUGE-2: {eval_result['eval_rouge2']:.4f}")
print(f"ROUGE-L: {eval_result['eval_rougeL']:.4f}")

In [None]:
wandb.log({
    "eval_loss": eval_result['eval_loss'], 
    "perplexity": torch.exp(torch.tensor(eval_result['eval_loss'])),
    "BLUE": eval_result['eval_bleu'],
    "ROUGE_1": eval_result['eval_rouge1'],
    "ROUGE_2": eval_result['eval_rouge2'],
    "ROUGE_L": eval_result['eval_rougeL']    
})

# Inference

In [None]:
def generate(model, text, max_new_tokens=128):
    sample = tokenizer(text, truncation=True, padding=False, max_length=max_len, return_attention_mask=True)
    input_ids = torch.tensor([sample["input_ids"]]).to(model.device)
    attention_mask = torch.tensor([sample["attention_mask"]]).to(model.device)
    
    generated_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=False
    )
    
    generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    return generated_texts[0]

In [None]:
generated_text = generate(model, "The relationship between diabetes and blood pressure\n")
print(generated_text)

In [None]:
wandb.log({"example_1": generated_text})