In [1]:
# ================================
# Standard Library
# ================================
import glob
import os
import re
from typing import Any, Dict, List, Optional

# ================================
# Third-Party Libraries
# ================================
import bitsandbytes
import numpy as np
import pandas as pd
import torch
import yaml
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from datasets import Dataset

# HuggingFace Transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
    GenerationConfig,
    Trainer,
    TrainingArguments,
    __version__ as HF_VER,
)

# PEFT (LoRA / QLoRA)
from peft import (
    LoraConfig,
    PeftModel,
    TaskType,
    get_peft_model,
    prepare_model_for_kbit_training,
)

# TRL (Supervised Fine-Tuning Trainer)
from trl import SFTConfig, SFTTrainer


In [2]:
def load_config(yaml_path):
    with open(yaml_path, "r", encoding="utf-8") as f:
        return yaml.safe_load(f)
    
file_paths = load_config(yaml_path="P3-config.yaml")
PARAMS = load_config(yaml_path="training_params.yaml")
prompts = load_config(yaml_path="prompts.yaml")

dtype = "bf16" if torch.cuda.is_bf16_supported() else "fp16"
PARAMS["model"]["dtype"] = dtype
PARAMS["bnb"]["compute_dtype"] = dtype
PARAMS["train"]["bf16"] = torch.cuda.is_bf16_supported()
PARAMS["train"]["fp16"] = not torch.cuda.is_bf16_supported()

In [3]:

drug_data_path = file_paths["input_file_paths"]["drug_data_path"]

drug_data_df = pd.read_csv(drug_data_path)
drug_data_df.head()

Unnamed: 0,Drug,Receptor,PDB_ID
0,Metformin,Acetyl-CoA carboxylase 2,"3FF6,3TDC,2X24,3JRX,3JRW,2HJW,4HQ6,5KKN,3GLK,3..."
1,Cetuximab,Epidermal growth factor receptor,"7SZ7,7SZ5,7SYE,7SYD,7SZ1,7SZ0,8HGS,8HGP,8HGO,5..."
2,Bevacizumab,Vascular endothelial growth factor A,"3V2A,5T89,8UWZ,6T9D,7KF1,7KF0,7KEZ,5FV2,5FV1,3..."
3,Pioglitazone,Peroxisome proliferator-activated receptor gamma,"3E00,3DZY,3DZU,7QB1,6L89,6K0T,6AD9,5HZC,5F9B,5..."
4,Adenosine triphosphate,Tyrosine-protein kinase ABL1,"5MO4,1OPK,1OPL,2FO0,8SSN,4XEY,6XR7,6XR6,2E2B,4..."


In [4]:

summaries_path = file_paths["input_file_paths"]["summaries_path"]

## Data cleaning


In [5]:
drug_data_clean = drug_data_df.copy()
drug_data_clean.columns

drug_data_clean.head()

Unnamed: 0,Drug,Receptor,PDB_ID
0,Metformin,Acetyl-CoA carboxylase 2,"3FF6,3TDC,2X24,3JRX,3JRW,2HJW,4HQ6,5KKN,3GLK,3..."
1,Cetuximab,Epidermal growth factor receptor,"7SZ7,7SZ5,7SYE,7SYD,7SZ1,7SZ0,8HGS,8HGP,8HGO,5..."
2,Bevacizumab,Vascular endothelial growth factor A,"3V2A,5T89,8UWZ,6T9D,7KF1,7KF0,7KEZ,5FV2,5FV1,3..."
3,Pioglitazone,Peroxisome proliferator-activated receptor gamma,"3E00,3DZY,3DZU,7QB1,6L89,6K0T,6AD9,5HZC,5F9B,5..."
4,Adenosine triphosphate,Tyrosine-protein kinase ABL1,"5MO4,1OPK,1OPL,2FO0,8SSN,4XEY,6XR7,6XR6,2E2B,4..."


In [6]:

def _norm(s: str) -> str:
    return re.sub(r"\s+", " ", str(s).strip().upper())

def _safe(s: Optional[str]) -> str:
    return "" if s is None else str(s).strip()

def _read_text_file(path: str) -> Optional[str]:
    try:
        with open(path, "r", encoding="utf-8") as f:
            txt = f.read().strip()
            return txt if txt else None
    except Exception:
        return None

def _canon(s: str) -> str:
    """
    Lowercase; turn any run of non-alphanumerics (incl. underscores, hyphens, commas)
    into a single space; then collapse spaces.
    """
    s = str(s).lower().strip()
    s = re.sub(r"[^a-z0-9]+", " ", s)   # underscores, hyphens, slashes -> space
    return re.sub(r"\s+", " ", s).strip()


def load_reports_dir(reports_dir: str) -> Dict[str, str]:
    pat_file = re.compile(r"^(.+?)_summary\.txt$", re.I)
    pat_rm   = re.compile(r"PMID:\s*\d+\s*(?:\n\s*)?(?:no relevant information found\.?)",
                          re.I)

    out = {}

    for p in glob.glob(os.path.join(reports_dir, "*_summary.txt")):
        fn = os.path.basename(p)
        m = pat_file.match(fn)
        if not m:
            continue

        key = _canon(m.group(1))


        txt = _read_text_file(p)
        if not txt:
            continue

        # remove junk pmid blocks
        txt = pat_rm.sub("", txt)

        # collapse lines
        txt = re.sub(r"\n{3,}", "\n\n", txt).strip()

        # *** reject empty after cleaning ***
        if not txt:
            # optional debug: print(f"SKIP blank: {fn}")
            continue

        out[key] = txt

    return out

In [7]:
reports_dict = load_reports_dir(summaries_path)

for i, (drug, text) in enumerate(reports_dict.items()):
    print(f"{i:02d} | {drug}")
    print(text[:200], "...\n")      # first 200 chars preview

    if i> 4:
        break

00 | amphetamine
PMID: 39287256
The study found KEGG enrichment of shared NASH/type 2 diabetes targets in the "amphetamine
addiction" pathway alongside colorectal cancer, PPAR signaling and toll‑like receptor signalin ...

01 | semaglutide
PMID: 40437949
Semaglutide, as a GLP‑1 receptor agonist used to treat diabetes and obesity, was part of trials
pooled here showing no overall cancer risk but a small increased colorectal cancer signal ...

02 | adenosine monophosphate
PMID: 37071615
Patchouli alcohol (PA) treatment increased phosphorylation (activation) of 5' adenosine
monophosphate‑activated protein kinase (AMPK) alongside protein kinase B (Akt) in differentiated
 ...

03 | hydroxychloroquine
PMID: 33608672
The study shows that blocking autophagy with chloroquine potentiated killing of KRAS‑mutant CRC
cells treated with glycolysis and OXPHOS inhibitors; by extension, hydroxychloroquine (a  ...

04 | cimetidine
PMID: 12938277
In this case report of a diabetic patient with ascending c

## merge datasets

In [8]:

## map drug report name to actual drug name from dataset
drug_data_clean["report_key"] = drug_data_clean["Drug"].map(_canon)
drug_data_clean["report"] = drug_data_clean["report_key"].map(lambda k: reports_dict.get(k))
print("---- DRUG → report_key → exists ----")
for d, k in drug_data_clean[["Drug","report_key"]].head(10).values:
    print(f"{d:<35} → {k:<35} → {k in reports_dict}")



---- DRUG → report_key → exists ----
Metformin                           → metformin                           → True
Cetuximab                           → cetuximab                           → True
Bevacizumab                         → bevacizumab                         → True
Pioglitazone                        → pioglitazone                        → True
Adenosine triphosphate              → adenosine triphosphate              → True
Thiazolidinedione                   → thiazolidinedione                   → True
Estrogen                            → estrogen                            → True
Exenatide                           → exenatide                           → True
Adenosine monophosphate             → adenosine monophosphate             → True
Dapagliflozin                       → dapagliflozin                       → True


In [9]:

reports_dict = load_reports_dir(summaries_path)

pat_drop_pmid = re.compile(r"^PMID:.*$", re.I | re.M)
drug_data_clean["report"] = drug_data_clean["report_key"].map(reports_dict.get)


drug_data_clean["report"] = drug_data_clean["report"].apply(
    lambda txt: pat_drop_pmid.sub("", txt).strip() if isinstance(txt, str) else txt
)
drug_data_clean = drug_data_clean.dropna(subset=["report"]).reset_index(drop=True)
print(drug_data_clean.shape)

drug_data_clean

(68, 5)


Unnamed: 0,Drug,Receptor,PDB_ID,report_key,report
0,Metformin,Acetyl-CoA carboxylase 2,"3FF6,3TDC,2X24,3JRX,3JRW,2HJW,4HQ6,5KKN,3GLK,3...",metformin,In this retrospective study of locally advance...
1,Cetuximab,Epidermal growth factor receptor,"7SZ7,7SZ5,7SYE,7SYD,7SZ1,7SZ0,8HGS,8HGP,8HGO,5...",cetuximab,This case report links cetuximab (used as biol...
2,Bevacizumab,Vascular endothelial growth factor A,"3V2A,5T89,8UWZ,6T9D,7KF1,7KF0,7KEZ,5FV2,5FV1,3...",bevacizumab,Bevacizumab was used in this metastatic CRC ca...
3,Pioglitazone,Peroxisome proliferator-activated receptor gamma,"3E00,3DZY,3DZU,7QB1,6L89,6K0T,6AD9,5HZC,5F9B,5...",pioglitazone,"Pioglitazone, a PPARG (thiazolidinedione) agon..."
4,Adenosine triphosphate,Tyrosine-protein kinase ABL1,"5MO4,1OPK,1OPL,2FO0,8SSN,4XEY,6XR7,6XR6,2E2B,4...",adenosine triphosphate,The abstract notes tumor (including colorectal...
...,...,...,...,...,...
63,Valproate,Glycogen synthase kinase-3 alpha,"7SXF,7SXG,8VMG,8VMF,8VME,4NM7,4NM5,5K5N,6TCU,3...",valproate,The HDAC inhibitor sodium valproate reduced di...
64,Venlafaxine,Serotonin transporter,"6VRL,6VRK,6VRH,7TXT,7LWD,5I6Z,6W2C,6W2B,6DZZ,6...",venlafaxine,The abstract reports that venlafaxine provides...
65,Vildagliptin,Dipeptidyl peptidase 4,"2QTB,2QT9,2BGR,2JID,3F8S,2QJR,3W2T,3VJM,3VJL,3...",vildagliptin,"Vildagliptin, as a member of the DPP‑4 inhibit..."
66,Vitamin B6,Aromatic-L-amino-acid decarboxylase,"9GNS,8ORA,8OR9,3RCH,3RBL,3RBF,9HRH,9HRI,1JS6,1JS3",vitamin b6,In this cohort of stage III colon cancer patie...


### read train and test

In [10]:
train_df_path = file_paths["train_val_paths"]["train_path"]
val_df_path = file_paths["train_val_paths"]["val_path"]

train_df = pd.read_csv(train_df_path)
val_df = pd.read_csv(val_df_path)

print(len(train_df), len(val_df))

54 14


## llm training

In [11]:
import torch
from transformers import pipeline

model_id = "meta-llama/Llama-3.2-1B"

pipe = pipeline(
    "text-generation", 
    model=model_id, 
    dtype=torch.bfloat16, 
    device_map="auto"
)

pipe("The key to life is")


Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


[{'generated_text': 'The key to life is to never give up.\nThis was the message that the National Association of Realtors (NAR) chose to focus on in its latest radio spot. The spot is one of a series of spots that NAR has run in conjunction with a new consumer awareness campaign, “Buy with Confidence.”\nIn the spot, a Realtor says, “The key to life is to never give up. That’s what I tell my clients. Never give up. I tell my clients, ‘The key to life is to never give up.’” The Realtor then goes on to explain why it’s important to never give up.\n“If you’re ready to buy a home, there are a lot of things that can get in your way. If you’re ready to buy a home, there are a lot of things that can get in your way. We’re going to help you navigate the path to your new home.”\nWhile the spot is about the importance of never giving up, it’s also a reminder that the key to life is to always be prepared to face the challenges that come your way.\nThe key to life is to never give up. That’s what I

In [12]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "meta-llama/Llama-3.2-1B"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map={"": 0},          # <-- force all modules to GPU:0
    dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
)

print("Loaded on:", next(model.parameters()).device)


Loaded on: cuda:0


### load model configs

In [13]:
MODEL_ID   = PARAMS["model"]["id"]
MODEL_OUTPUT_DIR = PARAMS["model"]["output_dir"] ## save model configs here

In [14]:
def _dtype_from_str(s: str):
    s = (s or "").lower()
    if s in ("bf16", "bfloat16"):
        return torch.bfloat16
    if s in ("fp16", "float16", "half"):
        return torch.float16
    if s in ("fp32", "float32"):
        return torch.float32
    if s in ("auto", ""):
        return None
    raise ValueError(f"Unknown dtype string: {s!r}")


def _normalize_max_memory(mm: Optional[Dict[Any, str]]) -> Optional[Dict[Any, str]]:
    """
    - Convert string keys like "0" to int 0.
    - Drop GPU entries if there is no CUDA device.
    - Leave 'cpu' / 'mps' / 'disk' untouched.
    """
    if not mm:
        return None

    has_cuda = torch.cuda.is_available() and torch.cuda.device_count() > 0
    out: Dict[Any, str] = {}

    for k, v in mm.items():
        # "0" -> 0
        if isinstance(k, str) and k.isdigit():
            k_int = int(k)
            if has_cuda and k_int < torch.cuda.device_count():
                out[k_int] = v
            # if no CUDA or out-of-range GPU id, just skip
        else:
            out[k] = v

    return out or None
    
def configure_precision_flags(params: Dict[str, Any]) -> None:
    has_cuda = torch.cuda.is_available()
    supports_bf16 = has_cuda and torch.cuda.is_bf16_supported()

    # Model / quantization dtype (for from_pretrained)
    if supports_bf16:
        params["model"]["dtype"] = "bf16"
        params["bnb"]["compute_dtype"] = "bf16"
    elif has_cuda:
        params["model"]["dtype"] = "fp16"
        params["bnb"]["compute_dtype"] = "fp16"
    else:
        # pure CPU: keep it fp32
        params["model"]["dtype"] = "fp32"
        params["bnb"]["compute_dtype"] = "fp32"

    # TrainingArguments precision flags – MUST be mutually exclusive
    params["train"]["bf16"] = bool(supports_bf16)
    params["train"]["fp16"] = bool(has_cuda and not supports_bf16)


def load_dataframe(path: str):
    df = pd.read_csv(path)
    drop_cols = PARAMS["data"]["dropna_cols"]
    df = df.dropna(subset=drop_cols).reset_index(drop=True)
    return df

class SafeDict(dict):
    def __missing__(self, key):
        return ""  # default blank for any missing placeholder

def get_first_key(ex, keys):
    """Return first non-empty value among candidate keys from a row/example."""
    for k in keys:
        if k in ex and ex[k] is not None and str(ex[k]).strip() != "":
            return str(ex[k])
    return ""

from typing import Any, Dict, List, Optional
from transformers import AutoTokenizer

# Optional fallback template (only used if tokenizer.chat_template is missing)
LLAMA3_FALLBACK_CHAT_TEMPLATE = """\
<|begin_of_text|>
{%- for message in messages -%}
{%- if message['role'] == 'system' -%}
<|start_header_id|>system<|end_header_id|>
{{ message['content'] }}<|eot_id|>
{%- elif message['role'] == 'user' -%}
<|start_header_id|>user<|end_header_id|>
{{ message['content'] }}<|eot_id|>
{%- elif message['role'] == 'assistant' -%}
<|start_header_id|>assistant<|end_header_id|>
{{ message['content'] }}<|eot_id|>
{%- endif -%}
{%- endfor -%}
{%- if add_generation_prompt -%}
<|start_header_id|>assistant<|end_header_id|>
{%- endif -%}
"""

def _get_first_nonempty(ex: Dict[str, Any], keys: List[str]) -> str:
    for k in keys:
        v = ex.get(k, None)
        if v is None:
            continue
        s = str(v).strip()
        if s != "":
            return s
    return ""

def chat_format_map_fn(
    ex: Dict[str, Any],
    training_prompt: str,
    system_prompt: str,
    few_shots: List[Dict[str, str]],
    tokenizer: AutoTokenizer,
    # If you have a target column in your CSV, list its possible names here:
    target_keys: Optional[List[str]] = None,
) -> Dict[str, str]:
    vals = {
        "Drug": ex.get("Drug", ""),
        "Receptor": ex.get("Receptor", ""),
        "report": ex.get("report", ""),
        "PDB_ID": ex.get("PDB_ID", ""),
    }
    user_text = training_prompt.format_map(SafeDict(vals))

    msgs: List[Dict[str, str]] = []
    if system_prompt and str(system_prompt).strip():
        msgs.append({"role": "system", "content": system_prompt})

    for shot in few_shots or []:
        if shot and "user" in shot and "assistant" in shot:
            msgs.append({"role": "user", "content": shot["user"]})
            msgs.append({"role": "assistant", "content": shot["assistant"]})

    msgs.append({"role": "user", "content": user_text})

    # --- Supervised target (HIGHLY recommended for SFT) ---
    # Put your ground-truth answer in the CSV (e.g., "answer" or "output")
    target_keys = target_keys or ["answer", "output", "response", "assistant", "completion", "label", "gold"]
    target_text = _get_first_nonempty(ex, target_keys)

    if target_text:
        # We include the assistant message, so we should NOT add a generation prompt.
        msgs.append({"role": "assistant", "content": target_text})
        add_gen = False
    else:
        # No ground-truth target column found: keep gen prompt so samples end at assistant header
        add_gen = True

    chat_template = tokenizer.chat_template or LLAMA3_FALLBACK_CHAT_TEMPLATE

    rendered = tokenizer.apply_chat_template(
        msgs,
        tokenize=False,
        add_generation_prompt=add_gen,
        chat_template=chat_template,  # safe even when tokenizer.chat_template exists
    )
    return {"text": rendered}

def plain_format_map_fn(
    ex: Dict[str, Any],
    training_prompt: str,
) -> Dict[str, str]:
    vals = {
        "Drug": ex.get("Drug", ""),
        "Receptor": ex.get("Receptor", ""),
        "report": ex.get("report", ""),
        "PDB_ID": ex.get("PDB_ID", ""),
    }
    text = training_prompt.format_map(SafeDict(vals))
    return {"text": text}




def build_dataset(
    csv_path: str,
    tokenizer: AutoTokenizer,
    params: Dict[str, Any],
    prompts: Dict[str, Any],
) -> Dataset:
    df = load_dataframe(csv_path)

    p = prompts["prompts"]
    training_prompt = p["train_user_template"]
    system_prompt   = p["system"]
    use_chat_format = p.get("use_chat_format", False)
    few_shots       = p.get("few_shots", [])

    cols = list(df.columns)
    ds = Dataset.from_pandas(df, preserve_index=False)

    if use_chat_format:
        # HuggingFace Datasets map passes only the example dict, so we wrap
        # our extra arguments using a closure-like helper via kwargs.
        def _wrapped_chat_map_fn(ex: Dict[str, Any]) -> Dict[str, str]:
            return chat_format_map_fn(
                ex=ex,
                training_prompt=training_prompt,
                system_prompt=system_prompt,
                few_shots=few_shots,
                tokenizer=tokenizer,
            )

        ds = ds.map(_wrapped_chat_map_fn, remove_columns=cols)
    else:
        def _wrapped_plain_map_fn(ex: Dict[str, Any]) -> Dict[str, str]:
            return plain_format_map_fn(
                ex=ex,
                training_prompt=training_prompt,
            )

        ds = ds.map(_wrapped_plain_map_fn, remove_columns=cols)

    ds = ds.train_test_split(
        test_size=params["data"]["test_size"],
        seed=params["data"]["seed"],
    )
    return ds



def get_model_and_tokenizer(params):
    model_id = params["model"]["id"]

    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        use_fast=True,
        trust_remote_code=params["model"].get("trust_remote_code", True),
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # bitsandbytes / QLoRA config
    bnb_cfg = None
    if params["model"].get("load_in_4bit", False):
        bnb_cfg = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type=params["bnb"]["quant_type"],
            bnb_4bit_use_double_quant=params["bnb"]["double_quant"],
            bnb_4bit_compute_dtype=_dtype_from_str(params["bnb"]["compute_dtype"]) or torch.float16,
        )

    # normalize max_memory from YAML
    max_memory = _normalize_max_memory(params["model"].get("max_memory"))

    # use dtype (new API) instead of torch_dtype
    dtype = _dtype_from_str(params["model"].get("dtype"))

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_cfg,
        device_map="auto",
        dtype=dtype,                           
        max_memory=max_memory,
        offload_folder=params["model"]["offload_folder"],
        low_cpu_mem_usage=True,
        trust_remote_code=params["model"]["trust_remote_code"],
    )

    model.config.use_cache = False
    model.gradient_checkpointing_enable()
    return model, tokenizer




def get_lora_wrapped(model, params: Dict[str, Any]):
    lconf = params["lora"]
    lora_cfg = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=lconf["r"],
        lora_alpha=lconf["alpha"],
        lora_dropout=lconf["dropout"],
        target_modules=lconf["target_modules"],
        bias=lconf["bias"],
    )
    model = get_peft_model(model, lora_cfg)
    return model


def main(
    train_csv: str,
    file_paths: Dict[str, Any],
    PARAMS: Dict[str, Any],
    prompts: Dict[str, Any],
) -> None:
    """
    Run SFT training.

    Args:
        train_csv: Path to the training CSV.
        file_paths: Config dict from P3-config.yaml (if you need more paths).
        params: Training hyperparameters and model settings.
        prompts: Prompt templates and chat-format toggles.
    """
    model, tokenizer = get_model_and_tokenizer(PARAMS)
    model = get_lora_wrapped(model, PARAMS)

    ds = build_dataset(train_csv, tokenizer, PARAMS, prompts)


    max_len = PARAMS["data"]["max_seq_length"]

    training_args = SFTConfig(
        output_dir=PARAMS["model"]["output_dir"],
        num_train_epochs=PARAMS["train"]["epochs"],
        per_device_train_batch_size=PARAMS["train"]["per_device_train_batch_size"],
        per_device_eval_batch_size=PARAMS["train"]["per_device_eval_batch_size"],
        gradient_accumulation_steps=PARAMS["train"]["gradient_accumulation_steps"],
        learning_rate=float(PARAMS["train"]["learning_rate"]),
        lr_scheduler_type=PARAMS["train"]["scheduler"],
        warmup_ratio=float(PARAMS["train"]["warmup_ratio"]),
        weight_decay=float(PARAMS["train"]["weight_decay"]),
        logging_steps=PARAMS["train"]["logging_steps"],
        # eval_strategy=PARAMS["train"]["evaluation_strategy"],   # <- NEW (replaces evaluation_strategy)
        # eval_steps=PARAMS["train"]["eval_steps"],
        save_steps=PARAMS["train"]["save_steps"],
        save_total_limit=PARAMS["train"]["save_total_limit"],
        bf16=PARAMS["train"]["bf16"],
        fp16=PARAMS["train"]["fp16"],
        gradient_checkpointing=PARAMS["train"]["gradient_checkpointing"],
        gradient_checkpointing_kwargs={"use_reentrant": False}, # <- NEW to silence Torch 2.5 warning
        packing=PARAMS["data"]["packing"],
        optim=PARAMS["train"]["optim"],
        max_grad_norm=PARAMS["train"]["max_grad_norm"],
        seed=PARAMS["train"]["seed"],
        report_to=PARAMS["train"]["report_to"],
        load_best_model_at_end=PARAMS["train"]["load_best_model"],
    
        # Move these here (don’t pass them to SFTTrainer):
        # max_seq_length=PARAMS["data"]["max_seq_length"],
        dataset_text_field="text",
    )

        
    # early_stopping = EarlyStoppingCallback(
    #     early_stopping_patience=1,      # VERY small because only 3 epochs
    #     early_stopping_threshold=0.0
    # )

    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=ds["train"],
        args=training_args,         # <- contains max_seq_length + dataset_text_field
        # callbacks=[early_stopping],

    )


    trainer.train()
    trainer.save_model()
    tokenizer.save_pretrained(PARAMS["model"]["output_dir"])

In [15]:
if __name__ == "__main__":
    # Load configs from YAML
    file_paths = load_config(yaml_path="P3-config.yaml")
    PARAMS = load_config(yaml_path="training_params.yaml")
    prompts = load_config(yaml_path="prompts.yaml")

    configure_precision_flags(PARAMS)


    # Default to train_path from config, but you can pass any CSV you want
    train_csv_path = file_paths["train_val_paths"]["train_path"]

    main(
        train_csv=train_csv_path,
        file_paths=file_paths,
        PARAMS=PARAMS,
        prompts=prompts,
    )

Map:   0%|          | 0/54 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/43 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/43 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/43 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 128009, 'pad_token_id': 128009}.


Step,Training Loss
