In [None]:
# ==========================================
# [1] Required package installation
# ==========================================
# 1. pytorch-tabnet: core implementation for TabNet
# 2. pytorch-widedeep: useful for SAINT / FT-Transformer experiments
# 3. optuna: hyperparameter optimization
!pip install pytorch-tabnet pytorch-widedeep optuna

%pip install pytorch_tabnet

# ==========================================
# Experiment 2: MA-TabNet
# ==========================================

import pandas as pd
import numpy as np
import torch
import io
import optuna
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

# TabNet
from pytorch_tabnet.tab_model import TabNetRegressor

In [None]:
# ==========================================
# [1] Data loading and preprocessing
#     (enhanced MA-TabNet version)
# ==========================================
def load_and_prep_ma_tabnet_advanced():
    # 1. Load dataset
    try:
        from google.colab import files
        print("=== [Google Colab] Please upload Dataset.csv ===")
        uploaded = files.upload()
        filename = list(uploaded.keys())[0]
        df = pd.read_csv(io.BytesIO(uploaded[filename]))
    except:
        print("=== [Local Environment] Loading Dataset.csv ===")
        df = pd.read_csv('Dataset.csv')

    # 2. Basic preprocessing
    if 'Country Name' in df.columns:
        df = df.drop('Country Name', axis=1)

    target = 'Maternal Mortality Ratio'

    # ---------------------------------------------------------
    # [Key MA-TabNet design] Selective masking strategy
    # Instead of generating missing-value masks for all variables,
    # we only create masks for features whose missingness exceeds
    # a predefined threshold. The goal is to preserve informative
    # missingness patterns while avoiding unnecessary indicator noise.
    # ---------------------------------------------------------
    MISSING_THRESHOLD = 0.05  # generate masks only for variables with >= 5% missingness

    cat_cols = ['Country Code', 'Continent']
    exclude = cat_cols + ['Year', target]
    num_cols = [c for c in df.columns if c not in exclude]

    added_masks = []

    for col in num_cols:
        missing_ratio = df[col].isnull().mean()

        # Only features with nontrivial missingness are treated as
        # carrying meaningful observation-pattern information.
        if missing_ratio >= MISSING_THRESHOLD:
            mask_col_name = f"{col}_is_missing"

            # IMPORTANT:
            # The mask is treated as a numerical feature (0.0 / 1.0),
            # not as a categorical feature. This keeps the architecture
            # simple and allows TabNet to exploit the signal efficiently.
            df[mask_col_name] = df[col].isnull().astype(float)
            added_masks.append(mask_col_name)

    # Impute numerical missing values with the feature-wise median
    for col in num_cols:
        df[col] = df[col].fillna(df[col].median())

    # Impute categorical missing values with an explicit "Unknown" category
    for col in cat_cols:
        df[col] = df[col].fillna("Unknown")

    # 3. Categorical encoding
    categorical_dims = {}
    for col in cat_cols:
        l_enc = LabelEncoder()
        df[col] = l_enc.fit_transform(df[col].astype(str).values)
        categorical_dims[col] = len(l_enc.classes_)

    # Reconstruct feature list:
    # [original numerical features] + [selected missingness masks as numeric]
    # + [categorical features]
    features = [c for c in df.columns if c not in ['Year', target]]

    # Categorical indices used for embedding layers
    # (only true categorical variables, not missingness masks)
    cat_idxs = [i for i, f in enumerate(features) if f in cat_cols]
    cat_dims = [categorical_dims[f] for f in features if f in cat_cols]

    print(f"\n[Model Info] Total number of input features: {len(features)}")
    print(f"[Model Info] Number of selectively added missingness masks: {len(added_masks)}")
    print(f" -> Strategy: only variables with missingness >= {MISSING_THRESHOLD*100:.1f}% are augmented with mask indicators")

    # Temporal data splits
    def to_xy(data):
        return data[features].values, data[target].values.reshape(-1, 1)

    train_p1 = df[(df['Year'] >= 2011) & (df['Year'] <= 2014)]
    val_p1   = df[df['Year'] == 2015]
    train_p2 = df[(df['Year'] >= 2011) & (df['Year'] <= 2015)]
    test     = df[df['Year'] == 2016]

    return (to_xy(train_p1), to_xy(val_p1), to_xy(train_p2), to_xy(test),
            cat_idxs, cat_dims, features, added_masks)

# Prepare data
(data_train_1, data_val_1, data_train_2, data_test,
 cat_idxs, cat_dims, feature_names, added_masks) = load_and_prep_ma_tabnet_advanced()

X_train_1, y_train_1 = data_train_1
X_val_1, y_val_1     = data_val_1
X_train_2, y_train_2 = data_train_2
X_test, y_test       = data_test

In [None]:
# ==========================================
# [2] Optuna-based hyperparameter search
# ==========================================
def objective(trial):
    # Because the feature space expands after selective masking,
    # the search space allows slightly larger model capacity.
    param = {
        'n_d': trial.suggest_int('n_d', 16, 64, step=8),
        'n_a': trial.suggest_int('n_a', 16, 64, step=8),
        'n_steps': trial.suggest_int('n_steps', 3, 8),
        'gamma': trial.suggest_float('gamma', 1.0, 2.0),
        'n_independent': trial.suggest_int('n_independent', 2, 5),
        'n_shared': trial.suggest_int('n_shared', 2, 5),
        'lambda_sparse': trial.suggest_float('lambda_sparse', 1e-4, 1e-2, log=True),
        'optimizer_params': {'lr': trial.suggest_float('lr', 1e-2, 5e-2)},
    }

    clf = TabNetRegressor(
        cat_idxs=cat_idxs,
        cat_dims=cat_dims,
        cat_emb_dim=1,  # only categorical identifiers are embedded
        scheduler_params={"step_size": 10, "gamma": 0.9},
        scheduler_fn=torch.optim.lr_scheduler.StepLR,
        mask_type='entmax',
        verbose=0,
        **param
    )

    clf.fit(
        X_train=X_train_1, y_train=y_train_1,
        eval_set=[(X_val_1, y_val_1)],
        eval_name=['valid'],
        eval_metric=['mae'],
        max_epochs=150,
        patience=20,
        batch_size=256,
        virtual_batch_size=128,
        drop_last=False
    )
    return clf.best_cost

print("\n>>> [Step 1] Starting MA-TabNet advanced hyperparameter optimization...")
study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=20)

print("\n>>> [Result] Best hyperparameters")
best_params = study.best_trial.params
print(best_params)

In [None]:
# ==========================================
# [3] Final training with optimized hyperparameters
# ==========================================
print("\n>>> [Step 2] Training final MA-TabNet model using the best configuration...")

final_params = best_params.copy()
lr = final_params.pop('lr')

clf_final = TabNetRegressor(
    cat_idxs=cat_idxs,
    cat_dims=cat_dims,
    cat_emb_dim=1,
    optimizer_params={'lr': lr},
    scheduler_params={"step_size": 10, "gamma": 0.95},
    scheduler_fn=torch.optim.lr_scheduler.StepLR,
    mask_type='entmax',
    verbose=10,
    **final_params
)

clf_final.fit(
    X_train=X_train_2, y_train=y_train_2,
    eval_set=None,
    max_epochs=200,
    batch_size=256,
    virtual_batch_size=128,
    drop_last=False
)

In [None]:
# ==========================================
# [4] Final evaluation and performance summary
# ==========================================
preds = clf_final.predict(X_test)
print("\n=== [MA-TabNet Advanced Performance on the 2016 Held-out Test Set] ===")
print(f"MAE : {mean_absolute_error(y_test, preds):.4f}")
print(f"RMSE: {np.sqrt(mean_squared_error(y_test, preds)):.4f}")
print(f"R2  : {r2_score(y_test, preds):.4f}")

# Inspect feature importance and highlight missingness-mask variables
plt.figure(figsize=(10, 6))
importances = pd.Series(clf_final.feature_importances_, index=feature_names)

colors = ['red' if '_is_missing' in name else '#0984e3'
          for name in importances.nlargest(15).sort_values().index]

importances.nlargest(15).sort_values().plot(kind='barh', color=colors)
plt.title('Feature Importance (Red = Missingness Mask)')
plt.show()

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# =========================
# Output file name
# =========================
FNAME = "feature_importance_missing_mask_top15"
OUT_PNG = f"{FNAME}.png"
OUT_PDF = f"{FNAME}.pdf"

In [None]:
# =========================
# Feature importance plot
# =========================
plt.figure(figsize=(10, 6))

importances = pd.Series(clf_final.feature_importances_, index=feature_names)

topk = 15
top_imp = importances.nlargest(topk).sort_values()  # sorted for better horizontal bar display

# Highlight mask variables in red
colors = ['#d63031' if '_is_missing' in name else '#0984e3' for name in top_imp.index]

ax = top_imp.plot(kind='barh', color=colors, edgecolor='0.25', linewidth=0.6)

ax.set_xlabel("Importance", fontsize=11)
ax.set_ylabel("")
ax.grid(axis='x', linestyle='--', alpha=0.25, linewidth=0.8)

# Clean axis appearance
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_alpha(0.5)
ax.spines['bottom'].set_alpha(0.5)

ax.tick_params(axis='y', labelsize=10)
ax.tick_params(axis='x', labelsize=10)

# Add numeric labels at the end of each bar
xmax = float(top_imp.max())
for i, v in enumerate(top_imp.values):
    ax.text(v + xmax * 0.01, i, f"{v:.4f}", va='center', fontsize=9)

# Small explanatory note
ax.text(
    0.99, 0.02,
    "Red: missingness-mask feature\nBlue: original feature",
    transform=ax.transAxes,
    ha='right', va='bottom',
    fontsize=9, alpha=0.85
)

plt.tight_layout()

# Save in high resolution
plt.savefig(OUT_PNG, dpi=600, bbox_inches="tight", facecolor="white")
plt.savefig(OUT_PDF, bbox_inches="tight", facecolor="white")

plt.show()

print(f"Saved: {OUT_PNG} (dpi=600)")
print(f"Saved: {OUT_PDF}")

try:
    import shap
    import seaborn as sns
except ImportError:
    %pip -q install shap seaborn
    import shap
    import seaborn as sns

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.inspection import permutation_importance
from sklearn.inspection import PartialDependenceDisplay

In [None]:
# ==============================================================================
# [Visualization settings]
# ==============================================================================
# 1. Seaborn theme
sns.set_theme(style="whitegrid", context="talk", font_scale=1.1)

# 2. Matplotlib global settings
plt.rcParams.update({
    'figure.figsize': (12, 8),
    'figure.dpi': 150,
    'savefig.dpi': 600,
    'axes.titlesize': 18,
    'axes.titleweight': 'bold',
    'axes.labelsize': 15,
    'xtick.labelsize': 13,
    'ytick.labelsize': 13,
    'font.family': 'sans-serif',
})

# Color settings
PALETTE_GLOBAL = "viridis"
PALETTE_LOCAL = "mako"
COLOR_BAR = "#3498db"

import os

In [None]:
# ==============================================================================
# 1. TabNet attention visualization
# ==============================================================================
def tabnet_attention_xai_hq(model, X, feature_names, top_k=15, sample_idx=0, save_dir=".", dpi=600):
    """
    High-quality visualization of TabNet attention explanations.

    Outputs:
    - global attention ranking
    - stepwise attention summaries
    - local attention for a selected sample

    All plots are saved in both PNG and PDF formats.
    """
    os.makedirs(save_dir, exist_ok=True)

    out = model.explain(X)
    if isinstance(out, tuple) and len(out) == 2:
        M_explain, masks = out
    else:
        raise ValueError("Please verify the return format of model.explain(X).")

    step_keys = list(masks.keys())
    step_keys_sorted = sorted(step_keys)

    # --- (1) Global attention importance aggregated across steps
    step_global = {}
    global_sum = np.zeros(len(feature_names))
    for k in step_keys_sorted:
        step_global[k] = masks[k].mean(axis=0)
        global_sum += step_global[k]
    global_sum = global_sum / (len(step_keys_sorted) + 1e-12)

    df_global = pd.DataFrame({
        'Feature': feature_names,
        'Importance': global_sum
    }).sort_values(by='Importance', ascending=False).head(top_k)

    # === Plot 1: Global attention importance
    fname_global = f"tabnet_global_attention_top{top_k}"
    plt.figure(figsize=(12, 8))
    sns.barplot(data=df_global, x='Importance', y='Feature', palette=PALETTE_GLOBAL, edgecolor=".2")
    plt.xlabel("Mean Attention Score")
    plt.ylabel("")
    plt.tight_layout()

    plt.savefig(os.path.join(save_dir, f"{fname_global}.png"), dpi=dpi, bbox_inches='tight', facecolor="white")
    plt.savefig(os.path.join(save_dir, f"{fname_global}.pdf"), bbox_inches='tight', facecolor="white")
    plt.show()
    plt.close()

    # === Plot 2: Stepwise attention (show a subset of representative steps)
    steps_to_plot = step_keys_sorted[:2] if len(step_keys_sorted) >= 2 else step_keys_sorted
    for k in steps_to_plot:
        df_step = pd.DataFrame({
            'Feature': feature_names,
            'Importance': step_global[k]
        }).sort_values(by='Importance', ascending=False).head(top_k)

        safe_k = str(k).replace("/", "_").replace(" ", "_")
        fname_step = f"tabnet_step_attention_{safe_k}_top{top_k}"

        plt.figure(figsize=(12, 6))
        sns.barplot(data=df_step, x='Importance', y='Feature', palette="rocket", edgecolor=".2")
        plt.xlabel("Attention Score")
        plt.ylabel("")
        plt.tight_layout()

        plt.savefig(os.path.join(save_dir, f"{fname_step}.png"), dpi=dpi, bbox_inches='tight', facecolor="white")
        plt.savefig(os.path.join(save_dir, f"{fname_step}.pdf"), bbox_inches='tight', facecolor="white")
        plt.show()
        plt.close()

    # --- (3) Local attention for one selected sample
    local_sum = np.zeros(len(feature_names))
    for k in step_keys_sorted:
        local_sum += masks[k][sample_idx]
    local_sum = local_sum / (len(step_keys_sorted) + 1e-12)

    df_local = pd.DataFrame({
        'Feature': feature_names,
        'Importance': local_sum
    }).sort_values(by='Importance', ascending=False).head(top_k)

    # === Plot 3: Local attention importance
    fname_local = f"tabnet_local_attention_sample{sample_idx}_top{top_k}"
    plt.figure(figsize=(12, 8))
    sns.barplot(data=df_local, x='Importance', y='Feature', palette=PALETTE_LOCAL, edgecolor=".2")
    plt.xlabel("Attention Score")
    plt.ylabel("")
    plt.tight_layout()

    plt.savefig(os.path.join(save_dir, f"{fname_local}.png"), dpi=dpi, bbox_inches='tight', facecolor="white")
    plt.savefig(os.path.join(save_dir, f"{fname_local}.pdf"), bbox_inches='tight', facecolor="white")
    plt.show()
    plt.close()

    return df_global, df_local, masks


# Run TabNet attention analysis
print(">>> Running TabNet attention analysis...")
df_global_attn, df_local_attn, masks = tabnet_attention_xai_hq(
    clf_final, X_test, feature_names, top_k=15, sample_idx=0, save_dir=".", dpi=600
)

# Save raw feature importances
fi = pd.Series(clf_final.feature_importances_, index=feature_names).sort_values(ascending=False)
fi.to_csv("tabnet_feature_importances.csv", header=["importance"])

import os

In [None]:
# ==============================================================================
# 2. SHAP analysis
# ==============================================================================
print("\n>>> Running SHAP analysis (this may take some time)...")

def predict_1d(X):
    return clf_final.predict(X).reshape(-1)

save_dir = "."
dpi = 600
os.makedirs(save_dir, exist_ok=True)

rng = np.random.default_rng(0)

# Sample background data for SHAP
idx = rng.choice(len(X_train_2), size=min(500, len(X_train_2)), replace=False)
X_bg = X_train_2[idx]

masker = shap.maskers.Independent(X_bg)
explainer = shap.PermutationExplainer(predict_1d, masker, feature_names=feature_names)

# Samples to explain
X_explain = X_test[:150]
shap_values = explainer(X_explain, max_evals=2 * X_explain.shape[1] + 1)

# --- SHAP Beeswarm Plot ---
fname_beeswarm = "shap_beeswarm_global_top15"
plt.figure(figsize=(14, 10))
shap.plots.beeswarm(shap_values, max_display=15, show=False, color_bar_label="Feature Value")
plt.tight_layout()

plt.savefig(os.path.join(save_dir, f"{fname_beeswarm}.png"), dpi=dpi, bbox_inches='tight', facecolor="white")
plt.savefig(os.path.join(save_dir, f"{fname_beeswarm}.pdf"), bbox_inches='tight', facecolor="white")
plt.show()
plt.close()

# --- SHAP Bar Plot ---
fname_bar = "shap_bar_global_mean_abs_top15"
plt.figure(figsize=(14, 10))
shap.plots.bar(shap_values, max_display=15, show=False)
plt.tight_layout()

plt.savefig(os.path.join(save_dir, f"{fname_bar}.png"), dpi=dpi, bbox_inches='tight', facecolor="white")
plt.savefig(os.path.join(save_dir, f"{fname_bar}.pdf"), bbox_inches='tight', facecolor="white")
plt.show()
plt.close()

# --- SHAP Waterfall Plot (local explanation) ---
sample_idx = 0
fname_waterfall = f"shap_waterfall_local_sample{sample_idx}_top15"
plt.figure(figsize=(14, 8))
shap.plots.waterfall(shap_values[sample_idx], max_display=15, show=False)
plt.tight_layout()

fig = plt.gcf()
fig.savefig(os.path.join(save_dir, f"{fname_waterfall}.png"), dpi=dpi, bbox_inches='tight', facecolor="white")
fig.savefig(os.path.join(save_dir, f"{fname_waterfall}.pdf"), bbox_inches='tight', facecolor="white")

plt.show()
plt.close()

import os
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# ==============================================================================
# 3. Permutation importance visualization
# ==============================================================================
save_dir = "."
dpi = 600
os.makedirs(save_dir, exist_ok=True)

FNAME = "permutation_importance_mae_top15"
OUT_PNG = os.path.join(save_dir, f"{FNAME}.png")
OUT_PDF = os.path.join(save_dir, f"{FNAME}.pdf")

# Assumes df_perm is already prepared and contains:
# columns = ['Feature', 'Importance', 'Std']

try:
    colors
except NameError:
    colors = sns.color_palette("crest", n_colors=len(df_perm))

plt.figure(figsize=(12, 8))

plt.barh(
    df_perm['Feature'],
    df_perm['Importance'],
    xerr=df_perm['Std'],
    color=colors,
    edgecolor="0.25",
    capsize=4,
    height=0.62,
    error_kw={"elinewidth": 1.2, "alpha": 0.9}
)

plt.xlabel("Mean Importance Score (MAE Degradation)", fontsize=11)
plt.ylabel("")

plt.gca().invert_yaxis()

ax = plt.gca()
ax.grid(True, axis='x', linestyle='--', alpha=0.25, linewidth=0.8)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_alpha(0.6)
ax.spines["bottom"].set_alpha(0.6)
ax.tick_params(axis="y", labelsize=10)
ax.tick_params(axis="x", labelsize=10)

plt.tight_layout()

plt.savefig(OUT_PNG, dpi=dpi, bbox_inches="tight", facecolor="white")
plt.savefig(OUT_PDF, bbox_inches="tight", facecolor="white")

plt.show()
plt.close()

print(f"Saved: {OUT_PNG} (dpi=600)")
print(f"Saved: {OUT_PDF}")

import numpy as np
from sklearn.base import is_classifier, is_regressor

print("is_classifier:", is_classifier(tabnet_wrapper))
print("is_regressor :", is_regressor(tabnet_wrapper))

# Check predict_proba only if available
if hasattr(tabnet_wrapper, "predict_proba"):
    p = tabnet_wrapper.predict_proba(X_test[:10])
    print("predict_proba shape:", np.asarray(p).shape)
    print("proba min/max:", np.nanmin(p), np.nanmax(p))

import os
import numpy as np
import matplotlib.pyplot as plt

SAVE_PATH_PNG = "PDP_top_features_dpi600.png"
SAVE_PATH_PDF = "PDP_top_features_dpi600.pdf"

os.makedirs(os.path.dirname(SAVE_PATH_PNG) or ".", exist_ok=True)
os.makedirs(os.path.dirname(SAVE_PATH_PDF) or ".", exist_ok=True)

features = list(top_idx)
n = len(features)

ncols = min(3, n)
nrows = int(np.ceil(n / ncols))

fig, axes = plt.subplots(
    nrows, ncols,
    figsize=(5.4 * ncols, 3.6 * nrows),
    squeeze=False
)
axes = axes.reshape(-1)

for i, f in enumerate(features):
    ax = axes[i]

    pd = pd_one_feature(
        estimator=tabnet_wrapper,
        X=X_test,
        f=f,
        response_method=response_method,
        class_idx=class_idx,
        grid_resolution=80
    )

    xs = pd["grid_values"][0]
    ys = pd["average"][0].reshape(-1)

    m = np.isfinite(xs) & np.isfinite(ys)
    xs, ys = xs[m], ys[m]

    ax.plot(xs, ys, linewidth=2.4)
    title = feature_names[f] if feature_names is not None else f"feature {f}"
    ax.set_title(title, fontsize=12, fontweight="semibold", pad=6)
    ax.set_xlabel("Feature value", fontsize=10, labelpad=4)
    ax.set_ylabel("Partial dependence", fontsize=10, labelpad=4)

    ax.grid(True, linestyle="--", alpha=0.25, linewidth=0.8)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_alpha(0.6)
    ax.spines["bottom"].set_alpha(0.6)
    ax.tick_params(axis="both", labelsize=9)

# Remove unused axes
for j in range(n, len(axes)):
    fig.delaxes(axes[j])

fig.subplots_adjust(
    left=0.06, right=0.995,
    bottom=0.14, top=0.88,
    wspace=0.18, hspace=0.30
)

fig.savefig(SAVE_PATH_PNG, dpi=600, bbox_inches="tight", facecolor="white")
fig.savefig(SAVE_PATH_PDF, bbox_inches="tight", facecolor="white")

plt.show()
plt.close(fig)

print(f"Saved: {SAVE_PATH_PNG} (dpi=600)")
print(f"Saved: {SAVE_PATH_PDF}")

import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.ticker as ticker

In [None]:
# ==============================================================================
# 5. Group-wise residual diagnostics
#    (high-end boxplot + strip plot)
# ==============================================================================
print("\n>>> [Advanced Visualization] Running continent-level residual diagnostics...")

continent_col = "Continent"

if continent_col in feature_names:
    # 1. Prepare data
    c_idx = feature_names.index(continent_col)
    y_true = y_test.reshape(-1)
    y_pred = preds.reshape(-1)
    resid = y_true - y_pred

    df_err = pd.DataFrame({
        "continent_code": X_test[:, c_idx].astype(int),
        "abs_err": np.abs(resid),
        "resid": resid
    })

    # Stabilize log-scale visualization
    eps = 1e-8
    df_err["abs_err"] = df_err["abs_err"].clip(lower=eps)

    # 2. Sort groups by median absolute error
    sorted_idx = (
        df_err.groupby("continent_code")['abs_err']
        .median()
        .sort_values(ascending=False)
        .index
    )

    # 3. Plot styling
    plt.rcParams.update({
        'font.family': 'sans-serif',
        'font.size': 14,
        'axes.titlesize': 20,
        'axes.titleweight': 'bold',
        'xtick.labelsize': 12,
        'ytick.labelsize': 12
    })

    # 4. Create figure
    fig, ax = plt.subplots(figsize=(14, 9))

    # 5. Boxplot
    sns.boxplot(
        data=df_err,
        x="continent_code",
        y="abs_err",
        order=sorted_idx,
        palette="coolwarm_r",
        linewidth=2,
        width=0.6,
        fliersize=0,
        ax=ax,
        boxprops=dict(alpha=0.85)
    )

    # 6. Overlay strip plot
    sns.stripplot(
        data=df_err,
        x="continent_code",
        y="abs_err",
        order=sorted_idx,
        color="#2c3e50",
        size=3,
        alpha=0.28,
        jitter=0.25,
        ax=ax
    )

    # 7. Save file names
    save_dir = "."
    os.makedirs(save_dir, exist_ok=True)
    save_base = "continent_error_analysis_highend"
    save_png = os.path.join(save_dir, f"{save_base}.png")
    save_pdf = os.path.join(save_dir, f"{save_base}.pdf")

    # 8. Axis labels
    ax.set_xlabel("Continent Code (Sorted by Median Error)", labelpad=12, fontweight='bold')
    ax.set_ylabel("Absolute Error (Log Scale)", labelpad=12, fontweight='bold')

    # 9. Log scale
    ax.set_yscale('log')
    ax.yaxis.set_major_formatter(ticker.ScalarFormatter())

    # 10. Minimal-style cleanup
    ax.grid(True, axis='y', linestyle='--', alpha=0.28, color='gray', linewidth=0.8)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(1.2)
    ax.spines['bottom'].set_linewidth(1.2)

    # 11. Add simple statistics for the top-3 highest-error groups
    top3_codes = sorted_idx[:3]
    for i, code in enumerate(top3_codes):
        mean_val = df_err[df_err['continent_code'] == code]['abs_err'].mean()
        y_pos = df_err[df_err['continent_code'] == code]['abs_err'].quantile(0.95)
        ax.text(
            i, y_pos,
            f"Mean:\n{mean_val:.2f}",
            ha='center', va='bottom',
            fontsize=11, fontweight='bold',
            color='#c0392b'
        )

    plt.tight_layout(pad=0.6)

    # 12. Save in high resolution
    plt.savefig(save_png, dpi=600, bbox_inches='tight', facecolor='white')
    plt.savefig(save_pdf, bbox_inches='tight', facecolor='white')
    print(f">>> ✅ High-resolution error diagnostic saved: {save_png}")
    print(f">>> ✅ Vector PDF version saved: {save_pdf}")

    plt.show()
    plt.close(fig)

else:
    print(f"Warning: column '{continent_col}' was not found. Skipping group-wise error analysis.")