# Fine-Tuning Prompt Ablations on Base Qwen

In [1]:
# add .. path 
import os
import sys
sys.path.append('../../')
import utils.llm_training as llm_training
import utils.llm_configs as llm_configs
import wandb
import logging
import re
from tqdm import tqdm
import numpy as np
from datasets import Dataset
import pandas as pd
import argparse
from sklearn.metrics import roc_auc_score

from importlib import reload 
reload(llm_training)
reload(llm_configs)

<module 'utils.llm_configs' from '/home/josephL/fine-tuning-or-retrieval/scripts/MEDEX/../../utils/llm_configs.py'>

In [2]:
%env WANDB_QUIET=false

env: WANDB_QUIET=false


In [None]:
# --- Basic Configuration ---
dataset ="AMES"
metric="auroc"
model_name= "jiosephlee/therapeutic_fine_tuning_1M_v2"
run_name = f"{dataset}_fine_tuning/{model_name}"
# model_name="Qwen/Qwen2.5-0.5B"
_METHOD = 'text'
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - [%(name)s] - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__name__)

run = wandb.init(
    project="medex_fine_tuning",
    name=run_name,
    tags=["Ablation", "AMES"],
    group="CPT prompt ablation",
)

# --- Load Data and Preprocess---
train_df = pd.read_csv(f'./../../data/TDC/{dataset}/train_df.csv')
val_df = pd.read_csv(f'./../../data/TDC/{dataset}/val_df.csv')
test_df = pd.read_csv(f'./../../data/TDC/{dataset}/test_df.csv')

dataset ="AMES"

def row_to_text( row, split='train', dataset='AMES_1'):
    if dataset == 'AMES':
        text = f"SMILES: {row['Drug']}\nQuestion: Is the drug represented by this SMILES string mutagenic?\nAnswer:"
        if split == 'train':
            text += f"{' Yes, the drug is mutagenic.' if row['Y']==1 else ' No, the drug is not mutagenic.'}"
    if dataset == 'AMES_1':
        text = f"SMILES: {row['Drug']}\nQuestion: Is the drug represented by this SMILES string mutagenic?\nAnswer:"
        if split == 'train':
            text += f"{' Yes, the drug is mutagenic.' if row['Y']==1 else ' No, the drug is not mutagenic.'}"
    if dataset == 'AMES_2':
        text = f"SMILES: {row['Drug']}\nQuestion: Is the drug represented by this SMILES string mutagenic?\nAnswer:"
        if split == 'train':
            text += f"{' Yes.' if row['Y']==1 else ' No.'}"
    if dataset == 'AMES_3':
        text = f"Question: Is the drug represented by this SMILES string, {row['Drug']}, mutagenic?\nAnswer:"
        if split == 'train':
            text += f"{' Yes, the drug is mutagenic.' if row['Y']==1 else ' No, the drug is not mutagenic.'}"
    elif dataset == 'Skin Reaction':
        text = f"Q: This is the SMILES string of the drug: {row['Drug']}. Can this drug cause skin reaction?\nA: "

    return text

def row_to_prompt( row, dataset='AMES'):
    if dataset == 'AMES':
        prompt = f"SMILES: {row['Drug']}\nQuestion: Is the drug represented by this SMILES string mutagenic?\nAnswer:"
    elif dataset == 'Skin Reaction':
        prompt = f"Q: This is the SMILES string of the drug: {row['Drug']}. Can this drug cause skin reaction?\nA: "

    return prompt

def row_to_completion( row, dataset='AMES'):
    if dataset == 'AMES':
        completion = " Yes, the drug is mutagenic."
    elif dataset == 'Skin Reaction':
        completion = f"Question: This is the SMILES string of the drug: {row['Drug']}. Can this drug cause skin reaction?\nA: "
    return completion

def transform_df(train_df, val_df, test_df, dataset, method='text'):
    if method == 'text': 
        train_df["text"] = train_df.apply(row_to_text, axis=1, split = 'train', dataset = dataset)
    elif method=='completion':
        train_df["prompt"] = train_df.apply(row_to_prompt, axis=1, dataset = dataset)
        train_df["completion"] = train_df.apply(row_to_completion, axis=1, dataset = dataset)
    val_df["text"] = val_df.apply(row_to_text, axis=1, split = 'val', dataset = dataset)
    test_df["text"] = test_df.apply(row_to_text, axis=1, split = 'test', dataset = dataset)

transform_df(train_df, val_df, test_df, dataset, method=_METHOD)

training_ds = Dataset.from_pandas(train_df, preserve_index=False)
training_ds = training_ds.select_columns(
                    {"text", "Y", "prompt", "completion"}.intersection(training_ds.column_names)
                )
val_ds = Dataset.from_pandas(val_df, preserve_index=False)
val_ds = val_ds.select_columns(
                    {"text", "Y", "prompt", "completion"}.intersection(val_ds.column_names)
                )
test_ds = Dataset.from_pandas(test_df, preserve_index=False)
test_ds = test_ds.select_columns(
                    {"text", "Y", "prompt", "completion"}.intersection(test_ds.column_names)
                )

log.info(f"Training dataset example: {training_ds[0]}")
log.info(f"Validation dataset example: {val_ds[0]}")
log.info(f"Test dataset example: {test_ds[0]}")

# --- Load Model ---
model_config = llm_configs.ModelConfig(
    id=model_name,
    peft=llm_configs.PeftConfig(
        enabled=False,
        add_eot_token=False,  # No longer doing EOT token for LIMA
    ),
    quantization=llm_configs.QuantizationConfig(mode=None), # Use QLoRA
)

log.info("--- Model Configuration ---")
log.info(model_config.model_dump_json(indent=2))

log.info("\n--- Loading Model for Training ---\n")
model, tokenizer = llm_training.load_model_for_training(model_config, log)

lima_training_config = llm_configs.TrainingConfig(
    run_name = run_name,
    num_train_epochs = 10,
    learning_rate  = 8e-5,
    logging_strategy = "steps", 
    logging_steps = 1,
    gradient_checkpointing=False,
    context_length = 4096,
    use_liger_kernel=True,
    per_device_train_batch_size = 128,
    gradient_accumulation_steps=1,
    warmup_steps  = 0, # If 0, it does not override warmup ratio
    warmup_ratio = 0.1, # Use our default warmup ratio instead
    packing = False,
    padding_free = True,
    completion_only_loss=True,
    sequential_sampling = False,
    reverse_ffd_packing= False,
    remove_unused_columns=False,
)

lima_training_config.push_to_wandb(run)



[34m[1mwandb[0m: Currently logged in as: [33mjiosephlee[0m ([33mupenn-ml[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


2025-07-14 19:34:56 - INFO - [__main__] - Training dataset example: {'text': 'SMILES: Nc1cccc([N+](=O)[O-])c1CO\nQuestion: Is the drug represented by this SMILES string mutagenic?\nAnswer: No, the drug is not mutagenic.', 'Y': 0}
2025-07-14 19:34:56 - INFO - [__main__] - Validation dataset example: {'text': 'SMILES: O=[N+]([O-])c1ccc(-c2nc3n(c2[N+](=O)[O-])CCS3)cc1\nQuestion: Is the drug represented by this SMILES string mutagenic?\nAnswer:', 'Y': 1}
2025-07-14 19:34:56 - INFO - [__main__] - Test dataset example: {'text': 'SMILES: CC(=O)Nc1ccc2c(=O)c(=O)c3cccc4ccc1c2c43\nQuestion: Is the drug represented by this SMILES string mutagenic?\nAnswer:', 'Y': 1}
2025-07-14 19:34:56 - INFO - [__main__] - --- Model Configuration ---
2025-07-14 19:34:56 - INFO - [__main__] - {
  "id": "jiosephlee/therapeutic_fine_tuning_1M_v2",
  "torch_dtype": "auto",
  "attn_implementation": "flash_attention_2",
  "peft": {
    "enabled": false,
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,

In [4]:

log.info(f"\n--- Starting {dataset} Fine-Tuning ---")
trainer = llm_training.sft_train_on_dataset(
    model=model,
    tokenizer=tokenizer,
    log=log,
    train_dataset=training_ds,
    train_cfg=lima_training_config,
    train=True,
    use_liger_loss = True
)

log.info("\n\n--- Fine-Tuning Complete ---\n\n")
log.info(f"Training arguments: {trainer.args}")

2025-07-14 19:35:10 - INFO - [__main__] - 
--- Starting AMES Fine-Tuning ---
2025-07-14 19:35:10 - INFO - [__main__] - Starting SFT training run...


Adding EOS to train dataset:   0%|          | 0/5094 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/5094 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/5094 [00:00<?, ? examples/s]

2025-07-14 19:35:11 - INFO - [liger_kernel.transformers.monkey_patch] - Applying Liger kernels to model instance with model type: qwen2 with kwargs: {}


Applied Liger kernels to Qwen2


Step,Training Loss
1,2.5932
2,2.581
3,2.4622
4,2.0652
5,1.8544
6,1.5229
7,1.276
8,1.0308
9,0.7831
10,0.6631


2025-07-14 19:47:51 - INFO - [__main__] - SFT training complete.
[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
train/epoch,▁▁▁▁▁▁▁▁▂▂▂▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇█
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇█████
train/grad_norm,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,▂▃▃▄▄▆▇████████▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▃▃▃▃▃▁▁▁▁▁
train/loss,█▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/num_tokens,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇█

0,1
total_flos,5538673706142720.0
train/epoch,10.0
train/global_step,400.0
train/grad_norm,1.50781
train/learning_rate,0.0
train/loss,0.1797
train/num_tokens,2579260.0
train_loss,0.32779
train_runtime,760.5267
train_samples_per_second,66.98


2025-07-14 19:47:52 - INFO - [__main__] - 

--- Fine-Tuning Complete ---


2025-07-14 19:47:52 - INFO - [__main__] - Training arguments: SFTConfig(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
activation_offloading=False,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
assistant_only_loss=False,
auto_find_batch_size=False,
average_tokens_across_devices=False,
batch_eval_metrics=False,
bf16=True,
bf16_full_eval=False,
chat_template_path=None,
completion_only_loss=True,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
dataset_kwargs=None,
dataset_num_proc=None,
dataset_text_field=text,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_p

## Check if padding_free without packing is valid

I think when it was packed we had like wayyy fewer steps. like 20? when it was 10 epochs. Let's double check this. In a single batch, how many questions fit? when it's not packed it's simple. it's defined by the batch size. what about when it's packed?

Are they the same shape? just longer? If so, why do we have attention masks in only one situations?

In [None]:
lima_training_config = llm_configs.TrainingConfig(
    run_name = f"{dataset} fine-tuning with {model_name}",
    num_train_epochs = 10,
    learning_rate  = 8e-5,
    logging_strategy = "steps", 
    logging_steps = 1,
    gradient_checkpointing=False,
    context_length = 4096,
    use_liger_kernel=True,
    per_device_train_batch_size = 128,
    gradient_accumulation_steps=1,
    warmup_steps  = 0, # If 0, it does not override warmup ratio
    warmup_ratio = 0.1, # Use our default warmup ratio instead
    packing = False,
    padding_free = True,
    sequential_sampling = False,
    reverse_ffd_packing= False,
    remove_unused_columns=False,
)



trainer = llm_training.sft_train_on_dataset(
    model=model,
    tokenizer=tokenizer,
    log=log,
    train_dataset=training_ds,
    train_cfg=lima_training_config,
    train=False,
    use_liger_loss = True
)

for i, batch in enumerate(trainer.get_train_dataloader()):
    count = 0
    if i == 0:
        print(batch)
        for j in batch['position_ids'][0]:
            if j == 0:
                count+=1
        print(count)
        print(tokenizer.decode(batch['input_ids'][0].cpu().numpy()))

{'input_ids': tensor([[  9501,  45978,     25,  ...,    292,     13, 151643]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'), 'position_ids': tensor([[ 0,  1,  2,  ..., 50, 51, 52]], device='cuda:0'), 'labels': tensor([[  9501,  45978,     25,  ...,    292,     13, 151643]],
       device='cuda:0')}
128
SMILES: C#C[C@]1(OC(C)=O)CC[C@H]2[C@@H]3CCC4=CC(=O)CC[C@@H]4[C@H]3CC[C@@]21C
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug is not mutagenic.<|endoftext|>SMILES: Cc1ccc(C=O)cc1
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug is not mutagenic.<|endoftext|>SMILES: CCC(C)(C)C
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug is not mutagenic.<|endoftext|>SMILES: CCOC(=O)c1[nH]c2ccccc2c1/N=C/c1ccc(O)cc1
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>SMILES: B

## Eval on Training and Validation Set

In [5]:

# dataset='AMES_2'
# def row_to_text( row, split='train', dataset='AMES_1'):
#     if dataset == 'AMES_1':
#         text = f"SMILES: {row['Drug']}\nQuestion: Is the drug represented by this SMILES string mutagenic?\nAnswer:"
#         if split == 'train':
#             text += f"{' Yes, the drug is mutagenic.' if row['Y']==1 else ' No, the drug is not mutagenic.'}"
#     if dataset == 'AMES_2':
#         text = f"SMILES: {row['Drug']}\nQuestion: Is the drug represented by this SMILES string mutagenic?\nAnswer:"
#         if split == 'train':
#             text += f"{' Yes.' if row['Y']==1 else ' No.'}"
#     if dataset == 'AMES_3':
#         text = f"Question: Is the drug represented by this SMILES string, {row['Drug']}, mutagenic?\nAnswer:"
#         if split == 'train':
#             text += f"{' Yes.' if row['Y']==1 else ' No.'}"
#     elif dataset == 'Skin Reaction':
#         text = f"Q: This is the SMILES string of the drug: {row['Drug']}. Can this drug cause skin reaction?\nA: "

#     return text

train_df["text"] = train_df.apply(row_to_text, axis=1, split = 'test', dataset = dataset)
training_ds = Dataset.from_pandas(train_df, preserve_index=False)
training_ds = training_ds.select_columns(
                    {"text", "Y", "prompt", "completion"}.intersection(training_ds.column_names)
                )

# val_df["text"] = val_df.apply(row_to_text, axis=1, split = 'val', dataset = dataset)
# val_ds = Dataset.from_pandas(val_df, preserve_index=False)
# val_ds = val_ds.select_columns(
#                     {"text", "Y", "prompt", "completion"}.intersection(val_ds.column_names)
#                 )

In [None]:
# --- Evaluate ---
inference_cfg = llm_configs.InferenceConfig(
    temperature=0,
    do_sample=False,
    repetition_penalty=1.0,
    max_new_tokens=64,
)

targets, preds = [], []

for i in tqdm(range(len(training_ds)), desc="Inference on test set"):
    row = training_ds[i]
    prompt = row["text"]
    gt_answer = "yes" if row["Y"] == 1 else "no"
    
    gen_text = llm_training.generate_text(model, tokenizer, prompt, inference_cfg)
    
    # Extract generated text (remove the prompt part)
    generated_response = gen_text[len(prompt):].strip().lower()

    if i < 10:
        print(f"Prompt: {prompt}\n")
        print(f"Generated response: {gen_text}")
        print(f"GT answer: {gt_answer}")
        print("-"*100)
        # print(llm_training.analyze_text_generation(model, tokenizer, prompt, 'cuda', 4))
    # Simple matching - check if "yes" or "no" appears in the response
    if "yes" in generated_response:
        pred_answer = 1
    elif "no" in generated_response:
        pred_answer = 0
    else:
        probs = llm_training.extract_logits_first_step(model, tokenizer, prompt, ["Yes","No"])
        pred_answer = int(probs["Yes"] > probs["No"]) 
        # If neither yes nor no is found, skip this example
        # continue

    
    targets.append(gt_answer)
    preds.append(pred_answer)

# ------------------
# Compute Accuracy
# ------------------
targets = np.array(targets)
preds = np.array(preds)

if metric == "accuracy":
    accuracy = np.mean(targets == preds)
    print(f"\nAccuracy on {len(targets)} examples: {accuracy:.4f}")
elif metric == "auroc":
    auroc = roc_auc_score(targets, preds)
    print(f"\nAUROC on {len(targets)} examples: {auroc:.4f}")

# Save model before we LIMA tune
#model.push_to_hub('jiosephlee/therapeutic_fine_tuning_36M')
#tokenizer.push_to_hub('jiosephlee/therapeutic_fine_tuning_36M')

Inference on test set:   0%|          | 1/5094 [00:00<25:00,  3.39it/s]

Prompt: SMILES: Nc1cccc([N+](=O)[O-])c1CO
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: Nc1cccc([N+](=O)[O-])c1CO
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug is not mutagenic.<|endoftext|>
GT answer: no
----------------------------------------------------------------------------------------------------


Inference on test set:   0%|          | 2/5094 [00:00<22:51,  3.71it/s]

Prompt: SMILES: O=[N+]([O-])c1cccc(O)c1[N+](=O)[O-]
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: O=[N+]([O-])c1cccc(O)c1[N+](=O)[O-]
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------


Inference on test set:   0%|          | 3/5094 [00:00<20:55,  4.05it/s]

Prompt: SMILES: Cc1cc(Cl)ccc1O
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: Cc1cc(Cl)ccc1O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug is not mutagenic.<|endoftext|>
GT answer: no
----------------------------------------------------------------------------------------------------


Inference on test set:   0%|          | 4/5094 [00:01<21:05,  4.02it/s]

Prompt: SMILES: CNC(=O)Oc1ccccc1OC(C)C
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: CNC(=O)Oc1ccccc1OC(C)C
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------


Inference on test set:   0%|          | 5/5094 [00:01<20:19,  4.17it/s]

Prompt: SMILES: Cc1cc([N+](=O)[O-])cc([N+](=O)[O-])c1C
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: Cc1cc([N+](=O)[O-])cc([N+](=O)[O-])c1C
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug is not mutagenic.<|endoftext|>
GT answer: no
----------------------------------------------------------------------------------------------------


Inference on test set:   0%|          | 6/5094 [00:01<20:44,  4.09it/s]

Prompt: SMILES: Nc1ccc(Cl)cc1Cl
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: Nc1ccc(Cl)cc1Cl
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------


Inference on test set:   0%|          | 7/5094 [00:01<20:09,  4.21it/s]

Prompt: SMILES: Cc1cccc([N+](=O)[O-])c1C
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: Cc1cccc([N+](=O)[O-])c1C
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug is not mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------


Inference on test set:   0%|          | 8/5094 [00:01<20:29,  4.14it/s]

Prompt: SMILES: O=[N+]([O-])c1ccccc1SSC(F)=C(Cl)Cl
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: O=[N+]([O-])c1ccccc1SSC(F)=C(Cl)Cl
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------


Inference on test set:   0%|          | 9/5094 [00:02<19:25,  4.36it/s]

Prompt: SMILES: CN(Cc1ccc(F)cc1)N=O
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: CN(Cc1ccc(F)cc1)N=O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------


Inference on test set:   0%|          | 10/5094 [00:02<20:31,  4.13it/s]

Prompt: SMILES: Clc1cccc(Cl)c1Cl
Question: Is the drug represented by this SMILES string mutagenic?
Answer:

Generated response: SMILES: Clc1cccc(Cl)c1Cl
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug is not mutagenic.<|endoftext|>
GT answer: no
----------------------------------------------------------------------------------------------------


Inference on test set:  35%|███▍      | 1772/5094 [06:56<13:21,  4.14it/s]

In [None]:
# --- Evaluate ---
inference_cfg = llm_configs.InferenceConfig(
    temperature=0,
    do_sample=False,
    repetition_penalty=1.0,
    max_new_tokens=64,
)

targets, preds = [], []

for i in tqdm(range(len(val_ds)), desc="Inference on validation set"):
    row = val_ds[i]
    prompt = row["text"]
    gt_answer = "yes" if row["Y"] == 1 else "no"
    
    gen_text = llm_training.generate_text(model, tokenizer, prompt, inference_cfg)
    
    # Extract generated text (remove the prompt part)
    generated_response = gen_text[len(prompt):].strip().lower()

    if i < 10:
        # print(f"Prompt: {prompt}")
        print(f"Generated response: {gen_text}")
        print(f"GT answer: {gt_answer}")
        print("-"*100)
    if i == 10:
        print(llm_training.analyze_text_generation(model, tokenizer, prompt, 'cuda', 4))
    # Simple matching - check if "yes" or "no" appears in the response
    if "yes" in generated_response:
        pred_answer = 1
    elif "no" in generated_response:
        pred_answer = 0
    else:
        probs = llm_training.extract_logits_first_step(model, tokenizer, prompt, ["Yes","No"])
        pred_answer = int(probs["Yes"] > probs["No"]) 
        # If neither yes nor no is found, skip this example
        # continue

    
    targets.append(gt_answer)
    preds.append(pred_answer)

# ------------------
# Compute Accuracy
# ------------------
targets = np.array(targets)
preds = np.array(preds)

if metric == "accuracy":
    accuracy = np.mean(targets == preds)
    print(f"\nAccuracy on {len(targets)} examples: {accuracy:.4f}")
elif metric == "auroc":
    auroc = roc_auc_score(targets, preds)
    print(f"\nAUROC on {len(targets)} examples: {auroc:.4f}")

# Save model before we LIMA tune
#model.push_to_hub('jiosephlee/therapeutic_fine_tuning_36M')
#tokenizer.push_to_hub('jiosephlee/therapeutic_fine_tuning_36M')

Inference on validation set:   0%|          | 0/727 [00:00<?, ?it/s]

Generated response: SMILES: O=[N+]([O-])c1ccc(-c2nc3n(c2[N+](=O)[O-])CCS3)cc1
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------


Inference on validation set:   0%|          | 1/727 [00:00<06:36,  1.83it/s]

Output: SMILES: O=[N+]([O-])c1ccc(-c2nc3n(c2[N+](=O)[O-])CCS3)cc1
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug

➡️ Generated Token #1: "Yes" (Probability: 90.37%)
   Top 5 candidates for this position:
      1. "Yes" (90.37%)
      2. "No" (9.53%)
      3. "[" (0.06%)
      4. "Yes" (0.01%)
      5. "S" (0.01%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "<|endoftext|>" (0.00%)
      3. "." (0.00%)
      4. ".," (0.00%)
      5. "ess" (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "the" (0.00%)
      3. "The" (0.00%)
      4. "is" (0.00%)
      5. "a" (0.00%)

➡️ Generated Token #4: "drug" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "drug" (100.00%)
      2. "Drug" (0.00%)
      3. "Drug" (0.00%)
      4. "medication" (0.00%)
      5. "drug" (0.00%)


Inference on validation set:   0%|          | 2/727 [00:00<05:43,  2.11it/s]

Generated response: SMILES: O=[N+]([O-])c1c(-c2ccc(Cl)cc2)nc2n1CCS2
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------
Output: SMILES: O=[N+]([O-])c1c(-c2ccc(Cl)cc2)nc2n1CCS2
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug

➡️ Generated Token #1: "Yes" (Probability: 65.11%)
   Top 5 candidates for this position:
      1. "Yes" (65.11%)
      2. "No" (34.85%)
      3. "Cl" (0.02%)
      4. "[" (0.00%)
      5. "no" (0.00%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "<|endoftext|>" (0.00%)
      3. "." (0.00%)
      4. ".," (0.00%)
      5. "ude" (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "the" (0.00%)


Inference on validation set:   0%|          | 3/727 [00:01<04:42,  2.57it/s]

Generated response: SMILES: Cc1ccc(-c2nc3n(c2[N+](=O)[O-])CCS3)cc1
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------
Output: SMILES: Cc1ccc(-c2nc3n(c2[N+](=O)[O-])CCS3)cc1
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug

➡️ Generated Token #1: "Yes" (Probability: 95.75%)
   Top 5 candidates for this position:
      1. "Yes" (95.75%)
      2. "No" (4.21%)
      3. "Yes" (0.02%)
      4. "[" (0.00%)
      5. "S" (0.00%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "<|endoftext|>" (0.00%)
      3. "." (0.00%)
      4. ",'" (0.00%)
      5. ","" (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "the" (0.00%)
    

Inference on validation set:   1%|          | 4/727 [00:01<05:14,  2.30it/s]

Output: SMILES: COC(=O)[C@]12O[C@@]1(C)[C@@](C)(O)NC2=O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug

➡️ Generated Token #1: "Yes" (Probability: 98.19%)
   Top 5 candidates for this position:
      1. "Yes" (98.19%)
      2. "No" (1.80%)
      3. "Yes" (0.01%)
      4. "NC" (0.00%)
      5. "The" (0.00%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "<|endoftext|>" (0.00%)
      3. "." (0.00%)
      4. ","" (0.00%)
      5. ",N" (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "the" (0.00%)
      3. "The" (0.00%)
      4. "is" (0.00%)
      5. "_the" (0.00%)

➡️ Generated Token #4: "drug" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "drug" (100.00%)
      2. "Drug" (0.00%)
      3. "Drug" (0.00%)
      4. "drug" (0.00%)
      5. "medication" (0.00%)


Inference on validation set:   1%|          | 5/727 [00:02<05:20,  2.26it/s]

Generated response: SMILES: CC[C@@]12O[C@]1(C(=O)OC)C(=O)N[C@]2(C)O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug is not mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------
Output: SMILES: CC[C@@]12O[C@]1(C(=O)OC)C(=O)N[C@]2(C)O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: No, the drug

➡️ Generated Token #1: "No" (Probability: 77.47%)
   Top 5 candidates for this position:
      1. "No" (77.47%)
      2. "Yes" (22.20%)
      3. "N" (0.28%)
      4. "The" (0.02%)
      5. "In" (0.01%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "<|endoftext|>" (0.00%)
      3. "." (0.00%)
      4. ",M" (0.00%)
      5. "," (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "the" (0.00%)


Inference on validation set:   1%|          | 6/727 [00:02<05:47,  2.07it/s]

Output: SMILES: COC(=O)C12OC1(C)C(O)(C(C)C)NC2=O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug

➡️ Generated Token #1: "Yes" (Probability: 97.70%)
   Top 5 candidates for this position:
      1. "Yes" (97.70%)
      2. "No" (2.30%)
      3. "Yes" (0.00%)
      4. "OC" (0.00%)
      5. "N" (0.00%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "<|endoftext|>" (0.00%)
      3. "." (0.00%)
      4. ",'" (0.00%)
      5. ","" (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "the" (0.00%)
      3. "The" (0.00%)
      4. "_the" (0.00%)
      5. "there" (0.00%)

➡️ Generated Token #4: "drug" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "drug" (100.00%)
      2. "Drug" (0.00%)
      3. "Drug" (0.00%)
      4. "drug" (0.00%)
      5. "medication" (0.00%)


Inference on validation set:   1%|          | 7/727 [00:03<05:37,  2.13it/s]

Generated response: SMILES: CCC12OC1(C(=O)OC)C(=O)NC2(C)O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------
Output: SMILES: CCC12OC1(C(=O)OC)C(=O)NC2(C)O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug

➡️ Generated Token #1: "Yes" (Probability: 75.49%)
   Top 5 candidates for this position:
      1. "Yes" (75.49%)
      2. "No" (24.51%)
      3. "Yes" (0.00%)
      4. "YES" (0.00%)
      5. "NO" (0.00%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "." (0.00%)
      3. "<|endoftext|>" (0.00%)
      4. ","" (0.00%)
      5. ",'" (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "the" (0.00%)
      3. "The" (0.

Inference on validation set:   1%|          | 8/727 [00:03<06:07,  1.96it/s]

➡️ Generated Token #1: "No" (Probability: 59.25%)
   Top 5 candidates for this position:
      1. "No" (59.25%)
      2. "Yes" (40.72%)
      3. "C" (0.01%)
      4. "Based" (0.01%)
      5. "no" (0.00%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "<|endoftext|>" (0.00%)
      3. "." (0.00%)
      4. ",O" (0.00%)
      5. "ude" (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "The" (0.00%)
      3. "the" (0.00%)
      4. "_the" (0.00%)
      5. "-the" (0.00%)

➡️ Generated Token #4: "drug" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "drug" (100.00%)
      2. "Drug" (0.00%)
      3. "Drug" (0.00%)
      4. "medication" (0.00%)
      5. "drugs" (0.00%)
Generated response: SMILES: COC(=O)C12OC1(C)C(C)(O)NC2=O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mu

Inference on validation set:   1%|          | 9/727 [00:04<05:58,  2.01it/s]

Output: SMILES: COC(=O)C12OC1(C)C(C)(O)NC2=O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug

➡️ Generated Token #1: "Yes" (Probability: 96.69%)
   Top 5 candidates for this position:
      1. "Yes" (96.69%)
      2. "No" (3.31%)
      3. "Yes" (0.00%)
      4. "OC" (0.00%)
      5. "NO" (0.00%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "<|endoftext|>" (0.00%)
      3. "." (0.00%)
      4. ",N" (0.00%)
      5. ","" (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "the" (0.00%)
      3. "The" (0.00%)
      4. "_the" (0.00%)
      5. "the" (0.00%)

➡️ Generated Token #4: "drug" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "drug" (100.00%)
      2. "Drug" (0.00%)
      3. "drug" (0.00%)
      4. "Drug" (0.00%)
      5. "medication" (0.00%)


Inference on validation set:   1%|▏         | 10/727 [00:04<06:03,  1.97it/s]

Generated response: SMILES: COC(=O)[C@]12O[C@@]1(C)[C@](O)(C(C)C)NC2=O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug is mutagenic.<|endoftext|>
GT answer: yes
----------------------------------------------------------------------------------------------------
Output: SMILES: COC(=O)[C@]12O[C@@]1(C)[C@](O)(C(C)C)NC2=O
Question: Is the drug represented by this SMILES string mutagenic?
Answer: Yes, the drug

➡️ Generated Token #1: "Yes" (Probability: 96.22%)
   Top 5 candidates for this position:
      1. "Yes" (96.22%)
      2. "No" (3.73%)
      3. "C" (0.02%)
      4. "The" (0.01%)
      5. "N" (0.01%)

➡️ Generated Token #2: "," (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "," (100.00%)
      2. "<|endoftext|>" (0.00%)
      3. "." (0.00%)
      4. ","" (0.00%)
      5. ",'" (0.00%)

➡️ Generated Token #3: "the" (Probability: 100.00%)
   Top 5 candidates for this position:
      1. "the" (100.00%)
      2. "the" (0.0

Inference on validation set: 100%|██████████| 727/727 [03:52<00:00,  3.13it/s]


AUROC on 727 examples: 0.7757



