In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Decision Tree (DT) Pipeline for Tumor T-Cell Antigen Classification — Hardened
-----------------------------------------------------------------------------
- Loads CSV; expects a binary label column (default: 'label').
- Robust handling for string labels (explicit --label_map or 2-class auto-map).
- Uses ONLY numeric features (drops non-numeric with a warning).
- 80/20 stratified train/holdout split.
- Stratified K-Fold CV (auto-reduces folds if the minority class is tiny).
- **No scaling** (trees are scale-invariant).
- Optional SMOTE on training folds only (and final train) with adaptive k_neighbors.
- DecisionTreeClassifier with predict_proba for ROC/AUC.
- Metrics: ACC, SN, SP, AUC, MCC.
- Saves: cv_fold_metrics.csv, cv_summary.csv, holdout_metrics.json,
         confusion_matrix.png, roc_curve.png, holdout_predictions.csv,
         and model (joblib) with feature names & metadata.

Run examples:
  python dt_pipeline.py --csv_path data.csv --label_col label --out_dir ./dt_outputs
  python dt_pipeline.py --csv_path data.csv --label_map '{"neg":0,"pos":1}' \
      --class_weight balanced --max_depth 8 --min_samples_leaf 5 --ccp_alpha 0.0
"""

import os
import json
import argparse
import warnings
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    matthews_corrcoef,
    confusion_matrix,
    roc_curve,
)
from sklearn.tree import DecisionTreeClassifier

from imblearn.over_sampling import SMOTE
from joblib import dump

import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

# ===============================
# Helpers
# ===============================

def ensure_out_dir(path: str):
    os.makedirs(path, exist_ok=True)


def to_bool(x: str) -> bool:
    return str(x).lower() in {"1", "true", "yes", "y", "t"}


def safe_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    if cm.shape != (2, 2):
        fixed = np.zeros((2, 2), dtype=int)
        fixed[: cm.shape[0], : cm.shape[1]] = cm
        cm = fixed
    return cm


def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
    y_pred = (y_prob >= 0.5).astype(int)
    cm = safe_confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()
    acc = (tp + tn) / max(tp + tn + fp + fn, 1)
    sn = tp / (tp + fn + 1e-9)
    sp = tn / (tn + fp + 1e-9)
    try:
        auc = roc_auc_score(y_true, y_prob)
    except ValueError:
        auc = float("nan")
    mcc = matthews_corrcoef(y_true, y_pred) if (tp + tn + fp + fn) > 0 else 0.0
    return {"ACC": acc, "SN": sn, "SP": sp, "AUC": auc, "MCC": mcc}


def plot_confusion_matrix(cm: np.ndarray, out_path: str, class_names: List[str]):
    fig, ax = plt.subplots(figsize=(6, 5))
    _ = ax.imshow(cm, interpolation="nearest", cmap="Blues")
    ax.set_title("Confusion Matrix (Decision Tree)")
    ax.set_xticks([0, 1]); ax.set_yticks([0, 1])
    ax.set_xticklabels([f"Predicted {c}" for c in class_names], rotation=15, ha="right")
    ax.set_yticklabels([f"Actual {c}" for c in class_names])
    for (i, j), val in np.ndenumerate(cm):
        ax.text(j, i, f"{int(val)}", ha="center", va="center", fontsize=12)
    fig.tight_layout(); fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def plot_roc(y_true: np.ndarray, y_score: np.ndarray, out_path: str, label="Decision Tree"):
    try:
        fpr, tpr, _ = roc_curve(y_true, y_score)
        auc = roc_auc_score(y_true, y_score)
    except ValueError:
        fpr, tpr, auc = [0, 1], [0, 1], float("nan")
    fig, ax = plt.subplots(figsize=(6, 5))
    ax.plot(fpr, tpr, linewidth=2, label=f"{label} (AUC={auc:.2f})")
    ax.plot([0, 1], [0, 1], "--", linewidth=1)
    ax.set_xlabel("False Positive Rate"); ax.set_ylabel("True Positive Rate")
    ax.set_title("ROC Curve (Holdout)"); ax.legend(loc="lower right")
    fig.tight_layout(); fig.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)


def pick_numeric_features(df: pd.DataFrame, label_col: str) -> Tuple[pd.DataFrame, List[str]]:
    feature_df = df.drop(columns=[label_col])
    numeric_df = feature_df.select_dtypes(include=[np.number])
    dropped = sorted(set(feature_df.columns) - set(numeric_df.columns))
    if dropped:
        print(f"[Warn] Dropping non-numeric columns (not used as features): {dropped}")
    return numeric_df, list(numeric_df.columns)


def parse_class_names(arg: Optional[str]) -> List[str]:
    default = ["Non-Tumor", "Tumor"]
    if not arg:
        return default
    parts = [p.strip() for p in arg.split(",")]
    if len(parts) != 2:
        print(f"[Warn] --class_names expects exactly two names; using default {default}.")
        return default
    return parts


def minority_count(y: np.ndarray) -> int:
    uniq, counts = np.unique(y, return_counts=True)
    return int(counts.min()) if len(counts) else 0


def best_smote(y: np.ndarray, seed: int) -> SMOTE:
    m = minority_count(y)
    k = max(1, min(5, m - 1))
    return SMOTE(random_state=seed, k_neighbors=k)


def auto_folds(y_tr: np.ndarray, desired: int = 5) -> int:
    m = minority_count(y_tr)
    return max(2, min(desired, m))

# ===============================
# Core Pipeline
# ===============================

def run(
    csv_path: str,
    label_col: str = "label",
    label_map: str = None,
    out_dir: str = "./dt_outputs",
    use_smote: bool = True,
    class_weight: str = "none",  # 'balanced' or 'none'
    criterion: str = "gini",      # 'gini' or 'entropy' or 'log_loss'
    max_depth: Optional[int] = None,
    min_samples_split: int = 2,
    min_samples_leaf: int = 1,
    max_features: Optional[str] = None,  # 'sqrt', 'log2', None, or int/float
    ccp_alpha: float = 0.0,              # cost-complexity pruning
    class_names_arg: Optional[str] = None,
    seed: int = 42,
):
    np.random.seed(seed)
    ensure_out_dir(out_dir)

    # Load data
    df = pd.read_csv(csv_path)
    if label_col not in df.columns:
        raise ValueError(f"Label column '{label_col}' not found in CSV.")

    # Labels
    y_raw = df[label_col]
    if label_map:
        mapping = json.loads(label_map)
        y = y_raw.map(mapping)
    else:
        if pd.api.types.is_numeric_dtype(y_raw):
            y = y_raw
        else:
            uniq = y_raw.dropna().unique()
            if len(uniq) == 2:
                keys = sorted(list(uniq), key=lambda v: str(v))
                mapping = {keys[0]: 0, keys[1]: 1}
                print(f"[Info] Auto label_map inferred: {mapping}")
                y = y_raw.map(mapping)
            else:
                raise ValueError(
                    "Label column is non-numeric and has !=2 unique values. "
                    "Provide --label_map, e.g. '{\"neg\":0,\"pos\":1}'."
                )
    if y.isna().any():
        raise ValueError("Label mapping produced NaNs. Check --label_map and label values.")
    y = y.astype(int).values

    # Features (numeric only)
    X_df, feature_names = pick_numeric_features(df, label_col)
    X = X_df.values.astype(np.float32)

    print(f"[Info] Data: X={X.shape}, positives={int(y.sum())}, negatives={int((y==0).sum())}")

    # Holdout split
    X_tr, X_te, y_tr, y_te = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=seed
    )
    print(f"[Info] Holdout split => train={X_tr.shape[0]}, test={X_te.shape[0]}")

    # CV
    n_splits = auto_folds(y_tr, desired=5)
    if n_splits < 5:
        print(f"[Warn] Reduced CV folds to {n_splits} due to limited minority samples.")
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

    cv_metrics = []

    cw = None if str(class_weight).lower() == "none" else "balanced"

    for fold, (idx_tr, idx_va) in enumerate(skf.split(X_tr, y_tr), start=1):
        X_tr_fold, X_va_fold = X_tr[idx_tr], X_tr[idx_va]
        y_tr_fold, y_va_fold = y_tr[idx_tr], y_tr[idx_va]

        # SMOTE on training fold only (optional)
        if use_smote:
            sm = best_smote(y_tr_fold, seed)
            X_tr_fold, y_tr_fold = sm.fit_resample(X_tr_fold, y_tr_fold)
            print(f"[Fold {fold}] After SMOTE: X={X_tr_fold.shape}, pos={int(y_tr_fold.sum())}")

        clf = DecisionTreeClassifier(
            criterion=criterion,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            max_features=max_features,
            class_weight=cw,
            ccp_alpha=ccp_alpha,
            random_state=seed,
        )
        clf.fit(X_tr_fold, y_tr_fold)

        y_va_prob = clf.predict_proba(X_va_fold)[:, 1]
        fold_m = compute_metrics(y_va_fold, y_va_prob)
        cv_metrics.append(fold_m)
        print(f"[Fold {fold}] {fold_m}")

    # CV summary
    cv_df = pd.DataFrame(cv_metrics)
    cv_df.to_csv(os.path.join(out_dir, "data.csv"), index=False)
    cv_summary = cv_df.agg(["mean", "std"]).T
    cv_summary.to_csv(os.path.join(out_dir, "cv_summary.csv"))
    print("\n[CV Summary]\n", cv_summary)

    # Final train & holdout eval
    if use_smote:
        sm = best_smote(y_tr, seed)
        X_tr_bal, y_tr_bal = sm.fit_resample(X_tr, y_tr)
    else:
        X_tr_bal, y_tr_bal = X_tr, y_tr

    clf_f = DecisionTreeClassifier(
        criterion=criterion,
        max_depth=max_depth,
        min_samples_split=min_samples_split,
        min_samples_leaf=min_samples_leaf,
        max_features=max_features,
        class_weight=cw,
        ccp_alpha=ccp_alpha,
        random_state=seed,
    )
    clf_f.fit(X_tr_bal, y_tr_bal)

    y_te_prob = clf_f.predict_proba(X_te)[:, 1]
    holdout = compute_metrics(y_te, y_te_prob)
    with open(os.path.join(out_dir, "holdout_metrics.json"), "w") as f:
        json.dump(holdout, f, indent=2)
    print("\n[Holdout Metrics]\n", holdout)

    # Plots
    y_te_pred = (y_te_prob >= 0.5).astype(int)
    cm = safe_confusion_matrix(y_te, y_te_pred)
    class_names = parse_class_names(class_names_arg)

    plot_confusion_matrix(cm, os.path.join(out_dir, "confusion_matrix.png"), class_names)
    plot_roc(y_te, y_te_prob, os.path.join(out_dir, "roc_curve.png"), label="Decision Tree")

    # Save artifacts
    artifact = {
        "dt": clf_f,
        "feature_names": feature_names,
        "label_col": label_col,
        "class_names": class_names,
        "params": {
            "criterion": criterion,
            "max_depth": max_depth,
            "min_samples_split": min_samples_split,
            "min_samples_leaf": min_samples_leaf,
            "max_features": max_features,
            "class_weight": cw,
            "ccp_alpha": ccp_alpha,
            "use_smote": use_smote,
        },
    }
    dump(artifact, os.path.join(out_dir, "dt_model.joblib"))

    # Save holdout predictions
    pd.DataFrame({
        "y_true": y_te,
        "y_prob": y_te_prob,
        "y_pred": y_te_pred,
    }).to_csv(os.path.join(out_dir, "holdout_predictions.csv"), index=False)

    print(f"[Done] Outputs saved to: {out_dir}")


# ===============================
# CLI
# ===============================

def build_argparser():
    ap = argparse.ArgumentParser(description="Decision Tree pipeline with optional SMOTE and ROC plotting (hardened).")
    ap.add_argument("--csv_path", type=str, required=True, help="Path to local CSV file.")
    ap.add_argument("--label_col", type=str, default="label", help="Name of label column (0/1).")
    ap.add_argument("--label_map", type=str, default=None, help='Optional JSON mapping, e.g. {"neg":0,"pos":1}')
    ap.add_argument("--out_dir", type=str, default="./dt_outputs", help="Output directory.")
    ap.add_argument("--smote", type=str, default="true", help="Apply SMOTE on train folds (true/false).")
    ap.add_argument("--class_weight", type=str, default="none", choices=["none", "balanced"], help='Class weighting for imbalance.')
    ap.add_argument("--criterion", type=str, default="gini", choices=["gini", "entropy", "log_loss"], help="Split criterion.")
    ap.add_argument("--max_depth", type=int, default=None, help="Max depth of the tree (None for full growth).")
    ap.add_argument("--min_samples_split", type=int, default=2, help="Min samples to split an internal node.")
    ap.add_argument("--min_samples_leaf", type=int, default=1, help="Min samples at a leaf node.")
    ap.add_argument("--max_features", type=str, default=None, help="Number of features to consider when looking for the best split (e.g., 'sqrt', 'log2', None).")
    ap.add_argument("--ccp_alpha", type=float, default=0.0, help="Complexity parameter used for Minimal Cost-Complexity Pruning.")
    ap.add_argument("--class_names", type=str, default=None, help='Comma-separated names for classes in plots, e.g. "Non-Tumor,Tumor"')
    ap.add_argument("--seed", type=int, default=42, help="Random seed.")
    return ap


if __name__ == "__main__":
    args = build_argparser().parse_args()
    run(
        csv_path=args.csv_path,
        label_col=args.label_col,
        label_map=args.label_map,
        out_dir=args.out_dir,
        use_smote=to_bool(args.smote),
        class_weight=args.class_weight,
        criterion=args.criterion,
        max_depth=args.max_depth,
        min_samples_split=args.min_samples_split,
        min_samples_leaf=args.min_samples_leaf,
        max_features=args.max_features,
        ccp_alpha=args.ccp_alpha,
        class_names_arg=args.class_names,
        seed=args.seed,
    )
