# 05. Generate SHAP Analysis

計算 SHAP 值用於模型可解釋性。

**兩層輸出：**
- **Global SHAP**: 整體特徵重要性 → 放入 System Prompt
- **Local SHAP Examples**: 典型案例的 SHAP 分解 → 供 RAG 參考

In [None]:
import pandas as pd
import numpy as np
import json
import os
import lightgbm as lgb
import shap
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

# Paths
DATA_PATH = '../../data/ilec_cleaned.parquet'
MODEL_PATH = '../../models/lgbm_mortality_offset_poisson.txt'
OUTPUT_DIR = '../../knowledge_base/shap'

os.makedirs(OUTPUT_DIR, exist_ok=True)
print('Setup complete')

In [None]:
# Load data (smaller sample for SHAP - 1%)
print('Loading data...')
df_full = pd.read_parquet(DATA_PATH)
print(f'Full data: {len(df_full):,} rows')

# 1% sample for SHAP (SHAP is computationally expensive)
df = df_full.sample(frac=0.01, random_state=42).reset_index(drop=True)
print(f'Sampled 1%: {len(df):,} rows')
del df_full

In [None]:
# Load model
print('Loading LightGBM model...')
model = lgb.Booster(model_file=MODEL_PATH)
print(f'Model loaded: {model.num_trees()} trees')

In [None]:
# Prepare features
FEATURES = ['Attained_Age', 'Issue_Age', 'Duration', 'Sex', 'Smoker_Status', 
            'Insurance_Plan', 'Face_Amount_Band', 'Preferred_Class', 
            'SOA_Post_Lvl_Ind', 'SOA_Antp_Lvl_TP', 'SOA_Guar_Lvl_TP']

CATEGORICAL = ['Sex', 'Smoker_Status', 'Insurance_Plan', 'Face_Amount_Band', 
               'Preferred_Class', 'SOA_Post_Lvl_Ind', 'SOA_Antp_Lvl_TP', 'SOA_Guar_Lvl_TP']

# Encode
X = df[FEATURES].copy()
encoders = {}
for col in CATEGORICAL:
    le = LabelEncoder()
    X[col] = le.fit_transform(X[col].astype(str))
    encoders[col] = le

# Save encoders for later use
encoder_mappings = {}
for col, le in encoders.items():
    encoder_mappings[col] = {str(i): str(c) for i, c in enumerate(le.classes_)}

print(f'Features prepared: {X.shape}')

## Global SHAP

In [None]:
# Create SHAP explainer
print('Creating SHAP explainer...')
explainer = shap.TreeExplainer(model)

# Calculate SHAP values (this may take a few minutes)
print('Calculating SHAP values (this may take a few minutes)...')
shap_values = explainer.shap_values(X.values)
print(f'SHAP values shape: {shap_values.shape}')

In [None]:
# Global feature importance
mean_abs_shap = np.abs(shap_values).mean(axis=0)
feature_importance = pd.DataFrame({
    'feature': FEATURES,
    'mean_abs_shap': mean_abs_shap
}).sort_values('mean_abs_shap', ascending=False).reset_index(drop=True)

feature_importance['rank'] = range(1, len(feature_importance) + 1)
feature_importance['pct_contribution'] = (feature_importance['mean_abs_shap'] / 
                                          feature_importance['mean_abs_shap'].sum() * 100)

print('=== Global Feature Importance (SHAP) ===')
print(feature_importance.to_string(index=False))

In [None]:
# Save global importance
base_value = float(explainer.expected_value)

global_shap = {
    "description": "Global SHAP feature importance for mortality prediction model",
    "base_value": round(base_value, 6),
    "base_value_interpretation": f"Average predicted mortality rate: {base_value:.6f}",
    "sample_size": len(df),
    "feature_importance": []
}

for _, row in feature_importance.iterrows():
    global_shap["feature_importance"].append({
        "feature": row['feature'],
        "rank": int(row['rank']),
        "mean_abs_shap": round(row['mean_abs_shap'], 6),
        "pct_contribution": round(row['pct_contribution'], 2)
    })

with open(f'{OUTPUT_DIR}/global_importance.json', 'w') as f:
    json.dump(global_shap, f, indent=2)

print(f'✓ global_importance.json saved')

In [None]:
# SHAP summary plot
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X, feature_names=FEATURES, show=False)
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/shap_summary_plot.png', dpi=150, bbox_inches='tight')
plt.show()
print(f'✓ shap_summary_plot.png saved')

## Local SHAP Examples

In [None]:
# Create typical case examples
# High risk case
high_risk_idx = df[df['Death_Count'] > 0].sample(1, random_state=42).index[0]

# Low risk case (young, non-smoker)
low_risk_mask = (df['Attained_Age'] < 40) & (df['Smoker_Status'] == 'NS') & (df['Death_Count'] == 0)
low_risk_idx = df[low_risk_mask].sample(1, random_state=42).index[0]

# Average case
avg_age = df['Attained_Age'].median()
avg_mask = (df['Attained_Age'] >= avg_age - 5) & (df['Attained_Age'] <= avg_age + 5)
avg_idx = df[avg_mask].sample(1, random_state=42).index[0]

example_indices = {
    'high_risk': high_risk_idx,
    'low_risk': low_risk_idx,
    'average': avg_idx
}

print('Example indices selected:')
for name, idx in example_indices.items():
    print(f'  {name}: index {idx}, Age={df.loc[idx, "Attained_Age"]}, Smoker={df.loc[idx, "Smoker_Status"]}')

In [None]:
# Generate local SHAP for each example
local_examples = []

for case_name, idx in example_indices.items():
    # Get original values
    original_row = df.loc[idx, FEATURES].to_dict()
    
    # Get encoded values
    encoded_row = X.loc[idx].values
    
    # Get prediction
    prediction = model.predict([encoded_row])[0]
    
    # Get SHAP values for this row
    row_shap = shap_values[df.index.get_loc(idx)]
    
    # Build SHAP breakdown
    shap_breakdown = []
    for i, feat in enumerate(FEATURES):
        shap_breakdown.append({
            "feature": feat,
            "value": str(original_row[feat]),
            "shap_value": round(float(row_shap[i]), 6),
            "direction": "increases" if row_shap[i] > 0 else "decreases"
        })
    
    # Sort by absolute SHAP value
    shap_breakdown = sorted(shap_breakdown, key=lambda x: abs(x['shap_value']), reverse=True)
    
    # Top drivers
    top_drivers = shap_breakdown[:3]
    
    example = {
        "case_type": case_name,
        "input": original_row,
        "prediction": round(float(prediction), 6),
        "base_value": round(base_value, 6),
        "shap_breakdown": shap_breakdown,
        "top_drivers": top_drivers,
        "explanation": f"Predicted mortality rate: {prediction:.6f}. "
                       f"Main drivers: {top_drivers[0]['feature']} ({top_drivers[0]['direction']} risk), "
                       f"{top_drivers[1]['feature']} ({top_drivers[1]['direction']} risk)."
    }
    
    local_examples.append(example)
    
    print(f"\n=== {case_name.upper()} ===")
    print(f"Prediction: {prediction:.6f}")
    print(f"Top 3 drivers:")
    for d in top_drivers:
        print(f"  {d['feature']}: {d['value']} → SHAP={d['shap_value']:+.6f}")

In [None]:
# Save local examples
local_shap = {
    "description": "Local SHAP examples for typical cases",
    "base_value": round(base_value, 6),
    "examples": local_examples
}

with open(f'{OUTPUT_DIR}/local_examples.json', 'w') as f:
    json.dump(local_shap, f, indent=2)

print(f'\n✓ local_examples.json saved')

In [None]:
# Waterfall plot for high risk case
high_risk_loc = df.index.get_loc(high_risk_idx)
shap.plots.waterfall(shap.Explanation(
    values=shap_values[high_risk_loc],
    base_values=base_value,
    data=X.iloc[high_risk_loc].values,
    feature_names=FEATURES
), show=False)
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/shap_waterfall_high_risk.png', dpi=150, bbox_inches='tight')
plt.show()
print(f'✓ shap_waterfall_high_risk.png saved')

## Summary

In [None]:
print('='*60)
print('SHAP ANALYSIS COMPLETE')
print('='*60)
print(f'\nBase value (avg prediction): {base_value:.6f}')
print(f'\nTop 5 Features by Importance:')
for i, row in feature_importance.head(5).iterrows():
    print(f"  {row['rank']}. {row['feature']}: {row['pct_contribution']:.1f}%")

print(f'\nGenerated files:')
print(f'  ✓ global_importance.json')
print(f'  ✓ local_examples.json')
print(f'  ✓ shap_summary_plot.png')
print(f'  ✓ shap_waterfall_high_risk.png')