In [None]:
# ==========================================
# [1] Required package installation
# ==========================================
!pip install pytorch-tabnet pytorch-widedeep optuna

import pandas as pd
import numpy as np
import torch
import optuna
import io
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# WideDeep modules
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.models import WideDeep, SAINT, FTTransformer
from pytorch_widedeep import Trainer
from pytorch_widedeep.callbacks import EarlyStopping

In [None]:
# ==========================================
# [1] Data loading and preprocessing
# ==========================================
def load_data():
    """
    Load the harmonized dataset used for deep tabular baseline experiments.

    If Dataset.csv is not found locally and the script is running in
    Google Colab, the user is prompted to upload the file manually.
    """
    try:
        from google.colab import files
        import os
        if not os.path.exists('Dataset.csv'):
            print("=== [Google Colab] Please upload Dataset.csv ===")
            uploaded = files.upload()
            filename = list(uploaded.keys())[0]
            df = pd.read_csv(io.BytesIO(uploaded[filename]))
        else:
            df = pd.read_csv('Dataset.csv')
    except:
        df = pd.read_csv('Dataset.csv')
    return df

df = load_data()

In [None]:
# ==========================================
# [1-1] Missing-value handling and basic cleanup
# ==========================================
if 'Country Name' in df.columns:
    df = df.drop('Country Name', axis=1)

cat_cols = ['Country Code', 'Continent']
target_col = 'Maternal Mortality Ratio'

# Missing-value treatment
# - categorical variables: "Unknown"
# - numerical variables: median imputation
for col in cat_cols:
    df[col] = df[col].fillna("Unknown")

num_cols = [c for c in df.columns if c not in cat_cols + ['Year', target_col]]
for col in num_cols:
    df[col] = df[col].fillna(df[col].median())

In [None]:
# ==========================================
# [1-2] Temporal split design
# ==========================================
# Phase 1: model selection
# Train: 2011-2014 / Validation: 2015
train_opt = df[(df['Year'] >= 2011) & (df['Year'] <= 2014)]
val_opt   = df[df['Year'] == 2015]

# Phase 2: retraining after model selection
# Train: 2011-2015
train_retrain = df[(df['Year'] >= 2011) & (df['Year'] <= 2015)]

# Phase 3: held-out test evaluation
# Test: 2016
test = df[df['Year'] == 2016]

feature_names = [c for c in df.columns if c not in ['Year', target_col]]

In [None]:
# ==========================================
# [2] WideDeep preprocessing
#     (TabPreprocessor)
# ==========================================
# Fit the preprocessing pipeline only on the optimization-training period
tab_preprocessor = TabPreprocessor(
    cat_embed_cols=cat_cols,
    continuous_cols=num_cols,
    scale=True
)
tab_preprocessor.fit(train_opt)

# Transform the full dataset
X_tab_all = tab_preprocessor.transform(df)
y_all = df[target_col].values

# Extract row indices for each temporal split
idx_opt_train = train_opt.index
idx_opt_val   = val_opt.index
idx_retrain   = train_retrain.index
idx_test      = test.index

# Input parameters required by SAINT / FT-Transformer
input_params = {
    "column_idx": tab_preprocessor.column_idx,
    "cat_embed_input": tab_preprocessor.cat_embed_input,
    "continuous_cols": num_cols,
}

In [None]:
# ==========================================
# [3] Unified experiment pipeline
#     (Optuna -> Retrain -> Evaluate)
# ==========================================
def run_experiment(model_name):
    """
    Run a complete deep tabular baseline experiment.

    Workflow:
    1. Hyperparameter optimization on 2011-2014 / 2015 split
    2. Retraining on the full pretest period (2011-2015)
    3. Evaluation on the held-out year 2016
    4. Permutation-based feature importance analysis
    """
    print(f"\n{'='*10} Processing {model_name} {'='*10}")

    # --- Step 1: Hyperparameter optimization ---
    def objective(trial):
        # Define search space
        if model_name == 'SAINT':
            params = {
                'n_blocks': trial.suggest_int('n_blocks', 1, 3),
                'n_heads': trial.suggest_categorical('n_heads', [2, 4, 8]),
                'attn_dropout': trial.suggest_float('attn_dropout', 0.0, 0.3),
                'ff_dropout': trial.suggest_float('ff_dropout', 0.0, 0.3),
                'input_dim': 64
            }
            tab_model = SAINT(**input_params, **params)

        else:  # FT-Transformer
            params = {
                'n_blocks': trial.suggest_int('n_blocks', 2, 4),
                'n_heads': trial.suggest_categorical('n_heads', [2, 4, 8]),
                'attn_dropout': trial.suggest_float('attn_dropout', 0.0, 0.3),
                'ff_dropout': trial.suggest_float('ff_dropout', 0.0, 0.3),
                'input_dim': 32
            }
            tab_model = FTTransformer(**input_params, **params)

        # Wrap the tabular model inside the WideDeep interface
        model = WideDeep(deeptabular=tab_model)

        # Train on 2011-2014 and validate on 2015
        trainer = Trainer(model, objective="regression", verbose=0)
        trainer.fit(
            X_train={"X_tab": X_tab_all[idx_opt_train], "target": y_all[idx_opt_train]},
            X_val={"X_tab": X_tab_all[idx_opt_val], "target": y_all[idx_opt_val]},
            n_epochs=15,
            batch_size=128,
            callbacks=[EarlyStopping(patience=5, monitor="val_loss")]
        )

        return trainer.history['val_loss'][-1]

    print(">> [Step 1] Optimizing hyperparameters...")
    study = optuna.create_study(direction='minimize')
    study.optimize(objective, n_trials=10)
    best_params = study.best_trial.params
    print(f"   Best parameters: {best_params}")

    # --- Step 2: Retraining on 2011-2015 ---
    print(">> [Step 2] Retraining on the full pretest period (2011-2015)...")

    if model_name == 'SAINT':
        final_tab_model = SAINT(**input_params, input_dim=64, **best_params)
    else:
        final_tab_model = FTTransformer(**input_params, input_dim=32, **best_params)

    final_model = WideDeep(deeptabular=final_tab_model)

    # Train on the full pretest period without a separate validation fold
    trainer = Trainer(final_model, objective="regression", verbose=0)
    trainer.fit(
        X_train={"X_tab": X_tab_all[idx_retrain], "target": y_all[idx_retrain]},
        n_epochs=30,
        batch_size=128
    )

    # --- Step 3: Final evaluation on the held-out year 2016 ---
    print(">> [Step 3] Evaluating on the held-out test year (2016)...")
    preds = trainer.predict(X_tab=X_tab_all[idx_test])
    y_true = y_all[idx_test]

    mae = mean_absolute_error(y_true, preds)
    rmse = np.sqrt(mean_squared_error(y_true, preds))
    r2 = r2_score(y_true, preds)

    # --- Step 4: Permutation-based feature importance ---
    print(">> [Step 4] Calculating permutation importance...")
    base_score = r2_score(y_true, preds)
    importances = {}

    X_test_curr = X_tab_all[idx_test]
    for i, col in enumerate(feature_names):
        X_p = X_test_curr.copy()
        np.random.shuffle(X_p[:, i])
        preds_p = trainer.predict(X_tab=X_p)
        importances[col] = base_score - r2_score(y_true, preds_p)

    return {
        'Metrics': {'MAE': mae, 'RMSE': rmse, 'R2': r2},
        'Importance': pd.Series(importances).sort_values(ascending=False),
        'Preds': preds
    }

In [None]:
# ==========================================
# [4] Run experiments and summarize outputs
# ==========================================
results_saint = run_experiment('SAINT')
results_ft = run_experiment('FT-Transformer')

# --- Summary table ---
metrics_df = pd.DataFrame({
    'SAINT': results_saint['Metrics'],
    'FT-Transformer': results_ft['Metrics']
}).T[['MAE', 'RMSE', 'R2']]

print("\n" + "="*40)
print(" FINAL RESULTS (Held-out Test Year: 2016) ")
print("="*40)
print(metrics_df)

In [None]:
# ==========================================
# [5] Visualization 1: performance comparison
# ==========================================
fig, ax1 = plt.subplots(figsize=(10, 6))
metrics_df[['MAE', 'RMSE']].plot(
    kind='bar',
    ax=ax1,
    width=0.4,
    position=1,
    color=['#00cec9', '#fab1a0']
)
ax1.set_ylabel("Error (MAE / RMSE)", fontsize=12)
ax1.set_title("Performance Comparison: SAINT vs FT-Transformer", fontsize=14)

ax2 = ax1.twinx()
metrics_df['R2'].plot(
    kind='bar',
    ax=ax2,
    width=0.2,
    position=0,
    color='#6c5ce7',
    label='R2 Score'
)
ax2.set_ylabel("R2 Score", color='#6c5ce7', fontsize=12)
ax2.set_ylim(0, 1.1)
plt.tight_layout()
plt.show()

In [None]:
# ==========================================
# [6] Visualization 2: permutation importance
# ==========================================
fig, axs = plt.subplots(1, 2, figsize=(14, 6))

results_saint['Importance'].nlargest(10).sort_values().plot(
    kind='barh',
    ax=axs[0],
    color='#00cec9'
)
axs[0].set_title("SAINT Feature Importance (Permutation)")

results_ft['Importance'].nlargest(10).sort_values().plot(
    kind='barh',
    ax=axs[1],
    color='#fab1a0'
)
axs[1].set_title("FT-Transformer Feature Importance (Permutation)")

plt.tight_layout()
plt.show()