In [None]:
# ==========================================
# [1] Required package installation
# ==========================================
# 1. pytorch-tabnet: core implementation for the TabNet baseline
# 2. pytorch-widedeep: included for compatibility with related deep tabular baselines
# 3. optuna: hyperparameter optimization
!pip install pytorch-tabnet pytorch-widedeep optuna

import pandas as pd
import numpy as np
import torch
import io
import optuna
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# TabNet (original implementation)
from pytorch_tabnet.tab_model import TabNetRegressor

# Deep tabular baseline utilities from pytorch-widedeep
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.models import SAINT, FTTransformer
from pytorch_widedeep import Trainer
from pytorch_widedeep.callbacks import EarlyStopping

In [None]:
# ==========================================
# [1] Data loading and preprocessing
#     (standard TabNet baseline setting)
# ==========================================
def load_and_prep():
    # 1. Load dataset
    try:
        from google.colab import files
        print("=== [Google Colab] Please upload Dataset.csv ===")
        uploaded = files.upload()
        filename = list(uploaded.keys())[0]
        df = pd.read_csv(io.BytesIO(uploaded[filename]))
    except:
        print("=== [Local Environment] Loading Dataset.csv ===")
        df = pd.read_csv('Dataset.csv')

    # 2. Basic preprocessing
    if 'Country Name' in df.columns:
        df = df.drop('Country Name', axis=1)

    cat_cols = ['Country Code', 'Continent']
    for col in cat_cols:
        df[col] = df[col].fillna("Unknown")

    target = 'Maternal Mortality Ratio'
    exclude = cat_cols + ['Year', target]
    num_cols = [c for c in df.columns if c not in exclude]

    # Numerical missing values -> median imputation
    for col in num_cols:
        df[col] = df[col].fillna(df[col].median())

    # 3. Encoding and split setup
    categorical_dims = {}
    for col in cat_cols:
        l_enc = LabelEncoder()
        df[col] = l_enc.fit_transform(df[col].values)
        categorical_dims[col] = len(l_enc.classes_)

    features = [c for c in df.columns if c not in ['Year', target]]
    cat_idxs = [i for i, f in enumerate(features) if f in cat_cols]
    cat_dims = [categorical_dims[f] for f in features if f in cat_cols]

    # Temporal data splits
    # Phase 1: model selection (Train 2011-2014 / Validation 2015)
    train_p1 = df[(df['Year'] >= 2011) & (df['Year'] <= 2014)]
    val_p1   = df[df['Year'] == 2015]

    # Phase 2: final retraining on the pretest period (2011-2015)
    train_p2 = df[(df['Year'] >= 2011) & (df['Year'] <= 2015)]

    # Held-out test year
    test = df[df['Year'] == 2016]

    def to_xy(data):
        return data[features].values, data[target].values.reshape(-1, 1)

    return (to_xy(train_p1), to_xy(val_p1), to_xy(train_p2), to_xy(test),
            cat_idxs, cat_dims, features)

# Prepare data
(data_train_1, data_val_1, data_train_2, data_test,
 cat_idxs, cat_dims, feature_names) = load_and_prep()

X_train_1, y_train_1 = data_train_1
X_val_1, y_val_1     = data_val_1
X_train_2, y_train_2 = data_train_2
X_test, y_test       = data_test

In [None]:
# ==========================================
# [2] Optuna-based hyperparameter optimization
# ==========================================
def objective(trial):
    # Search space definition
    param = {
        'n_d': trial.suggest_int('n_d', 8, 64, step=8),          # width of decision layers
        'n_a': trial.suggest_int('n_a', 8, 64, step=8),          # width of attention layers
        'n_steps': trial.suggest_int('n_steps', 3, 10),          # number of sequential decision steps
        'gamma': trial.suggest_float('gamma', 1.0, 2.0),         # relaxation parameter
        'n_independent': trial.suggest_int('n_independent', 1, 5),
        'n_shared': trial.suggest_int('n_shared', 1, 5),
        'lambda_sparse': trial.suggest_float('lambda_sparse', 1e-4, 1e-2, log=True),
        'optimizer_params': {'lr': trial.suggest_float('lr', 1e-3, 1e-1, log=True)},
    }

    # Baseline TabNet model
    clf = TabNetRegressor(
        cat_idxs=cat_idxs,
        cat_dims=cat_dims,
        cat_emb_dim=1,
        scheduler_params={"step_size": 10, "gamma": 0.9},
        scheduler_fn=torch.optim.lr_scheduler.StepLR,
        mask_type='entmax',
        verbose=0,
        **param
    )

    # Train on the model-selection split
    clf.fit(
        X_train=X_train_1, y_train=y_train_1,
        eval_set=[(X_val_1, y_val_1)],
        eval_name=['valid'],
        eval_metric=['mae'],
        max_epochs=100,
        patience=15,
        batch_size=256,
        virtual_batch_size=128,
        drop_last=False
    )

    # Return the best validation score
    return clf.best_cost

print("\n>>> [Step 1] Starting Optuna hyperparameter optimization (20 trials)...")
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=20)

print("\n>>> [Result] Best hyperparameters")
best_params = study.best_trial.params
print(best_params)

In [None]:
# ==========================================
# [3] Retraining on the full pretest period
# ==========================================
print("\n>>> [Step 2] Retraining the optimized TabNet baseline on 2011-2015...")

# Parse optimized parameters
final_params = best_params.copy()
lr = final_params.pop('lr')

clf_final = TabNetRegressor(
    cat_idxs=cat_idxs,
    cat_dims=cat_dims,
    cat_emb_dim=1,
    optimizer_params={'lr': lr},
    scheduler_params={"step_size": 10, "gamma": 0.9},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    mask_type='entmax',
    verbose=10,
    **final_params
)

clf_final.fit(
    X_train=X_train_2, y_train=y_train_2,
    eval_set=None,
    max_epochs=150,
    batch_size=256,
    virtual_batch_size=128,
    drop_last=False
)

In [None]:
# ==========================================
# [4] Final evaluation and visualization
# ==========================================
preds = clf_final.predict(X_test)

print("\n=== [Step 3] Held-out Test Results (Year 2016) ===")
print(f"MAE : {mean_absolute_error(y_test, preds):.4f}")
print(f"RMSE: {np.sqrt(mean_squared_error(y_test, preds)):.4f}")
print(f"R2  : {r2_score(y_test, preds):.4f}")

# ------------------------------------------
# Feature importance
# ------------------------------------------
plt.figure(figsize=(10, 5))
importances = pd.Series(clf_final.feature_importances_, index=feature_names)
importances.nlargest(10).sort_values().plot(kind='barh', color='#6c5ce7')
plt.title('TabNet Baseline Feature Importance')
plt.xlabel("Importance")
plt.tight_layout()
plt.show()

# ------------------------------------------
# Attention masks
# ------------------------------------------
explain_matrix, masks = clf_final.explain(X_test)

fig, axs = plt.subplots(1, 3, figsize=(15, 4))
for i in range(3):
    sns.heatmap(masks[i][:15], ax=axs[i], cbar=False, cmap='viridis')
    axs[i].set_title(f"Decision-Step Mask {i} (Top 15 Samples)")
    axs[i].set_xlabel("Features")
    axs[i].set_ylabel("Samples")

plt.suptitle("TabNet Baseline Attention Masks")
plt.tight_layout()
plt.show()