In [8]:
# ================================
# Standard Library
# ================================
import glob
import os
import re
from typing import Any, Dict, List, Optional

# ================================
# Third-Party Libraries
# ================================
import bitsandbytes
import numpy as np
import pandas as pd
import torch
import yaml
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from datasets import Dataset

# HuggingFace Transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    GenerationConfig,
    Trainer,
    TrainingArguments,
    __version__ as HF_VER,
)

# PEFT (LoRA / QLoRA)
from peft import (
    LoraConfig,
    PeftModel,
    TaskType,
    get_peft_model,
    prepare_model_for_kbit_training,
)

# TRL (Supervised Fine-Tuning Trainer)
from trl import SFTConfig, SFTTrainer


In [9]:
def load_config(yaml_path):
    with open(yaml_path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)
    
file_paths = load_config(yaml_path="P3-config.yaml")
PARAMS = load_config(yaml_path="validation_params.yaml")
prompts = load_config(yaml_path="prompts.yaml")

In [3]:
test_df_path = file_paths['test_paths']['test_path']
test_df = pd.read_csv(test_df_path)
test_df

Unnamed: 0,Drug,Receptor,PDB_ID
0,Metformin,Acetyl-CoA carboxylase 2,"3FF6,3TDC,2X24,3JRX,3JRW,2HJW,4HQ6,5KKN,3GLK,3..."
1,Pioglitazone,Peroxisome proliferator-activated receptor gamma,"3E00,3DZY,3DZU,7QB1,6L89,6K0T,6AD9,5HZC,5F9B,5..."
2,Alogliptin,Dipeptidyl peptidase 4,"2QTB,2QT9,2BGR,2JID,3F8S,2QJR,3W2T,3VJM,3VJL,3..."
3,Linagliptin,Dipeptidyl peptidase 4,"2QTB,2QT9,2BGR,2JID,3F8S,2QJR,3W2T,3VJM,3VJL,3..."
4,Sitagliptin,Dipeptidyl peptidase 4,"2QTB,2QT9,2BGR,2JID,3F8S,2QJR,3W2T,3VJM,3VJL,3..."
5,Saxagliptin,Dipeptidyl peptidase 4,"2QTB,2QT9,2BGR,2JID,3F8S,2QJR,3W2T,3VJM,3VJL,3..."
6,Vildagliptin,Dipeptidyl peptidase 4,"2QTB,2QT9,2BGR,2JID,3F8S,2QJR,3W2T,3VJM,3VJL,3..."
7,Bromocriptine,Dopamine D2 receptor,"8IRS,7JVR,8U02,8TZQ,6VMS"
8,Canagliflozin,Sodium/glucose cotransporter 2,"7YNK,7YNJ,7VSI,8HIN,8HG7,8HEZ,8HDH,8HB0"
9,Dapagliflozin,Sodium/glucose cotransporter 2,"7YNK,7YNJ,7VSI,8HIN,8HG7,8HEZ,8HDH,8HB0"


In [4]:
def _dtype_from_str(s: str):
    s = (s or "").lower()
    if s in ("bf16", "bfloat16"):
        return torch.bfloat16
    if s in ("fp16", "float16", "half"):
        return torch.float16
    if s in ("fp32", "float32"):
        return torch.float32
    return torch.float16


def make_bnb_config(infer_cfg: Dict[str, Any]) -> BitsAndBytesConfig:
    """
    Build BitsAndBytesConfig from validation_params.yaml's `quantization` block.
    """
    q = infer_cfg["quantization"]
    compute_dtype = q.get("compute_dtype", "auto")

    # allow "auto" in YAML, resolve here
    if compute_dtype == "auto":
        compute_dtype = "bf16" if torch.cuda.is_bf16_supported() else "fp16"

    return BitsAndBytesConfig(
        load_in_4bit=q["load_in_4bit"],
        bnb_4bit_quant_type=q["quant_type"],
        bnb_4bit_use_double_quant=q["double_quant"],
        bnb_4bit_compute_dtype=_dtype_from_str(compute_dtype),
    )


# -------------------------------
# Model loading
# -------------------------------

def load_model_and_tokenizer_for_inference(
    infer_cfg: Dict[str, Any]
):
    """
    Load base model + LoRA adapter for inference.

    Uses:
      - 4-bit quantization (preferred) OR
      - merged 16-bit model if `merge_16bit` is True.
    """
    base_id = infer_cfg["base_model_id"]
    adapter_dir = infer_cfg["adapter_dir"]
    device_map = infer_cfg["batch"]["device_map"]
    merge_16bit = infer_cfg["merge_16bit"]

    tokenizer = AutoTokenizer.from_pretrained(
        base_id,
        use_fast=True,
        trust_remote_code=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    if merge_16bit:
        # 16-bit merged model
        dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
        model = AutoModelForCausalLM.from_pretrained(
            base_id,
            device_map=device_map,
            torch_dtype=dtype,
            trust_remote_code=True,
        )
        model = PeftModel.from_pretrained(model, adapter_dir)
        model = model.merge_and_unload()
    else:
        # 4-bit quantized path
        bnb_cfg = make_bnb_config(infer_cfg)
        model = AutoModelForCausalLM.from_pretrained(
            base_id,
            quantization_config=bnb_cfg,
            device_map=device_map,
            trust_remote_code=True,
        )
        model = PeftModel.from_pretrained(model, adapter_dir)

    model.eval()
    return model, tokenizer


# -------------------------------
# Prompt composition
# -------------------------------
def render_user_template(user_template: str, drug: str, receptor: str, pdb_ids: str) -> str:
    """
    Safely substitute only {Drug}, {Receptor}, and {PDB_ID} without
    treating other braces in the template (e.g. JSON) as format fields.
    """
    return (
        user_template
        .replace("{Drug}", drug or "")
        .replace("{Receptor}", receptor or "")
        .replace("{PDB_ID}", pdb_ids or "")
    )
    
def select_user_template(p_cfg: Dict[str, Any], test: bool, tier: str) -> str:
    """
    tier: one of {"basic", "oncologist", "reviewer"} for test.
    For validation we ignore tier and use val_user_template.
    """
    if not test:
        return p_cfg["val_user_template"]

    # test mode
    if tier == "basic":
        return p_cfg["test_user_template"]
    if tier == "oncologist":
        return p_cfg["test_user_template_oncologist"]
    if tier == "reviewer":
        return p_cfg["test_user_template_reviewer"]

    raise ValueError(f"Unknown prompt tier: {tier}")


def compose_chat_prompt(
    drug: str,
    receptor: str,
    pdb_ids: str,
    system_prompt: str,
    user_template: str,
    tokenizer,
) -> str:
    """
    Construct a chat-style prompt (same style as training),
    folding the system prompt into the first user turn.
    """
    user_text = render_user_template(
        user_template=user_template,
        drug=drug,
        receptor=receptor,
        pdb_ids=pdb_ids,
    )


    first_user = f"<<SYS>>{system_prompt}<</SYS>>\n\n{user_text}"
    msgs = [{"role": "user", "content": first_user}]

    rendered = tokenizer.apply_chat_template(
        msgs,
        tokenize=False,
        add_generation_prompt=True,
    )
    return rendered


def compose_plain_prompt(
    drug: str,
    receptor: str,
    pdb_ids: str,
    user_template: str,
) -> str:
    """
    Construct a plain text prompt (no chat template).
    """
    return user_template.format(
        Drug=drug or "",
        Receptor=receptor or "",
        PDB_ID=pdb_ids or "",
    )


# -------------------------------
# Row-level generation
# -------------------------------

@torch.no_grad()
def generate_for_row(
    row: Dict[str, Any],
    model,
    tokenizer,
    infer_cfg: Dict[str, Any],
    prompts_cfg: Dict[str, Any],
    test: bool,
    tier: str
) -> str:
    """
    Generate a report for a single row.

    `prompts_cfg` is expected to be the loaded prompts.yaml dict.
    We select either the validation or test user template based on `test`.
    """
    col_cfg = infer_cfg["columns"]
    col_d = col_cfg["drug"]
    col_r = col_cfg["receptor"]
    col_p = col_cfg["pdb"]

    drug = str(row.get(col_d, "") or "")
    receptor = str(row.get(col_r, "") or "")

    if col_p in row and pd.notna(row[col_p]):
        pdb_raw = str(row[col_p])
    else:
        pdb_raw = ""

    pdb_ids = ",".join([p.strip() for p in pdb_raw.split(",")]) if pdb_raw else ""

    p_cfg = prompts_cfg.get("prompts", prompts_cfg)

    use_chat_format = p_cfg.get("use_chat_format", False)
    system_prompt = p_cfg.get("system", "")
    user_template = select_user_template(p_cfg, test=test, tier=tier)

    if use_chat_format:
        prompt = compose_chat_prompt(
            drug=drug,
            receptor=receptor,
            pdb_ids=pdb_ids,
            system_prompt=system_prompt,
            user_template=user_template,
            tokenizer=tokenizer,
        )
    else:
        prompt = compose_plain_prompt(
            drug=drug,
            receptor=receptor,
            pdb_ids=pdb_ids,
            user_template=user_template,
        )

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        add_special_tokens=False,
    )
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    gen_cfg = infer_cfg["generation"]
    out = model.generate(
        **inputs,
        max_new_tokens=gen_cfg["max_new_tokens"],
        do_sample=gen_cfg["do_sample"],
        temperature=gen_cfg["temperature"],
        top_p=gen_cfg["top_p"],
        top_k=gen_cfg["top_k"],
        repetition_penalty=gen_cfg["repetition_penalty"],
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )

    # Decode only the newly generated portion
    generated = out[0][inputs["input_ids"].shape[1]:]
    text = tokenizer.decode(generated, skip_special_tokens=True)

    # Optional: apply custom stop strings
    for s in gen_cfg.get("stop_strings", []):
        if s and s in text:
            text = text.split(s)[0].rstrip()
            break

    return text.strip()


# -------------------------------
# Output path helper
# -------------------------------

def get_output_path(
    infer_cfg: Dict[str, Any],
    test: bool,
    relation: bool,
    real_world: bool,
) -> str:
    """
    Decide which CSV path to save to based on flags.
    Uses the `save_paths` block from validation_params.yaml.
    """
    paths = infer_cfg["save_paths"]

    if not test:
        return paths["inference_reports"]

    if real_world:
        return paths["real_world_test"]

    if relation:
        return paths["some_relation"]

    return paths["no_relation"]


# -------------------------------
# DataFrame-level inference
# -------------------------------

def run_inference(
    df: pd.DataFrame,
    model,
    tokenizer,
    infer_cfg: Dict[str, Any],
    prompts_cfg: Dict[str, Any],
    test: bool,
    relation: bool,
    real_world: bool,
    tier: str
) -> pd.DataFrame:
    """
    Run inference over all rows in a DataFrame.

    Args:
        df: Input examples (must contain columns specified in infer_cfg["columns"]).
        model, tokenizer: Loaded via `load_model_and_tokenizer_for_inference`.
        infer_cfg: validation_params.yaml contents.
        prompts_cfg: prompts.yaml contents.
        test: If False => validation-like set; if True => test scenarios.
        relation: If True (and test=True), saves to "some_relation" path.
        real_world: If True (and test=True), saves to "real_world_test" path.

    Returns:
        DataFrame with an extra column `generated_report`.
    """
    records = []
    for _, row in df.iterrows():
        report = generate_for_row(
            row=row.to_dict(),
            model=model,
            tokenizer=tokenizer,
            infer_cfg=infer_cfg,
            prompts_cfg=prompts_cfg,
            test=test,
            tier=tier,
        )
        rec = dict(row)
        rec["generated_report"] = report
        rec["prompt_tier"] = tier
        records.append(rec)

    out_df = pd.DataFrame(records)
    out_path = get_output_path(
        infer_cfg=infer_cfg,
        test=test,
        relation=relation,
        real_world=real_world,
    )
    print(out_path)
    out_path.replace(".csv", f"_{tier}.csv")
    out_df.to_csv(out_path, index=False)
    return out_df

In [5]:
model, tokenizer = load_model_and_tokenizer_for_inference(PARAMS)
test_outputs = run_inference(
    df=test_df,
    model=model,
    tokenizer=tokenizer,
    infer_cfg=PARAMS,
    prompts_cfg=prompts,
    test=True,
    relation=False,
    real_world=True,
    tier="basic"
)

In [10]:
test_outputs = run_inference(
    df=test_df,
    model=model,
    tokenizer=tokenizer,
    infer_cfg=PARAMS,
    prompts_cfg=prompts,
    test=True,
    relation=False,
    real_world=True,
    tier="oncologist"
)

In [11]:
test_outputs = run_inference(
    df=test_df,
    model=model,
    tokenizer=tokenizer,
    infer_cfg=PARAMS,
    prompts_cfg=prompts,
    test=True,
    relation=False,
    real_world=True,
    tier="reviewer"
)