In [None]:
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    roc_auc_score, precision_score,
    recall_score, f1_score, accuracy_score
)
from pytorch_tabnet.tab_model import TabNetClassifier

In [None]:
# ───────────────────────────────────────────────────────────────────────────────
# 1) Load & prepare training data
# ───────────────────────────────────────────────────────────────────────────────
train_path = "./data/tabnet_train.csv"
test_path  = "./data/tabnet_test.csv"

df_train = pd.read_csv(train_path, index_col=0)
target   = "PD-L1"
df_train[target] = df_train[target].astype(int)

# Columns to embed vs scale
cat_cols        = ['A','B','C','D','E','F']
numeric_cols    = [c for c in df_train.columns if c not in cat_cols + [target]]

# 1a) Fit LabelEncoders on categorical columns
cat_encoders = {}
for col in cat_cols:
    le = LabelEncoder().fit(df_train[col].astype(str))
    df_train[col] = le.transform(df_train[col].astype(str))
    cat_encoders[col] = le

# 1b) Fit StandardScaler on numeric columns
scaler = StandardScaler().fit(df_train[numeric_cols])
df_train[numeric_cols] = scaler.transform(df_train[numeric_cols])

# Split into feature matrix X and target y
X = df_train.drop(columns=[target])
y = df_train[target].values

In [None]:
# ───────────────────────────────────────────────────────────────────────────────
# 2) Compute TabNet embedding info
# ───────────────────────────────────────────────────────────────────────────────
cat_idxs    = [X.columns.get_loc(c) for c in cat_cols]
cat_dims    = [df_train[c].nunique() for c in cat_cols]
cat_emb_dim = [min(50, (dim+1)//2) for dim in cat_dims]

In [None]:
# ───────────────────────────────────────────────────────────────────────────────
# 3) 5‑Fold CV with weighted metrics
# ───────────────────────────────────────────────────────────────────────────────
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
metrics_cv = {m: [] for m in ["auc","precision","recall","f1","accuracy"]}

for fold, (tr_idx, va_idx) in enumerate(skf.split(X.values, y), start=1):
    X_tr, X_va = X.values[tr_idx], X.values[va_idx]
    y_tr, y_va = y[tr_idx], y[va_idx]

    clf = TabNetClassifier(
        cat_idxs         = cat_idxs,
        cat_dims         = cat_dims,
        cat_emb_dim      = cat_emb_dim,
        optimizer_fn     = torch.optim.Adam,
        optimizer_params = dict(lr=2e-2),
        scheduler_fn     = torch.optim.lr_scheduler.StepLR,
        scheduler_params = {"step_size":10, "gamma":0.9},
        device_name      = "cuda" if torch.cuda.is_available() else "cpu"
    )

    clf.fit(
        X_tr, y_tr,
        eval_set       = [(X_va, y_va)],
        eval_name      = ["val"],
        eval_metric    = ["accuracy"],   # multiclass early-stop
        max_epochs     = 100,
        patience       = 20,
        batch_size     = 2048,
        virtual_batch_size = 512,
        num_workers    = 4,
        drop_last      = False
    )

    y_proba = clf.predict_proba(X_va)
    y_pred  = clf.predict(X_va)

    fold_metrics = {
        "auc":       roc_auc_score(y_va, y_proba, multi_class="ovr", average="weighted"),
        "precision": precision_score(y_va, y_pred, average="weighted"),
        "recall":    recall_score(y_va, y_pred, average="weighted"),
        "f1":        f1_score(y_va, y_pred, average="weighted"),
        "accuracy":  accuracy_score(y_va, y_pred)
    }

    print(f"Fold {fold} → " +
          ", ".join(f"{k}={v:.4f}" for k,v in fold_metrics.items()))

    for k,v in fold_metrics.items():
        metrics_cv[k].append(v)

print("\n5‑Fold CV Summary:")
for k, vals in metrics_cv.items():
    print(f"{k:>9}: mean={np.mean(vals):.4f}, min={np.min(vals):.4f}, max={np.max(vals):.4f}")

In [None]:
# ───────────────────────────────────────────────────────────────────────────────
# 4) Train final model on all training data
# ───────────────────────────────────────────────────────────────────────────────
final_clf = TabNetClassifier(
    cat_idxs         = cat_idxs,
    cat_dims         = cat_dims,
    cat_emb_dim      = cat_emb_dim,
    optimizer_fn     = torch.optim.Adam,
    optimizer_params = dict(lr=2e-2),
    scheduler_fn     = torch.optim.lr_scheduler.StepLR,
    scheduler_params = {"step_size":10, "gamma":0.9},
    device_name      = "cuda" if torch.cuda.is_available() else "cpu"
)
final_clf.fit(
    X.values, y,
    max_epochs         = 100,
    patience           = 20,
    batch_size         = 2048,
    virtual_batch_size = 512,
    num_workers        = 4,
    drop_last          = False
)

In [None]:
# ───────────────────────────────────────────────────────────────────────────────
# 5) Bootstrap evaluation on hold‑out set (95% CI)
# ───────────────────────────────────────────────────────────────────────────────
df_hold = pd.read_csv(test_path, index_col=0)
df_hold[target] = df_hold[target].astype(int)

# 5a) Encode categorical columns
for col, le in cat_encoders.items():
    unseen = set(df_hold[col].astype(str)) - set(le.classes_)
    if unseen:
        raise ValueError(f"Hold‑out has unseen labels in '{col}': {unseen}")
    df_hold[col] = le.transform(df_hold[col].astype(str))

# 5b) Scale numeric columns
df_hold[numeric_cols] = scaler.transform(df_hold[numeric_cols])

X_hold = df_hold.drop(columns=[target]).values
y_hold = df_hold[target].values

# 5c) Single-shot predictions
y_proba = final_clf.predict_proba(X_hold)
y_pred  = final_clf.predict(X_hold)

# 5d) Bootstrap metrics
n_boot = 1000
n      = len(y_hold)
metrics_bo = {m: [] for m in metrics_cv}

for _ in range(n_boot):
    idx = np.random.randint(0, n, n)
    yt, yp, yv = y_hold[idx], y_pred[idx], y_proba[idx]
    metrics_bo["accuracy"].append( accuracy_score(yt, yp) )
    metrics_bo["precision"].append(
        precision_score(yt, yp, average="weighted") )
    metrics_bo["recall"].append(
        recall_score(yt, yp, average="weighted") )
    metrics_bo["f1"].append(
        f1_score(yt, yp, average="weighted") )
    metrics_bo["auc"].append(
        roc_auc_score(yt, yv, multi_class="ovr", average="weighted") )

print("\nHold‑out bootstrap (1 000 samples):")
for k, vals in metrics_bo.items():
    mean = np.mean(vals)
    lo, hi = np.percentile(vals, [2.5, 97.5])
    print(f"{k:>9}: {mean:.4f} (95% CI {lo:.4f}–{hi:.4f})")

In [None]:
import plotly.graph_objects as go
import pandas as pd
import numpy as np
from sklearn.preprocessing import label_binarize
from sklearn.metrics import (
    roc_curve,
    auc,
    roc_auc_score,
    confusion_matrix,
    precision_recall_fscore_support
)

def calculate_sensitivity_specificity(
    df,
    true_label_col='true_label',
    pred_label_col='prediction_label',
    class_names=None
):
    # Get sorted list of classes
    classes = sorted(df[true_label_col].unique())
    # Default names
    if class_names is None:
        class_names = [f'Class {cls}' for cls in classes]

    y_true = df[true_label_col].values
    y_pred = df[pred_label_col].values

    sens_spec = {}
    for cls_val, cls_name in zip(classes, class_names):
        y_true_bin = (y_true == cls_val).astype(int)
        y_pred_bin = (y_pred == cls_val).astype(int)
        tn, fp, fn, tp = confusion_matrix(y_true_bin, y_pred_bin).ravel()
        sens = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        spec = tn / (tn + fp) if (tn + fp) > 0 else 0.0
        sens_spec[cls_name] = {
            'sensitivity': sens,
            'specificity': spec,
            'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn,
            'class_value': cls_val
        }
    return sens_spec

def plot_multiclass_roc_with_metrics(
    df,
    true_label_col='true_label',
    pred_label_col='prediction_label',
    # Now expects a list of one column name per class
    pred_score_cols=None,
    class_names=None,
    colors=None,
    font_family='Arial Black',
    font_size_title=16,
    font_size_axes=12,
    font_size_ticks=7,
    save_path='multiclass_roc_curves.svg',
    show_legend=False,
    print_detailed_metrics=True
):
    # 1) Prepare classes and binarized true labels
    classes = sorted(df[true_label_col].unique())
    n_classes = len(classes)
    if class_names is None:
        class_names = [f'Class {cls}' for cls in classes]

    y_true = df[true_label_col].values
    y_true_bin = label_binarize(y_true, classes=classes)
    if y_true_bin.ndim == 1:
        y_true_bin = np.column_stack([1 - y_true_bin, y_true_bin])

    # 2) Determine which probability columns to use
    if pred_score_cols is None:
        # infer any column starting with 'Score_'
        pred_score_cols = [c for c in df.columns if c.startswith('Score_')]
    assert len(pred_score_cols) == n_classes, \
        f"Need {n_classes} score columns, got {len(pred_score_cols)}"

    # 3) Sensitivity & Specificity printout
    print("=" * 60)
    print("SENSITIVITY AND SPECIFICITY ANALYSIS")
    print("=" * 60)
    sens_spec = calculate_sensitivity_specificity(
        df, true_label_col, pred_label_col, class_names
    )
    print(f"{'Class':<15}{'Sens':<10}{'Spec':<10}{'TP':<5}{'TN':<5}{'FP':<5}{'FN':<5}")
    print("-" * 60)
    for name, m in sens_spec.items():
        print(f"{name:<15}{m['sensitivity']:<10.4f}{m['specificity']:<10.4f}"
              f"{m['tp']:<5}{m['tn']:<5}{m['fp']:<5}{m['fn']:<5}")
    # Averages
    avg_sens = np.mean([m['sensitivity'] for m in sens_spec.values()])
    avg_spec = np.mean([m['specificity'] for m in sens_spec.values()])
    print("-" * 60)
    print(f"{'Macro Avg':<15}{avg_sens:<10.4f}{avg_spec:<10.4f}")
    # Weighted
    counts = df[true_label_col].value_counts()
    total = len(df)
    w_sens = sum(m['sensitivity'] * (counts[m['class_value']] / total)
                 for m in sens_spec.values())
    w_spec = sum(m['specificity'] * (counts[m['class_value']] / total)
                 for m in sens_spec.values())
    print(f"{'Weighted':<15}{w_sens:<10.4f}{w_spec:<10.4f}")

    # 4) ROC curve + AUC
    print("\n" + "=" * 60)
    print("ROC CURVE AND AUC ANALYSIS")
    print("=" * 60)
    fpr, tpr, roc_auc = {}, {}, {}
    fig = go.Figure()

    # Default colors if None
    if colors is None:
        colors = ['#ff8d7f', '#84c9ff', '#e4a8ff'] * ((n_classes // 3) + 1)

    for i, (cls_val, cls_name, col) in enumerate(zip(classes, class_names, colors)):
        # True binary for this class
        y_true_i = y_true_bin[:, i]
        # Use the real probability for this class
        y_score_i = df[pred_score_cols[i]].values
        fpr[i], tpr[i], _ = roc_curve(y_true_i, y_score_i)
        roc_auc[i] = auc(fpr[i], tpr[i])

        fig.add_trace(go.Scatter(
            x=fpr[i], y=tpr[i],
            mode='lines',
            line=dict(color=col, width=3),
            name=f'{cls_name} (AUC={roc_auc[i]:.4f})',
            hovertemplate='<b>%{fullData.name}</b><br>FPR: %{x:.4f}<br>TPR: %{y:.4f}<extra></extra>'
        ))

    # Diagonal
    fig.add_trace(go.Scatter(
        x=[0, 1], y=[0, 1],
        mode='lines',
        line=dict(color='gray', width=2, dash='dash'),
        showlegend=False
    ))

    fig.update_layout(
        title=dict(
            text='ROC Curves – Multiclass (One vs Rest)',
            font=dict(family=font_family, size=font_size_title),
            x=0.5, xanchor='center'
        ),
        xaxis=dict(
            title=dict(text='False Positive Rate (1 − Specificity)',
                       font=dict(family=font_family, size=font_size_axes)),
            tickfont=dict(family=font_family, size=font_size_ticks),
            range=[0, 1], showgrid=True, gridcolor='lightgray',
            showline=True, linecolor='black', linewidth=2
        ),
        yaxis=dict(
            title=dict(text='True Positive Rate (Sensitivity)',
                       font=dict(family=font_family, size=font_size_axes)),
            tickfont=dict(family=font_family, size=font_size_ticks),
            range=[0, 1], showgrid=True, gridcolor='lightgray',
            showline=True, linecolor='black', linewidth=2
        ),
        showlegend=show_legend,
        legend=dict(
            font=dict(family=font_family, size=10),
            x=0.6, y=0.2,
            bgcolor='rgba(255,255,255,0.8)',
            bordercolor='black',
            borderwidth=1
        ),
        plot_bgcolor='white',
        paper_bgcolor='white',
        width=800,
        height=600
    )

    # Compute macro & weighted AUC with sklearn
    macro_auc = np.mean(list(roc_auc.values()))
    proba_matrix = df[pred_score_cols].values
    weighted_auc = roc_auc_score(
        y_true_bin,
        proba_matrix,
        multi_class='ovr',
        average='weighted'
    )

    print("\nAUC Scores per class:")
    for i, name in enumerate(class_names):
        cnt = counts[classes[i]]
        pct = cnt / total * 100
        print(f"  {name}: {roc_auc[i]:.4f} (n={cnt}, {pct:.1f}%)")
    print(f"\nMacro‐Average AUC:    {macro_auc:.4f}")
    print(f"Weighted‐Average AUC: {weighted_auc:.4f}")

    # Detailed metrics if desired
    if print_detailed_metrics:
        print("\n" + "=" * 60)
        print("DETAILED METRICS SUMMARY")
        print("=" * 60)
        p, r, f1, supp = precision_recall_fscore_support(
            df[true_label_col],
            df[pred_label_col],
            labels=classes,
            average=None
        )
        print(f"{'Class':<15}{'Prec':<10}{'Rec':<10}{'F1':<10}{'Supp':<8}")
        print("-" * 55)
        for i, name in enumerate(class_names):
            print(f"{name:<15}{p[i]:<10.4f}{r[i]:<10.4f}{f1[i]:<10.4f}{supp[i]:<8}")
        p_m, r_m, f1_m, _ = precision_recall_fscore_support(
            df[true_label_col],
            df[pred_label_col],
            average='macro'
        )
        p_w, r_w, f1_w, _ = precision_recall_fscore_support(
            df[true_label_col],
            df[pred_label_col],
            average='weighted'
        )
        print("-" * 55)
        print(f"{'Macro Avg':<15}{p_m:<10.4f}{r_m:<10.4f}{f1_m:<10.4f}")
        print(f"{'Weighted':<15}{p_w:<10.4f}{r_w:<10.4f}{f1_w:<10.4f}")
        print("\n(Note: Recall = Sensitivity)")

    # Save and show
    fig.write_image(save_path)
    print(f"\nROC plot saved as {save_path}")
    fig.show()

    return {
        'sensitivity_specificity': sens_spec,
        'macro_auc': macro_auc,
        'weighted_auc': weighted_auc,
        'figure': fig
    }

In [None]:
# ───────────────────────────────────────────────────────────────────────────────
# Build predictions DataFrame and plot ROC + metrics
# ───────────────────────────────────────────────────────────────────────────────
# 1) Create a DataFrame with true & predicted labels
predictions = df_hold[[target]].copy()
predictions['prediction_label'] = y_pred   # your predicted labels from final_clf

# 2) Add one Score_* column per class in the order TabNet.classes_
for i, cls in enumerate(final_clf.classes_):
    predictions[f"Score_{cls}"] = y_proba[:, i]  # your predicted probabilities

In [None]:
import pandas as pd
predictions = pd.read_csv(".../tabnet_holdout_for_roc.csv")