In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Deep Neural Network (DNN) for Tumor T-Cell Antigen Classification
---------------------------------------------------------------
This script loads a CSV dataset, preprocesses it, applies SMOTE,
performs feature extraction, trains a DNN, evaluates with CV, and
reports performance metrics with plots.
"""

# ===============================
# Imports
# ===============================
import os
import json
import random
import warnings
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (accuracy_score, roc_auc_score, matthews_corrcoef,
                             confusion_matrix, roc_curve)
from imblearn.over_sampling import SMOTE

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, regularizers, callbacks

import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")


# ===============================
# Utility functions
# ===============================
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)


def ensure_out_dir(out_dir: str):
    os.makedirs(out_dir, exist_ok=True)


def compute_metrics(y_true: np.ndarray, y_prob: np.ndarray) -> Dict[str, float]:
    """Compute performance metrics."""
    y_pred = (y_prob >= 0.5).astype(int)
    acc = accuracy_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    tn, fp, fn, tp = cm.ravel()
    sn = tp / (tp + fn + 1e-9)
    sp = tn / (tn + fp + 1e-9)
    auc = roc_auc_score(y_true, y_prob)
    mcc = matthews_corrcoef(y_true, y_pred)
    return {"ACC": acc, "SN": sn, "SP": sp, "AUC": auc, "MCC": mcc}


def plot_confusion_matrix(cm: np.ndarray, out_path: str, class_names=["Non-Tumor", "Tumor"]):
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(cm, interpolation="nearest", cmap="Blues")
    ax.set_title("Confusion Matrix")
    ax.set_xticks([0, 1])
    ax.set_yticks([0, 1])
    ax.set_xticklabels([f"Predicted {c}" for c in class_names], rotation=15, ha="right")
    ax.set_yticklabels([f"Actual {c}" for c in class_names])
    for (i, j), val in np.ndenumerate(cm):
        ax.text(j, i, f"{int(val)}", ha="center", va="center", fontsize=12)
    fig.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)


def plot_roc_curve(y_true: np.ndarray, y_score: np.ndarray, out_path: str, label: str = "DNN"):
    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc = roc_auc_score(y_true, y_score)
    fig, ax = plt.subplots(figsize=(6, 5))
    ax.plot(fpr, tpr, linewidth=2, label=f"{label} (AUC={auc:.2f})")
    ax.plot([0, 1], [0, 1], "--", linewidth=1)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title("ROC Curve")
    ax.legend(loc="lower right")
    fig.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close(fig)


# ===============================
# Model definition
# ===============================
def build_dnn(input_dim: int, l2_reg: float = 0.01, dropout: float = 0.5) -> keras.Model:
    """Builds a simple DNN classifier."""
    inp = layers.Input(shape=(input_dim,), name="input")
    x = layers.Dense(256, activation="relu",
                     kernel_regularizer=regularizers.l2(l2_reg))(inp)
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(128, activation="relu",
                     kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(64, activation="relu",
                     kernel_regularizer=regularizers.l2(l2_reg))(x)
    out = layers.Dense(1, activation="sigmoid", name="output")(x)

    model = keras.Model(inputs=inp, outputs=out, name="DNN")
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
                  loss="binary_crossentropy",
                  metrics=["accuracy"])
    return model


# ===============================
# Main Training Pipeline
# ===============================
def run_pipeline(csv_path: str,
                 label_col: str = "label",
                 seq_col: str = None,
                 out_dir: str = "./outputs",
                 epochs: int = 50,
                 batch_size: int = 32,
                 use_smote: bool = True,
                 seed: int = 42):

    seed_everything(seed)
    ensure_out_dir(out_dir)

    # 1. Load CSV
    df = pd.read_csv("Dataset")
    if label_col not in df.columns:
        raise ValueError(f"Label column '{label_col}' not found in CSV.")

    y = df[label_col].astype(int).values
    X = df.drop(columns=[label_col]).values.astype(np.float32)

    print(f"[Info] Dataset: X={X.shape}, y={y.shape}, positive={y.sum()}, negative={len(y)-y.sum()}")

    # 2. Train/test split
    X_tr, X_te, y_tr, y_te = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=seed
    )

    # 3. Cross-validation on training set
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed)
    cv_metrics = []

    for fold, (idx_tr, idx_va) in enumerate(skf.split(X_tr, y_tr), start=1):
        X_tr_fold, X_va_fold = X_tr[idx_tr], X_tr[idx_va]
        y_tr_fold, y_va_fold = y_tr[idx_tr], y_tr[idx_va]

        # Scaling
        scaler = StandardScaler()
        X_tr_fold = scaler.fit_transform(X_tr_fold)
        X_va_fold = scaler.transform(X_va_fold)

        # SMOTE
        if use_smote:
            sm = SMOTE(random_state=seed)
            X_tr_fold, y_tr_fold = sm.fit_resample(X_tr_fold, y_tr_fold)

        # Model
        model = build_dnn(input_dim=X_tr_fold.shape[1])

        cb = [
            callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
            callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5)
        ]

        model.fit(X_tr_fold, y_tr_fold,
                  validation_data=(X_va_fold, y_va_fold),
                  epochs=epochs, batch_size=batch_size,
                  verbose=0, callbacks=cb)

        y_va_prob = model.predict(X_va_fold, verbose=0).ravel()
        fold_metrics = compute_metrics(y_va_fold, y_va_prob)
        cv_metrics.append(fold_metrics)
        print(f"[Fold {fold}] {fold_metrics}")

    # CV Summary
    cv_df = pd.DataFrame(cv_metrics)
    cv_summary = cv_df.agg(["mean", "std"]).T
    cv_summary.to_csv(os.path.join(out_dir, "cv_summary.csv"))
    print("\n[CV Summary]")
    print(cv_summary)

    # 4. Train on full train set & evaluate on holdout
    scaler = StandardScaler()
    X_tr_scaled = scaler.fit_transform(X_tr)
    X_te_scaled = scaler.transform(X_te)

    if use_smote:
        sm = SMOTE(random_state=seed)
        X_tr_scaled, y_tr = sm.fit_resample(X_tr_scaled, y_tr)

    model_final = build_dnn(input_dim=X_tr_scaled.shape[1])

    cb_final = [
        callbacks.EarlyStopping(monitor="val_loss", patience=10, restore_best_weights=True),
        callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5)
    ]

    model_final.fit(X_tr_scaled, y_tr,
                    validation_split=0.1,
                    epochs=epochs, batch_size=batch_size,
                    verbose=0, callbacks=cb_final)

    y_te_prob = model_final.predict(X_te_scaled, verbose=0).ravel()
    holdout_metrics = compute_metrics(y_te, y_te_prob)
    print("\n[Holdout Metrics]")
    print(holdout_metrics)

    # Save holdout metrics
    with open(os.path.join(out_dir, "holdout_metrics.json"), "w") as f:
        json.dump(holdout_metrics, f, indent=2)

    # Plots
    y_te_pred = (y_te_prob >= 0.5).astype(int)
    cm = confusion_matrix(y_te, y_te_pred, labels=[0, 1])
    plot_confusion_matrix(cm, os.path.join(out_dir, "confusion_matrix.png"))
    plot_roc_curve(y_te, y_te_prob, os.path.join(out_dir, "roc_curve.png"))

    # Save model
    model_final.save(os.path.join(out_dir, "dnn_model.h5"))
    print(f"\n[Done] Outputs saved in: {out_dir}")


# ===============================
# Run
# ===============================
if __name__ == "__main__":
    # Example usage
    run_pipeline(
        csv_path="your_dataset.csv",  # <-- Replace with your CSV path
        label_col="label",
        out_dir="./outputs",
        epochs=50,
        batch_size=32,
        use_smote=True
    )
