In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
SVM Pipeline for Tumor T-Cell Antigen Classification
----------------------------------------------------
- Loads a CSV dataset from local path.
- Expects a binary label column (default: 'label').
- Works with any numeric feature columns (precomputed features welcome).
- 80/20 stratified train/holdout split.
- 5-fold Stratified CV on the training split.
- SMOTE applied **only** to training folds (and final train) to avoid leakage.
- Standardization (StandardScaler) inside CV.
- SVM (RBF kernel) with probability=True for ROC/AUC.
- Metrics: ACC, SN (recall+), SP (specificity), AUC, MCC.
- Saves: cv_fold_metrics.csv, cv_summary.csv, holdout_metrics.json,
         confusion_matrix.png, roc_curve.png, and model (joblib).

Run:
  python svm_pipeline.py --csv_path /path/to/data.csv --label_col label --out_dir ./svm_outputs

If labels are strings, map them with:
  --label_map '{"negative":0,"positive":1}'
"""

# ===============================
# Imports
# ===============================
import os
import json
import argparse
import warnings
from typing import Dict

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (accuracy_score, roc_auc_score, matthews_corrcoef,
                             confusion_matrix, roc_curve)
from sklearn.svm import SVC

from imblearn.over_sampling import SMOTE
from joblib import dump

import matplotlib.pyplot as plt

warnings.filterwarnings('ignore')


# ===============================
# Helpers
# ===============================
def ensure_out_dir(path: str):
    os.makedirs(path, exist_ok=True)

def to_bool(x: str) -> bool:
    return str(x).lower() in {'1','true','yes','y','t'}

def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
    """ACC, SN, SP, AUC, MCC using 0.5 threshold for class prediction."""
    y_pred = (y_prob >= 0.5).astype(int)
    acc = accuracy_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    tn, fp, fn, tp = cm.ravel()
    sn = tp / (tp + fn + 1e-9)  # sensitivity (recall+)
    sp = tn / (tn + fp + 1e-9)  # specificity
    auc = roc_auc_score(y_true, y_prob)
    mcc = matthews_corrcoef(y_true, y_pred)
    return {"ACC": acc, "SN": sn, "SP": sp, "AUC": auc, "MCC": mcc}

def plot_confusion_matrix(cm: np.ndarray, out_path: str, class_names = ["Non-Tumor","Tumor"]):
    fig, ax = plt.subplots(figsize=(6,5))
    im = ax.imshow(cm, interpolation='nearest', cmap='Blues')
    ax.set_title('Confusion Matrix (SVM)')
    ax.set_xticks([0,1]); ax.set_yticks([0,1])
    ax.set_xticklabels([f'Predicted {c}' for c in class_names], rotation=15, ha='right')
    ax.set_yticklabels([f'Actual {c}' for c in class_names])
    for (i,j), val in np.ndenumerate(cm):
        ax.text(j, i, f'{int(val)}', ha='center', va='center', fontsize=12)
    fig.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.close(fig)

def plot_roc(y_true: np.ndarray, y_score: np.ndarray, out_path: str, label='SVM (RBF)'):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc = roc_auc_score(y_true, y_score)
    fig, ax = plt.subplots(figsize=(6,5))
    ax.plot(fpr, tpr, linewidth=2, label=f'{label} (AUC={auc:.2f})')
    ax.plot([0,1],[0,1],'--', linewidth=1)
    ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate')
    ax.set_title('ROC Curve (Holdout)')
    ax.legend(loc='lower right')
    fig.tight_layout()
    fig.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.close(fig)


# ===============================
# Core Pipeline
# ===============================
def run(csv_path: str,
        label_col: str = 'label',
        label_map: str = None,
        out_dir: str = './svm_outputs',
        use_smote: bool = True,
        C: float = 1.0,
        gamma: str = 'scale',
        seed: int = 42):

    ensure_out_dir(out_dir)

    # Load data
    df = pd.read_csv(csv_path)
    if label_col not in df.columns:
        raise ValueError(f"Label column '{label_col}' not found in CSV.")

    y = df[label_col]
    if label_map:
        mapping = json.loads(label_map)
        y = y.map(mapping)
    y = y.astype(int).values

    X = df.drop(columns=[label_col]).values.astype(np.float32)
    print(f"[Info] Data: X={X.shape}, positives={int(y.sum())}, negatives={int((y==0).sum())}")

    # Holdout split
    X_tr, X_te, y_tr, y_te = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=seed
    )
    print(f"[Info] Holdout split => train={X_tr.shape[0]}, test={X_te.shape[0]}")

    # 5-fold CV on training split
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    cv_metrics = []

    for fold, (idx_tr, idx_va) in enumerate(skf.split(X_tr, y_tr), start=1):
        X_tr_fold, X_va_fold = X_tr[idx_tr], X_tr[idx_va]
        y_tr_fold, y_va_fold = y_tr[idx_tr], y_tr[idx_va]

        # Scale
        scaler = StandardScaler()
        X_tr_fold = scaler.fit_transform(X_tr_fold)
        X_va_fold = scaler.transform(X_va_fold)

        # SMOTE on training fold only
        if use_smote:
            sm = SMOTE(random_state=seed)
            X_tr_fold, y_tr_fold = sm.fit_resample(X_tr_fold, y_tr_fold)
            print(f"[Fold {fold}] After SMOTE: {X_tr_fold.shape}, pos={y_tr_fold.sum()}")

        # SVM (RBF)
        clf = SVC(C=C, gamma=gamma, kernel='rbf', probability=True, random_state=seed)
        clf.fit(X_tr_fold, y_tr_fold)

        # Validation metrics
        y_va_prob = clf.predict_proba(X_va_fold)[:,1]
        fold_m = compute_metrics(y_va_fold, y_va_prob)
        cv_metrics.append(fold_m)
        print(f"[Fold {fold}] {fold_m}")

    # CV summary
    cv_df = pd.DataFrame(cv_metrics)
    cv_df.to_csv(os.path.join(out_dir, 'data.csv'), index=False)
    cv_summary = cv_df.agg(['mean','std']).T
    cv_summary.to_csv(os.path.join(out_dir, 'data.csv'))
    print("\n[CV Summary]\n", cv_summary)

    # Final training on full train split & evaluate on holdout
    scaler_f = StandardScaler()
    X_tr_scaled = scaler_f.fit_transform(X_tr)
    X_te_scaled = scaler_f.transform(X_te)

    if use_smote:
        sm = SMOTE(random_state=seed)
        X_tr_scaled, y_tr = sm.fit_resample(X_tr_scaled, y_tr)

    clf_f = SVC(C=C, gamma=gamma, kernel='rbf', probability=True, random_state=seed)
    clf_f.fit(X_tr_scaled, y_tr)

    # Holdout evaluation
    y_te_prob = clf_f.predict_proba(X_te_scaled)[:,1]
    holdout = compute_metrics(y_te, y_te_prob)
    with open(os.path.join(out_dir, 'holdout_metrics.json'), 'w') as f:
        json.dump(holdout, f, indent=2)
    print("\n[Holdout Metrics]\n", holdout)

    # Plots
    y_te_pred = (y_te_prob >= 0.5).astype(int)
    cm = confusion_matrix(y_te, y_te_pred, labels=[0,1])
    plot_confusion_matrix(cm, os.path.join(out_dir, 'confusion_matrix.png'))
    plot_roc(y_te, y_te_prob, os.path.join(out_dir, 'roc_curve.png'))

    # Save model & scaler
    dump({'scaler': scaler_f, 'svm': clf_f}, os.path.join(out_dir, 'svm_model.joblib'))
    print(f"[Done] Outputs saved to: {out_dir}")


# ===============================
# CLI
# ===============================
def build_argparser():
    ap = argparse.ArgumentParser(description="SVM pipeline with SMOTE and ROC plotting.")
    ap.add_argument("--csv_path", type=str, required=True, help="Path to local CSV file.")
    ap.add_argument("--label_col", type=str, default="label", help="Name of label column (0/1).")
    ap.add_argument("--label_map", type=str, default=None, help='Optional JSON mapping, e.g. {"neg":0,"pos":1}')
    ap.add_argument("--out_dir", type=str, default="./svm_outputs", help="Output directory.")
    ap.add_argument("--smote", type=str, default="true", help="Apply SMOTE on train folds (true/false).")
    ap.add_argument("--C", type=float, default=1.0, help="SVM C parameter.")
    ap.add_argument("--gamma", type=str, default="scale", help="SVM gamma parameter (scale|auto or float).")
    ap.add_argument("--seed", type=int, default=42, help="Random seed.")
    return ap


if __name__ == "__main__":
    args = build_argparser().parse_args()
    run(csv_path=args.csv_path,
        label_col=args.label_col,
        label_map=args.label_map,
        out_dir=args.out_dir,
        use_smote=to_bool(args.smote),
        C=args.C,
        gamma=args.gamma,
        seed=args.seed)
