In [None]:
# ================================
# Standard Library
# ================================
import glob
import os
import re
from typing import Any, Dict, List, Optional

# ================================
# Third-Party Libraries
# ================================
import bitsandbytes
import numpy as np
import pandas as pd
import torch
import yaml
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from datasets import Dataset

# HuggingFace Transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    GenerationConfig,
    Trainer,
    TrainingArguments,
    __version__ as HF_VER,
)

# PEFT (LoRA / QLoRA)
from peft import (
    LoraConfig,
    PeftModel,
    TaskType,
    get_peft_model,
    prepare_model_for_kbit_training,
)

# TRL (Supervised Fine-Tuning Trainer)
from trl import SFTConfig, SFTTrainer


In [None]:
def load_config(yaml_path):
    with open(yaml_path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)
    
file_paths = load_config(yaml_path="P3-config.yaml")
PARAMS = load_config(yaml_path="training-params.yaml")
prompts = load_config(yaml_path="prompts.yaml")

In [None]:

drug_data_path = file_paths["input_file_paths"]["drug_data_path"]

drug_data_df = pd.read_csv(drug_data_path)
drug_data_df.head()

In [None]:

summaries_path = file_paths["input_file_paths"]["summaries_path"]

### data cleaning

In [None]:
drug_data_clean = drug_data_df.copy()
drug_data_clean.columns
# print(list(drug_data_clean.columns))
columns_to_drop = ['PMID Count', 'PMIDs']
drug_data_clean.drop(columns=columns_to_drop, inplace=True)
drug_data_clean.head()

In [None]:

def _norm(s: str) -> str:
    return re.sub(r"\s+", " ", str(s).strip().upper())

def _safe(s: Optional[str]) -> str:
    return "" if s is None else str(s).strip()

def _read_text_file(path: str) -> Optional[str]:
    try:
        with open(path, "r", encoding="utf-8") as f:
            txt = f.read().strip()
            return txt if txt else None
    except Exception:
        return None

def _canon(s: str) -> str:
    """
    Lowercase; turn any run of non-alphanumerics (incl. underscores, hyphens, commas)
    into a single space; then collapse spaces.
    """
    s = str(s).lower().strip()
    s = re.sub(r"[^a-z0-9]+", " ", s)   # underscores, hyphens, slashes -> space
    return re.sub(r"\s+", " ", s).strip()


def load_reports_dir(reports_dir: str) -> Dict[str, str]:
    pat_file = re.compile(r"^(.+?)_summary\.txt$", re.I)
    pat_rm   = re.compile(r"PMID:\s*\d+\s*(?:\n\s*)?(?:no relevant information found\.?)",
                          re.I)

    out = {}

    for p in glob.glob(os.path.join(reports_dir, "*_summary.txt")):
        fn = os.path.basename(p)
        m = pat_file.match(fn)
        if not m:
            continue

        key = _canon(m.group(1))


        txt = _read_text_file(p)
        if not txt:
            continue

        # remove junk pmid blocks
        txt = pat_rm.sub("", txt)

        # collapse lines
        txt = re.sub(r"\n{3,}", "\n\n", txt).strip()

        # *** reject empty after cleaning ***
        if not txt:
            # optional debug: print(f"SKIP blank: {fn}")
            continue

        out[key] = txt

    return out

### merge datasets

In [None]:

reports_dict = load_reports_dir(summaries_path)

drug_data_clean["report"] = drug_data_clean["report_key"].map(reports_dict.get)
pat_drop_pmid = re.compile(r"^PMID:.*$", re.I | re.M)

drug_data_clean["report"] = drug_data_clean["report"].apply(
    lambda txt: pat_drop_pmid.sub("", txt).strip() if isinstance(txt, str) else txt
)
drug_data_clean = drug_data_clean.dropna(subset=["report"]).reset_index(drop=True)
print(drug_data_clean.shape)

drug_data_clean.head()

### split into train-test

In [None]:

print(drug_data_clean.shape)
train_df, val_df = train_test_split(
    drug_data_clean,
    test_size=0.2,          # 20% test
    random_state=42,        # reproducible
    shuffle=True
)
train_df_path = file_paths["train_val_paths"]["train_path"]
val_df_path = file_paths["train_val_paths"]["val_path"]

train_df.to_csv(train_df_path)
val_df.to_csv(val_df_path)

print(len(train_df), len(val_df))

In [None]:

train_df.head()

## llm training

### load model configs

In [None]:
MODEL_ID   = PARAMS["model"]["id"]
MODEL_OUTPUT_DIR = PARAMS["model"]["output_dir"] ## save model configs here

### training functions

In [None]:
def _dtype_from_str(s: str):
    s = (s or "").lower()
    if s in ("bf16", "bfloat16"):
        return torch.bfloat16
    if s in ("fp16", "float16", "half"):
        return torch.float16
    if s in ("fp32", "float32"):
        return torch.float32
    # default safe fallback
    return torch.float16


def load_dataframe(path: str):
    df = pd.read_csv(path)
    drop_cols = PARAMS["data"]["dropna_cols"]
    df = df.dropna(subset=drop_cols).reset_index(drop=True)
    return df

class SafeDict(dict):
    def __missing__(self, key):
        return ""  # default blank for any missing placeholder

def get_first_key(ex, keys):
    """Return first non-empty value among candidate keys from a row/example."""
    for k in keys:
        if k in ex and ex[k] is not None and str(ex[k]).strip() != "":
            return str(ex[k])
    return ""

def chat_format_map_fn(
    ex: Dict[str, Any],
    training_prompt: str,
    system_prompt: str,
    few_shots: List[Dict[str, str]],
    tokenizer: AutoTokenizer,
) -> Dict[str, str]:
    """Map a row into chat-formatted text for SFT."""
    vals = {
        "Drug": ex.get("Drug", ""),
        "Receptor": ex.get("Receptor", ""),
        "report": ex.get("report", ""),
        "PDB_ID": ex.get("PDB_ID", ""),
    }
    user_text = training_prompt.format_map(SafeDict(vals))

    # few-shots (if any)
    msgs: List[Dict[str, str]] = []
    for shot in few_shots or []:
        if shot and "user" in shot and "assistant" in shot:
            msgs.append({"role": "user", "content": shot["user"]})
            msgs.append({"role": "assistant", "content": shot["assistant"]})

    # fold system into first user
    first_user = f"<<SYS>>{system_prompt}<</SYS>>\n\n{user_text}"
    msgs.append({"role": "user", "content": first_user})

    rendered = tokenizer.apply_chat_template(
        msgs + [{"role": "assistant", "content": ""}],
        tokenize=False,
        add_generation_prompt=True,
    )
    return {"text": rendered}


def plain_format_map_fn(
    ex: Dict[str, Any],
    training_prompt: str,
) -> Dict[str, str]:
    """Map a row into plain text (no chat template) for SFT."""
    vals = {
        "Drug": ex.get("Drug", ""),
        "Receptor": ex.get("Receptor", ""),
        "Disease": ex.get("Disease", ""),
        "report": ex.get("report", ""),
        "PDB_ID": ex.get("PDB_ID", ""),
    }
    text = training_prompt.format_map(SafeDict(vals))
    return {"text": text}


def build_dataset(
    csv_path: str,
    tokenizer: AutoTokenizer,
    params: Dict[str, Any],
    prompts: Dict[str, Any],
) -> Dataset:
    df = load_dataframe(csv_path, params)

    training_prompt = prompts["train_user_template"]
    system_prompt = prompts["system"]
    use_chat_format = prompts.get("use_chat_format", False)
    few_shots = prompts.get("few_shots", [])

    cols = list(df.columns)
    ds = Dataset.from_pandas(df, preserve_index=False)

    if use_chat_format:
        # HuggingFace Datasets map passes only the example dict, so we wrap
        # our extra arguments using a closure-like helper via kwargs.
        def _wrapped_chat_map_fn(ex: Dict[str, Any]) -> Dict[str, str]:
            return chat_format_map_fn(
                ex=ex,
                training_prompt=training_prompt,
                system_prompt=system_prompt,
                few_shots=few_shots,
                tokenizer=tokenizer,
            )

        ds = ds.map(_wrapped_chat_map_fn, remove_columns=cols)
    else:
        def _wrapped_plain_map_fn(ex: Dict[str, Any]) -> Dict[str, str]:
            return plain_format_map_fn(
                ex=ex,
                training_prompt=training_prompt,
            )

        ds = ds.map(_wrapped_plain_map_fn, remove_columns=cols)

    ds = ds.train_test_split(
        test_size=params["data"]["test_size"],
        seed=params["data"]["seed"],
    )
    return ds



def get_model_and_tokenizer(params: Dict[str, Any]):
    # Perf niceties
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    # --- bitsandbytes (QLoRA) config from params ---
    compute_dtype = _dtype_from_str(params["bnb"]["compute_dtype"])
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=params["model"]["load_in_4bit"],
        bnb_4bit_quant_type=params["bnb"]["quant_type"],
        bnb_4bit_use_double_quant=params["bnb"]["double_quant"],
        bnb_4bit_compute_dtype=compute_dtype,
    )

    model_id = params["model"]["id"]
    tok = AutoTokenizer.from_pretrained(
        model_id,
        use_fast=True,
        trust_remote_code=params["model"]["trust_remote_code"],
    )
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    tok.padding_side = "right"

    max_memory = params["model"]["max_memory"]
    os.makedirs(params["model"]["offload_folder"], exist_ok=True)

    torch_dtype = _dtype_from_str(params["model"]["torch_dtype"])

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_cfg,
        device_map="auto",
        torch_dtype=torch_dtype,
        max_memory=max_memory,
        offload_folder=params["model"]["offload_folder"],
        low_cpu_mem_usage=True,
        trust_remote_code=params["model"]["trust_remote_code"],
    )
    model.config.use_cache = False
    model.gradient_checkpointing_enable()

    # Prepare for k-bit training (casts norms/embeddings for stability)
    model = prepare_model_for_kbit_training(model)

    if getattr(model.config, "pad_token_id", None) is None:
        model.config.pad_token_id = tok.pad_token_id

    return model, tok


def get_lora_wrapped(model, params: Dict[str, Any]):
    lconf = params["lora"]
    lora_cfg = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=lconf["r"],
        lora_alpha=lconf["alpha"],
        lora_dropout=lconf["dropout"],
        target_modules=lconf["target_modules"],
        bias=lconf["bias"],
    )
    model = get_peft_model(model, lora_cfg)
    return model


def main(
    train_csv: str,
    file_paths: Dict[str, Any],
    params: Dict[str, Any],
    prompts: Dict[str, Any],
) -> None:
    """
    Run SFT training.

    Args:
        train_csv: Path to the training CSV.
        file_paths: Config dict from P3-config.yaml (if you need more paths).
        params: Training hyperparameters and model settings.
        prompts: Prompt templates and chat-format toggles.
    """
    model, tokenizer = get_model_and_tokenizer(params)
    model = get_lora_wrapped(model, params)

    ds = build_dataset(train_csv, tokenizer, params, prompts)

    training_args = SFTConfig(
        output_dir=params["model"]["output_dir"],
        num_train_epochs=params["train"]["epochs"],
        per_device_train_batch_size=params["train"]["per_device_train_batch_size"],
        per_device_eval_batch_size=params["train"]["per_device_eval_batch_size"],
        gradient_accumulation_steps=params["train"]["gradient_accumulation_steps"],
        learning_rate=params["train"]["learning_rate"],
        lr_scheduler_type=params["train"]["scheduler"],
        warmup_ratio=params["train"]["warmup_ratio"],
        weight_decay=params["train"]["weight_decay"],
        logging_steps=params["train"]["logging_steps"],
        evaluation_strategy=params["train"]["evaluation_strategy"],
        eval_steps=params["train"]["eval_steps"],
        save_steps=params["train"]["save_steps"],
        save_total_limit=params["train"]["save_total_limit"],
        bf16=params["train"]["bf16"],
        fp16=params["train"]["fp16"],
        gradient_checkpointing=params["train"]["gradient_checkpointing"],
        gradient_checkpointing_kwargs={"use_reentrant": False},
        packing=params["data"]["packing"],
        optim=params["train"]["optim"],
        max_grad_norm=params["train"]["max_grad_norm"],
        seed=params["train"]["seed"],
        report_to=params["train"]["report_to"],
        load_best_model_at_end=params["train"]["load_best_model"],
        # SFT-specific fields
        max_seq_length=params["data"]["max_seq_length"],
        dataset_text_field="text",
    )

    early_stopping = EarlyStoppingCallback(
        early_stopping_patience=1,      # small because only a few epochs
        early_stopping_threshold=0.0,
    )

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=ds["train"],
        eval_dataset=ds["test"],
        args=training_args,
        callbacks=[early_stopping],
    )

    trainer.train()
    trainer.save_model()
    tokenizer.save_pretrained(params["model"]["output_dir"])

In [None]:
if __name__ == "__main__":
    # Load configs from YAML
    file_paths = load_config(yaml_path="P3-config.yaml")
    params = load_config(yaml_path="training-params.yaml")
    prompts = load_config(yaml_path="prompts.yaml")

    # Default to train_path from config, but you can pass any CSV you want
    train_csv_path = file_paths["train_val_paths"]["train_path"]

    main(
        train_csv=train_csv_path,
        file_paths=file_paths,
        params=params,
        prompts=prompts,
    )