
# Early Breast Cancer Risk Stratification – SHAP Interpretability

In this iteration, we assume access to the full **1.5 million**‑record risk factor dataset (the 150 k sample used here represents roughly 10 % of it) and focus on model interpretability using **SHAP** (SHapley Additive exPlanations).  SHAP values provide insight into how each feature contributes to individual predictions, offering transparency crucial for clinical applications.  We build a tree‑based model and compute SHAP values, taking into account that each row summarises many women via the `count` column.



In [None]:

# ## 1. Load the data and prepare training/testing sets
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import numpy as np

# Load the 10% sample dataset (stand‑in for full dataset)
df = pd.read_csv('/home/oai/share/sample_10percent.csv')

# Remove rows with unknown target
df = df[df['breast_cancer_history'] != 9].copy()

# Define categorical and numeric columns
categorical_cols = [col for col in df.columns if col not in ['year','count','breast_cancer_history']]
numerical_cols = ['year']

# Separate target and features
y = df['breast_cancer_history']
X = df.drop(columns=['breast_cancer_history'])

# Capture sample weights
weights = X['count'].values

# Remove count from features for modelling
X = X.drop(columns=['count'])

# Build preprocessor
preprocessor = ColumnTransformer([
    ('categorical', OneHotEncoder(handle_unknown='ignore'), categorical_cols),
    ('numeric', 'passthrough', numerical_cols)
])

# Split data
X_train, X_test, y_train, y_test, w_train, w_test = train_test_split(
    X, y, weights, test_size=0.2, random_state=42, stratify=y
)

print('Train size:', X_train.shape)
print('Test size:', X_test.shape)


In [None]:

# ## 2. Train a tree‑based model for SHAP analysis

from sklearn.ensemble import RandomForestClassifier

# Use a reasonably strong random forest; in practice one would use tuned hyperparameters
model = Pipeline([
    ('preprocessor', preprocessor),
    ('classifier', RandomForestClassifier(n_estimators=400, max_depth=20, class_weight='balanced', random_state=42))
])

# Fit model using sample weights
model.fit(X_train, y_train, classifier__sample_weight=w_train)

# Evaluate on test set
prob_test = model.predict_proba(X_test)[:,1]
pred_test = (prob_test >= 0.5).astype(int)

acc = accuracy_score(y_test, pred_test)
f1 = f1_score(y_test, pred_test)
roc_auc = roc_auc_score(y_test, prob_test)

print('Random Forest performance:')
print('  Accuracy:', acc)
print('  F1 score:', f1)
print('  ROC AUC:', roc_auc)


In [None]:

# ## 3. SHAP analysis for interpretability

# SHAP can be computationally intensive.  We'll sample a subset of the training data to compute SHAP values.
# We also need to transform the data using the fitted preprocessor and extract the tree model.

# Attempt to import shap; install if not already available
import importlib
try:
    import shap
except ImportError:
    import sys
    !{sys.executable} -m pip install shap
    import shap

# Prepare a background sample (1000 samples) for SHAP explainer
sample_indices = np.random.choice(len(X_train), size=min(1000, len(X_train)), replace=False)
X_train_sample = X_train.iloc[sample_indices]

# Fit the preprocessor on full training data and transform samples
X_train_enc = model.named_steps['preprocessor'].fit_transform(X_train)
X_test_enc = model.named_steps['preprocessor'].transform(X_test)

# Extract the trained RandomForestClassifier
rf = model.named_steps['classifier']

# Use TreeExplainer for tree-based models
explainer = shap.TreeExplainer(rf)

# Compute SHAP values for the test set (limit to first 200 samples for speed)
sample_for_shap = min(200, X_test_enc.shape[0])
shap_values = explainer.shap_values(X_test_enc[:sample_for_shap])

# Get feature names from the one‑hot encoder
feature_names = list(model.named_steps['preprocessor'].transformers_[0][1].get_feature_names_out(categorical_cols)) + numerical_cols

# Create a SHAP summary plot for the positive class (index 1)
shap.summary_plot(shap_values[1], X_test_enc[:sample_for_shap], feature_names=feature_names, plot_type='bar', show=False)
plt.title('Mean absolute SHAP values – Random Forest')
plt.show()

# Additionally, display a beeswarm plot to visualise the distribution of SHAP values
shap.summary_plot(shap_values[1], X_test_enc[:sample_for_shap], feature_names=feature_names, show=False)
plt.title('SHAP summary plot (class = 1) – Random Forest')
plt.show()

# Explain a single instance (e.g. the first test sample)
instance_index = 0
shap.initjs()
shap.force_plot(explainer.expected_value[1], shap_values[1][instance_index], features=X_test_enc[instance_index], feature_names=feature_names)



# ## 4. Discussion: Interpreting SHAP outputs

The SHAP bar plot ranks features by their mean absolute contribution to the model’s output across the sampled test instances.  Features with higher mean SHAP values influence predictions more strongly.  For example, age group, BMI group, breast density and family history typically show higher contributions, consistent with our earlier feature‑importance analyses.

The beeswarm plot provides a more detailed view: each point represents a SHAP value for a feature in an individual case.  Points are coloured by the original feature value (after encoding), showing how high or low values push the prediction towards or away from the positive class (prior breast cancer).  Clusters of points reveal whether the model’s decisions are consistent across different subgroups.

Finally, the force plot explains an individual prediction by showing how each feature contributes to pushing the base value (average prediction) towards the final probability for that instance.  Such local explanations are useful when discussing specific recommendations with clinicians or patients.

Remember that SHAP analyses are approximate when working with aggregated data—each row summarises many individuals, so the interpretation reflects the average effect within each group rather than specific individuals.

