###  We run the model following below workflow

### Model Workflow

In [1]:
import numpy as np
import pandas as pd

from sklearn.model_selection import KFold, RandomizedSearchCV, cross_val_score
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.inspection import PartialDependenceDisplay

from xgboost import XGBRegressor
import shap
import matplotlib as mpl
import matplotlib.pyplot as plt
from PIL import Image

In [2]:

# Set global font to Arial
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['font.sans-serif'] = ['Arial']
mpl.rcParams['font.size'] = 10
mpl.rcParams['axes.titlesize'] = 10
mpl.rcParams['axes.labelsize'] = 10
mpl.rcParams['xtick.labelsize'] = 9
mpl.rcParams['ytick.labelsize'] = 9


In [3]:

# -----------------------------
# 0) Load + basic prep
# -----------------------------
DATA_PATH = "Alzheimer_merged1.csv"  # change to your path if needed

# Your file uses ISO-8859-1 encoding (because of µ in column names)
df = pd.read_csv(DATA_PATH, encoding="ISO-8859-1")

y = df["AD_PREV_MEAN"].astype(float)
X = df.drop(columns=["AD_PREV_MEAN", "Counties", "FIPS"])

# Ensure numeric + simple imputation (tree models still need finite numbers)
X = X.apply(pd.to_numeric, errors="coerce")
X = X.replace([np.inf, -np.inf], np.nan)
X = X.fillna(X.median(numeric_only=True))


# -----------------------------
# 1) Tune XGBoost with 5-fold CV
# -----------------------------
cv5 = KFold(n_splits=5, shuffle=True, random_state=42)

param_dist = {
    "n_estimators": [300, 500, 700, 900],
    "max_depth": [3, 4, 5, 6, 7],
    "learning_rate": [0.01, 0.03, 0.05, 0.1, 0.2],
    "subsample": [0.5, 0.7, 0.85, 1.0],
    "colsample_bytree": [0.7, 0.8, 0.9, 1.0],
    "min_child_weight": [1, 3, 5, 7, 10],
    "gamma": [0, 1, 2, 3, 5],
    "reg_alpha": [0, 0.1, 0.5, 1.0],
    "reg_lambda": [1, 3, 5, 10],
}

base_model = XGBRegressor(
    objective="reg:squarederror",
    random_state=42,
    n_jobs=-1,
    tree_method="hist",   # much faster than exact
)

search = RandomizedSearchCV(
    estimator=base_model,
    param_distributions=param_dist,
    n_iter=40,  # reduce to 20 if slow
    scoring="neg_root_mean_squared_error",
    cv=cv5,
    random_state=42,
    n_jobs=-1,
    verbose=1
)

search.fit(X, y)
best_model = search.best_estimator_
print("Best params:", search.best_params_)





Fitting 5 folds for each of 40 candidates, totalling 200 fits
Best params: {'subsample': 0.7, 'reg_lambda': 3, 'reg_alpha': 0.1, 'n_estimators': 700, 'min_child_weight': 3, 'max_depth': 5, 'learning_rate': 0.03, 'gamma': 0, 'colsample_bytree': 0.9}


In [4]:
# -----------------------------
# 2) RFE (keep top 15)
#    - iterative pruning by feature_importances_
# -----------------------------
def rfe_xgb(model, X, y, target_n=15, drop_frac=0.10, random_state=42):
    features = list(X.columns)

    while len(features) > target_n:
        model.fit(X[features], y)

        imp = pd.Series(model.feature_importances_, index=features).sort_values()
        drop_n = max(1, int(len(features) * drop_frac))
        drop_feats = imp.index[:drop_n].tolist()
        features = [f for f in features if f not in drop_feats]

    return features

selected_15 = rfe_xgb(best_model, X, y, target_n=15, drop_frac=0.10)
print("\nSelected 15 features:")
print(selected_15)




Selected 15 features:
['Walkability score', '% Open land', 'Farms per acre', 'Coal mines', '% Families in poverty', 'Median household value', 'Income Inequality', 'Calcium precipitation', 'Nitrate precipitation', 'Chloride precipitation ', 'Sulfate precipitation', 'Cyanide ', 'Dinoseb ', 'Benzo(a)pyrene  ', 'PCBs']


In [5]:
# -----------------------------
# 3) Evaluate performance (5-fold CV) using selected 15
# -----------------------------
rmse_scores = -cross_val_score(
    best_model,
    X[selected_15],
    y,
    scoring="neg_root_mean_squared_error",
    cv=cv5,
    n_jobs=-1,
)

r2_scores = cross_val_score(
    best_model,
    X[selected_15],
    y,
    scoring="r2",
    cv=cv5,
    n_jobs=-1,
)

print(f"\n5-fold CV RMSE: {rmse_scores.mean():.4f} ± {rmse_scores.std():.4f}")
print(f"5-fold CV R²:   {r2_scores.mean():.4f} ± {r2_scores.std():.4f}")


# -----------------------------
# 4) Final model on full data (for SHAP + PDP)
# -----------------------------
final_model = XGBRegressor(**search.best_params_, objective="reg:squarederror",
                           random_state=42, n_jobs=-1, tree_method="hist")
final_model.fit(X[selected_15], y)




5-fold CV RMSE: 0.8879 ± 0.0354
5-fold CV R²:   0.6215 ± 0.0342


In [None]:

# -----------------------------
# 5) SHAP (TreeSHAP) + ranking
# -----------------------------
explainer = shap.TreeExplainer(final_model)
shap_values = explainer.shap_values(X[selected_15])

# SHAP summary plot (beeswarm)
shap.summary_plot(shap_values, X[selected_15], show=False)
plt.tight_layout()
plt.savefig("shap_beeswarm.png", dpi=300)
plt.close()

# Rank features by mean(|SHAP|)
shap_importance = pd.DataFrame({
    "feature": selected_15,
    "mean_abs_shap": np.abs(shap_values).mean(axis=0)
}).sort_values("mean_abs_shap", ascending=False)

print("\nSHAP importance ranking (top 15):")
print(shap_importance)




In [None]:
# -----------------------------
# 6) PDP for top K SHAP variables
#    - pick K=6 (nice balance) unless you prefer otherwise
# -----------------------------
K = 6
top_vars = shap_importance["feature"].head(K).tolist()
print(f"\nMaking PDP plots for top {K} variables:")
print(top_vars)

# A) Save ONE PDP per variable (cleanest for papers)
for v in top_vars:
    fig, ax = plt.subplots(figsize=(5.5, 4.0))
    PartialDependenceDisplay.from_estimator(
        final_model,
        X[selected_15],
        features=[v],
        kind="average",
        grid_resolution=60,   # smooth curve; reduce if slow
        percentiles=(0.01, 0.99),  # trim extremes for nicer plots
        ax=ax
    )
    ax.set_title(f"PDP: {v}")
    plt.tight_layout()
    plt.savefig(f"pdp_{v[:40].replace(' ', '_')}.png", dpi=300)
    plt.close()

# B) Optional: one combined panel (2×3) if you want a single figure
# NOTE: This is optional; many journals prefer separate figures or a clean panel.
fig, ax = plt.subplots(2, 3, figsize=(14, 8))
ax = ax.ravel()
for i, v in enumerate(top_vars):
    PartialDependenceDisplay.from_estimator(
        final_model, X[selected_15], [v],
        kind="average",
        grid_resolution=60,
        percentiles=(0.01, 0.99),
        ax=ax[i]
    )
    ax[i].set_title(v)
plt.tight_layout()
plt.savefig("pdp_panel_top6.png", dpi=300)
plt.close()


In [None]:
# shap_values already computed earlier:
# shap_values = explainer.shap_values(X[selected_15])

# Convert to DataFrame
shap_df = pd.DataFrame(
    shap_values,
    columns=selected_15
)

# Add county identifiers
shap_df["FIPS"] = df["FIPS"].values
shap_df["County"] = df["Counties"].values


In [None]:
# find top variable per county
top_driver = shap_df[selected_15].abs().idxmax(axis=1)

shap_df["Top_Driver"] = top_driver

# save results
shap_df[["FIPS", "County", "Top_Driver"]].to_csv(
    "county_top_drivers.csv",
    index=False
)

print(shap_df[["FIPS", "County", "Top_Driver"]].head())


In [None]:
import matplotlib.pyplot as plt
import shap

# A) Bar plot
plt.figure()
shap.summary_plot(
    shap_values,
    X[selected_15],
    plot_type="bar",
    show=False
)
plt.tight_layout()
plt.savefig("shap_bar.tiff", dpi=600, bbox_inches="tight")
plt.close()

# B) Beeswarm plot
plt.figure()
shap.summary_plot(
    shap_values,
    X[selected_15],
    show=False
)
plt.tight_layout()
plt.savefig("shap_beeswarm.tiff", dpi=600, bbox_inches="tight")
plt.close()


In [None]:


bar = Image.open("shap_bar.tiff")
bee = Image.open("shap_beeswarm.tiff")

# Make them the same height (keeps “one row” clean)
target_h = max(bar.height, bee.height)
bar = bar.resize((int(bar.width * target_h / bar.height), target_h), Image.Resampling.LANCZOS)
bee = bee.resize((int(bee.width * target_h / bee.height), target_h), Image.Resampling.LANCZOS)

combined = Image.new("RGB", (bar.width + bee.width, target_h), (255, 255, 255))
combined.paste(bar, (0, 0))
combined.paste(bee, (bar.width, 0))

combined.save("SHAP_bar_beeswarm_row.tiff", dpi=(600, 600))


In [None]:

X_selected = X[selected_15].copy()
fig, axes = plt.subplots(2, 3, figsize=(12, 9))
axes = axes.ravel()

for i, v in enumerate(top_vars):
    plt.sca(axes[i])
    shap.dependence_plot(
        ind=v,
        shap_values=shap_values,
        features=X_selected,
        interaction_index="auto",
        show=False,
        ax=axes[i]  # if your SHAP supports ax; if error, remove this line
    )
    axes[i].set_title(v)

plt.tight_layout()
plt.savefig("shap_dependence_top6_panel.png", dpi=300, bbox_inches="tight")
plt.show()
plt.close()
