# BERT Fine-tuning on Risk-Factor Text (Melanoma / BCC / SCC / AnySC / KC)
This notebook trains a binary classifier with HuggingFace Transformers on a TSV where each row is a participant:
- `text` is the natural-language encoding of structured variables
- labels: `label_melanoma`, `label_bcc`, `label_scc`, `label_any_skin_cancer`, `label_keratinocyte_cancer`


In [None]:
# ==========================================================
# 5-Fold (4 folds train + 1 fold validation) with Early Stopping
# Model choices: BERT / BioBERT / ClinicalBERT (always fine-tune)
# Safe loader for torch<2.6 (CVE) with safetensors-first & auto-fallback
# ==========================================================
# Optional installation:
# !pip install torch transformers scikit-learn pandas numpy matplotlib safetensors

import os, json, time, random, re, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, accuracy_score

from packaging.version import Version
from transformers import AutoTokenizer, AutoModelForSequenceClassification
try:
    from transformers import get_linear_schedule_with_warmup
except Exception:
    get_linear_schedule_with_warmup = None  # Compatibility with very old versions

# ------------------- Config -------------------
TSV_PATH   = r"D:\All projects_250223\Phototherapy_and_ AI\Skin cancer dataset Kalia clinic\bert_inputs.tsv"   # Local path to the TSV input file
LABEL_NAME = "label_melanoma"                 # Options: label_melanoma / label_bcc / label_scc / label_any_skin_cancer / label_keratinocyte_cancer

# Choose a pretrained model (select one of the following):
# "bert" -> bert-base-uncased   (has safetensors)
# "biobert" -> dmis-lab/biobert-base-cased-v1.1       (no safetensors)
# "clinicalbert" -> emilyalsentzer/Bio_ClinicalBERT   (no safetensors)
MODEL_CHOICE = "clinicalbert"   # <<< Change model choice here: "bert" | "biobert" | "clinicalbert"

MODEL_VARIANTS = {
    "bert": "bert-base-uncased",
    "biobert": "dmis-lab/biobert-base-cased-v1.1",
    "clinicalbert": "emilyalsentzer/Bio_ClinicalBERT",
}
assert MODEL_CHOICE in MODEL_VARIANTS, f"MODEL_CHOICE must be one of {list(MODEL_VARIANTS.keys())}"
MODEL_NAME = MODEL_VARIANTS[MODEL_CHOICE]

# Whether to automatically fall back to bert-base-uncased when torch < 2.6
# and the selected model does not provide safetensors weights

AUTO_FALLBACK_TO_BERT = True

MAX_LENGTH = 256
BATCH_SIZE = 16
EPOCHS     = 200              # Set high; rely on early stopping to terminate training
LR         = 2e-5
WARMUP_RATIO = 0.06
WEIGHT_DECAY = 0.01
MAX_GRAD_NORM = 1.0

SEED      = 42
N_SPLITS  = 5
PATIENCE  = 30                 # Stop training if validation metric does not improve for PATIENCE consecutive epochs
OUTPUT_DIR_BASE = f"./cv_models_foldval_{MODEL_CHOICE}"
os.makedirs(OUTPUT_DIR_BASE, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE, "| Requested model:", MODEL_NAME, "| torch:", torch.__version__)

# ------------------- Utils -------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=256):
        self.labels = labels.astype(int).tolist()
        self.enc = tokenizer(
            list(texts),
            truncation=True,
            padding=False,              # Do not pad here; padding is handled in collate_fn
            max_length=max_length,
            return_tensors=None
        )
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {
            'input_ids': torch.tensor(self.enc['input_ids'][idx], dtype=torch.long),
            'attention_mask': torch.tensor(self.enc['attention_mask'][idx], dtype=torch.long),
            'labels': torch.tensor(self.labels[idx], dtype=torch.long)
        }
        if 'token_type_ids' in self.enc:
            item['token_type_ids'] = torch.tensor(self.enc['token_type_ids'][idx], dtype=torch.long)
        return item

def collate_fn(batch):
    """
    Pad only variable-length sequence fields; labels are scalars and stacked directly.
    """
    out = {'labels': torch.tensor([b['labels'] for b in batch], dtype=torch.long)}
    for k in batch[0].keys():
        if k == 'labels':
            continue
        seqs = [b[k] for b in batch]  # list[1D tensor]
        out[k] = torch.nn.utils.rnn.pad_sequence(seqs, batch_first=True, padding_value=0)
    return out

def compute_metrics_from_probs(y_true, probs):
    preds = (probs >= 0.5).astype(int)
    out = {}
    try: out["auroc"] = float(roc_auc_score(y_true, probs))
    except: out["auroc"] = float("nan")
    try: out["auprc"] = float(average_precision_score(y_true, probs))
    except: out["auprc"] = float("nan")
    out["f1"] = float(f1_score(y_true, preds))
    out["accuracy"] = float(accuracy_score(y_true, preds))
    return out

def bar_plot_class_balance(df, label_name):
    vc = df[label_name].value_counts(dropna=False).sort_index()
    labels = [str(k) for k in vc.index]
    vals   = [int(v) for v in vc.values]
    plt.figure(figsize=(5,4))
    plt.bar(labels, vals)
    plt.title(f"Class Balance: {label_name}")
    plt.xlabel("Class"); plt.ylabel("Count")
    plt.show()
    print(vc)

# -------- Safe loader to handle torch<2.6 & CVE (prefer safetensors) --------
def safe_load_tokenizer(model_id):
    # Tokenizers are not affected by torch.load vulnerabilities
    return AutoTokenizer.from_pretrained(model_id, use_fast=True)

def _is_cve_torch_error(e: Exception) -> bool:
    msg = str(e)
    return "vulnerability" in msg.lower() and "torch.load" in msg.lower() or "CVE-2025-32434" in msg

def safe_load_sequence_classifier(model_id, num_labels=2):
    """
    1) Prefer loading safetensors weights
    2) If loading fails due to torch<2.6 CVE issues and fallback is allowed, switch to bert-base-uncased (which provides safetensors)
    3) Re-raise all other errors unchanged
    """
    try:
        return AutoModelForSequenceClassification.from_pretrained(
            model_id,
            num_labels=num_labels,
            use_safetensors=True,   # Important, looking for safetensors
        )
    except Exception as e:
        if _is_cve_torch_error(e) and Version(torch.__version__) < Version("2.6"):
            # Indicates that the model lacks safetensors weights and local torch < 2.6
            warn = (f"[SAFELOAD] '{model_id}' may not provide safetensors weights, and your local torch={torch.__version__} < 2.6ï¼Œ"
                    f"Due to CVE-2025-32434 loading bin weights cannot be performed safely.")
            if AUTO_FALLBACK_TO_BERT:
                print(warn + " utomatically falling back to 'bert-base-uncased'(which provides safetensors) to continue training")
                return AutoModelForSequenceClassification.from_pretrained(
                    "bert-base-uncased", num_labels=num_labels, use_safetensors=True
                )
            else:
                raise RuntimeError(warn + " Please upgrade PyTorch to >= 2.6 or switch to a model that provides safetensors weights (e.g., bert-base-uncased).") from e
        else:
            # Not triggered by the CVE; re-raise the original exception to aid debugging
            raise

# ------------------- Load data -------------------
set_seed(SEED)
df = pd.read_csv(TSV_PATH, sep="\t")
assert "text" in df.columns, "TSV must contain 'text'."
assert LABEL_NAME in df.columns, f"TSV must contain '{LABEL_NAME}'."
df = df.dropna(subset=[LABEL_NAME]).copy().reset_index(drop=True)
df[LABEL_NAME] = df[LABEL_NAME].astype(int)
print("Columns:", list(df.columns))
print("Shape:", df.shape)
bar_plot_class_balance(df, LABEL_NAME)

# ------------------- Train / Eval loops -------------------
def train_one_epoch(model, loader, optimizer, scheduler, loss_fct, scaler=None):
    model.train()
    total_loss = 0.0
    for batch in loader:
        for k in batch:
            batch[k] = batch[k].to(DEVICE)
        optimizer.zero_grad(set_to_none=True)

        if scaler is not None:
            with torch.cuda.amp.autocast():
                outputs = model(**{k: batch[k] for k in batch if k != "labels"})
                logits = outputs.logits  # [B,2]
                loss = loss_fct(logits, batch["labels"])
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            scaler.step(optimizer)
            if scheduler is not None:
                scheduler.step()
            scaler.update()
        else:
            outputs = model(**{k: batch[k] for k in batch if k != "labels"})
            logits = outputs.logits
            loss = loss_fct(logits, batch["labels"])
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            optimizer.step()
            if scheduler is not None:
                scheduler.step()

        total_loss += loss.item() * batch["labels"].size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def eval_get_probs(model, loader):
    model.eval()
    probs = []
    ys = []
    for batch in loader:
        for k in batch:
            batch[k] = batch[k].to(DEVICE)
        outputs = model(**{k: batch[k] for k in batch if k != "labels"})
        logits = outputs.logits  # [B,2]
        p = torch.softmax(logits, dim=1)[:,1].detach().cpu().numpy()
        probs.append(p)
        ys.append(batch["labels"].detach().cpu().numpy())
    probs = np.concatenate(probs, axis=0)
    ys = np.concatenate(ys, axis=0)
    return ys, probs

def build_class_weight_loss(y_int, device):
    # CrossEntropyLoss(weight=[w0,w1]) where weights are inverse frequency
    pos = int(y_int.sum())
    neg = int(len(y_int) - pos)
    if pos == 0 or neg == 0:
        return nn.CrossEntropyLoss().to(device), None
    total = pos + neg
    w0 = total / (2.0 * neg)
    w1 = total / (2.0 * pos)
    weights = torch.tensor([w0, w1], dtype=torch.float32, device=device)
    return nn.CrossEntropyLoss(weight=weights).to(device), (w0, w1, pos, neg)

# ------------------- Fold training with Early Stopping (validation = held-out fold) -------------------
def train_fold(train_df, val_df, fold_id, tokenizer, model_name):
    # Datasets and data loaders
    ds_trn = TextDataset(train_df["text"], train_df[LABEL_NAME], tokenizer, MAX_LENGTH)
    ds_val = TextDataset(val_df["text"],   val_df[LABEL_NAME],   tokenizer, MAX_LENGTH)

    dl_trn = DataLoader(ds_trn, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collate_fn, num_workers=0)
    dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=0)

    # Pretrained model with fine-tuning (always fine-tune; prefer safetensors)
    model = safe_load_sequence_classifier(model_name, num_labels=2)
    model.to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    # scheduler
    total_steps = EPOCHS * len(dl_trn)
    if get_linear_schedule_with_warmup is not None:
        warmup_steps = int(WARMUP_RATIO * total_steps)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    else:
        scheduler = None

    # Class-weighted loss (computed from training fold only)
    loss_fct, cls_info = build_class_weight_loss(train_df[LABEL_NAME].values.astype(int), DEVICE)
    if cls_info:
        w0, w1, pos, neg = cls_info
        print(f"[Fold {fold_id}] Class weights -> w0={w0:.4f}, w1={w1:.4f} (neg={neg}, pos={pos})")
    else:
        print(f"[Fold {fold_id}] Class weights skipped (single-class in train).")

    # Automatic Mixed Precision (if available)
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    # Early stopping based on validation AUPRC
    best_metric = -np.inf
    best_state = None
    best_epoch = -1
    bad_rounds = 0

    for epoch in range(1, EPOCHS+1):
        t0 = time.time()
        train_loss = train_one_epoch(model, dl_trn, optimizer, scheduler, loss_fct, scaler)
        y_val, p_val = eval_get_probs(model, dl_val)
        m_val = compute_metrics_from_probs(y_val, p_val)
        elapsed = time.time() - t0
        print(f"[Fold {fold_id}] Epoch {epoch}/{EPOCHS} | train_loss={train_loss:.4f} | VAL AUPRC={m_val['auprc']:.4f} AUROC={m_val['auroc']:.4f} | {elapsed:.1f}s")

        # Use AUPRC for early stopping; fall back to AUROC if AUPRC is NaN
        score = m_val["auprc"]
        if np.isnan(score):
            score = m_val["auroc"]

        if score > best_metric:
            best_metric = score
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            best_epoch = epoch
            bad_rounds = 0
        else:
            bad_rounds += 1

        if bad_rounds >= PATIENCE:
            print(f"[Fold {fold_id}] Early stopping at epoch {epoch} (best epoch {best_epoch}, best metric {best_metric:.4f})")
            break

    # Restore best-performing model parameters
    if best_state is not None:
        model.load_state_dict(best_state)

    # Evaluate best model on the validation fold
    y_fold, p_fold = eval_get_probs(model, dl_val)
    m_fold = compute_metrics_from_probs(y_fold, p_fold)
    print(f"[Fold {fold_id}] Best on VAL | AUPRC={m_fold['auprc']:.4f} AUROC={m_fold['auroc']:.4f} F1={m_fold['f1']:.4f} ACC={m_fold['accuracy']:.4f} | best_epoch={best_epoch}")

    return model, tokenizer, p_fold, m_fold

# ------------------- Run 5-fold -------------------
set_seed(SEED)
y_all = df[LABEL_NAME].values
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)

# Tokenizers are not affected by the torch.load CVE and can be loaded safely
tokenizer = safe_load_tokenizer(MODEL_NAME)

fold_metrics = []
oof_probs = np.zeros(len(df), dtype=float)

for fold_id, (trn_idx, val_idx) in enumerate(skf.split(np.zeros(len(y_all)), y_all), start=1):
    print(f"\n======== Fold {fold_id}/{N_SPLITS} ========")
    # 4 folds for training, 1 fold for validation
    df_train = df.iloc[trn_idx].reset_index(drop=True).copy()
    df_valid = df.iloc[val_idx].reset_index(drop=True).copy()

    model, tok, p_fold, m_fold = train_fold(df_train, df_valid, fold_id, tokenizer, MODEL_NAME)
    fold_metrics.append(m_fold)

    # Save the best model and tokenizer for this fold
    out_dir = os.path.join(OUTPUT_DIR_BASE, f"fold_{fold_id}")
    os.makedirs(out_dir, exist_ok=True)
    model.save_pretrained(out_dir)
    tok.save_pretrained(out_dir)
    with open(os.path.join(out_dir, "fold_metrics.json"), "w") as f:
        json.dump(m_fold, f, indent=2)

    # Store out-of-fold probabilities (val_idx corresponds to original df indices)
    oof_probs[val_idx] = p_fold

# Aggregate metrics across folds
def agg(key):
    vals = np.array([m[key] for m in fold_metrics], dtype=float)
    return float(np.nanmean(vals)), float(np.nanstd(vals))

print("\n====== 5-fold Summary (VAL per fold) ======")
print("AUPRC : mean={:.4f}, std={:.4f}".format(*agg("auprc")))
print("AUROC : mean={:.4f}, std={:.4f}".format(*agg("auroc")))
print("F1    : mean={:.4f}, std={:.4f}".format(*agg("f1")))
print("ACC   : mean={:.4f}, std={:.4f}".format(*agg("accuracy")))

# OOF evaluation (predictions from each fold's validation set)
y_true = df[LABEL_NAME].values.astype(int)
oof_metrics = compute_metrics_from_probs(y_true, oof_probs)
print("\n====== OOF Metrics ======")
print(json.dumps(oof_metrics, indent=2))

# Save OOF
np.save(os.path.join(OUTPUT_DIR_BASE, f"oof_probs_{LABEL_NAME}.npy"), oof_probs)
df.assign(oof_prob=oof_probs).to_csv(os.path.join(OUTPUT_DIR_BASE, f"oof_{LABEL_NAME}.csv"), index=False)
print(f"Saved models & OOF to: {OUTPUT_DIR_BASE}")

# ------------------- Ensemble Inference -------------------
@torch.no_grad()
def predict_proba_single_model(texts, model_path, max_length=MAX_LENGTH, device=DEVICE):
    tok = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    mdl = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=2, use_safetensors=True).to(device).eval()
    enc = tok(list(texts), truncation=True, padding=True, max_length=max_length, return_tensors="pt")
    for k in enc: enc[k] = enc[k].to(device)
    logits = mdl(**enc).logits
    probs = torch.softmax(logits, dim=1)[:,1].detach().cpu().numpy()
    return probs

def predict_proba_ensemble(texts, base_dir=OUTPUT_DIR_BASE, n_splits=N_SPLITS, max_length=MAX_LENGTH, device=DEVICE):
    probs_list = []
    for k in range(1, n_splits+1):
        fold_dir = os.path.join(base_dir, f"fold_{k}")
        if not os.path.isdir(fold_dir):
            raise FileNotFoundError(f"Missing model dir: {fold_dir}")
        probs_list.append(predict_proba_single_model(texts, fold_dir, max_length=max_length, device=device))
    return np.vstack(probs_list).mean(axis=0)

# #  Example usage:
# sample_texts = ["Age: 40s. Sex: male. ...", "Age: 60s. Sex: female. ..."]
# ens_probs = predict_proba_ensemble(sample_texts)
# print("Ensemble probs:", ens_probs)
