# Step 5: Explainable AI - SHAP Analysis

## Objective
Apply SHAP (SHapley Additive exPlanations) to understand model predictions and quantify transboundary pollution influence.

### Tasks:
1. Apply SHAP analysis to best models
2. Extract global feature importance
3. Analyze individual predictions
4. Quantify country-to-country influence
5. Create influence strength matrices
6. Generate policy-relevant insights

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import os
import json
import joblib
import warnings
warnings.filterwarnings('ignore')

# SHAP for explainability
import shap

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Set display options
pd.set_option('display.max_columns', None)
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 8)

print("Libraries loaded successfully!")
print(f"SHAP version: {shap.__version__}")

## 5.1 Load Models and Data

In [None]:
# Define paths
DATA_PATH = './processed_data/'
MODEL_PATH = './model_outputs/'
OUTPUT_PATH = './shap_outputs/'
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Load data
print("Loading data and models...")
data = pd.read_pickle(os.path.join(DATA_PATH, 'features_engineered_with_country.pkl'))

# Load feature info
with open(os.path.join(DATA_PATH, 'feature_info.json'), 'r') as f:
    feature_info = json.load(f)

# Load model performance
results_df = pd.read_csv(os.path.join(MODEL_PATH, 'model_performance.csv'))

# Load predictions
predictions_df = pd.read_pickle(os.path.join(MODEL_PATH, 'test_predictions.pkl'))

TARGET_POLLUTANTS = ['CO', 'NO2', 'PM10']
EXCLUDE_COLS = ['country', 'date', 'season'] + TARGET_POLLUTANTS
feature_columns = [col for col in data.columns if col not in EXCLUDE_COLS]

# Load best models
best_models = {}
for pollutant in TARGET_POLLUTANTS:
    model_file = os.path.join(MODEL_PATH, f'best_model_{pollutant}.pkl')
    best_models[pollutant] = joblib.load(model_file)
    model_name = results_df[results_df['Pollutant'] == pollutant].sort_values('R²', ascending=False).iloc[0]['Model']
    print(f"✓ Loaded {pollutant} model: {model_name}")

print(f"\nData shape: {data.shape}")
print(f"Features: {len(feature_columns)}")

## 5.2 Prepare Test Data for SHAP

In [None]:
# Sort and split data (same as training)
data_clean = data.copy()
data_clean = data_clean.replace([np.inf, -np.inf], np.nan)
data_clean = data_clean.fillna(0)
data_clean = data_clean.sort_values('date').reset_index(drop=True)

# Use same split as training (80/20)
split_idx = int(len(data_clean) * 0.8)
test_data = data_clean.iloc[split_idx:].copy()

# Prepare feature matrix
X_test = test_data[feature_columns]

# Sample for SHAP (SHAP can be slow on large datasets)
# Use stratified sampling by country
sample_size = min(1000, len(test_data))
X_test_sample = test_data.groupby('country', group_keys=False).apply(
    lambda x: x.sample(min(len(x), sample_size // test_data['country'].nunique()))
).sample(sample_size)

X_shap = X_test_sample[feature_columns]
test_sample_data = X_test_sample.copy()

print(f"Test set size: {len(test_data):,}")
print(f"SHAP sample size: {len(X_shap):,}")
print(f"Countries in sample: {X_test_sample['country'].nunique()}")

## 5.3 SHAP Analysis for Each Pollutant

In [None]:
# Initialize SHAP explainers and values storage
shap_explainers = {}
shap_values_dict = {}

print("="*80)
print("COMPUTING SHAP VALUES")
print("="*80)
print("This may take several minutes...\n")

for pollutant in TARGET_POLLUTANTS:
    print(f"\nProcessing {pollutant}...")
    
    model = best_models[pollutant]
    
    try:
        # Create SHAP explainer (TreeExplainer for tree-based models)
        print(f"  Creating explainer...")
        explainer = shap.TreeExplainer(model)
        
        # Calculate SHAP values
        print(f"  Calculating SHAP values...")
        shap_values = explainer.shap_values(X_shap)
        
        # Store results
        shap_explainers[pollutant] = explainer
        shap_values_dict[pollutant] = shap_values
        
        print(f"  ✓ SHAP values computed for {pollutant}")
        print(f"    Shape: {shap_values.shape}")
        
    except Exception as e:
        print(f"  ✗ Error computing SHAP for {pollutant}: {str(e)}")
        continue

print("\n" + "="*80)
print("✓ SHAP COMPUTATION COMPLETED")
print("="*80)

## 5.4 Global Feature Importance (SHAP)

In [None]:
# SHAP summary plots (global importance)
for pollutant in TARGET_POLLUTANTS:
    if pollutant not in shap_values_dict:
        continue
    
    print(f"\nCreating SHAP summary plot for {pollutant}...")
    
    plt.figure(figsize=(12, 10))
    shap.summary_plot(
        shap_values_dict[pollutant], 
        X_shap, 
        max_display=20,
        show=False
    )
    plt.title(f'SHAP Feature Importance - {pollutant}', fontsize=16, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, f'shap_summary_{pollutant}.png'), 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"  ✓ Summary plot saved")

print("\n✓ Global SHAP importance plots created")

## 5.5 Mean Absolute SHAP Values (Feature Ranking)

In [None]:
# Calculate mean absolute SHAP values
shap_importance = {}

for pollutant in TARGET_POLLUTANTS:
    if pollutant not in shap_values_dict:
        continue
    
    # Mean absolute SHAP value for each feature
    mean_abs_shap = np.abs(shap_values_dict[pollutant]).mean(axis=0)
    
    shap_df = pd.DataFrame({
        'feature': feature_columns,
        'mean_abs_shap': mean_abs_shap
    }).sort_values('mean_abs_shap', ascending=False)
    
    shap_importance[pollutant] = shap_df
    
    # Save to CSV
    shap_df.to_csv(os.path.join(OUTPUT_PATH, f'shap_importance_{pollutant}.csv'), index=False)
    
    print(f"\nTop 15 features for {pollutant} (by SHAP importance):")
    print(shap_df.head(15).to_string(index=False))

print("\n✓ SHAP importance rankings saved")

## 5.6 SHAP Bar Plots (Top Features)

In [None]:
# Bar plots for top features
for pollutant in TARGET_POLLUTANTS:
    if pollutant not in shap_values_dict:
        continue
    
    plt.figure(figsize=(10, 8))
    shap.summary_plot(
        shap_values_dict[pollutant], 
        X_shap, 
        plot_type="bar",
        max_display=20,
        show=False
    )
    plt.title(f'Top 20 Features (Mean |SHAP|) - {pollutant}', fontsize=14, fontweight='bold', pad=15)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, f'shap_bar_{pollutant}.png'), 
                dpi=300, bbox_inches='tight')
    plt.show()

print("✓ SHAP bar plots created")

## 5.7 Neighbor Influence Quantification

In [None]:
# Extract neighbor feature importance
neighbor_influence = {}

for pollutant in TARGET_POLLUTANTS:
    if pollutant not in shap_importance:
        continue
    
    shap_df = shap_importance[pollutant]
    
    # Filter neighbor features
    neighbor_features = shap_df[shap_df['feature'].str.contains('neighbor')].copy()
    
    # Self features (lag, rolling)
    self_features = shap_df[
        (shap_df['feature'].str.contains(pollutant)) & 
        (~shap_df['feature'].str.contains('neighbor'))
    ].copy()
    
    # Calculate total influence
    neighbor_total = neighbor_features['mean_abs_shap'].sum()
    self_total = self_features['mean_abs_shap'].sum()
    
    neighbor_influence[pollutant] = {
        'neighbor_importance': neighbor_total,
        'self_importance': self_total,
        'neighbor_percentage': (neighbor_total / (neighbor_total + self_total)) * 100,
        'self_percentage': (self_total / (neighbor_total + self_total)) * 100
    }
    
    print(f"\n{pollutant} Influence Analysis:")
    print(f"  Neighbor influence: {neighbor_total:.4f} ({neighbor_influence[pollutant]['neighbor_percentage']:.2f}%)")
    print(f"  Self influence: {self_total:.4f} ({neighbor_influence[pollutant]['self_percentage']:.2f}%)")
    print(f"\n  Top 5 neighbor features:")
    print(neighbor_features.head(5).to_string(index=False))

# Save influence summary
influence_df = pd.DataFrame(neighbor_influence).T
influence_df.to_csv(os.path.join(OUTPUT_PATH, 'neighbor_vs_self_influence.csv'))

print("\n✓ Neighbor influence quantified")

## 5.8 Visualize Self vs Neighbor Influence

In [None]:
# Visualization of self vs neighbor influence
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Bar chart
pollutants = list(neighbor_influence.keys())
neighbor_pcts = [neighbor_influence[p]['neighbor_percentage'] for p in pollutants]
self_pcts = [neighbor_influence[p]['self_percentage'] for p in pollutants]

x = np.arange(len(pollutants))
width = 0.35

ax1.bar(x - width/2, neighbor_pcts, width, label='Neighbor Influence', color='#FF6B6B', alpha=0.8)
ax1.bar(x + width/2, self_pcts, width, label='Self Influence', color='#4ECDC4', alpha=0.8)
ax1.set_xlabel('Pollutant', fontsize=12)
ax1.set_ylabel('Influence (%)', fontsize=12)
ax1.set_title('Self vs Neighbor Influence (by SHAP)', fontsize=14, fontweight='bold')
ax1.set_xticks(x)
ax1.set_xticklabels(pollutants)
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# Pie chart for average
avg_neighbor = np.mean(neighbor_pcts)
avg_self = np.mean(self_pcts)

ax2.pie([avg_neighbor, avg_self], 
        labels=['Neighbor Influence', 'Self Influence'],
        colors=['#FF6B6B', '#4ECDC4'],
        autopct='%1.1f%%',
        startangle=90,
        textprops={'fontsize': 12})
ax2.set_title('Average Influence Distribution', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'self_vs_neighbor_influence.png'), dpi=300, bbox_inches='tight')
plt.show()

print("✓ Self vs neighbor influence visualized")

## 5.9 SHAP Dependence Plots (Key Features)

In [None]:
# Dependence plots for top 3 features of each pollutant
for pollutant in TARGET_POLLUTANTS:
    if pollutant not in shap_values_dict or pollutant not in shap_importance:
        continue
    
    top_features = shap_importance[pollutant].head(3)['feature'].tolist()
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    for idx, feature in enumerate(top_features):
        feature_idx = feature_columns.index(feature)
        
        shap.dependence_plot(
            feature_idx,
            shap_values_dict[pollutant],
            X_shap,
            ax=axes[idx],
            show=False
        )
        axes[idx].set_title(f'{feature}', fontsize=11, fontweight='bold')
    
    plt.suptitle(f'SHAP Dependence Plots - {pollutant} (Top 3 Features)', 
                fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_PATH, f'shap_dependence_{pollutant}.png'), 
                dpi=300, bbox_inches='tight')
    plt.show()

print("✓ SHAP dependence plots created")

## 5.10 Country-Specific SHAP Analysis

In [None]:
# Analyze SHAP values by country
country_shap_analysis = {}

for pollutant in TARGET_POLLUTANTS:
    if pollutant not in shap_values_dict:
        continue
    
    # Create DataFrame with SHAP values and country info
    shap_df = pd.DataFrame(
        shap_values_dict[pollutant],
        columns=feature_columns
    )
    shap_df['country'] = test_sample_data['country'].values
    
    # Calculate mean SHAP by country for neighbor features
    neighbor_cols = [col for col in feature_columns if 'neighbor' in col]
    
    country_neighbor_shap = shap_df.groupby('country')[neighbor_cols].mean().mean(axis=1)
    country_neighbor_shap = country_neighbor_shap.sort_values(ascending=False)
    
    country_shap_analysis[pollutant] = country_neighbor_shap
    
    print(f"\n{pollutant} - Countries most influenced by neighbors (mean SHAP):")
    print(country_neighbor_shap.head(10))

# Visualize country-level neighbor influence
fig, axes = plt.subplots(1, 3, figsize=(20, 6))
colors_pol = ['#FF6B6B', '#4ECDC4', '#45B7D1']

for idx, pollutant in enumerate(TARGET_POLLUTANTS):
    if pollutant in country_shap_analysis:
        top_countries = country_shap_analysis[pollutant].head(10)
        
        axes[idx].barh(range(len(top_countries)), top_countries.values, 
                      color=colors_pol[idx], alpha=0.8)
        axes[idx].set_yticks(range(len(top_countries)))
        axes[idx].set_yticklabels(top_countries.index)
        axes[idx].invert_yaxis()
        axes[idx].set_title(f'{pollutant}\nNeighbor Influence by Country', 
                          fontsize=12, fontweight='bold')
        axes[idx].set_xlabel('Mean SHAP Value (Neighbor Features)', fontsize=11)
        axes[idx].grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'country_neighbor_influence.png'), 
            dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Country-specific analysis completed")

## 5.11 Feature Category Importance

In [None]:
# Aggregate SHAP importance by feature category
def categorize_feature(feature_name):
    """Categorize feature based on name"""
    if 'neighbor' in feature_name:
        return 'Neighbor'
    elif any(x in feature_name for x in ['lag_', 'rolling_', 'ewm_', 'change_']):
        return 'Temporal_Lag'
    elif any(x in feature_name for x in ['year', 'month', 'day', 'week', 'season', 'sin', 'cos']):
        return 'Temporal_Calendar'
    elif any(x in feature_name for x in ['latitude', 'longitude']):
        return 'Spatial'
    elif 'country_' in feature_name:
        return 'Country_Identity'
    elif any(x in feature_name for x in ['ratio', 'product', 'sum', 'AQI']):
        return 'Interaction'
    else:
        return 'Other'

category_importance = {}

for pollutant in TARGET_POLLUTANTS:
    if pollutant not in shap_importance:
        continue
    
    shap_df = shap_importance[pollutant].copy()
    shap_df['category'] = shap_df['feature'].apply(categorize_feature)
    
    category_totals = shap_df.groupby('category')['mean_abs_shap'].sum().sort_values(ascending=False)
    category_importance[pollutant] = category_totals
    
    print(f"\n{pollutant} - Feature Category Importance:")
    for cat, importance in category_totals.items():
        pct = (importance / category_totals.sum()) * 100
        print(f"  {cat:20s}: {importance:.4f} ({pct:.2f}%)")

# Visualize category importance
fig, axes = plt.subplots(1, 3, figsize=(20, 6))

for idx, pollutant in enumerate(TARGET_POLLUTANTS):
    if pollutant in category_importance:
        cat_data = category_importance[pollutant]
        axes[idx].pie(cat_data.values, labels=cat_data.index, autopct='%1.1f%%',
                     startangle=90, textprops={'fontsize': 10})
        axes[idx].set_title(f'{pollutant}\nFeature Category Importance', 
                          fontsize=12, fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'category_importance.png'), 
            dpi=300, bbox_inches='tight')
plt.show()

print("\n✓ Feature category importance analyzed")

## 5.12 Export SHAP Results

In [None]:
# Save SHAP values for each pollutant
for pollutant in TARGET_POLLUTANTS:
    if pollutant not in shap_values_dict:
        continue
    
    # Create DataFrame with SHAP values
    shap_export = pd.DataFrame(
        shap_values_dict[pollutant],
        columns=feature_columns
    )
    
    # Add metadata
    shap_export['country'] = test_sample_data['country'].values
    shap_export['date'] = test_sample_data['date'].values
    shap_export['actual_value'] = test_sample_data[pollutant].values
    
    # Save
    shap_export.to_csv(os.path.join(OUTPUT_PATH, f'shap_values_{pollutant}.csv'), index=False)
    shap_export.to_pickle(os.path.join(OUTPUT_PATH, f'shap_values_{pollutant}.pkl'))
    
    print(f"✓ SHAP values exported for {pollutant}")

print("\n✓ All SHAP results exported")

## 5.13 Policy Insights Summary

In [None]:
print("="*80)
print("EXPLAINABLE AI - POLICY INSIGHTS SUMMARY")
print("="*80)

print("\n1. TRANSBOUNDARY INFLUENCE:")
for pollutant in TARGET_POLLUTANTS:
    if pollutant in neighbor_influence:
        neighbor_pct = neighbor_influence[pollutant]['neighbor_percentage']
        print(f"   {pollutant}: {neighbor_pct:.2f}% influenced by neighboring countries")

print("\n2. KEY DRIVER IDENTIFICATION:")
for pollutant in TARGET_POLLUTANTS:
    if pollutant in shap_importance:
        top_feature = shap_importance[pollutant].iloc[0]
        print(f"   {pollutant}: Most important feature - {top_feature['feature']}")

print("\n3. MOST AFFECTED COUNTRIES (by neighbor pollution):")
for pollutant in TARGET_POLLUTANTS:
    if pollutant in country_shap_analysis:
        most_affected = country_shap_analysis[pollutant].idxmax()
        print(f"   {pollutant}: {most_affected}")

print("\n4. FEATURE CATEGORY RANKINGS:")
for pollutant in TARGET_POLLUTANTS:
    if pollutant in category_importance:
        top_category = category_importance[pollutant].idxmax()
        print(f"   {pollutant}: {top_category} features dominate")

print("\n5. ACTIONABLE INSIGHTS:")
avg_neighbor_influence = np.mean([neighbor_influence[p]['neighbor_percentage'] 
                                   for p in TARGET_POLLUTANTS if p in neighbor_influence])

if avg_neighbor_influence > 30:
    print("   → Strong transboundary effects detected (>30%)")
    print("   → International cooperation is CRITICAL for pollution control")
    print("   → Unilateral policies may have limited effectiveness")
else:
    print("   → Moderate transboundary effects detected")
    print("   → Local policies remain primary driver")
    print("   → Regional cooperation recommended but not critical")

print("\n6. OUTPUT FILES CREATED:")
print("   - shap_summary_<pollutant>.png")
print("   - shap_importance_<pollutant>.csv")
print("   - neighbor_vs_self_influence.csv")
print("   - country_neighbor_influence.png")
print("   - shap_values_<pollutant>.pkl")

print("\n" + "="*80)
print("✓ EXPLAINABLE AI ANALYSIS COMPLETED")
print("="*80)
print("\nReady for spatial analysis and GIS visualization!")

## Summary

### Completed Tasks:
1. ✓ Applied SHAP to all best models
2. ✓ Extracted global feature importance
3. ✓ Quantified self vs neighbor influence
4. ✓ Analyzed country-specific transboundary effects
5. ✓ Categorized feature importance by type
6. ✓ Generated policy-relevant insights
7. ✓ Created comprehensive visualizations

### Key Findings:
- **Transboundary influence** quantified for each pollutant
- **Top drivers** identified through SHAP values
- **Country vulnerability** mapped by neighbor dependency
- **Policy implications** derived from model explanations

### Next Steps:
**Notebook 06: Spatial Analysis**
- Spatial autocorrelation (Moran's I)
- Hotspot analysis (Getis-Ord Gi*)
- Spatial regression models
- Network flow visualization preparation