# Compare different models with the finetuned model

## Imports

In [1]:
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TrainingArguments, Trainer,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding
)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import torch
import time
import wandb
import evaluate  # Hugging Face's evaluate library
import numpy as np
import torch
from tqdm import tqdm
from bert_score import BERTScorer, score as bert_score
from openai import OpenAI

from tokenize_functions import tokenize_dataset_for_domain_bound_qna
from prompt_templates import qna_prompt_template as prompt_template
from generate import generate, stream_generate
from evaluation_metrics import calculate_metrics_for_qna_str

  from .autonotebook import tqdm as notebook_tqdm


## Configs

In [40]:
data_path = "../data/comparison/data.csv"
base_phi_model_id = "microsoft/Phi-3.5-mini-instruct"
qwen3_model_id = "Qwen/Qwen3-8B"
med_qna_finetuned_model_path = "../models/phi_domain_bound_qna_finetuned_attempt_10/final_merged"

## Load dataset

In [3]:
test_set = pd.read_csv(data_path)

In [4]:
test_set.head()

Unnamed: 0.1,Unnamed: 0,question,answer,gpt_4o_mini_answer
0,0,What causes Hereditary diffuse leukoencephalop...,What causes hereditary diffuse leukoencephalop...,Hereditary diffuse leukoencephalopathy with sp...
1,1,What is (are) Jones syndrome ?,Jones syndrome is a very rare condition charac...,"Jones syndrome, also known as oculodentodigita..."
2,2,What is (are) Familial mixed cryoglobulinemia ?,Familial mixed cryoglobulinemia is a rare cond...,Familial mixed cryoglobulinemia is a rare cond...
3,3,What is (are) centronuclear myopathy ?,Centronuclear myopathy is a condition characte...,Centronuclear myopathy (CNM) is a rare genetic...
4,4,How many people are affected by alkaptonuria ?,"This condition is rare, affecting 1 in 250,000...",Alkaptonuria is a rare genetic disorder caused...


## Open AI evaluation

In [None]:
client = OpenAI()
client.api_key = ""

# Function to get response from GPT-4o
def ask_gpt4o(prompt):
    try:
        completion = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {
                    "role": "user",
                    "content": prompt
                }
            ]
        )

        answer = completion.choices[0].message.content
        return answer
        
    except Exception as e:
        return f"Error: {e}"

# Collect Q&A pairs
answers = []
for question in df["question"]:
    print(f"Processing: {question}")
    prompt = prompt_template.format(question=question)
    answer = ask_gpt4o(prompt)
    print(answer)
    answers.append(answer)
    time.sleep(3)  # optional: to respect rate limits

# Save to CSV
print(answers)
test_set["gpt_4o_mini_answer"] = pd.Series(answers)
test_set.to_csv(data_path, index=False)

In [6]:
def evaluate_qna_dataset(predicted_list, reference_list):
    assert len(predicted_list) == len(reference_list), "Predictions and references must be the same length"

    total_metrics = {
        "bleu": 0.0,
        "rouge1": 0.0,
        "rouge2": 0.0,
        "rougeL": 0.0,
        "bertscore_precision": 0.0,
        "bertscore_recall": 0.0,
        "bertscore_f1": 0.0
    }

    n = len(predicted_list)
    for pred, ref in zip(predicted_list, reference_list):
        metrics = calculate_metrics_for_qna_str(pred, ref)
        for key in total_metrics:
            total_metrics[key] += metrics[key]

    # Compute average
    avg_metrics = {key: value / n for key, value in total_metrics.items()}
    return avg_metrics

In [None]:
eval_result = evaluate_qna_dataset(test_set.loc[:, "gpt_4o_mini_answer"], test_set.loc[:, "answer"])

In [11]:
# Print all results
print("\nEvaluation Metrics for Open AI gpt-4o-min:")
print(f"BLEU: {eval_result['bleu']:.4f}")
print(f"ROUGE-1: {eval_result['rouge1']:.4f}")
print(f"ROUGE-2: {eval_result['rouge2']:.4f}")
print(f"ROUGE-L: {eval_result['rougeL']:.4f}")
print(f"BERTscore precision: {eval_result['bertscore_precision']:.4f}")
print(f"BERTscore recall: {eval_result['bertscore_recall']:.4f}")
print(f"BERTscore f1: {eval_result['bertscore_f1']:.4f}")


Evaluation Metrics for Open AI gpt-4o-min:
BLEU: 0.0342
ROUGE-1: 0.2918
ROUGE-2: 0.0928
ROUGE-L: 0.1651
BERTscore precision: -0.0059
BERTscore recall: 0.1704
BERTscore f1: 0.0802


## Base Phi 3.5 model evaluation

In [None]:
base_phi_model = AutoModelForCausalLM.from_pretrained(
    base_phi_model_id,
    device_map="auto",
    trust_remote_code=True
)

In [16]:
base_phi_tokenizer = AutoTokenizer.from_pretrained(base_phi_model_id, trust_remote_code=True)

In [26]:
def get_model_responses(model, tokenizer, questions):
    return questions.apply(lambda question: generate(model, tokenizer, prompt_template.format(question=question)).split("# Answer:")[1].strip())

In [20]:
test_set["base_phi_model_answers"] = get_model_responses(base_phi_model, base_phi_tokenizer, test_set["question"])



In [23]:
test_set.head()

Unnamed: 0.1,Unnamed: 0,question,answer,gpt_4o_mini_answer,base_phi_model_answers
0,0,What causes Hereditary diffuse leukoencephalop...,What causes hereditary diffuse leukoencephalop...,Hereditary diffuse leukoencephalopathy with sp...,Hereditary diffuse leukoencephalopathy with sp...
1,1,What is (are) Jones syndrome ?,Jones syndrome is a very rare condition charac...,"Jones syndrome, also known as oculodentodigita...","Jones syndrome, also known as Type III hyperse..."
2,2,What is (are) Familial mixed cryoglobulinemia ?,Familial mixed cryoglobulinemia is a rare cond...,Familial mixed cryoglobulinemia is a rare cond...,Familial mixed cryoglobulinemia (FMC) is a rar...
3,3,What is (are) centronuclear myopathy ?,Centronuclear myopathy is a condition characte...,Centronuclear myopathy (CNM) is a rare genetic...,Centronuclear myopathy (CNM) is a rare genetic...
4,4,How many people are affected by alkaptonuria ?,"This condition is rare, affecting 1 in 250,000...",Alkaptonuria is a rare genetic disorder caused...,"Alkaptonuria is a rare inherited disorder, and..."


In [None]:
test_set.to_csv(data_path, index=False)

In [None]:
eval_result = evaluate_qna_dataset(test_set.loc[:, "base_phi_model_answers"], test_set.loc[:, "answer"])

In [25]:
# Print all results
print("\nEvaluation Metrics for base phi 3.5 model:")
print(f"BLEU: {eval_result['bleu']:.4f}")
print(f"ROUGE-1: {eval_result['rouge1']:.4f}")
print(f"ROUGE-2: {eval_result['rouge2']:.4f}")
print(f"ROUGE-L: {eval_result['rougeL']:.4f}")
print(f"BERTscore precision: {eval_result['bertscore_precision']:.4f}")
print(f"BERTscore recall: {eval_result['bertscore_recall']:.4f}")
print(f"BERTscore f1: {eval_result['bertscore_f1']:.4f}")


Evaluation Metrics for base phi 3.5 model:
BLEU: 0.0299
ROUGE-1: 0.2840
ROUGE-2: 0.0924
ROUGE-L: 0.1644
BERTscore precision: -0.0186
BERTscore recall: 0.1453
BERTscore f1: 0.0608


In [27]:
del base_phi_model

## Qwen 3 8B model evaluation

In [33]:
qwen_model = AutoModelForCausalLM.from_pretrained(
    qwen3_model_id,
    device_map="auto",
    trust_remote_code=True
)

Fetching 5 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:07<00:00,  1.49s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  2.11it/s]


In [34]:
qwen_tokenizer = AutoTokenizer.from_pretrained(qwen3_model_id, trust_remote_code=True)

In [None]:
test_set["qwen_3_8B_answers"] = get_model_responses(qwen_model, qwen_tokenizer, test_set["question"])

In [39]:
test_set.head()

Unnamed: 0.1,Unnamed: 0,question,answer,gpt_4o_mini_answer,base_phi_model_answers,qwen_3_8B_answers
0,0,What causes Hereditary diffuse leukoencephalop...,What causes hereditary diffuse leukoencephalop...,Hereditary diffuse leukoencephalopathy with sp...,Hereditary diffuse leukoencephalopathy with sp...,Hereditary diffuse leukoencephalopathy with sp...
1,1,What is (are) Jones syndrome ?,Jones syndrome is a very rare condition charac...,"Jones syndrome, also known as oculodentodigita...","Jones syndrome, also known as Type III hyperse...",(Please write in English)\n\nJones syndrome is...
2,2,What is (are) Familial mixed cryoglobulinemia ?,Familial mixed cryoglobulinemia is a rare cond...,Familial mixed cryoglobulinemia is a rare cond...,Familial mixed cryoglobulinemia (FMC) is a rar...,"(Please write in English, using the language o..."
3,3,What is (are) centronuclear myopathy ?,Centronuclear myopathy is a condition characte...,Centronuclear myopathy (CNM) is a rare genetic...,Centronuclear myopathy (CNM) is a rare genetic...,Centronuclear myopathy is a rare genetic disor...
4,4,How many people are affected by alkaptonuria ?,"This condition is rare, affecting 1 in 250,000...",Alkaptonuria is a rare genetic disorder caused...,"Alkaptonuria is a rare inherited disorder, and...",Alkaptonuria is a rare genetic disorder. It af...


In [None]:
eval_result = evaluate_qna_dataset(test_set.loc[:, "qwen_3_8B_answers"], test_set.loc[:, "answer"])

In [38]:
print("\nEvaluation Metrics for qwen 3 8B model:")
print(f"BLEU: {eval_result['bleu']:.4f}")
print(f"ROUGE-1: {eval_result['rouge1']:.4f}")
print(f"ROUGE-2: {eval_result['rouge2']:.4f}")
print(f"ROUGE-L: {eval_result['rougeL']:.4f}")
print(f"BERTscore precision: {eval_result['bertscore_precision']:.4f}")
print(f"BERTscore recall: {eval_result['bertscore_recall']:.4f}")
print(f"BERTscore f1: {eval_result['bertscore_f1']:.4f}")


Evaluation Metrics for qwen 3 8B model:
BLEU: 0.0312
ROUGE-1: 0.2822
ROUGE-2: 0.0937
ROUGE-L: 0.1757
BERTscore precision: 0.0848
BERTscore recall: 0.1149
BERTscore f1: 0.0987


In [41]:
del qwen_model

# Fine tuned model evaluation

In [42]:
finetuned_model = AutoModelForCausalLM.from_pretrained(
    med_qna_finetuned_model_path,
    device_map="auto",
    trust_remote_code=False
)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:09<00:00,  2.49s/it]


In [44]:
finetuned_model_tokenizer = AutoTokenizer.from_pretrained(med_qna_finetuned_model_path, trust_remote_code=False)

In [45]:
test_set["finetuned_model_answer"] = get_model_responses(finetuned_model, finetuned_model_tokenizer, test_set["question"])

In [47]:
# Remove tags
test_set["finetuned_model_answer"] = test_set["finetuned_model_answer"].apply(lambda answer: answer.replace("<med>", "").replace("<non_med>", ""))

In [48]:
test_set.head(30)

Unnamed: 0.1,Unnamed: 0,question,answer,gpt_4o_mini_answer,base_phi_model_answers,qwen_3_8B_answers,finetuned_model_answer
0,0,What causes Hereditary diffuse leukoencephalop...,What causes hereditary diffuse leukoencephalop...,Hereditary diffuse leukoencephalopathy with sp...,Hereditary diffuse leukoencephalopathy with sp...,Hereditary diffuse leukoencephalopathy with sp...,What causes hereditary diffuse leukoencephalop...
1,1,What is (are) Jones syndrome ?,Jones syndrome is a very rare condition charac...,"Jones syndrome, also known as oculodentodigita...","Jones syndrome, also known as Type III hyperse...",(Please write in English)\n\nJones syndrome is...,Jones syndrome is a rare genetic disorder that...
2,2,What is (are) Familial mixed cryoglobulinemia ?,Familial mixed cryoglobulinemia is a rare cond...,Familial mixed cryoglobulinemia is a rare cond...,Familial mixed cryoglobulinemia (FMC) is a rar...,"(Please write in English, using the language o...",Familial mixed cryoglobulinemia is a rare cond...
3,3,What is (are) centronuclear myopathy ?,Centronuclear myopathy is a condition characte...,Centronuclear myopathy (CNM) is a rare genetic...,Centronuclear myopathy (CNM) is a rare genetic...,Centronuclear myopathy is a rare genetic disor...,Centronuclear myopathy is a group of rare musc...
4,4,How many people are affected by alkaptonuria ?,"This condition is rare, affecting 1 in 250,000...",Alkaptonuria is a rare genetic disorder caused...,"Alkaptonuria is a rare inherited disorder, and...",Alkaptonuria is a rare genetic disorder. It af...,"Alkaptonuria is a rare condition, affecting ab..."
5,5,Is cytochrome P450 oxidoreductase deficiency i...,This condition is inherited in an autosomal re...,"Yes, cytochrome P450 oxidoreductase (POR) defi...","Yes, cytochrome P450 oxidoreductase (CPOR) def...","Yes, cytochrome P450 oxidoreductase deficiency...",Cytochrome P450 oxidoreductase deficiency is i...
6,6,Is focal dermal hypoplasia inherited ?,Focal dermal hypoplasia is inherited in an X-l...,"Focal dermal hypoplasia, also known as Goltz s...","Focal dermal hypoplasia (FDH), also known as G...","Yes, focal dermal hypoplasia is inherited in a...",Focal dermal hypoplasia is inherited in an aut...
7,7,What is (are) Fever ?,A fever is a body temperature that is higher t...,"Fever, also known as pyrexia, is defined as an...","Fever, also known as pyrexia, is an elevation ...",Fever is a temporary increase in body temperat...,Fever is a temporary increase in the body’s te...
8,8,What causes Gout ?,Most people with gout have too much uric acid ...,Gout is a form of inflammatory arthritis chara...,Gout is a form of inflammatory arthritis chara...,Gout is caused by the accumulation of uric aci...,Gout is caused by a buildup of uric acid in th...
9,9,What is (are) Stroke ?,The most commonly used imaging procedure is th...,A stroke is a serious medical condition that o...,"A stroke, also known as a cerebrovascular acci...","(Please write in English, using the language o...",A stroke occurs when blood flow to the brain i...


In [49]:
test_set.to_csv(data_path, index=False)

In [None]:
eval_result = evaluate_qna_dataset(test_set.loc[:, "finetuned_model_answer"], test_set.loc[:, "answer"])

In [51]:
print("\nEvaluation Metrics for finetuned model:")
print(f"BLEU: {eval_result['bleu']:.4f}")
print(f"ROUGE-1: {eval_result['rouge1']:.4f}")
print(f"ROUGE-2: {eval_result['rouge2']:.4f}")
print(f"ROUGE-L: {eval_result['rougeL']:.4f}")
print(f"BERTscore precision: {eval_result['bertscore_precision']:.4f}")
print(f"BERTscore recall: {eval_result['bertscore_recall']:.4f}")
print(f"BERTscore f1: {eval_result['bertscore_f1']:.4f}")


Evaluation Metrics for finetuned model:
BLEU: 0.0681
ROUGE-1: 0.3349
ROUGE-2: 0.1523
ROUGE-L: 0.2599
BERTscore precision: 0.1736
BERTscore recall: 0.1491
BERTscore f1: 0.1575


In [52]:
del finetuned_model