# Week 3 – XGBoost and SHAP Analysis

Train an XGBoost classifier with imbalance-aware parameters, optionally tune hyperparameters, and generate SHAP explanations to understand feature contributions.

In [None]:
import json
from pathlib import Path

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import shap
from sklearn.metrics import (ConfusionMatrixDisplay, average_precision_score,
                             classification_report, confusion_matrix, precision_recall_curve,
                             roc_auc_score, roc_curve)
from sklearn.model_selection import RandomizedSearchCV, train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from xgboost import XGBClassifier

DATA_PATH = Path('data/processed/hits_dataset.csv')
BASELINE_METRICS = Path('models/baseline_metrics.json')
FIG_DIR = Path('figures')
MODEL_DIR = Path('models')
FIG_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)

assert DATA_PATH.exists(), 'Run Week 1 notebook first to generate hits_dataset.csv'

df = pd.read_csv(DATA_PATH)
feature_candidates = [
    'danceability', 'energy', 'loudness', 'speechiness', 'acousticness',
    'instrumentalness', 'liveness', 'valence', 'tempo'
]
features = [f for f in feature_candidates if f in df.columns]
target = 'is_hit'

X = df[features]
y = df[target]

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

pos = (y_train == 1).sum()
neg = (y_train == 0).sum()
scale_pos_weight = neg / pos
print(f'scale_pos_weight: {scale_pos_weight:.2f}')

## Train XGBoost (toggle tuning with `SKIP_TUNING`)

Use a modest randomized search for better performance, or set `SKIP_TUNING=True` for a faster run. Imbalance is handled via `scale_pos_weight` and evaluation focuses on PR-AUC.

In [None]:
SKIP_TUNING = True

base_params = dict(
    n_estimators=300,
    learning_rate=0.05,
    max_depth=6,
    subsample=0.8,
    colsample_bytree=0.8,
    reg_lambda=1.0,
    objective='binary:logistic',
    eval_metric='aucpr',
    scale_pos_weight=scale_pos_weight,
    random_state=42,
    n_jobs=4,
)

if SKIP_TUNING:
    best_params = base_params
else:
    search_space = {
        'model__n_estimators': [200, 300, 500],
        'model__learning_rate': [0.05, 0.1, 0.2],
        'model__max_depth': [4, 6, 8],
        'model__subsample': [0.7, 0.8, 0.9],
        'model__colsample_bytree': [0.7, 0.8, 0.9],
        'model__min_child_weight': [1, 3, 5],
        'model__gamma': [0, 0.25, 0.5]
    }
    pipeline = Pipeline([
        ('scaler', StandardScaler()),
        ('model', XGBClassifier(**base_params))
    ])
    search = RandomizedSearchCV(
        pipeline,
        param_distributions=search_space,
        n_iter=15,
        scoring='average_precision',
        n_jobs=4,
        cv=3,
        random_state=42,
        verbose=1
    )
    search.fit(X_train, y_train)
    best_params = search.best_estimator_.named_steps['model'].get_params()
    print('Best params:', search.best_params_)

model = Pipeline([
    ('scaler', StandardScaler()),
    ('model', XGBClassifier(**best_params))
])

model.fit(X_train, y_train)
probs = model.predict_proba(X_test)[:, 1]
preds = model.predict(X_test)

metrics = {
    'accuracy': model.score(X_test, y_test),
    'roc_auc': roc_auc_score(y_test, probs),
    'pr_auc': average_precision_score(y_test, probs)
}
precision_vals, recall_vals, _ = precision_recall_curve(y_test, probs)
f1_scores = 2 * (precision_vals * recall_vals) / (precision_vals + recall_vals + 1e-9)
metrics['f1'] = float(np.max(f1_scores))
print('XGBoost metrics:', metrics)
print('
Classification report:
', classification_report(y_test, preds, digits=3))

## Plots and model persistence

Save confusion matrix, ROC, and precision-recall curves for the XGBoost model.

In [None]:
cm = confusion_matrix(y_test, preds, labels=[0, 1])
disp = ConfusionMatrixDisplay(cm, display_labels=['Non-hit', 'Hit'])
disp.plot(values_format='d')
plt.title('XGBoost Confusion Matrix')
plt.tight_layout()
plt.savefig(FIG_DIR / 'xgboost_confusion_matrix.png', dpi=300)
plt.close()

fpr, tpr, _ = roc_curve(y_test, probs)
fig, ax = plt.subplots(figsize=(5, 4))
ax.plot(fpr, tpr, label=f"ROC AUC = {metrics['roc_auc']:.3f}")
ax.plot([0, 1], [0, 1], '--', color='gray')
ax.set_title('XGBoost ROC Curve')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.legend()
fig.tight_layout()
fig.savefig(FIG_DIR / 'xgboost_roc.png', dpi=300)
plt.close(fig)

precision_vals, recall_vals, _ = precision_recall_curve(y_test, probs)
fig, ax = plt.subplots(figsize=(5, 4))
ax.plot(recall_vals, precision_vals, label=f"PR AUC = {metrics['pr_auc']:.3f}")
ax.set_title('XGBoost Precision-Recall Curve')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.legend()
fig.tight_layout()
fig.savefig(FIG_DIR / 'xgboost_pr.png', dpi=300)
plt.close(fig)

model_path = MODEL_DIR / 'final_xgboost.pkl'
joblib.dump(model, model_path)
print(f'Saved XGBoost model to {model_path}')

with open(MODEL_DIR / 'xgboost_metrics.json', 'w') as f:
    json.dump(metrics, f, indent=2)
print('Stored XGBoost metrics for comparison.')

## SHAP interpretability

Compute SHAP values on a sample for efficiency, then visualize feature impact with bar and beeswarm plots.

In [None]:
explainer = shap.TreeExplainer(model.named_steps['model'])
sample_size = min(1000, len(X_train))
shap_sample = X_train.sample(sample_size, random_state=42)
shap_values = explainer.shap_values(shap_sample)

plt.figure(figsize=(6, 4))
shap.summary_plot(shap_values, shap_sample, plot_type='bar', show=False)
plt.tight_layout()
plt.savefig(FIG_DIR / 'shap_feature_importance.png', dpi=300)
plt.close()

plt.figure(figsize=(7, 5))
shap.summary_plot(shap_values, shap_sample, show=False)
plt.tight_layout()
plt.savefig(FIG_DIR / 'shap_summary_detailed.png', dpi=300)
plt.close()

print('Saved SHAP plots to figures/.')

## Compare against baseline

Load Week 2 metrics (if available) to quantify the lift from XGBoost.

In [None]:
comparison = {'xgboost': metrics}
if BASELINE_METRICS.exists():
    with open(BASELINE_METRICS) as f:
        comparison['log_reg'] = json.load(f)

print(json.dumps(comparison, indent=2))