# Algorithmic Fairness Audit: Malaria Risk Prediction

This notebook evaluates disparate impact and calibration across populations using the Tai & Dhaliwal (2022) synthetic wGRS+GF+POS dataset, with a focus on Gambia, Kenya, and Nigeria.

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import shap
from sklearn.calibration import calibration_curve
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split

sys.path.append(str(Path('..') / 'src'))

from synthetic_clinical_data import MalariaDataGenerator
from evolutionary_models import MalariaRiskPredictor

sns.set_theme(style='white', font='serif')
plt.rcParams['figure.dpi'] = 120


In [None]:
generator = MalariaDataGenerator()
df = generator.generate()

predictor = MalariaRiskPredictor()
X, y, feature_names = predictor.prepare_features(df)
groups = df['population'].values

X_train, X_test, y_train, y_test, grp_train, grp_test = train_test_split(
    X, y, groups, test_size=0.2, stratify=groups, random_state=42
)

predictor.train_ridge(X_train, y_train)
predictor.train_lightgbm(X_train, y_train)
predictor.train_svr(X_train, y_train)


In [None]:
fairness_results = {}

for model_name in ['ridge', 'lightgbm', 'svr']:
    metrics = predictor.fairness_metrics(X_test, y_test, grp_test, model_name)
    fairness_results[model_name] = metrics
    print(f'
{model_name.upper()} Fairness Audit')
    print('TPR disparity:', metrics['tpr_disparity'])
    print('FPR disparity:', metrics['fpr_disparity'])
    print('Max calibration error:', metrics['max_calibration_error'])


In [None]:
focus_pops = ['Gambia', 'Kenya', 'Nigeria']
summary_rows = []
for model_name, metrics in fairness_results.items():
    for pop in focus_pops:
        if pop in metrics['by_group']:
            row = metrics['by_group'][pop].copy()
            row['population'] = pop
            row['model'] = model_name
            summary_rows.append(row)

pd.DataFrame(summary_rows)[['model', 'population', 'tpr', 'fpr', 'calibration_error']]


In [None]:
def get_predictions(model_name, X):
    model = predictor.models[model_name]
    scaler = predictor.scalers.get(model_name)
    X_in = scaler.transform(X) if scaler else X
    return model.predict(X_in)

preds_lightgbm = get_predictions('lightgbm', X_test)

fig, ax = plt.subplots(figsize=(10, 6))
colors = plt.cm.tab10(np.linspace(0, 1, len(np.unique(grp_test))))

for idx, pop in enumerate(np.unique(grp_test)):
    mask = grp_test == pop
    y_pop = y_test[mask]
    prob_pop = preds_lightgbm[mask]
    prob_true, prob_pred = calibration_curve(y_pop, prob_pop, n_bins=5)
    ax.plot(prob_pred, prob_true, marker='o', label=pop, color=colors[idx])

ax.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
ax.set_xlabel('Mean Predicted Risk')
ax.set_ylabel('Fraction of Positives')
ax.set_title('Calibration Curves by Population (LightGBM)')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()
plt.show()


In [None]:
pops = list(fairness_results['lightgbm']['by_group'].keys())
tpr_values = [
    [fairness_results['ridge']['by_group'][p]['tpr'] for p in pops],
    [fairness_results['lightgbm']['by_group'][p]['tpr'] for p in pops],
]

x = np.arange(len(pops))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(x - width / 2, tpr_values[0], width, label='Ridge', color='#4C72B0', alpha=0.8)
ax.bar(x + width / 2, tpr_values[1], width, label='LightGBM', color='#55A868', alpha=0.8)
ax.set_ylabel('True Positive Rate (Sensitivity)')
ax.set_title('Model Sensitivity Across Populations')
ax.set_xticks(x)
ax.set_xticklabels(pops, rotation=45, ha='right')
ax.legend()
ax.axhline(0.8, color='red', linestyle='--', alpha=0.5, label='80% threshold')
sns.despine()
plt.tight_layout()
plt.show()


In [None]:
positive_rates = {
    pop: fairness_results['lightgbm']['by_group'][pop]['positive_rate']
    for pop in fairness_results['lightgbm']['by_group']
}
reference_rate = max(positive_rates.values())
impact_ratio = {pop: rate / reference_rate for pop, rate in positive_rates.items()}

impact_df = pd.DataFrame({
    'population': list(impact_ratio.keys()),
    'impact_ratio': list(impact_ratio.values()),
})

impact_df['passes_80_rule'] = impact_df['impact_ratio'] >= 0.8
impact_df.sort_values('impact_ratio', ascending=False)


In [None]:
print('Computing SHAP values for LightGBM...')
explainer = shap.TreeExplainer(predictor.models['lightgbm'])
shap_values = explainer.shap_values(X_test)

plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X_test, feature_names=feature_names, show=False)
plt.title('SHAP Beeswarm Plot (LightGBM)')
plt.tight_layout()
plt.show()


In [None]:
rs334_idx = feature_names.index('rs334') if 'rs334' in feature_names else 0
rs334_shap = shap_values[:, rs334_idx]
rs334_vals = X_test[:, rs334_idx]

fig, ax = plt.subplots(figsize=(10, 6))
for pop in np.unique(grp_test):
    mask = grp_test == pop
    ax.scatter(rs334_vals[mask], rs334_shap[mask], label=pop, alpha=0.6)

ax.set_xlabel('rs334 Genotype (0, 1, 2)')
ax.set_ylabel('SHAP Value (Impact on Prediction)')
ax.set_title('rs334 Effect Heterogeneity Across Populations')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
sns.despine()
plt.tight_layout()
plt.show()
