In [None]:
# ─── Imports ────────────────────────────────────────────────────────
import pandas as pd
import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import f1_score, classification_report, accuracy_score, ConfusionMatrixDisplay
from sklearn.utils import resample
from pytorch_tabnet.tab_model import TabNetClassifier
import torch

# ─── Reproducibility ────────────────────────────────────────────────
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# ─── Load + Preprocess Data ─────────────────────────────────────────
df_full = pd.read_csv("modified_tedsa_data_clean.csv")
df_full = df_full.sample(n=50000, random_state=42)
print(df_full.head())
df, _ = train_test_split(df_full, stratify=df_full["SUB1"], random_state=42)

label_encoder = LabelEncoder()
y_raw = label_encoder.fit_transform(df["SUB1"])
X_raw = df.drop(columns=["SUB1"])
cat_cols = X_raw.columns.tolist()
cat_maps = {col: {val: i for i, val in enumerate(X_raw[col].unique())} for col in cat_cols}
X_encoded = X_raw.apply(lambda col: col.map(cat_maps[col.name]))

# ─── Split Data ─────────────────────────────────────────────────────
X_temp, X_test, y_temp, y_test = train_test_split(X_encoded, y_raw, test_size=0.2, stratify=y_raw, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.125, stratify=y_temp, random_state=42)
# 0.125 * 0.8 = 0.10 of original data, so 70/10/20 split overall

# ─── Oversample Training Set Only ───────────────────────────────────
def manual_oversample(X, y):
    df_bal = pd.concat([X.reset_index(drop=True), pd.Series(y, name="label")], axis=1)
    max_count = df_bal["label"].value_counts().max()
    df_os = pd.concat([
        resample(df_bal[df_bal["label"] == lbl], replace=True, n_samples=max_count, random_state=42)
        for lbl in df_bal["label"].unique()
    ])
    return df_os.drop(columns="label"), df_os["label"]

X_train_os, y_train_os = manual_oversample(X_train, y_train)

# ─── Prepare for TabNet ─────────────────────────────────────────────
X_train_np, y_train_np = X_train_os.values, y_train_os.values
X_val_np = X_val.values
y_val_np = y_val

X_test_np = X_test.values
y_test_np = y_test


cat_idxs = list(range(X_train_np.shape[1]))
cat_dims = [len(cat_maps[col]) for col in cat_cols]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ─── Hyperparameter Search ──────────────────────────────────────────
search_space = [
    {"n_d": d, "n_a": a, "lr": lr, "gamma": g}
    for d in [32, 64]
    for a in [32, 64]
    for lr in [1e-3, 5e-4]
    for g in [1.0, 2.0]
]
random.shuffle(search_space)

best_macro_f1 = -1
best_model, best_preds, best_true, best_config = None, None, None, None

for i, params in enumerate(search_space[:10]):
    print(f"\n🔍 Trial {i+1}: {params}")
    model = TabNetClassifier(
        n_d=params["n_d"], n_a=params["n_a"], gamma=params["gamma"],
        n_steps=5, verbose=0,
        optimizer_fn=torch.optim.Adam,
        optimizer_params=dict(lr=params["lr"]),
        cat_idxs=cat_idxs, cat_dims=cat_dims, cat_emb_dim=1,
        device_name=device
    )
    model.fit(
        X_train=X_train_np, y_train=y_train_np,
        eval_set=[(X_val_np, y_val_np)],
        eval_name=["val"], eval_metric=["balanced_accuracy"],
        max_epochs=100, patience=10, batch_size=64, virtual_batch_size=32
    )
    preds = model.predict(X_test_np)
    macro_f1 = f1_score(y_test_np, preds, average="macro")
    print(f"🎯 Macro F1: {macro_f1:.4f}")
    if macro_f1 > best_macro_f1:
        best_macro_f1 = macro_f1
        best_model = model
        best_preds = preds
        best_true = y_test_np
        best_config = params

# ─── Final Report ───────────────────────────────────────────────────
y_true_labels = label_encoder.inverse_transform(best_true)
y_pred_labels = label_encoder.inverse_transform(best_preds)

print("\n✅ Best Hyperparameters:", best_config)
print(f"🏆 Best Macro F1 Score: {best_macro_f1:.4f}")
print("\n📋 Final Classification Report:\n", classification_report(y_true_labels, y_pred_labels, zero_division=0))
print(f"🎯 Final Accuracy: {accuracy_score(y_true_labels, y_pred_labels):.4f}")

ConfusionMatrixDisplay.from_predictions(y_true_labels, y_pred_labels, xticks_rotation=45, cmap='Blues')
plt.title("Final Confusion Matrix - TabNet")
plt.tight_layout()
plt.show()
