In [None]:
!pip install absl-py rouge-score nltk

In [None]:
!python -m nltk.downloader punkt

In [None]:
!pip install bert-score

# Imports

In [1]:
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TrainingArguments, Trainer,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding
)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import torch
import wandb
import evaluate  # Hugging Face's evaluate library
import numpy as np
import torch

from utils import tokenize_dataset_for_qna
from prompt_templates import qna_prompt_template as prompt_template
from generate import generate, stream_generate
from evaluation_metrics import compute_metrics_for_qna

  from .autonotebook import tqdm as notebook_tqdm


# Configs

In [2]:
model_path = "../models/phi_qna_finetuned_attempt_5/final"

data_path = "../data/qna/"
test_data_path = data_path + "test.csv"

model_id = "microsoft/Phi-3.5-mini-instruct"
base_model_path = "../models/phi_pubmed_pretrained_attempt_5/final_merged"

max_len = 512
batch_size = 8

In [3]:
wandb.init(project="qna_finetune-evaluation", name="attempt_5")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhasindumadushan325[0m ([33mhasindumadushan325-university-of-peradeniya[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Dataset

In [4]:
print(prompt_template)

# Instruction:
Assume you are an excellent doctor. Using your knowledge, answer the question given below.

# Question: {question}

# Answer:


In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

In [6]:
test_df = pd.read_csv(test_data_path)
test_set = tokenize_dataset_for_qna(tokenizer, test_df.iloc[:1500, :], prompt_template, max_len)

Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1500/1500 [00:01<00:00, 849.36 examples/s]


In [7]:
test_set[0]

{'question': 'Ive had sleep apnea for approximately 8 years now. I also have polycythemia vera which started around the same time. I was constantly having to get plebotomies every six months when I wasnt wearing my cpap all of the time. But for the last few years now, Ive been real diligent about wearing it and I havent had to get any phlebotomies. Im now thinking that my polycythemia vera was because I was not getting enough oxygen? Am I on the right track? I would appreciate any help you could offer. Thanks.',
 'answer': 'Hello, Thank you for contacting ChatDoctorI understand your concern, I am Chat Doctor, Infectious Disease Specialist answering your query. Yes you are on right track. This can happen to you for low oxygen tension. As apnea occurs during sleep time there is adaptation by mean of increase RBC which leads you towards the poly Bohemia. I advise you to use the CPAP during sleep otherwise the polycythemia will increase. I advise you to take the phlebotomies done and once 

# Model

In [8]:
# === Quantized model loading ===
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16
# )

model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=False
)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.98it/s]


In [9]:
model = PeftModel.from_pretrained(model, model_path)
model.eval() 

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (qkv_proj): lora.Linear(
                (base_layer): Linear(in_features=3072, out_features=9216, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=24, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=24, out_features=9216, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
         

### Base model

In [None]:
model_id

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    trust_remote_code=True
)
model.eval()

In [None]:
# training_args = TrainingArguments(
#     output_dir="./eval_output_base",
#     per_device_eval_batch_size=batch_size,
#     do_eval=True,
#     report_to="none"
# )

# base_model_trainer = Trainer(
#     model=base_model,
#     args=training_args,
#     tokenizer=tokenizer,
#     data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
# )

# base_model_eval_result = base_model_trainer.evaluate(test_set)

# Test set evaluation

In [15]:
training_args = TrainingArguments(
    output_dir="./eval_output",
    per_device_eval_batch_size=batch_size,
    do_eval=True,
    report_to="none",
    eval_accumulation_steps=2,
    label_names=["labels"]
)

In [16]:
# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer, padding=False),
    compute_metrics=compute_metrics_for_qna
)

  trainer = Trainer(


In [None]:
# === Evaluate perplexity ===
eval_result = trainer.evaluate(test_set)

In [None]:
# Print all results
print("\nEvaluation Metrics:")
print(f"Loss: {eval_result['eval_loss']:.4f}")
print(f"Perplexity: {torch.exp(torch.tensor(eval_result['eval_loss'])):.2f}")
print(f"BLEU: {eval_result['eval_bleu']:.4f}")
print(f"ROUGE-1: {eval_result['eval_rouge1']:.4f}")
print(f"ROUGE-2: {eval_result['eval_rouge2']:.4f}")
print(f"ROUGE-L: {eval_result['eval_rougeL']:.4f}")
print(f"BERTscore precision: {eval_result['eval_bertscore_precision']:.4f}")
print(f"BERTscore recall: {eval_result['eval_bertscore_recall']:.4f}")
print(f"BERTscore f1: {eval_result['eval_bertscore_f1']:.4f}")

In [None]:
wandb.log({
    "eval_loss": eval_result['eval_loss'], 
    "perplexity": torch.exp(torch.tensor(eval_result['eval_loss'])),
    "BLUE": eval_result['eval_bleu'],
    "ROUGE_1": eval_result['eval_rouge1'],
    "ROUGE_2": eval_result['eval_rouge2'],
    "ROUGE_L": eval_result['eval_rougeL'],
    "BERTscore_precision": eval_result['eval_bertscore_precision'],
    "BERTscore recall": eval_result['eval_bertscore_recall'],
    "BERTscore f1": eval_result['eval_bertscore_f1']
})

# Inference

In [10]:
examples = [
    "What is Glaucoma ?",
    "What are the symptoms of Glaucoma ??",
    "My sister is on Xanax, feyntnol patch and a pain medicine for cancer.  She has been on 25 of fentynol and within 6 days she has been bumped up to 100 now she is almost lethargic and breathing is really labored and right arm is twitching.. She was carrying on conversation Sunday and Monday patch was put on Tuesday and now cant even sit up..no one seems worried but me.. Just wondering what I could do",
    "I was playing basketball the other night and went up to block a shot and flipped over the guy and landed on my side/back. Since then the lower left side of back/side have been sore, hurts when I take deep breaths and when I lay on my back, any chance of a bruised kidney or any serious injury I could have?",
    "What are the treatments for High Blood Pressure ?",
    "What is (are) Urinary Tract Infections ?",
    "What are the symptoms of diabetes?"
]

In [None]:
generated_text = generate(model, tokenizer, prompt_template.format(question=examples[3]), max_new_tokens=200)	
print(generated_text)

In [None]:
wandb.log({"example_2": generated_text})

In [None]:
test_set[0]

In [None]:
generate(model, "# The relationship between diabetes and blood pressure\n")

In [None]:
def stream_generate(model, tokenizer, text, max_new_tokens=300):
    model.eval()
    sample = tokenizer(
        text + tokenizer.eos_token,
        return_tensors="pt",
        truncation=True,
        max_length=max_len
    ).to(model.device)

    input_ids = sample["input_ids"]
    generated = input_ids.clone()
    past_key_values = None
    position_ids = torch.arange(0, input_ids.shape[1], device=model.device).unsqueeze(0)

    prev_decoded = tokenizer.decode(generated[0], skip_special_tokens=True)

    for i in range(max_new_tokens):
        if i == 0:
            input_token = input_ids
        else:
            input_token = next_token_id
            position_ids = torch.tensor([[generated.shape[1] - 1]], device=model.device)

        with torch.no_grad():
            outputs = model(
                input_ids=input_token,
                past_key_values=past_key_values,
                use_cache=True,
                position_ids=position_ids
            )

        logits = outputs.logits[:, -1, :]
        next_token_id = torch.argmax(logits, dim=-1, keepdim=True)

        generated = torch.cat((generated, next_token_id), dim=1)
        past_key_values = outputs.past_key_values

        # Decode full sequence and compute the diff
        decoded = tokenizer.decode(generated[0], skip_special_tokens=True)
        new_text = decoded[len(prev_decoded):]
        prev_decoded = decoded

        yield new_text

        if next_token_id.squeeze().item() == tokenizer.eos_token_id:
            break


"What are the symptoms of Glaucoma ??"
2. "Explain Management of advanced and extragonadal germ-cell tumors."
3. "what is the role of hns in the virulence phenotype of pathogenic salmonellae?"
4. "what are drinking patterns of high-risk drivers?"
5. "My sister is on Xanax, feyntnol patch and a pain medicine for cancer.  She has been on 25 of fentynol and within 6 days she has been bumped up to 100 now she is almost lethargic and breathing is really labored and right arm is twitching.. She was carrying on conversation Sunday and Monday patch was put on Tuesday and now cant even sit up..no one seems worried but me.. Just wondering what I could do"
6. "I was playing basketball the other night and went up to block a shot and flipped over the guy and landed on my side/back. Since then the lower left side of back/side have been sore, hurts when I take deep breaths and when I lay on my back, any chance of a bruised kidney or any serious injury I could have?"

In [None]:
for token in stream_generate(model, tokenizer, "The symptoms of Glaucoma\n"):
    print(token, end='', flush=True)

In [None]:
for token in stream_generate(model, tokenizer, prompt_template.format(question="I was playing basketball the other night and went up to block a shot and flipped over the guy and landed on my side/back. Since then the lower left side of back/side have been sore, hurts when I take deep breaths and when I lay on my back, any chance of a bruised kidney or any serious injury I could have?",
)):
    print(token, end='', flush=True)

In [None]:
examples = [
    "What are the symptoms of Glaucoma ??",
    "My sister is on Xanax, feyntnol patch and a pain medicine for cancer.  She has been on 25 of fentynol and within 6 days she has been bumped up to 100 now she is almost lethargic and breathing is really labored and right arm is twitching.. She was carrying on conversation Sunday and Monday patch was put on Tuesday and now cant even sit up..no one seems worried but me.. Just wondering what I could do",
    "I was playing basketball the other night and went up to block a shot and flipped over the guy and landed on my side/back. Since then the lower left side of back/side have been sore, hurts when I take deep breaths and when I lay on my back, any chance of a bruised kidney or any serious injury I could have?",
    "What are the treatments for High Blood Pressure ?",
    "What is (are) Urinary Tract Infections ?",
]

In [None]:
wandb_table = wandb.Table(columns=["Question", "Generated answer"])

for example in examples:
    generated_answer = ""
    for token in stream_generate(model, tokenizer, prompt_template.format(question=example)):
        generated_answer += token
    wandb_table.add_data(example, generated_answer)
    print(example, "\n", generated_answer)

In [None]:
wandb.log({"generated_examples": wandb_table})

In [None]:
model = model.merge_and_unload()

In [None]:
model.save_pretrained("../models/phi_qna_finetuned_attempt_3/final_pretrained_2", safe_serialization=True)

In [None]:
loaded_model = AutoModelForCausalLM.from_pretrained(
    "../models/phi_qna_finetuned_attempt_3/final_pretrained_2",    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=False
)
loaded_model.eval()

In [11]:
for token in stream_generate(model, tokenizer, prompt_template.format(question=examples[3]), do_sample=True, temperature=0.1, top_k=20, max_new_tokens=512):
    print(token, end='', flush=True)

Hi, Thanks for writing in. I am Chat Doctor, infectious disease specialist, answering your query. I understand your concern. I would suggest you to get an ultrasound of the abdomen and kidneys to rule out any injury to the kidneys. If the ultrasound is normal, then you can take a course of antibiotics for 5 days. I hope this information helps you. Thank you for choosing Chat Doctor. I wish you a quick recovery. Best, Chans Doctor.<|endoftext|>