# Imports

In [None]:
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TrainingArguments, Trainer,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding
)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import torch
import wandb
import evaluate  # Hugging Face's evaluate library
import numpy as np
import torch
from tqdm import tqdm
import re
from bert_score import BERTScorer, score as bert_score

from utils import tokenize_dataset_for_domain_bound_qna
from prompt_templates import qna_prompt_template as prompt_template
from generate import generate, stream_generate
from evaluation_metrics import compute_metrics_for_qna

# Configs

In [None]:
model_path = "../models/phi_domain_bound_qna_finetuned_attempt_10/final"

data_path = "../data/domain_bound_data/v7/"
test_data_path = data_path + "test.csv"

model_id = "microsoft/Phi-3.5-mini-instruct"
base_model_path = "../models/phi_qna_finetuned_attempt_5/final_merged"

max_len = 512
batch_size = 8

In [None]:
wandb.init(project="domain_bound_qna_finetune-evaluation", name="attempt_10")

# Dataset

In [None]:
print(prompt_template)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

In [None]:
test_df = pd.read_csv(test_data_path)
test_set = tokenize_dataset_for_domain_bound_qna(tokenizer, test_df, prompt_template, max_len)

# Model

In [None]:
base_model_path

In [None]:
# === Quantized model loading ===
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=False
)

In [None]:
model_path

In [None]:
model = PeftModel.from_pretrained(model, model_path)
model.eval() 

### Base model

In [None]:
model_id

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    trust_remote_code=True
)
model.eval()

In [None]:
# training_args = TrainingArguments(
#     output_dir="./eval_output_base",
#     per_device_eval_batch_size=batch_size,
#     do_eval=True,
#     report_to="none"
# )

# base_model_trainer = Trainer(
#     model=base_model,
#     args=training_args,
#     tokenizer=tokenizer,
#     data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
# )

# base_model_eval_result = base_model_trainer.evaluate(test_set)

# Test set evaluation

In [None]:
training_args = TrainingArguments(
    output_dir="./eval_output",
    per_device_eval_batch_size=batch_size,
    do_eval=True,
    report_to="none",
    eval_accumulation_steps=2,
    label_names=["labels"]
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer, padding=False),
    compute_metrics=lambda sample: compute_metrics_for_qna(sample, tokenizer)
)

In [None]:
# === Evaluate perplexity ===
eval_result = trainer.evaluate(test_set)

In [None]:
# Print all results
print("\nEvaluation Metrics:")
print(f"Loss: {eval_result['eval_loss']:.4f}")
print(f"Perplexity: {torch.exp(torch.tensor(eval_result['eval_loss'])):.2f}")
print(f"BLEU: {eval_result['eval_bleu']:.4f}")
print(f"ROUGE-1: {eval_result['eval_rouge1']:.4f}")
print(f"ROUGE-2: {eval_result['eval_rouge2']:.4f}")
print(f"ROUGE-L: {eval_result['eval_rougeL']:.4f}")
print(f"BERTscore precision: {eval_result['eval_bertscore_precision']:.4f}")
print(f"BERTscore recall: {eval_result['eval_bertscore_recall']:.4f}")
print(f"BERTscore f1: {eval_result['eval_bertscore_f1']:.4f}")

In [None]:
wandb.log({
    "eval_loss": eval_result['eval_loss'], 
    "perplexity": torch.exp(torch.tensor(eval_result['eval_loss'])),
    "BLUE": eval_result['eval_bleu'],
    "ROUGE_1": eval_result['eval_rouge1'],
    "ROUGE_2": eval_result['eval_rouge2'],
    "ROUGE_L": eval_result['eval_rougeL'],
    "BERTscore_precision": eval_result['eval_bertscore_precision'],
    "BERTscore recall": eval_result['eval_bertscore_recall'],
    "BERTscore f1": eval_result['eval_bertscore_f1']
})

## Med / non-med classication evalution

How well the model idetifies the non med questions 

In [None]:
def get_predicted_class(sample):
    predicted = generate(model, tokenizer, prompt_template.format(question=sample["question"]), max_new_tokens=5).split("# Answer:")[1].strip()
    predicted_class = re.findall(r"<.*>", predicted)[0][1:-1]
    return predicted_class

In [None]:
test_df["predicted_class"] = test_df.apply(get_predicted_class, axis=1)

In [None]:
test_df.head()

In [None]:
confusion_matrix = {"med": {"med": 0, "non_med": 0}, "non_med": {"med": 0, "non_med": 0}}
correct_count = 0

def update_confusion_matrix(sample):
    global correct_count
    confusion_matrix[sample["predicted_class"]][sample["class"]] += 1
    if sample["predicted_class"] != sample["class"]:
        print(sample["question"])
    else:
        correct_count += 1
    
test_df.apply(update_confusion_matrix , axis=1)
confusion_matrix

In [None]:
accuracy = correct_count / test_df.shape[0]

In [None]:
non_med_precision = confusion_matrix["non_med"]["non_med"] / (confusion_matrix["non_med"]["med"] + confusion_matrix["non_med"]["non_med"])
non_med_recall = confusion_matrix["non_med"]["non_med"] / (confusion_matrix["med"]["non_med"] + confusion_matrix["non_med"]["non_med"])
med_precision = confusion_matrix["med"]["med"] / (confusion_matrix["med"]["med"] + confusion_matrix["med"]["non_med"])
med_recall = confusion_matrix["med"]["med"] / (confusion_matrix["med"]["med"] + confusion_matrix["non_med"]["med"])

In [None]:
print(f"Non-med precision: {non_med_precision:.4f}")
print(f"Non-med recall: {non_med_recall:.4f}")
print(f"Med precision: {med_precision:.4f}")
print(f"Med recall: {med_recall:.4f}")
print(f"Accuracy: {accuracy:.4f}")

In [None]:
wandb.log({
    "non_med_precision": non_med_precision,
    "non_med_recall": non_med_recall,
    "med_precision:": med_precision,
    "med_recall": med_recall,
    "accuracy": accuracy
})

# Inference

In [None]:
examples = [
    "What is Glaucoma ?",
    "What are the symptoms of Glaucoma ?",
    "My sister is on Xanax, feyntnol patch and a pain medicine for cancer.  She has been on 25 of fentynol and within 6 days she has been bumped up to 100 now she is almost lethargic and breathing is really labored and right arm is twitching.. She was carrying on conversation Sunday and Monday patch was put on Tuesday and now cant even sit up..no one seems worried but me.. Just wondering what I could do",
    "I was playing basketball the other night and went up to block a shot and flipped over the guy and landed on my side/back. Since then the lower left side of back/side have been sore, hurts when I take deep breaths and when I lay on my back, any chance of a bruised kidney or any serious injury I could have?",
    "What are the treatments for High Blood Pressure ?",
    "What is (are) Urinary Tract Infections ?",
    "Create a C++ function that computes the Fast Fourier Transform (FFT) of a signal",
    "When did Beyonce start becoming popular?",
    "What are the symptoms of diabetes ?"
]

In [None]:
print(generate(model, tokenizer, prompt_template.format(question=examples[-3])))

In [None]:
for token in stream_generate(model, tokenizer, prompt_template.format(question="Create a C++ function that computes the Fast Fourier Transform (FFT) of a signal"), do_sample=False, max_new_tokens=512, skip_special_tokens=False):
    print(token, end='', flush=True)

In [None]:
wandb_table = wandb.Table(columns=["Question", "Generated answer"])

for example in examples:
    generated_answer = ""
    for token in stream_generate(model, tokenizer, prompt_template.format(question=example), do_sample=False):
        generated_answer += token
    wandb_table.add_data(example, generated_answer)
    print(example, "\n", generated_answer)

In [None]:
wandb.log({"generated_samples": wandb_table})