# Imports

In [1]:
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    TrainingArguments, Trainer,
    DataCollatorForLanguageModeling,
    DataCollatorWithPadding
)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import torch
import wandb
import evaluate  # Hugging Face's evaluate library
import numpy as np
import torch
from bert_score import BERTScorer, score as bert_score

from utils import tokenize_dataset_for_domain_bound_qna
# from prompt_templates import domain_bound_qna_prompt_template as prompt_template

  from .autonotebook import tqdm as notebook_tqdm


# Configs

In [2]:
model_path = "../models/phi_domain_bound_qna_finetuned_attempt_6/final"

data_path = "../data/domain_bound_data/v4/"
test_data_path = data_path + "test.csv"

model_id = "microsoft/Phi-3.5-mini-instruct"
base_model_path = "../models/phi_qna_finetuned_attempt_3/final_pretrained_2"

max_len = 512
batch_size = 8

In [3]:
wandb.init(project="domain_bound_qna_finetune-evaluation", name="attempt_6")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhasindumadushan325[0m ([33mhasindumadushan325-university-of-peradeniya[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


# Dataset

In [4]:
prompt_template = """
# Instruction:
Assume you are an excellent doctor. Using your knowledge, answer the question given below.

# Question: {question}

# Answer: """
prompt_template = prompt_template.strip()
print(prompt_template)

# Instruction:
Assume you are an excellent doctor. Using your knowledge, answer the question given below.

# Question: {question}

# Answer:


In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

In [6]:
test_df = pd.read_csv(test_data_path)
test_set = tokenize_dataset_for_domain_bound_qna(tokenizer, test_df, prompt_template, max_len)

Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1034.07 examples/s]


In [None]:
test_set[10]

# Model

In [9]:
base_model_path

'../models/phi_qna_finetuned_attempt_3/final_pretrained_2'

In [7]:
# === Quantized model loading ===
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    base_model_path,
    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=False
)

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.99it/s]


In [8]:
model_path

'../models/phi_domain_bound_qna_finetuned_attempt_6/final'

In [10]:
model = PeftModel.from_pretrained(model, model_path)
model.eval() 

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (qkv_proj): lora.Linear(
                (base_layer): Linear(in_features=3072, out_features=9216, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=64, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=64, out_features=9216, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
         

### Base model

In [None]:
model_id

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    trust_remote_code=True
)
model.eval()

In [None]:
# training_args = TrainingArguments(
#     output_dir="./eval_output_base",
#     per_device_eval_batch_size=batch_size,
#     do_eval=True,
#     report_to="none"
# )

# base_model_trainer = Trainer(
#     model=base_model,
#     args=training_args,
#     tokenizer=tokenizer,
#     data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
# )

# base_model_eval_result = base_model_trainer.evaluate(test_set)

# Test set evaluation

In [23]:
training_args = TrainingArguments(
    output_dir="./eval_output",
    per_device_eval_batch_size=batch_size,
    do_eval=True,
    report_to="none",
    eval_accumulation_steps=2,
)

In [24]:

# Initialize metrics ONCE (reuse them)
bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")
bert_scorer = BERTScorer(
    lang="en",
    model_type="bert-base-uncased",
    device="cuda" if torch.cuda.is_available() else "cpu",
    idf=False,  # Disable IDF to save memory
    rescale_with_baseline=True  # Better score normalization
)

def compute_metrics(eval_preds):
    """ Metric computation """
    with torch.no_grad():
        logits, labels = eval_preds
        
        # Convert to numpy (move to CPU first if needed)
        if torch.is_tensor(logits):
            logits = logits.detach().cpu().numpy()
        if torch.is_tensor(labels):
            labels = labels.detach().cpu().numpy()
        
        # Get predicted tokens (shape: [batch_size, seq_length])
        pred_ids = np.argmax(logits, axis=-1)
        
        # Decode in batches to avoid memory spikes
        batch_size = 8  # Adjust based on your GPU memory
        pred_str, label_str = [], []
        
        for i in range(0, len(pred_ids), batch_size):
            # Decode predictions
            batch_preds = pred_ids[i:i+batch_size]
            pred_str.extend(tokenizer.batch_decode(
                batch_preds, 
                skip_special_tokens=True
            ))
            
            # Decode labels (replace -100 with pad_token_id)
            batch_labels = labels[i:i+batch_size]
            batch_labels = np.where(
                batch_labels != -100, 
                batch_labels, 
                tokenizer.pad_token_id
            )
            label_str.extend(tokenizer.batch_decode(
                batch_labels, 
                skip_special_tokens=True
            ))
        
        # Skip if empty (avoid errors)
        if not pred_str or not label_str:
            return {
                'bleu': 0.0,
                'rouge1': 0.0,
                'rouge2': 0.0,
                'rougeL': 0.0,
                'bertscore_f1': 0.0
            }
        
        # Compute BLEU (handle edge cases)
        try:
            bleu_score = bleu_metric.compute(
                predictions=pred_str,
                references=[[ref] for ref in label_str]
            )['bleu']
        except:
            bleu_score = 0.0
        
        # Compute ROUGE
        rouge_scores = rouge_metric.compute(
            predictions=pred_str,
            references=label_str,
            use_stemmer=True
        )
        
        # Compute BERTScore in batches
        P, R, F1 = bert_scorer.score(
            pred_str, 
            label_str,
            batch_size=4  # Small batch for BERTScore
        )
        
        metrics = {
            'bleu': bleu_score,
            'rouge1': rouge_scores['rouge1'],
            'rouge2': rouge_scores['rouge2'],
            'rougeL': rouge_scores['rougeL'],
            'bertscore_precision': P.mean().item(),
            'bertscore_recall': R.mean().item(),
            'bertscore_f1': F1.mean().item(),
        }
        
        torch.cuda.empty_cache()
        return metrics

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer, padding=False),
    compute_metrics=compute_metrics
)

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [25]:
# === Evaluate perplexity ===
eval_result = trainer.evaluate(test_set)

In [26]:
# Print all results
print("\nEvaluation Metrics:")
print(f"Loss: {eval_result['eval_loss']:.4f}")
print(f"Perplexity: {torch.exp(torch.tensor(eval_result['eval_loss'])):.2f}")
print(f"BLEU: {eval_result['eval_bleu']:.4f}")
print(f"ROUGE-1: {eval_result['eval_rouge1']:.4f}")
print(f"ROUGE-2: {eval_result['eval_rouge2']:.4f}")
print(f"ROUGE-L: {eval_result['eval_rougeL']:.4f}")
print(f"BERTscore precision: {eval_result['eval_bertscore_precision']:.4f}")
print(f"BERTscore recall: {eval_result['eval_bertscore_recall']:.4f}")
print(f"BERTscore f1: {eval_result['eval_bertscore_f1']:.4f}")


Evaluation Metrics:
Loss: 1.6728
Perplexity: 5.33
BLEU: 0.0476
ROUGE-1: 0.2373
ROUGE-2: 0.1100
ROUGE-L: 0.1902
BERTscore precision: -0.0313
BERTscore recall: 0.3024
BERTscore f1: 0.0862


In [27]:
wandb.log({
    "eval_loss": eval_result['eval_loss'], 
    "perplexity": torch.exp(torch.tensor(eval_result['eval_loss'])),
    "BLUE": eval_result['eval_bleu'],
    "ROUGE_1": eval_result['eval_rouge1'],
    "ROUGE_2": eval_result['eval_rouge2'],
    "ROUGE_L": eval_result['eval_rougeL'],
    "BERTscore_precision": eval_result['eval_bertscore_precision'],
    "BERTscore recall": eval_result['eval_bertscore_recall'],
    "BERTscore f1": eval_result['eval_bertscore_f1']
})

# Inference

In [None]:
samples = test_set.select(range(433, 439))  # First 5 examples
input_ids = torch.tensor(samples["input_ids"]).to(model.device)
attention_mask = torch.tensor(samples["attention_mask"]).to(model.device)

generated_ids = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    max_new_tokens=128,
    do_sample=False,
    use_cache=False
)

generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

# Log predictions to W&B
wandb_table = wandb.Table(columns=["Title", "Actual Abstract", "Generated Text"])
for i, gen in enumerate(generated_texts):
    title = samples[i]["title"]
    actual = samples[i]["abstract"]
    print(f"\nActual: {title}\n{actual}\n---\nGenerated: {gen}\n")
    wandb_table.add_data(title, actual, gen)




In [None]:
wandb.log({"generated_examples": wandb_table})

In [None]:
def generate(model, text, max_new_tokens=128):
    sample = tokenizer(text + tokenizer.eos_token, truncation=True, padding="max_length", max_length=max_len, return_attention_mask=True)
    input_ids = torch.tensor([sample["input_ids"]]).to(model.device)
    attention_mask = torch.tensor([sample["attention_mask"]]).to(model.device)
    
    generated_ids = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        num_beams=1,
        do_sample=False,
        use_cache=False
    )
    
    generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    return generated_texts[0]
    # # Log predictions to W&B
    # for i, gen in enumerate(generated_texts):
    #     title = samples[i]["title"]
    #     actual = samples[i]["abstract"]
    #     print(f"\nTitle: {title}\n---\nActual Abstract: {actual}\n---\nGenerated: {gen}\n")
    #     wandb_table.add_data(title, actual, gen)
    
    
    # wandb.log({"generated_examples": wandb_table})

In [7]:
def stream_generate(model, tokenizer, text, max_new_tokens=300):
    model.eval()
    sample = tokenizer(
        text + tokenizer.eos_token,
        return_tensors="pt",
        truncation=True,
        max_length=max_len
    ).to(model.device)

    input_ids = sample["input_ids"]
    generated = input_ids.clone()
    past_key_values = None
    position_ids = torch.arange(0, input_ids.shape[1], device=model.device).unsqueeze(0)

    prev_decoded = tokenizer.decode(generated[0], skip_special_tokens=True)

    for i in range(max_new_tokens):
        if i == 0:
            input_token = input_ids
        else:
            input_token = next_token_id
            position_ids = torch.tensor([[generated.shape[1] - 1]], device=model.device)

        with torch.no_grad():
            outputs = model(
                input_ids=input_token,
                past_key_values=past_key_values,
                use_cache=True,
                position_ids=position_ids
            )

        logits = outputs.logits[:, -1, :]
        next_token_id = torch.argmax(logits, dim=-1, keepdim=True)

        generated = torch.cat((generated, next_token_id), dim=1)
        past_key_values = outputs.past_key_values

        # Decode full sequence and compute the diff
        decoded = tokenizer.decode(generated[0], skip_special_tokens=True)
        new_text = decoded[len(prev_decoded):]
        prev_decoded = decoded

        yield new_text

        if next_token_id.squeeze().item() == tokenizer.eos_token_id:
            break


In [11]:
import torch
import torch.nn.functional as F

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. """
    logits = logits.clone()

    # Top-K filtering
    if top_k > 0:
        top_k = min(top_k, logits.size(-1))  # Safety check
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    # Top-P (nucleus) filtering
    if top_p > 0.0 and top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above top_p
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the mask to keep at least one token
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[..., indices_to_remove] = filter_value

    return logits

def stream_generate(model, tokenizer, text, max_new_tokens=300, temperature=1.0, top_k=0, top_p=0.0, max_len=512):
    model.eval()
    sample = tokenizer(
        text + tokenizer.eos_token,
        return_tensors="pt",
        truncation=True,
        max_length=max_len
    ).to(model.device)

    input_ids = sample["input_ids"]
    generated = input_ids.clone()
    past_key_values = None
    position_ids = torch.arange(0, input_ids.shape[1], device=model.device).unsqueeze(0)

    prev_decoded = tokenizer.decode(generated[0], skip_special_tokens=True)

    for i in range(max_new_tokens):
        if i == 0:
            input_token = input_ids
        else:
            input_token = next_token_id
            position_ids = torch.tensor([[generated.shape[1] - 1]], device=model.device)

        with torch.no_grad():
            outputs = model(
                input_ids=input_token,
                past_key_values=past_key_values,
                use_cache=True,
                position_ids=position_ids
            )

        logits = outputs.logits[:, -1, :] / temperature
        logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
        probabilities = F.softmax(logits, dim=-1)
        next_token_id = torch.multinomial(probabilities, num_samples=1)

        generated = torch.cat((generated, next_token_id), dim=1)
        past_key_values = outputs.past_key_values

        # Decode full sequence and compute the diff
        decoded = tokenizer.decode(generated[0], skip_special_tokens=True)
        new_text = decoded[len(prev_decoded):]
        prev_decoded = decoded

        yield new_text

        if next_token_id.squeeze().item() == tokenizer.eos_token_id:
            break


In [None]:
prompt_template = """
# Instruction:
Assume you are an excellent doctor. Using your knowledge, answer the question given below.

# Question: {question}

# Answer: """
prompt_template = prompt_template.strip()
print(prompt_template)

In [12]:
examples = [
    "What is Glaucoma ?",
    "What are the symptoms of Glaucoma ??",
    "My sister is on Xanax, feyntnol patch and a pain medicine for cancer.  She has been on 25 of fentynol and within 6 days she has been bumped up to 100 now she is almost lethargic and breathing is really labored and right arm is twitching.. She was carrying on conversation Sunday and Monday patch was put on Tuesday and now cant even sit up..no one seems worried but me.. Just wondering what I could do",
    "I was playing basketball the other night and went up to block a shot and flipped over the guy and landed on my side/back. Since then the lower left side of back/side have been sore, hurts when I take deep breaths and when I lay on my back, any chance of a bruised kidney or any serious injury I could have?",
    "What are the treatments for High Blood Pressure ?",
    "What is (are) Urinary Tract Infections ?",
    "Create a C++ function that computes the Fast Fourier Transform (FFT) of a signal",
    "When did Beyonce start becoming popular?",
    "What are the symptoms of diabetes?"
]

In [13]:
model_id

'microsoft/Phi-3.5-mini-instruct'

In [14]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=False,
    cache_dir="../model_cache"
)

Fetching 2 files: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:24<00:00, 12.14s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.94it/s]


In [18]:
prompt_template.format(question=examples[4])

'# Instruction:\nAssume you are an excellent doctor. Using your knowledge, answer the question given below.\n\n# Question: What are the treatments for High Blood Pressure ?\n\n# Answer:'

In [22]:
for token in stream_generate(model, tokenizer, prompt_template.format(question=examples[4])):
    print(token, end='', flush=True)

 The treatment of hypertension depends on the underlying cause. Treatment may include lifestyle changes, medications, or surgery.

In [22]:
for token in stream_generate(model, tokenizer, prompt_template.format(question=examples[-3]), temperature=0.2, top_k=50, top_p=0.9):
    print(token, end='', flush=True)

 <reponame>ChatDoctor/ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-ChatDoctor-

In [79]:
wandb_table = wandb.Table(columns=["Question", "Generated answer"])

for example in examples:
    generated_answer = ""
    for token in stream_generate(model, tokenizer, prompt_template.format(question=example)):
        generated_answer += token
    wandb_table.add_data(example, generated_answer)
    print(example, "\n", generated_answer)

What is Glaucoma ? 
  <reponame>ChatGPT/OpenAI/ChatGPT-India-En-IN-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1-1
What are the symptoms of Glaucoma ?? 
  The following are symptoms of glaucoma:

- Loss of peripheral vision
- Blurred vision
- Halos around lights
- Eye pain
- Nausea and vomiting
- Redness of the eye

If you have any of these symptoms, see your eye doctor right away.
My sister is on Xanax, feyntnol patch and a pain medicine for cancer.  She has been on 25 of fentynol and within 6 days she has been bumped up to 100 now she is almost lethargic and breathing is really labored and right arm is twitching.. She was carrying on conversation Sunday and Monday patch was put on Tuesday and now cant even sit up..no one seems worried but me.. Just wondering w

In [80]:
wandb.log({"generated_examples": wandb_table})

In [None]:
test_df = pd.read_csv(test_data_path)

In [69]:
train_df = pd.read_csv("../data/pubmed_baseline/pubmed_train.csv")

In [70]:
train_df.head()

Unnamed: 0,pmid,title,abstract
0,2712356,Membrane relationships in murine Meissner corp...,Mechanoreceptive sensory corpuscles (murine Me...
1,8979397,Two-dimensional protein patterns of Arabidopsi...,In order to detect gene products involved in A...
2,1462207,Pathoanatomy of lumbar disc herniation as demo...,Computed tomography/discography was performed ...
3,1807731,An innovative method of teaching Advanced Card...,A demonstration and discussion of the effectiv...
4,3207648,The distribution of CA 125 in the reproductive...,Investigation of serum and tissue homogenates ...


In [72]:
def f(x):
    if not isinstance(x, str):
        print(x)
        return
    if "" in x: 
        print(x)
train_df["abstract"].apply(f)

0         None
1         None
2         None
3         None
4         None
          ... 
999995    None
999996    None
999997    None
999998    None
999999    None
Name: abstract, Length: 1000000, dtype: object

In [42]:
modules = set()
for name, module in model.named_modules():
    if "attn" in name or "proj" in name:
        modules.add(name.split(".")[-1])
modules

{'base_layer',
 'default',
 'down_proj',
 'gate_up_proj',
 'lora_A',
 'lora_B',
 'lora_dropout',
 'lora_embedding_A',
 'lora_embedding_B',
 'lora_magnitude_vector',
 'o_proj',
 'qkv_proj',
 'resid_attn_dropout',
 'self_attn'}