# üìÖ Day 7: SHAP Explainability
## TreeExplainer for all 3 Classification Levels

---

**Steps:**
1. Load best models
2. SHAP analysis for each level (Binary, 8-Class, 34-Class)
3. Summary plots, bar plots, dependence plots
4. Security validation: map features to cybersecurity meaning

---

In [None]:
import os
os.add_dll_directory(r'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.1\bin\x64')

import numpy as np
import pandas as pd
import xgboost as xgb
import shap
import matplotlib.pyplot as plt
import seaborn as sns
import time
import json
import gc
from datetime import datetime

plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['font.size'] = 12

os.makedirs('figures', exist_ok=True)

# Enable SHAP JS visualization (optional)
# shap.initjs()

print(f"‚úÖ Ready | SHAP version: {shap.__version__}")
print(f"üìÖ {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

In [None]:
# Load data
print("üì• Loading data and models...")

X_test = np.load('processed/X_test.npy')
y_binary_test = np.load('processed/y_binary_test.npy')
y_family_test = np.load('processed/y_family_test.npy')
y_subtype_test = np.load('processed/y_subtype_test.npy')

with open('processed/preprocessing_metadata.json', 'r') as f:
    meta = json.load(f)
feature_names = meta['feature_names']
family_classes = meta['family_classes']
subtype_classes = meta['subtype_classes']

# Use a sample for SHAP (full dataset is too large)
SHAP_SAMPLE_SIZE = 5000
np.random.seed(42)
sample_idx = np.random.choice(len(X_test), size=SHAP_SAMPLE_SIZE, replace=False)
X_sample = X_test[sample_idx]
X_sample_df = pd.DataFrame(X_sample, columns=feature_names)

print(f"   SHAP sample: {SHAP_SAMPLE_SIZE} instances")

# Load models
bst_binary = xgb.Booster()
bst_binary.load_model('models/binary_xgb_gpu.json')

bst_family = xgb.Booster()
bst_family.load_model('models/family_best_xgb_gpu.json')

bst_subtype = xgb.Booster()
bst_subtype.load_model('models/subtype_xgb_gpu.json')

print("‚úÖ All loaded")

## üîç Level 1: Binary Classification ‚Äî SHAP Analysis

In [None]:
print("üîç Computing SHAP values ‚Äî Binary Classification...")
t0 = time.time()

explainer_binary = shap.TreeExplainer(bst_binary)
shap_values_binary = explainer_binary.shap_values(X_sample_df)

print(f"   ‚úÖ Done in {time.time()-t0:.1f}s")
print(f"   Shape: {shap_values_binary.shape}")

In [None]:
# Summary plot (beeswarm)
plt.figure(figsize=(14, 8))
shap.summary_plot(shap_values_binary, X_sample_df, max_display=20, show=False)
plt.title('SHAP ‚Äî Binary Classification (Beeswarm)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('figures/shap_binary_beeswarm.png', dpi=150, bbox_inches='tight')
plt.show()
print("üíæ Saved")

In [None]:
# Bar plot (mean |SHAP|)
plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values_binary, X_sample_df, plot_type='bar', max_display=20, show=False)
plt.title('SHAP ‚Äî Binary Classification (Mean |SHAP|)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('figures/shap_binary_bar.png', dpi=150, bbox_inches='tight')
plt.show()
print("üíæ Saved")

In [None]:
# Dependence plots for top 3 features
mean_abs_shap = np.abs(shap_values_binary).mean(axis=0)
top3_idx = np.argsort(mean_abs_shap)[-3:][::-1]
top3_features = [feature_names[i] for i in top3_idx]

fig, axes = plt.subplots(1, 3, figsize=(20, 6))
for ax, feat_name in zip(axes, top3_features):
    shap.dependence_plot(feat_name, shap_values_binary, X_sample_df, ax=ax, show=False)
    ax.set_title(f'SHAP Dependence: {feat_name}', fontsize=12, fontweight='bold')

plt.suptitle('SHAP Dependence ‚Äî Binary Classification (Top 3)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('figures/shap_binary_dependence.png', dpi=150, bbox_inches='tight')
plt.show()
print("üíæ Saved")

## üîç Level 2: 8-Class Family ‚Äî SHAP Analysis

In [None]:
print("üîç Computing SHAP values ‚Äî 8-Class Family...")
t0 = time.time()

explainer_family = shap.TreeExplainer(bst_family)
shap_values_family = explainer_family.shap_values(X_sample_df)

print(f"   ‚úÖ Done in {time.time()-t0:.1f}s")
# shap_values_family is a list of arrays (one per class) for multi-class
if isinstance(shap_values_family, list):
    print(f"   Shape: {len(shap_values_family)} classes √ó {shap_values_family[0].shape}")
else:
    print(f"   Shape: {shap_values_family.shape}")

In [None]:
# Bar plot ‚Äî global feature importance across all classes
plt.figure(figsize=(12, 8))
if isinstance(shap_values_family, list):
    # Average absolute SHAP across all classes
    mean_shap = np.mean([np.abs(sv) for sv in shap_values_family], axis=0)
    shap.summary_plot(mean_shap, X_sample_df, plot_type='bar', max_display=20, show=False)
else:
    shap.summary_plot(shap_values_family, X_sample_df, plot_type='bar', max_display=20, show=False)
plt.title('SHAP ‚Äî 8-Class Family (Mean |SHAP| across classes)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('figures/shap_family_bar.png', dpi=150, bbox_inches='tight')
plt.show()
print("üíæ Saved")

In [None]:
# Per-class SHAP summary for selected classes
interesting_classes = ['DDoS', 'DoS', 'Mirai', 'Web', 'BruteForce']
interesting_idx = [family_classes.index(c) for c in interesting_classes if c in family_classes]

if isinstance(shap_values_family, list) and len(interesting_idx) > 0:
    n_plots = min(len(interesting_idx), 4)
    fig, axes = plt.subplots(1, n_plots, figsize=(6*n_plots, 8))
    if n_plots == 1:
        axes = [axes]
    
    for ax, cls_idx in zip(axes, interesting_idx[:n_plots]):
        plt.sca(ax)
        shap.summary_plot(shap_values_family[cls_idx], X_sample_df, plot_type='bar',
                         max_display=10, show=False)
        ax.set_title(f'{family_classes[cls_idx]}', fontsize=12, fontweight='bold')
    
    plt.suptitle('SHAP ‚Äî Per-Class Feature Importance (8-Class)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('figures/shap_family_per_class.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("üíæ Saved")

## üîç Level 3: 34-Class SubType ‚Äî SHAP Analysis

In [None]:
print("üîç Computing SHAP values ‚Äî 34-Class SubType...")
print("   ‚ö†Ô∏è This may take longer due to 34 classes")
t0 = time.time()

# Use smaller sample for 34-class to reduce computation time
SHAP_SAMPLE_34 = min(2000, SHAP_SAMPLE_SIZE)
X_sample_34 = X_sample_df.iloc[:SHAP_SAMPLE_34]

explainer_subtype = shap.TreeExplainer(bst_subtype)
shap_values_subtype = explainer_subtype.shap_values(X_sample_34)

print(f"   ‚úÖ Done in {time.time()-t0:.1f}s")

In [None]:
# Global bar plot
plt.figure(figsize=(12, 8))
if isinstance(shap_values_subtype, list):
    mean_shap_34 = np.mean([np.abs(sv) for sv in shap_values_subtype], axis=0)
    shap.summary_plot(mean_shap_34, X_sample_34, plot_type='bar', max_display=20, show=False)
else:
    shap.summary_plot(shap_values_subtype, X_sample_34, plot_type='bar', max_display=20, show=False)
plt.title('SHAP ‚Äî 34-Class SubType (Mean |SHAP|)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('figures/shap_subtype_bar.png', dpi=150, bbox_inches='tight')
plt.show()
print("üíæ Saved")

## üõ°Ô∏è Security Validation

In [None]:
# Security meaning table
security_map = {
    'syn_flag_number': 'SYN flood indicator ‚Äî high in DDoS-SYN, DoS-SYN',
    'syn_count': 'SYN flag count ‚Äî SYN flood attacks',
    'ack_count': 'ACK flag abuse ‚Äî DDoS-PSHACK, DDoS-ACK',
    'ack_flag_number': 'ACK flag presence ‚Äî ACK-based attacks',
    'flow_duration': 'Flow timing ‚Äî short for Recon, long for DoS',
    'Rate': 'Packet rate ‚Äî high = flood attack',
    'Srate': 'Source rate ‚Äî high packet rate = DDoS',
    'Drate': 'Destination rate ‚Äî response patterns',
    'Header_Length': 'Fragmentation attacks ‚Äî DDoS-ACK_Frag, ICMP_Frag',
    'Duration': 'Connection duration ‚Äî attack persistence',
    'HTTP': 'HTTP protocol ‚Äî Web attacks, DoS-HTTP',
    'HTTPS': 'HTTPS protocol ‚Äî encrypted web attacks',
    'ARP': 'ARP protocol ‚Äî MITM-ArpSpoofing',
    'DNS': 'DNS protocol ‚Äî DNS_Spoofing',
    'TCP': 'TCP protocol ‚Äî TCP-based floods',
    'UDP': 'UDP protocol ‚Äî UDP floods, Mirai',
    'ICMP': 'ICMP protocol ‚Äî ICMP floods, ping sweeps',
    'rst_count': 'RST flags ‚Äî connection reset attacks',
    'fin_count': 'FIN flags ‚Äî connection teardown patterns',
    'urg_count': 'URG flags ‚Äî urgent data exploitation',
    'Tot sum': 'Total packet sizes ‚Äî traffic volume',
    'Tot size': 'Total traffic size ‚Äî bandwidth consumption',
    'IAT': 'Inter-arrival time ‚Äî traffic timing patterns',
    'Variance': 'Traffic variance ‚Äî regularity of attack patterns',
    'Covariance': 'Feature covariance ‚Äî complex attack patterns'
}

# Get SHAP top features
if isinstance(shap_values_binary, np.ndarray):
    mean_abs = np.abs(shap_values_binary).mean(axis=0)
else:
    mean_abs = np.abs(np.array(shap_values_binary)).mean(axis=(0, 1)) if len(np.array(shap_values_binary).shape) > 2 else np.abs(shap_values_binary).mean(axis=0)

top_features_idx = np.argsort(mean_abs)[-15:][::-1]

print("="*80)
print("üõ°Ô∏è SECURITY VALIDATION ‚Äî SHAP Top Features")
print("="*80)
print(f"{'Rank':<6}{'Feature':<25}{'SHAP Score':<15}{'Security Meaning'}")
print("-"*80)
for rank, idx in enumerate(top_features_idx, 1):
    fname = feature_names[idx]
    shap_score = mean_abs[idx]
    meaning = security_map.get(fname, 'Network traffic characteristic')
    print(f"{rank:<6}{fname:<25}{shap_score:<15.4f}{meaning}")

In [None]:
# Save SHAP results
shap_results = {
    'timestamp': datetime.now().isoformat(),
    'sample_size': SHAP_SAMPLE_SIZE,
    'binary_top_features': [
        {'rank': r+1, 'feature': feature_names[i], 'mean_abs_shap': float(mean_abs[i]),
         'security_meaning': security_map.get(feature_names[i], 'Network characteristic')}
        for r, i in enumerate(top_features_idx)
    ]
}
with open('models/shap_results.json', 'w') as f:
    json.dump(shap_results, f, indent=2)

print("\nüèÜ" * 20)
print(f"  ‚úÖ SHAP EXPLAINABILITY COMPLETE!")
print(f"  üìä 3 levels analyzed | Security validation done")
print("üèÜ" * 20)