# Imports

In [None]:
!pip install absl-py rouge-score nltk

In [None]:
!python -m nltk.downloader punkt

In [None]:
!pip install bert-score

In [1]:
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TrainingArguments, Trainer,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding
)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import torch
import wandb
import evaluate  # Hugging Face's evaluate library
import numpy as np
import torch
from bert_score import BERTScorer, score as bert_score

from utils import tokenize_dataset_for_qna

  from .autonotebook import tqdm as notebook_tqdm


# Configs

In [2]:
model_path = "../models/phi_qna_finetuned_attempt_3/final"

data_path = "../data/qna/"
test_data_path = data_path + "test.csv"

model_id = "microsoft/Phi-3.5-mini-instruct"
base_model_path = "../models/phi_pubmed_pretrained_attempt_3/final_pretrained"

max_len = 512
batch_size = 8

In [3]:
wandb.init(project="qna_finetune-evaluation", name="attempt_3")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhasindumadushan325[0m ([33mhasindumadushan325-university-of-peradeniya[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Dataset

In [4]:
prompt_template = """
# Instruction:
Assume you are an excellent doctor. Using your knowledge, answer the question given below.

# Question: {question}

# Answer: """
prompt_template = prompt_template.strip()
print(prompt_template)

# Instruction:
Assume you are an excellent doctor. Using your knowledge, answer the question given below.

# Question: {question}

# Answer:


In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

In [None]:
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict


def tokenize_dataset_for_qna(tokenizer, data_df, prompt_template, max_len):
    dataset = Dataset.from_pandas(data_df)
    dataset = dataset.map(lambda sample: tokenize_for_qna(sample, tokenizer, prompt_template, max_len), batched=False)
    return dataset


def tokenize_for_qna(example, tokenizer, prompt_template, max_len):
    prompt = prompt_template.format(question=example['question'])
    answer = example["answer"] + tokenizer.eos_token
    full_text = prompt + answer

    # Tokenize answer to get the length of answer tokens
    prompt_len = len(tokenizer(
        prompt,
        truncation=True,
        max_length=max_len
    )["input_ids"])

    # Tokenize answer to get the length of answer tokens
    full_len = len(tokenizer(
        full_text,
        truncation=False
    )["input_ids"])
    
    # Tokenize full sequence once
    tokenized = tokenizer(
        full_text,
        truncation=True,
        max_length=max_len,
        padding="max_length",
        return_attention_mask=True    
    )
    
    # Convert to numpy arrays for faster operations
    input_ids = np.array(tokenized["input_ids"])
    attention_mask = np.array(tokenized["attention_mask"])
    
    # Create labels array and mask prompt portion efficiently
    labels = input_ids.copy()
    padding_len = max_len - full_len
    # Mask the prompt tokens
    labels[padding_len:padding_len + prompt_len] = -100
    
    # Update the tokenized dict with numpy arrays
    tokenized["input_ids"] = input_ids.tolist()
    tokenized["attention_mask"] = attention_mask.tolist()
    tokenized["labels"] = labels.tolist()
    
    return tokenized


In [6]:
test_df = pd.read_csv(test_data_path)
test_set = tokenize_dataset_for_qna(tokenizer, test_df.iloc[:1500, :], prompt_template, max_len)

Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1500/1500 [00:01<00:00, 806.55 examples/s]


In [7]:
test_set[0]

{'question': 'Ive had sleep apnea for approximately 8 years now. I also have polycythemia vera which started around the same time. I was constantly having to get plebotomies every six months when I wasnt wearing my cpap all of the time. But for the last few years now, Ive been real diligent about wearing it and I havent had to get any phlebotomies. Im now thinking that my polycythemia vera was because I was not getting enough oxygen? Am I on the right track? I would appreciate any help you could offer. Thanks.',
 'answer': 'Hello, Thank you for contacting ChatDoctorI understand your concern, I am Chat Doctor, Infectious Disease Specialist answering your query. Yes you are on right track. This can happen to you for low oxygen tension. As apnea occurs during sleep time there is adaptation by mean of increase RBC which leads you towards the poly Bohemia. I advise you to use the CPAP during sleep otherwise the polycythemia will increase. I advise you to take the phlebotomies done and once 

# Model

In [None]:
peft_config = PeftConfig.from_pretrained(model_path)
base_model_name = peft_config.base_model_name_or_path

In [None]:
base_model_name

In [8]:
# === Quantized model loading ===
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:01<00:00,  2.00it/s]


In [9]:
model = PeftModel.from_pretrained(model, model_path)
model.eval() 

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (embed_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): lora.Linear(
                (base_layer): Linear(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=24, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=24, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vect

### Base model

In [None]:
model_id

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    trust_remote_code=True
)
model.eval()

In [None]:
# training_args = TrainingArguments(
#     output_dir="./eval_output_base",
#     per_device_eval_batch_size=batch_size,
#     do_eval=True,
#     report_to="none"
# )

# base_model_trainer = Trainer(
#     model=base_model,
#     args=training_args,
#     tokenizer=tokenizer,
#     data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
# )

# base_model_eval_result = base_model_trainer.evaluate(test_set)

# Test set evaluation

In [10]:
training_args = TrainingArguments(
    output_dir="./eval_output",
    per_device_eval_batch_size=batch_size,
    do_eval=True,
    report_to="none",
    eval_accumulation_steps=2,
)

In [11]:

# Initialize metrics ONCE (reuse them)
bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")
bert_scorer = BERTScorer(
    lang="en",
    model_type="bert-base-uncased",
    device="cuda" if torch.cuda.is_available() else "cpu",
    idf=False,  # Disable IDF to save memory
    rescale_with_baseline=True  # Better score normalization
)

def compute_metrics(eval_preds):
    """ Metric computation """
    with torch.no_grad():
        logits, labels = eval_preds
        
        # Convert to numpy (move to CPU first if needed)
        if torch.is_tensor(logits):
            logits = logits.detach().cpu().numpy()
        if torch.is_tensor(labels):
            labels = labels.detach().cpu().numpy()
        
        # Get predicted tokens (shape: [batch_size, seq_length])
        pred_ids = np.argmax(logits, axis=-1)
        
        # Decode in batches to avoid memory spikes
        batch_size = 8  # Adjust based on your GPU memory
        pred_str, label_str = [], []
        
        for i in range(0, len(pred_ids), batch_size):
            # Decode predictions
            batch_preds = pred_ids[i:i+batch_size]
            pred_str.extend(tokenizer.batch_decode(
                batch_preds, 
                skip_special_tokens=True
            ))
            
            # Decode labels (replace -100 with pad_token_id)
            batch_labels = labels[i:i+batch_size]
            batch_labels = np.where(
                batch_labels != -100, 
                batch_labels, 
                tokenizer.pad_token_id
            )
            label_str.extend(tokenizer.batch_decode(
                batch_labels, 
                skip_special_tokens=True
            ))
        
        # Skip if empty (avoid errors)
        if not pred_str or not label_str:
            return {
                'bleu': 0.0,
                'rouge1': 0.0,
                'rouge2': 0.0,
                'rougeL': 0.0,
                'bertscore_f1': 0.0
            }
        
        # Compute BLEU (handle edge cases)
        try:
            bleu_score = bleu_metric.compute(
                predictions=pred_str,
                references=[[ref] for ref in label_str]
            )['bleu']
        except:
            bleu_score = 0.0
        
        # Compute ROUGE
        rouge_scores = rouge_metric.compute(
            predictions=pred_str,
            references=label_str,
            use_stemmer=True
        )
        
        # Compute BERTScore in batches
        P, R, F1 = bert_scorer.score(
            pred_str, 
            label_str,
            batch_size=4  # Small batch for BERTScore
        )
        
        metrics = {
            'bleu': bleu_score,
            'rouge1': rouge_scores['rouge1'],
            'rouge2': rouge_scores['rouge2'],
            'rougeL': rouge_scores['rougeL'],
            'bertscore_precision': P.mean().item(),
            'bertscore_recall': R.mean().item(),
            'bertscore_f1': F1.mean().item(),
        }
        
        torch.cuda.empty_cache()
        return metrics

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer, padding=False),
    compute_metrics=compute_metrics
)

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [12]:
# === Evaluate perplexity ===
eval_result = trainer.evaluate(test_set)

You are not running the flash-attention implementation, expect numerical differences.


In [13]:
# Print all results
print("\nEvaluation Metrics:")
print(f"Loss: {eval_result['eval_loss']:.4f}")
print(f"Perplexity: {torch.exp(torch.tensor(eval_result['eval_loss'])):.2f}")
print(f"BLEU: {eval_result['eval_bleu']:.4f}")
print(f"ROUGE-1: {eval_result['eval_rouge1']:.4f}")
print(f"ROUGE-2: {eval_result['eval_rouge2']:.4f}")
print(f"ROUGE-L: {eval_result['eval_rougeL']:.4f}")
print(f"BERTscore precision: {eval_result['eval_bertscore_precision']:.4f}")
print(f"BERTscore recall: {eval_result['eval_bertscore_recall']:.4f}")
print(f"BERTscore f1: {eval_result['eval_bertscore_f1']:.4f}")


Evaluation Metrics:
Loss: 1.6187
Perplexity: 5.05
BLEU: 0.1806
ROUGE-1: 0.4767
ROUGE-2: 0.2357
ROUGE-L: 0.3843
BERTscore precision: 0.3302
BERTscore recall: 0.5072
BERTscore f1: 0.4110


In [14]:
wandb.log({
    "eval_loss": eval_result['eval_loss'], 
    "perplexity": torch.exp(torch.tensor(eval_result['eval_loss'])),
    "BLUE": eval_result['eval_bleu'],
    "ROUGE_1": eval_result['eval_rouge1'],
    "ROUGE_2": eval_result['eval_rouge2'],
    "ROUGE_L": eval_result['eval_rougeL'],
    "BERTscore_precision": eval_result['eval_bertscore_precision'],
    "BERTscore recall": eval_result['eval_bertscore_recall'],
    "BERTscore f1": eval_result['eval_bertscore_f1']
})

# Inference

In [None]:
samples = test_set.select(range(433, 439))  # First 5 examples
input_ids = torch.tensor(samples["input_ids"]).to(model.device)
attention_mask = torch.tensor(samples["attention_mask"]).to(model.device)

generated_ids = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    max_new_tokens=128,
    do_sample=False,
    use_cache=False
)

generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

# Log predictions to W&B
wandb_table = wandb.Table(columns=["Title", "Actual Abstract", "Generated Text"])
for i, gen in enumerate(generated_texts):
    title = samples[i]["title"]
    actual = samples[i]["abstract"]
    print(f"\nActual: {title}\n{actual}\n---\nGenerated: {gen}\n")
    wandb_table.add_data(title, actual, gen)




In [None]:
samples[0]

In [None]:
wandb.log({"generated_examples": wandb_table})

In [None]:
def generate(model, text, max_new_tokens=128):
    sample = tokenizer(text + tokenizer.eos_token, truncation=True, padding="max_length", max_length=max_len, return_attention_mask=True)
    input_ids = torch.tensor([sample["input_ids"]]).to(model.device)
    attention_mask = torch.tensor([sample["attention_mask"]]).to(model.device)
    
    generated_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        num_beams=1,
        do_sample=False,
        use_cache=False
    )
    
    generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    return generated_texts[0]
    # # Log predictions to W&B
    # for i, gen in enumerate(generated_texts):
    #     title = samples[i]["title"]
    #     actual = samples[i]["abstract"]
    #     print(f"\nTitle: {title}\n---\nActual Abstract: {actual}\n---\nGenerated: {gen}\n")
    #     wandb_table.add_data(title, actual, gen)
    
    
    # wandb.log({"generated_examples": wandb_table})

In [None]:
generated_text = generate(model, prompt_template.format(question="What are the symptoms of Glaucoma ??"), max_new_tokens=200)	
print(generated_text)

# Instruction:
Assume you are an excellent doctor. Using your knowledge, answer the quesion given below.

# Question: What are the symptoms of Glaucoma ??

# Answer:ains of Glaucoma can vary depending on the type and severity of the condition. The most common symptom of glaucoma is loss of vision, which may begin with a loss of peripheral (side) vision. This is often described as tunnel vision.  Glaucoma can also cause blurred vision, halos around lights, eye pain, redness, and vision loss.  Glaucoma is often called the "silent thief of sight" because it usually has no symptoms until significant vision loss has occurred.  If you have glaucoma, you may not notice any changes in your vision until the damage is severe.  Glaucoma is a progressive disease, which means it gets worse over time.  If you have glaucoma, it's important to have regular eye exams to monitor your vision and eye pressure.  If you have any of the following symptoms, you should see


In [None]:
wandb.log({"example_2": generated_text})

In [None]:
test_set[0]

In [None]:
generate(base_model, "# The relationship between diabetes and blood pressure\n")

In [15]:
def stream_generate(model, tokenizer, text, max_new_tokens=300):
    model.eval()
    sample = tokenizer(
        text + tokenizer.eos_token,
        return_tensors="pt",
        truncation=True,
        max_length=max_len
    ).to(model.device)

    input_ids = sample["input_ids"]
    generated = input_ids.clone()
    past_key_values = None
    position_ids = torch.arange(0, input_ids.shape[1], device=model.device).unsqueeze(0)

    prev_decoded = tokenizer.decode(generated[0], skip_special_tokens=True)

    for i in range(max_new_tokens):
        if i == 0:
            input_token = input_ids
        else:
            input_token = next_token_id
            position_ids = torch.tensor([[generated.shape[1] - 1]], device=model.device)

        with torch.no_grad():
            outputs = model(
                input_ids=input_token,
                past_key_values=past_key_values,
                use_cache=True,
                position_ids=position_ids
            )

        logits = outputs.logits[:, -1, :]
        next_token_id = torch.argmax(logits, dim=-1, keepdim=True)

        generated = torch.cat((generated, next_token_id), dim=1)
        past_key_values = outputs.past_key_values

        # Decode full sequence and compute the diff
        decoded = tokenizer.decode(generated[0], skip_special_tokens=True)
        new_text = decoded[len(prev_decoded):]
        prev_decoded = decoded

        yield new_text

        if next_token_id.squeeze().item() == tokenizer.eos_token_id:
            break


"What are the symptoms of Glaucoma ??"
2. "Explain Management of advanced and extragonadal germ-cell tumors."
3. "what is the role of hns in the virulence phenotype of pathogenic salmonellae?"
4. "what are drinking patterns of high-risk drivers?"
5. "My sister is on Xanax, feyntnol patch and a pain medicine for cancer.  She has been on 25 of fentynol and within 6 days she has been bumped up to 100 now she is almost lethargic and breathing is really labored and right arm is twitching.. She was carrying on conversation Sunday and Monday patch was put on Tuesday and now cant even sit up..no one seems worried but me.. Just wondering what I could do"
6. "I was playing basketball the other night and went up to block a shot and flipped over the guy and landed on my side/back. Since then the lower left side of back/side have been sore, hurts when I take deep breaths and when I lay on my back, any chance of a bruised kidney or any serious injury I could have?"

In [31]:
for token in stream_generate(model, tokenizer, prompt_template.format(question="What is (are) Urinary Tract Infections ?")):
    print(token, end='', flush=True)

 The urinary tract is a system of organs that make up the body's waste disposal system. It includes the kidneys, ureters, bladder, and urethra. The kidneys filter waste from the blood and produce urine. Urine flows from the kidneys through the ureters to the bladder. The bladder stores urine until it is time to urinate. Urine leaves the body through the urethra.    Urinary tract infections (UTIs) are infections of the urinary tract. They are caused by bacteria. UTIs are more common in women than in men.    The signs and symptoms of a UTI include       - Pain or burning when urinating    - Frequent urination    - Urgent need to urinate    - Cloudy urine    - Strong-smelling urine    - Fever    - Chills    - Nausea    - Vomiting    - Pain in the lower abdomen    - Blood in the urine    - Urine that is dark or red    - Urine that is cloudy    - Urine that smells bad    - Urine that is foamy    - Urine that is pink or red    - Urine that is brown    - Urine that is yellow    - Urine that i

In [30]:
examples = [
    "What are the symptoms of Glaucoma ??",
    "My sister is on Xanax, feyntnol patch and a pain medicine for cancer.  She has been on 25 of fentynol and within 6 days she has been bumped up to 100 now she is almost lethargic and breathing is really labored and right arm is twitching.. She was carrying on conversation Sunday and Monday patch was put on Tuesday and now cant even sit up..no one seems worried but me.. Just wondering what I could do",
    "I was playing basketball the other night and went up to block a shot and flipped over the guy and landed on my side/back. Since then the lower left side of back/side have been sore, hurts when I take deep breaths and when I lay on my back, any chance of a bruised kidney or any serious injury I could have?",
    "What are the treatments for High Blood Pressure ?",
    "What is (are) Urinary Tract Infections ?",
]

In [35]:
wandb_table = wandb.Table(columns=["Question", "Generated answer"])

for example in examples:
    generated_answer = ""
    for token in stream_generate(model, tokenizer, prompt_template.format(question=example)):
        generated_answer += token
    wandb_table.add_data(example, generated_answer)
    print(example, "\n", generated_answer)

What are the symptoms of Glaucoma ?? 
  The symptoms of glaucoma are usually painless and may not be noticed until the disease has progressed. The first sign of glaucoma is often a loss of peripheral (side) vision. This is often described as tunnel vision. As glaucoma progresses, the blind spot in the visual field gets larger and the field of vision narrows. Eventually, the blind spot may reach the center of vision. This is called tunnel vision. If glaucoma is not treated, it can lead to blindness.
My sister is on Xanax, feyntnol patch and a pain medicine for cancer.  She has been on 25 of fentynol and within 6 days she has been bumped up to 100 now she is almost lethargic and breathing is really labored and right arm is twitching.. She was carrying on conversation Sunday and Monday patch was put on Tuesday and now cant even sit up..no one seems worried but me.. Just wondering what I could do 
  The symptoms of your sister are suggestive of a possible stroke. She needs to be evaluated 

In [34]:
wandb.log({"generated_examples": wandb_table})