# LightGBM Model Explainability & Visualization

Comprehensive analysis for the LightGBM mortality prediction model:

1. **Feature Importance** (Gain, Split)
2. **Tree Structure Visualization**
3. **SHAP-style Analysis** (Using LGB native methods)
4. **Partial Dependence Plots**
5. **Interaction Effects**
6. **Year Trend Factors**

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import lightgbm as lgb
import warnings
warnings.filterwarnings('ignore')

plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

print('Libraries loaded!')

## 1. Load Model and Data

In [None]:
# Load trained model
model = lgb.Booster(model_file='../models/lgbm_mortality_offset_poisson.txt')
print(f'Model loaded with {model.num_trees()} trees')

# Load year factors
year_factors = pd.read_csv('../models/year_factors_offset.csv')
print('Year factors:')
display(year_factors)

# Load data
df = pd.read_parquet('../data/ilec_cleaned.parquet')
print(f'\nData shape: {df.shape}')

In [None]:
# Feature setup matching training
NUMERICAL_FEATURES = ['Attained_Age', 'Issue_Age', 'Duration']
CATEGORICAL_FEATURES = [
    'Sex', 'Smoker_Status', 'Insurance_Plan', 'Face_Amount_Band',
    'Preferred_Class', 'SOA_Post_Lvl_Ind', 'SOA_Antp_Lvl_TP', 'SOA_Guar_Lvl_TP'
]
FEATURES = NUMERICAL_FEATURES + CATEGORICAL_FEATURES

# Sample for analysis
np.random.seed(42)
sample_size = min(10000, len(df))
sample_idx = np.random.choice(len(df), size=sample_size, replace=False)
X_sample = df[FEATURES].iloc[sample_idx].copy()

# Convert categorical features to category dtype (matching training)
for col in CATEGORICAL_FEATURES:
    X_sample[col] = X_sample[col].astype('category')

print(f'Sample size: {len(X_sample)}')

---
## 2. Feature Importance Analysis

In [None]:
# Get feature importance
feature_names = model.feature_name()
importance_gain = model.feature_importance(importance_type='gain')
importance_split = model.feature_importance(importance_type='split')

# Create DataFrame
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Gain': importance_gain,
    'Split': importance_split,
    'Gain_Pct': importance_gain / importance_gain.sum() * 100,
    'Split_Pct': importance_split / importance_split.sum() * 100
}).sort_values('Gain', ascending=False)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Gain importance
ax = axes[0]
colors = plt.cm.Blues(np.linspace(0.4, 1.0, len(importance_df)))
bars = ax.barh(importance_df['Feature'], importance_df['Gain_Pct'], color=colors[::-1])
ax.set_xlabel('Importance (%)')
ax.set_title('Feature Importance by Gain', fontweight='bold')
ax.invert_yaxis()
for bar, val in zip(bars, importance_df['Gain_Pct']):
    ax.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2, f'{val:.1f}%', va='center', fontsize=9)

# Split importance
ax = axes[1]
importance_split_sorted = importance_df.sort_values('Split', ascending=False)
colors = plt.cm.Greens(np.linspace(0.4, 1.0, len(importance_split_sorted)))
bars = ax.barh(importance_split_sorted['Feature'], importance_split_sorted['Split_Pct'], color=colors[::-1])
ax.set_xlabel('Importance (%)')
ax.set_title('Feature Importance by Split Count', fontweight='bold')
ax.invert_yaxis()

plt.tight_layout()
plt.savefig('../data/plots/lgbm_feature_importance.png', dpi=150, bbox_inches='tight')
plt.show()

print('\nFeature Importance Table:')
display(importance_df)

---
## 3. Tree Structure Visualization

In [None]:
# Plot the first tree
fig, ax = plt.subplots(figsize=(20, 12))
lgb.plot_tree(model, tree_index=0, ax=ax, show_info=['split_gain', 'leaf_count'])
ax.set_title('LightGBM Tree #1 Structure', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig('../data/plots/lgbm_tree_0.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Multiple trees comparison
fig, axes = plt.subplots(2, 2, figsize=(24, 16))

tree_indices = [0, 10, 50, 100]
for idx, tree_idx in enumerate(tree_indices):
    ax = axes[idx // 2, idx % 2]
    try:
        lgb.plot_tree(model, tree_index=tree_idx, ax=ax)
        ax.set_title(f'Tree #{tree_idx + 1}', fontsize=12, fontweight='bold')
    except Exception as e:
        ax.text(0.5, 0.5, f'Tree {tree_idx} not available', ha='center', va='center')

plt.tight_layout()
plt.savefig('../data/plots/lgbm_trees_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 4. Feature Contribution Analysis (SHAP-like)

Using LightGBM's native `pred_contrib` for feature contributions.

In [None]:
# Use LightGBM native feature contributions (pred_contrib)
# This is similar to SHAP but uses LightGBM's internal method

# Get contributions for a sample
contributions = model.predict(X_sample, pred_contrib=True)

# contributions has shape (n_samples, n_features + 1)
# Last column is the base value (expected value)
print(f'Contributions shape: {contributions.shape}')
print(f'Features: {len(FEATURES)}, Contributions columns: {contributions.shape[1]}')

# Separate feature contributions and base value
feature_contribs = contributions[:, :-1]
base_value = contributions[:, -1].mean()

print(f'Base value (expected prediction): {base_value:.6f}')

In [None]:
# Mean absolute contribution by feature (similar to SHAP bar plot)
mean_abs_contrib = np.abs(feature_contribs).mean(axis=0)

contrib_df = pd.DataFrame({
    'Feature': FEATURES,
    'Mean_Abs_Contribution': mean_abs_contrib
}).sort_values('Mean_Abs_Contribution', ascending=True)

# Plot
fig, ax = plt.subplots(figsize=(10, 6))
colors = plt.cm.Reds(np.linspace(0.4, 1.0, len(contrib_df)))
ax.barh(contrib_df['Feature'], contrib_df['Mean_Abs_Contribution'], color=colors)
ax.set_xlabel('Mean |Contribution|')
ax.set_title('Feature Contributions (SHAP-like)', fontweight='bold')
plt.tight_layout()
plt.savefig('../data/plots/lgbm_feature_contributions.png', dpi=150, bbox_inches='tight')
plt.show()

print('\nFeature Contribution Ranking:')
display(contrib_df.sort_values('Mean_Abs_Contribution', ascending=False))

In [None]:
# Feature contribution vs feature value (like SHAP dependence plot)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

top_features = ['Attained_Age', 'Duration', 'Issue_Age']
for i, feature in enumerate(top_features):
    ax = axes[i]
    feat_idx = FEATURES.index(feature)
    
    # Get feature values and contributions
    feat_values = X_sample[feature].values
    feat_contribs = feature_contribs[:, feat_idx]
    
    # Scatter plot
    scatter = ax.scatter(feat_values, feat_contribs, c=feat_contribs, 
                        cmap='RdBu_r', alpha=0.5, s=10)
    ax.axhline(0, color='gray', linestyle='--', alpha=0.5)
    ax.set_xlabel(feature)
    ax.set_ylabel('Contribution')
    ax.set_title(f'Contribution Dependence: {feature}', fontweight='bold')
    plt.colorbar(scatter, ax=ax, label='Contribution')

plt.tight_layout()
plt.savefig('../data/plots/lgbm_contribution_dependence.png', dpi=150, bbox_inches='tight')
plt.show()

### Local Explanation (Individual Predictions)

In [None]:
# Find high-risk and low-risk cases
predictions = model.predict(X_sample)
high_risk_idx = np.argmax(predictions)
low_risk_idx = np.argmin(predictions)

print('='*60)
print('HIGH RISK CASE')
print('='*60)
print(f'Predicted mortality rate: {predictions[high_risk_idx]:.6f}')
print(f'\nFeature Values:')
for feat in FEATURES:
    print(f'  {feat}: {X_sample[feat].iloc[high_risk_idx]}')

print('\n' + '='*60)
print('LOW RISK CASE')
print('='*60)
print(f'Predicted mortality rate: {predictions[low_risk_idx]:.6f}')
print(f'\nFeature Values:')
for feat in FEATURES:
    print(f'  {feat}: {X_sample[feat].iloc[low_risk_idx]}')

In [None]:
# Waterfall plot for high-risk case
def plot_waterfall(contributions, features, base_value, title, ax):
    # Sort by absolute contribution
    sorted_idx = np.argsort(np.abs(contributions))[::-1]
    
    # Prepare data
    cumsum = base_value
    colors = ['#E94F37' if c > 0 else '#2E86AB' for c in contributions[sorted_idx]]
    
    y_pos = np.arange(len(features))
    ax.barh(y_pos, contributions[sorted_idx], color=colors, alpha=0.8)
    ax.set_yticks(y_pos)
    ax.set_yticklabels([features[i] for i in sorted_idx])
    ax.axvline(0, color='gray', linestyle='-', alpha=0.5)
    ax.set_xlabel('Contribution')
    ax.set_title(title, fontweight='bold')
    ax.invert_yaxis()

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# High risk
plot_waterfall(feature_contribs[high_risk_idx], FEATURES, base_value, 
               f'High Risk Case (pred={predictions[high_risk_idx]:.4f})', axes[0])

# Low risk
plot_waterfall(feature_contribs[low_risk_idx], FEATURES, base_value, 
               f'Low Risk Case (pred={predictions[low_risk_idx]:.6f})', axes[1])

plt.tight_layout()
plt.savefig('../data/plots/lgbm_local_explanations.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 5. Partial Dependence Plots

In [None]:
# Partial dependence for numerical features with confidence bands
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, feature in enumerate(['Attained_Age', 'Duration', 'Issue_Age']):
    ax = axes[i]
    
    feature_values = np.linspace(X_sample[feature].min(), X_sample[feature].max(), 50)
    pdp_values = []
    pdp_std = []
    
    for val in feature_values:
        X_temp = X_sample.copy()
        X_temp[feature] = val
        preds = model.predict(X_temp)
        pdp_values.append(preds.mean())
        pdp_std.append(preds.std())
    
    pdp_values = np.array(pdp_values)
    pdp_std = np.array(pdp_std)
    
    ax.plot(feature_values, pdp_values, linewidth=2, color='#E94F37', label='Mean')
    ax.fill_between(feature_values, pdp_values - pdp_std, pdp_values + pdp_std, 
                    alpha=0.2, color='#E94F37', label='±1 Std')
    ax.set_xlabel(feature)
    ax.set_ylabel('Predicted Mortality Rate')
    ax.set_title(f'Partial Dependence: {feature}', fontweight='bold')
    ax.legend(loc='upper left')
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../data/plots/lgbm_partial_dependence.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 6. 2D Interaction Effect

In [None]:
# 2D Partial Dependence (Age x Duration)
fig, ax = plt.subplots(figsize=(10, 8))

age_range = np.linspace(30, 80, 20)
duration_range = np.linspace(1, 25, 20)

pdp_2d = np.zeros((len(age_range), len(duration_range)))

for i, age in enumerate(age_range):
    for j, dur in enumerate(duration_range):
        X_temp = X_sample.iloc[:100].copy()
        X_temp['Attained_Age'] = age
        X_temp['Duration'] = dur
        pdp_2d[i, j] = model.predict(X_temp).mean()

im = ax.contourf(duration_range, age_range, pdp_2d, levels=20, cmap='RdYlBu_r')
plt.colorbar(im, ax=ax, label='Predicted Mortality Rate')
ax.set_xlabel('Duration')
ax.set_ylabel('Attained Age')
ax.set_title('2D Partial Dependence: Age × Duration', fontweight='bold')
plt.tight_layout()
plt.savefig('../data/plots/lgbm_2d_pdp.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 7. Categorical Feature Impact

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

cat_features = ['Sex', 'Smoker_Status', 'Insurance_Plan', 'Preferred_Class']

for idx, feature in enumerate(cat_features):
    ax = axes[idx // 2, idx % 2]
    
    categories = X_sample[feature].cat.categories
    mean_preds = []
    std_preds = []
    
    for cat in categories:
        mask = X_sample[feature] == cat
        if mask.sum() > 0:
            preds = model.predict(X_sample[mask])
            mean_preds.append(preds.mean())
            std_preds.append(preds.std())
        else:
            mean_preds.append(0)
            std_preds.append(0)
    
    bars = ax.bar(range(len(categories)), mean_preds, yerr=std_preds, 
                  color='#2E86AB', alpha=0.8, capsize=3)
    ax.set_xticks(range(len(categories)))
    ax.set_xticklabels([str(c)[:12] for c in categories], rotation=45, ha='right')
    ax.set_xlabel(feature)
    ax.set_ylabel('Mean Predicted Mortality Rate')
    ax.set_title(f'Mortality by {feature}', fontweight='bold')

plt.tight_layout()
plt.savefig('../data/plots/lgbm_categorical_impact.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 8. Year Trend Visualization

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

ax.plot(year_factors['Year'], year_factors['Year_Factor'], 
        marker='o', markersize=10, linewidth=2, color='#2E86AB')
ax.axhline(1.0, color='gray', linestyle='--', alpha=0.7, label='Baseline (1.0)')
ax.fill_between(year_factors['Year'], 1.0, year_factors['Year_Factor'], 
                alpha=0.3, color='#2E86AB')

ax.set_xlabel('Observation Year')
ax.set_ylabel('Year Factor (Multiplicative)')
ax.set_title('Mortality Year Trend Factors', fontweight='bold')
ax.set_ylim(0.95, 1.05)
ax.legend()

for _, row in year_factors.iterrows():
    ax.annotate(f"{row['Year_Factor']:.3f}", 
                (row['Year'], row['Year_Factor']),
                textcoords='offset points', xytext=(0, 10), ha='center', fontsize=9)

plt.tight_layout()
plt.savefig('../data/plots/lgbm_year_factors.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 9. Model Diagnostics

In [None]:
# Prediction distribution
predictions = model.predict(X_sample)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
ax.hist(predictions, bins=50, color='#2E86AB', alpha=0.7, edgecolor='black')
ax.axvline(predictions.mean(), color='red', linestyle='--', label=f'Mean: {predictions.mean():.4f}')
ax.axvline(np.median(predictions), color='green', linestyle='--', label=f'Median: {np.median(predictions):.4f}')
ax.set_xlabel('Predicted Mortality Rate')
ax.set_ylabel('Frequency')
ax.set_title('Distribution of Predictions', fontweight='bold')
ax.legend()

ax = axes[1]
ax.hist(np.log10(predictions + 1e-10), bins=50, color='#E94F37', alpha=0.7, edgecolor='black')
ax.set_xlabel('Log10(Predicted Mortality Rate)')
ax.set_ylabel('Frequency')
ax.set_title('Distribution of Predictions (Log Scale)', fontweight='bold')

plt.tight_layout()
plt.savefig('../data/plots/lgbm_prediction_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

---
## 10. Summary Report

In [None]:
print('='*70)
print('LGBM MODEL EXPLAINABILITY SUMMARY')
print('='*70)
print(f'\nModel: Offset Poisson LightGBM')
print(f'Number of trees: {model.num_trees()}')
print(f'Number of features: {len(FEATURES)}')

print(f'\n--- Top 3 Features by Importance (Gain) ---')
for _, row in importance_df.head(3).iterrows():
    print(f"  {row['Feature']}: {row['Gain_Pct']:.1f}%")

print(f'\n--- Top 3 Features by Contribution ---')
for _, row in contrib_df.sort_values('Mean_Abs_Contribution', ascending=False).head(3).iterrows():
    print(f"  {row['Feature']}: {row['Mean_Abs_Contribution']:.4f}")

print(f'\n--- Prediction Statistics ---')
print(f'  Mean: {predictions.mean():.6f}')
print(f'  Median: {np.median(predictions):.6f}')
print(f'  Std: {predictions.std():.6f}')
print(f'  Range: [{predictions.min():.6f}, {predictions.max():.6f}]')

print(f'\n--- Year Factor Range ---')
print(f'  Min: {year_factors["Year_Factor"].min():.4f}')
print(f'  Max: {year_factors["Year_Factor"].max():.4f}')

print(f'\n--- Plots Saved (../data/plots/) ---')
plots = [
    'lgbm_feature_importance.png',
    'lgbm_tree_0.png', 'lgbm_trees_comparison.png',
    'lgbm_feature_contributions.png', 'lgbm_contribution_dependence.png',
    'lgbm_local_explanations.png', 'lgbm_partial_dependence.png',
    'lgbm_2d_pdp.png', 'lgbm_categorical_impact.png',
    'lgbm_year_factors.png', 'lgbm_prediction_distribution.png'
]
for p in plots:
    print(f'  ✓ {p}')

print('='*70)