In [1]:
# add .. path 
import os
import sys
sys.path.append('..')
import utils.llm_training as llm_training
import utils.llm_configs as llm_configs

import logging

# --- Basic Configuration ---
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - [%(name)s] - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger(__name__)

os.environ["WANDB_PROJECT"]="medex_fine_tuning"


In [2]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("medexanon/Medex")['train']

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/21 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/21 [00:00<?, ?it/s]

In [3]:
ds_subset = ds.select(range(1000000))


In [4]:
# === Cell 1: Configuration ===
model_config = llm_configs.ModelConfig(
    id="Qwen/Qwen2.5-0.5B",
    peft=llm_configs.PeftConfig(
        enabled=False,
        add_eot_token=False,  # No longer doing EOT token for LIMA
    ),
    quantization=llm_configs.QuantizationConfig(mode=None), # Use QLoRA
)

log.info("--- Configuration ---")
print(model_config.model_dump_json(indent=2))

log.info("\n--- Loading Model for Training ---")
model, tokenizer = llm_training.load_model_for_training(model_config, log)

2025-07-08 16:44:20 - INFO - [__main__] - --- Configuration ---
2025-07-08 16:44:20 - INFO - [__main__] - 
--- Loading Model for Training ---
2025-07-08 16:44:20 - INFO - [__main__] - Loading model 'Qwen/Qwen2.5-0.5B' for training...


{
  "id": "Qwen/Qwen2.5-0.5B",
  "torch_dtype": "auto",
  "attn_implementation": "flash_attention_2",
  "peft": {
    "enabled": false,
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "target_modules": [
      "q_proj",
      "k_proj",
      "v_proj",
      "o_proj",
      "gate_proj",
      "up_proj",
      "down_proj"
    ],
    "add_eot_token": false
  },
  "quantization": {
    "mode": null
  }
}


2025-07-08 16:44:21 - INFO - [__main__] - Model and tokenizer loaded successfully.


In [5]:
def concat_columns(example, tokenizer):
    """
    Combine DOI/entity/fact/MolInfo/GeneInfo into one human-readable string.
    Empty or missing fields are omitted for that row.
    """

    chunks = []

    # 1) flat string columns
    if example.get("DOI"):
        chunks.append(f"[DOI] {example['DOI']}")
    if example.get("entity"):
        chunks.append(f"[entity] {example['entity']}")
    if example.get("fact"):
        chunks.append(f"[fact] {example['fact']}")

    # 2) MolInfo → [SMILES] …
    mol = example.get("MolInfo")
    if isinstance(mol, dict):
        smiles = mol.get("SMILES")
        if smiles:
            chunks.append(f'[SMILES] "{smiles}"')

    # 3) GeneInfo → [GeneInfo] key: value, …
    gene = example.get("GeneInfo")
    if isinstance(gene, dict) and gene:
        def _fmt(key, val):
            return f'"{key}": {val}' if isinstance(val, int) else f'"{key}": "{val}"'
        fields = [_fmt(k, v) for k, v in gene.items() if v not in (None, "", [])]
        if fields:
            chunks.append(f"[GeneInfo] " + ", ".join(fields))

    # join all parts with a single space
    return {"text": " ".join(chunks) + tokenizer.eos_token}

# ---- apply to your Dataset ----
# creates a new 'text' column, keeps the originals (remove_columns=[] by default)
ds_with_text = ds_subset.map(concat_columns, fn_kwargs={"tokenizer": tokenizer},  desc="Building concatenated text")

Building concatenated text:   0%|          | 0/1000000 [00:00<?, ? examples/s]

In [6]:
medex_ds = ds_with_text.select_columns(["text"])
medex_ds

Dataset({
    features: ['text'],
    num_rows: 1000000
})

In [27]:
medex_ds[14524]

{'text': '[entity] glycine [fact] Glycine is the inhibitory neurotransmitter released by glycinergic inhibitory crossed caudal interneurons (CCINs) and glycinergic lateral interneurons (LINs) in the lamprey nervous system. [SMILES] "NCC(=O)O"<|endoftext|>'}

In [7]:
lima_training_config = llm_configs.TrainingConfig(
    run_name = "draft",
    num_train_epochs = 1,
    learning_rate  = 4e-5,
    logging_strategy = "steps", 
    logging_steps = 10,
    gradient_checkpointing=False,
    context_length = 1024,
    use_liger_kernel=True,
    per_device_train_batch_size = 32,
    gradient_accumulation_steps=1,
    # warmup_steps  = 0, # LIMA specifies no warmup, so we set this explicitly
    warmup_ratio = 0.3, # Use our default warmup ratio instead
    packing=True,
    padding_free = True,
    sequential_sampling = False,
    reverse_ffd_packing= False,
    remove_unused_columns=False,
)


# === Run LIMA Fine-Tuning ===
log.info("\n--- Starting LIMA Fine-Tuning ---")
# The model object will be updated with the fine-tuned weights
trainer = llm_training.sft_train_on_dataset(
    model=model,
    tokenizer=tokenizer,
    log=log,
    train_dataset=medex_ds,
    train_cfg=lima_training_config,
    train=False,
    use_liger_loss = True
)

2025-07-08 16:45:48 - INFO - [__main__] - 
--- Starting LIMA Fine-Tuning ---
2025-07-08 16:45:48 - INFO - [__main__] - Starting SFT training run...


False


Adding EOS to train dataset:   0%|          | 0/1000000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/1000000 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/1000000 [00:00<?, ? examples/s]

2025-07-08 16:50:18 - INFO - [liger_kernel.transformers.monkey_patch] - Applying Liger kernels to model instance with model type: qwen2 with kwargs: {}


Applied Liger kernels to Qwen2


In [8]:
print(len(trainer.get_train_dataloader()))

3678


In [9]:
import wandb
trainer.train()
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mjiosephlee[0m ([33mupenn-ml[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
10,2.5067
20,2.4903
30,2.4802
40,2.4271
50,2.3981
60,2.3184
70,2.2203
80,1.9929
90,1.7615
100,1.6242


0,1
train/epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▆▇▇▇▇▇▇▇▇█████
train/global_step,▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▇▇▇████
train/grad_norm,█▂▂▂▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,▂▂▄▅▅▆▆▇▇█████▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▂▂▂▁▁
train/loss,█▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/num_tokens,▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▅▅▅▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇████

0,1
total_flos,2.5475216315588582e+17
train/epoch,1.0
train/global_step,3678.0
train/grad_norm,1.72656
train/learning_rate,0.0
train/loss,0.8326
train/num_tokens,118633467.0
train_loss,0.98732
train_runtime,5100.063
train_samples_per_second,23.073


In [11]:
# git config --global user.email "jiosephlee@gmail.com"
# git config --global user.name "Joseph Lee"
inference_config = llm_configs.InferenceConfig(no_repeat_ngram_size=6, max_new_tokens=1024)
question = f"""Aluminum is"""
generated_text = llm_training.generate_text(model, tokenizer, question, inference_config)
print(generated_text)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Aluminum is a metal that can be used as a component in various applications, including in the manufacture of aluminum alloy products. In this context, aluminum alloy is classified as a metal with a melting point of 650°C and a hardness of 300 HV. The melting point of aluminum alloy is higher than that of pure aluminum, which is 648°C. Based on these properties, aluminum alloy is considered to have a lower melting point compared to pure aluminum. Is this classification correct? A. Correct B. Incorrect
Answer: B

In the context of the 'three reductions' policy for reducing energy consumption, which of the following is correct regarding the reduction of fuel consumption for heating and cooking?
A. Heating and cooking fuel consumption should be reduced by 20% compared to the previous year's level.
B. Heating and cooking fuel usage should be reduced by 10% compared to the last year's level.
C. Heating and cooking fuel use should be reduced by 30% compared to the current year's level.
D. Hea

## LIMA Instruct-Tuning

In [13]:
# Save model before we LIMA tune
model.push_to_hub('jiosephlee/therapeutic_fine_tuning')
tokenizer.push_to_hub('jiosephlee/therapeutic_fine_tuning')

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/jiosephlee/therapeutic_fine_tuning/commit/0d3b1ae4eadc89c908d41d457dc1bf680a927725', commit_message='Upload tokenizer', commit_description='', oid='0d3b1ae4eadc89c908d41d457dc1bf680a927725', pr_url=None, repo_url=RepoUrl('https://huggingface.co/jiosephlee/therapeutic_fine_tuning', endpoint='https://huggingface.co', repo_type='model', repo_id='jiosephlee/therapeutic_fine_tuning'), pr_revision=None, pr_num=None)

In [None]:
# Save model before we LIMA tune
model.push_to_hub('jiosephlee/therapeutic_fine_tuning')
tokenizer.push_to_hub('jiosephlee/therapeutic_fine_tuning')

# Load model

In [2]:
# === Cell 1: Configuration ===
model_config = llm_configs.ModelConfig(
    id="jiosephlee/therapeutic_fine_tuning",
    peft=llm_configs.PeftConfig(
        enabled=False,
        add_eot_token=False,  # No longer doing EOT token for LIMA
    ),
    quantization=llm_configs.QuantizationConfig(mode=None), # Use QLoRA
)

log.info("--- Configuration ---")
print(model_config.model_dump_json(indent=2))

log.info("\n--- Loading Model for Training ---")
model, tokenizer = llm_training.load_model_for_training(model_config, log)

2025-07-08 19:15:14 - INFO - [__main__] - --- Configuration ---
2025-07-08 19:15:14 - INFO - [__main__] - 
--- Loading Model for Training ---
2025-07-08 19:15:14 - INFO - [__main__] - Loading model 'jiosephlee/therapeutic_fine_tuning' for training...


{
  "id": "jiosephlee/therapeutic_fine_tuning",
  "torch_dtype": "auto",
  "attn_implementation": "flash_attention_2",
  "peft": {
    "enabled": false,
    "lora_r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "target_modules": [
      "q_proj",
      "k_proj",
      "v_proj",
      "o_proj",
      "gate_proj",
      "up_proj",
      "down_proj"
    ],
    "add_eot_token": false
  },
  "quantization": {
    "mode": null
  }
}


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/616 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

2025-07-08 19:15:40 - INFO - [__main__] - Model and tokenizer loaded successfully.


In [3]:
from tdc.multi_pred import DTI
data = DTI(name = 'DAVIS')
split = data.get_split()

Found local copy...
Loading...
Done!


In [4]:
data.convert_to_log(form = 'binding')

To log space...


In [6]:
split = data.get_split()
df = split['train']

In [7]:
split['test']

Unnamed: 0,Drug_ID,Drug,Target_ID,Target,Y
0,44150621,CC(O)C(=O)O.CN1CCN(c2ccc3c(c2)NC(=C2C(=O)N=c4c...,QSK,MPARIGYYEIDRTIGKGNFAVVKRATHLVTKAKVAIKIIDKTQLDE...,4.999996
1,10074640,Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nc(-...,IRAK1,MAGGPGPGEPAAPGAQHFLYEVPPWVMCRFYKVMDALEPADWCQFA...,4.999996
2,51004351,COC1C(N(C)C(=O)c2ccccc2)CC2OC1(C)n1c3ccccc3c3c...,CHEK2,MSRESDVEAQQSHGSSACSQPHGSVTQSQGSSSQSQGISSSSTSTM...,4.999996
3,9926054,Cc1ccc2nc(NCCN)c3ncc(C)n3c2c1.Cl,DAPK2,MFQASMRSPNMEPFKQQKVEDFYDIGEELGSGQFAIVKKCREKSTG...,4.999996
4,176155,CS(=O)c1ccc(-c2nc(-c3ccc(F)cc3)c(-c3ccncc3)[nH...,TYRO3,MALRRSMGRPGLPPLPLPPPPRLGLLLAALASLLLPESAAAGLKLM...,4.999996
...,...,...,...,...,...
5149,10184653,CN(C)CC=CC(=O)Nc1cc2c(Nc3ccc(F)c(Cl)c3)ncnc2cc...,TNNI3K,MGNYKSRPTQTCTDEWKKKVSESYVITIERLEDDLQIKEKELTELR...,4.999996
5150,16725726,CCn1c(-c2nonc2N)nc2c(C#CC(C)(C)O)ncc(OCC3CCCNC...,FYN,MGCVQCKDKEATKLTEERDGSLNQSSGYRYGTDPTPQHYPSFGVTS...,4.999996
5151,11656518,Cn1c(Nc2ccc(C(F)(F)F)cc2)nc2cc(Oc3ccnc(-c4ncc(...,CAMKK1,MEGGPAVCCQDPRAELVERVAAIDVTHLEEADGGPEPTRNGVDPPP...,4.999996
5152,153999,CN(C)CC1CCn2cc(c3ccccc32)C2=C(C(=O)NC2=O)c2cn(...,IRAK3,MAGNCGARGALSAHTLLFDLPPALLGELCAVLDSCDGALGWRGLAE...,4.999996


In [8]:
import pandas as pd
from datasets import Dataset

# ---- 1.  Your starting DataFrame (df) ----
# df = pd.read_csv(...)   # or however you loaded it

# ---- 2.  Build the concatenated text for every row ----
def row_to_text(row):
    return (
        f"[Drug SMILE] {row['Drug']} "
        f"[Target] {row['Target_ID']} "
        f"[Binding Affinity] {row['Y']:}"
    )

df["text"] = df.apply(row_to_text, axis=1)

# ---- 3.  Keep only the 'text' column and convert to a Dataset ----
training_ds = Dataset.from_pandas(df[["text"]], preserve_index=False)

print(training_ds[0]["text"])
# '[Drug SMILE] Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12 [Target] AAK1 [Binding Affinity] 7.37'

[Drug SMILE] Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12 [Target] AAK1 [Binding Affinity] 7.3655227298392685


In [10]:
lima_training_config = llm_configs.TrainingConfig(
    run_name = "finetuning on TDC DAVIS",
    num_train_epochs = 1,
    learning_rate  = 4e-5,
    logging_strategy = "steps", 
    logging_steps = 10,
    gradient_checkpointing=False,
    context_length = 1024,
    use_liger_kernel=True,
    per_device_train_batch_size = 32,
    gradient_accumulation_steps=1,
    # warmup_steps  = 0, # LIMA specifies no warmup, so we set this explicitly
    warmup_ratio = 0.3, # Use our default warmup ratio instead
    packing=True,
    padding_free = True,
    sequential_sampling = False,
    reverse_ffd_packing= False,
    remove_unused_columns=False,
)


# === Run LIMA Fine-Tuning ===
log.info("\n--- Starting LIMA Fine-Tuning ---")
# The model object will be updated with the fine-tuned weights
trainer = llm_training.sft_train_on_dataset(
    model=model,
    tokenizer=tokenizer,
    log=log,
    train_dataset=training_ds,
    train_cfg=lima_training_config,
    train=False,
    use_liger_loss = True
)

2025-07-08 19:16:42 - INFO - [__main__] - 
--- Starting LIMA Fine-Tuning ---
2025-07-08 19:16:42 - INFO - [__main__] - Starting SFT training run...


False


Adding EOS to train dataset:   0%|          | 0/18041 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/18041 [00:00<?, ? examples/s]

Packing train dataset:   0%|          | 0/18041 [00:00<?, ? examples/s]

2025-07-08 19:16:47 - INFO - [liger_kernel.transformers.monkey_patch] - Applying Liger kernels to model instance with model type: qwen2 with kwargs: {}


Applied Liger kernels to Qwen2


In [11]:
import wandb
trainer.train()
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mjiosephlee[0m ([33mupenn-ml[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
10,2.4979
20,0.6832
30,0.3755
40,0.333


0,1
train/epoch,▁▃▅██
train/global_step,▁▃▅██
train/grad_norm,█▃▁▁
train/learning_rate,▆█▄▁
train/loss,█▂▁▁
train/num_tokens,▁▃▅██

0,1
total_flos,2802241289692416.0
train/epoch,1.0
train/global_step,42.0
train/grad_norm,2.4375
train/learning_rate,0.0
train/loss,0.333
train/num_tokens,1304953.0
train_loss,0.94175
train_runtime,59.8823
train_samples_per_second,21.993


In [12]:
import pandas as pd
from datasets import Dataset

# ---- 1.  Your starting DataFrame (df) ----
# df = pd.read_csv(...)   # or however you loaded it

# ---- 2.  Build the concatenated text for every row ----
def row_to_text(row):
    return (
        f"[Drug SMILE] {row['Drug']} "
        f"[Target] {row['Target_ID']} "
        f"[Binding Affinity] {row['Y']:}"
    )
test_df = split['test']
test_df["text"] = test_df.apply(row_to_text, axis=1)

# ---- 3.  Keep only the 'text' column and convert to a Dataset ----
test_ds = Dataset.from_pandas(test_df[["text"]], preserve_index=False)

print(test_ds[0]["text"])
# '[Drug SMILE] Cc1[nH]nc2ccc(-c3cncc(OCC(N)Cc4ccccc4)c3)cc12 [Target] AAK1 [Binding Affinity] 7.37'

[Drug SMILE] CC(O)C(=O)O.CN1CCN(c2ccc3c(c2)NC(=C2C(=O)N=c4cccc(F)c4=C2N)N3)CC1.O [Target] QSK [Binding Affinity] 4.999995657076895


In [13]:
print(test_ds[114]["text"])


[Drug SMILE] O=c1ncn2nc(Sc3ccc(F)cc3F)ccc2c1-c1c(Cl)cccc1Cl [Target] CSNK1G1 [Binding Affinity] 4.999995657076895


In [14]:
# git config --global user.email "jiosephlee@gmail.com"
# git config --global user.name "Joseph Lee"
inference_config = llm_configs.InferenceConfig(temperature=0, repetition_penalty=1, max_new_tokens=1024)
question = f"""[Drug SMILE] CC(O)C(=O)O.CN1CCN(c2ccc3c(c2)NC(=C2C(=O)N=c4cccc(F)c4=C2N)N3)CC1.O [Target] QSK [Binding Affinity] """
generated_text = llm_training.generate_text(model, tokenizer, question, inference_config)
print(generated_text)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


[Drug SMILE] CC(O)C(=O)O.CN1CCN(c2ccc3c(c2)NC(=C2C(=O)N=c4cccc(F)c4=C2N)N3)CC1.O [Target] QSK [Binding Affinity] 4.999995657076895<|endoftext|>


In [None]:
import re
from tqdm import tqdm
import numpy as np

inference_cfg = llm_configs.InferenceConfig(
    temperature=0,
    repetition_penalty=1,
    max_new_tokens=32,   # 32 is plenty for a single number
)

# regular expressions
row_pat   = re.compile(
    r"\[Drug SMILE]\s+(.*?)\s+\[Target]\s+(.*?)\s+\[Binding Affinity]\s+([-+]?\d*\.?\d+)"
)
num_pat   = re.compile(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?")    # catch first float in the generation

targets, preds = [], []

for row in tqdm(test_ds["text"], desc="Inference on test set"):
    m = row_pat.match(row)
    if m is None:
        # skip badly-formatted rows
        continue

    drug_smiles, target_id, gt_aff_str = m.groups()
    gt_aff = float(gt_aff_str)

    prompt = f"[Drug SMILE] {drug_smiles} [Target] {target_id} [Binding Affinity] "

    gen_text = llm_training.generate_text(model, tokenizer, prompt, inference_cfg)

    num_match = num_pat.search(gen_text)
    if num_match is None:
        # model didn’t output a float we can parse → skip
        continue

    pred_aff = float(num_match.group())

    targets.append(gt_aff)
    preds.append(pred_aff)

# ------------------
# 2. compute MSE
# ------------------
targets = np.array(targets, dtype=np.float32)
preds   = np.array(preds,   dtype=np.float32)

mse = np.mean((preds - targets) ** 2)
print(f"\nMSE on {len(targets)} examples: {mse:.4f}")

Inference on test set:   0%|          | 0/5154 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Inference on test set:   0%|          | 1/5154 [00:00<30:24,  2.82it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Inference on test set:   0%|          | 2/5154 [00:00<30:49,  2.79it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Inference on test set:   0%|          | 3/5154 [00:01<30:13,  2.84it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Inference on test set:   0%|          | 4/5154 [00:01<29:58,  2.86it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Inference on test set:   0%|          | 5/5154 [00:01<29:47,  2.88it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Inference on test set:   0%|          | 6/5154 [00:02<29:41,  2.89it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


# Do it without the continued pre-training

In [None]:
# === Cell 1: Configuration ===
model_config = llm_configs.ModelConfig(
    id="Qwen/Qwen2.5-0.5B",
    peft=llm_configs.PeftConfig(
        enabled=False,
        add_eot_token=False,  # No longer doing EOT token for LIMA
    ),
    quantization=llm_configs.QuantizationConfig(mode=None), # Use QLoRA
)

log.info("--- Configuration ---")
print(model_config.model_dump_json(indent=2))

log.info("\n--- Loading Model for Training ---")
model, tokenizer = llm_training.load_model_for_training(model_config, log)

In [None]:
lima_training_config = llm_configs.TrainingConfig(
    run_name = "finetuning on TDC DAVIS",
    num_train_epochs = 1,
    learning_rate  = 4e-5,
    logging_strategy = "steps", 
    logging_steps = 10,
    gradient_checkpointing=False,
    context_length = 1024,
    use_liger_kernel=True,
    per_device_train_batch_size = 32,
    gradient_accumulation_steps=1,
    # warmup_steps  = 0, # LIMA specifies no warmup, so we set this explicitly
    warmup_ratio = 0.3, # Use our default warmup ratio instead
    packing=True,
    padding_free = True,
    sequential_sampling = False,
    reverse_ffd_packing= False,
    remove_unused_columns=False,
)


# === Run LIMA Fine-Tuning ===
log.info("\n--- Starting LIMA Fine-Tuning ---")
# The model object will be updated with the fine-tuned weights
trainer = llm_training.sft_train_on_dataset(
    model=model,
    tokenizer=tokenizer,
    log=log,
    train_dataset=training_ds,
    train_cfg=lima_training_config,
    train=True,
    use_liger_loss = True
)

In [None]:
import re
from tqdm import tqdm
import numpy as np

inference_cfg = llm_configs.InferenceConfig(
    temperature=0,
    repetition_penalty=1,
    max_new_tokens=32,   # 32 is plenty for a single number
)

# regular expressions
row_pat   = re.compile(
    r"\[Drug SMILE]\s+(.*?)\s+\[Target]\s+(.*?)\s+\[Binding Affinity]\s+([-+]?\d*\.?\d+)"
)
num_pat   = re.compile(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?")    # catch first float in the generation

targets, preds = [], []

for row in tqdm(test_ds["text"], desc="Inference on test set"):
    m = row_pat.match(row)
    if m is None:
        # skip badly-formatted rows
        continue

    drug_smiles, target_id, gt_aff_str = m.groups()
    gt_aff = float(gt_aff_str)

    prompt = f"[Drug SMILE] {drug_smiles} [Target] {target_id} [Binding Affinity] "

    gen_text = llm_training.generate_text(model, tokenizer, prompt, inference_cfg)

    num_match = num_pat.search(gen_text)
    if num_match is None:
        # model didn’t output a float we can parse → skip
        continue

    pred_aff = float(num_match.group())

    targets.append(gt_aff)
    preds.append(pred_aff)

# ------------------
# 2. compute MSE
# ------------------
targets = np.array(targets, dtype=np.float32)
preds   = np.array(preds,   dtype=np.float32)

mse = np.mean((preds - targets) ** 2)
print(f"\nMSE on {len(targets)} examples: {mse:.4f}")