In [None]:
# Step 4: Hybrid Explanation Generator - USING SHAP
# Key changes from original:
# 1. Replace DeepLIFT with SHAP KernelExplainer
# 2. Use background data for SHAP baseline
# 3. Extract class 1 (Important) attributions

import os
import warnings
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from pathlib import Path
import joblib
from typing import Dict, List, Tuple, Optional
import json
import pickle

warnings.filterwarnings('ignore')

print("="*70)
print("STEP 4: HYBRID EXPLANATION GENERATOR (USING SHAP)")
print("="*70)

# ==================== CONFIGURATION ====================

LSTM_MODEL_PATH = '../step1_lstm_xai/best_lstm.pt'
SCALER_PATH = '../step1_lstm_xai/scaler.joblib'
CAUSAL_GRAPH_PATH = '../step2_causal_discovery/causal_graph.gpickle'
DATA_PATH = '../UNSW_NB15_training-set.csv'

FEATURE_NAMES = [
    'dur', 'proto', 'service', 'state', 'spkts', 'dpkts',
    'sbytes', 'dbytes', 'rate', 'sttl', 'dttl', 'sload', 'dload',
    'sloss', 'dloss', 'sinpkt', 'dinpkt', 'sjit', 'djit',
    'swin', 'stcpb', 'dtcpb', 'dwin', 'tcprtt', 'synack', 'ackdat',
    'smean', 'dmean', 'trans_depth', 'response_body_len',
    'ct_srv_src', 'ct_state_ttl', 'ct_dst_ltm', 'ct_src_dport_ltm',
    'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'is_ftp_login',
    'ct_ftp_cmd', 'ct_flw_http_mthd', 'ct_src_ltm', 'ct_srv_dst',
    'is_sm_ips_ports'
]

CAUSAL_FEATURES = [
    'proto', 'sttl', 'state', 'dtcpb', 'is_sm_ips_ports',
    'dttl', 'stcpb', 'service', 'dwin', 'swin'
]

MISSING_VALUE_INDICATOR = -1.0

# ==================== LOAD MODELS AND DATA ====================
print("\nLoading models and data...")

# Load LSTM model architecture
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size=128, num_layers=2, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 2)
        )
    
    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        return self.fc(out)

# Load trained model
device = torch.device('cpu')
input_size = len(FEATURE_NAMES)
model = LSTMClassifier(input_size=input_size).to(device)

if Path(LSTM_MODEL_PATH).exists():
    model.load_state_dict(torch.load(LSTM_MODEL_PATH, map_location=device))
    model.eval()
    print(f"✓ Loaded LSTM model from {LSTM_MODEL_PATH}")
else:
    print(f"⚠ Warning: {LSTM_MODEL_PATH} not found. Using untrained model.")

# Load scaler
if Path(SCALER_PATH).exists():
    scaler = joblib.load(SCALER_PATH)
    print(f"✓ Loaded scaler from {SCALER_PATH}")
else:
    print(f"⚠ Warning: {SCALER_PATH} not found.")
    scaler = None

# Load causal graph
if Path(CAUSAL_GRAPH_PATH).exists():
    causal_graph = pickle.load(open(CAUSAL_GRAPH_PATH, 'rb'))
    print(f"✓ Loaded causal graph from {CAUSAL_GRAPH_PATH}")
else:
    print(f"⚠ Warning: {CAUSAL_GRAPH_PATH} not found.")
    causal_graph = nx.DiGraph()

# Load data
if Path(DATA_PATH).exists():
    df = pd.read_csv(DATA_PATH)
    print(f"✓ Loaded data from {DATA_PATH}: {df.shape}")
    
    # Encode categorical features
    from sklearn.preprocessing import LabelEncoder
    drop_cols = ['id', 'attack_cat']
    for col in drop_cols:
        if col in df.columns:
            df = df.drop(columns=[col])
    
    if 'label' in df.columns and 'Label' not in df.columns:
        df['Label'] = df['label']
    
    categorical_cols = ['proto', 'service', 'state']
    for col in categorical_cols:
        if col in df.columns and df[col].dtype == 'object':
            le = LabelEncoder()
            df[col] = df[col].fillna('unknown')
            df[col] = le.fit_transform(df[col].astype(str))
    
    print("✓ Categorical features encoded")
else:
    print(f"⚠ Warning: {DATA_PATH} not found.")
    df = None

# ==================== SHAP EXPLAINER SETUP ====================
print("\n" + "="*70)
print("XAI COMPONENT: SHAP Explainer")
print("="*70)

# Initialize SHAP explainer
shap_explainer = None
background_data = None

if df is not None and scaler is not None:
    try:
        import shap
        
        # Get background data (sample from training data)
        print("\nInitializing SHAP explainer...")
        available_features = [f for f in FEATURE_NAMES if f in df.columns]
        
        if len(available_features) == len(FEATURE_NAMES):
            # Sample 100 background samples
            background_indices = np.random.choice(len(df), size=min(100, len(df)), replace=False)
            background_data_raw = df.iloc[background_indices][FEATURE_NAMES].values
            
            # Scale background data
            background_data = scaler.transform(background_data_raw)
            print(f"  Background data shape: {background_data.shape}")
            
            # Define model prediction function for SHAP
            def model_predict_shap(x):
                """
                Prediction function for SHAP
                Input: numpy array (n_samples, n_features) - SCALED
                Output: numpy array (n_samples, n_classes) - probabilities
                """
                if x.ndim == 2:
                    # Add sequence dimension: (batch, features) -> (batch, 1, features)
                    x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(1).to(device)
                else:
                    x_tensor = torch.tensor(x, dtype=torch.float32).to(device)
                
                with torch.no_grad():
                    output = model(x_tensor)
                    probs = torch.softmax(output, dim=1).cpu().numpy()
                
                return probs
            
            # Initialize KernelExplainer
            shap_explainer = shap.KernelExplainer(
                model_predict_shap,
                background_data
            )
            print("✓ SHAP KernelExplainer initialized")
            
        else:
            print(f"⚠ Warning: Missing features. SHAP unavailable.")
            
    except ImportError:
        print("⚠ Warning: SHAP not installed. Install with: pip install shap")
    except Exception as e:
        print(f"⚠ Warning: SHAP initialization failed: {e}")

# ==================== XAI COMPONENT WITH SHAP ====================

def compute_shap_attribution(shap_explainer, alert_tensor_scaled, n_samples=100):
    """
    Compute SHAP attributions for a single alert
    
    Args:
        shap_explainer: SHAP KernelExplainer instance
        alert_tensor_scaled: numpy array (n_features,) - MUST BE SCALED
        n_samples: Number of samples for SHAP approximation
    
    Returns:
        numpy array of feature attributions for class 1 (Important)
    """
    if shap_explainer is None:
        print("    ⚠ SHAP explainer not available, using gradient fallback")
        return compute_gradient_attribution_fallback(alert_tensor_scaled)
    
    try:
        # SHAP expects 2D input: (1, n_features)
        alert_2d = alert_tensor_scaled.reshape(1, -1)
        
        # Compute SHAP values
        shap_values = shap_explainer.shap_values(alert_2d, nsamples=n_samples)
        
        # Debug: Check shape
        print(f"    [SHAP DEBUG] Raw SHAP output type: {type(shap_values)}")
        if isinstance(shap_values, list):
            print(f"    [SHAP DEBUG] List length: {len(shap_values)}")
            for i, sv in enumerate(shap_values):
                print(f"    [SHAP DEBUG] Class {i} shape: {sv.shape if hasattr(sv, 'shape') else type(sv)}")
        else:
            print(f"    [SHAP DEBUG] Single output shape: {shap_values.shape}")
        
        # Extract class 1 (Important) attributions
        if isinstance(shap_values, list):
            # Multi-class output: [class_0_shap, class_1_shap]
            attributions = shap_values[1]  # Get class 1
            
            # Handle different dimensionalities
            if attributions.ndim > 1:
                attributions = attributions.flatten()  # Flatten to 1D
        else:
            # Single output (binary classification)
            attributions = shap_values
            
            # Handle different dimensionalities
            if attributions.ndim > 1:
                attributions = attributions.flatten()  # Flatten to 1D
        
        print(f"    [SHAP DEBUG] Final attribution shape: {attributions.shape}")
        
        return attributions
        
    except Exception as e:
        print(f"    ⚠ SHAP failed: {e}. Using gradient fallback.")
        import traceback
        traceback.print_exc()
        return compute_gradient_attribution_fallback(alert_tensor_scaled)

def compute_gradient_attribution_fallback(alert_features_scaled):
    """
    Fallback: Simple gradient-based attribution
    """
    alert_tensor = torch.tensor(alert_features_scaled, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    alert_tensor.requires_grad = True
    
    output = model(alert_tensor)
    pred_class = output.argmax(dim=1).item()
    
    model.zero_grad()
    output[0, pred_class].backward()
    
    gradients = alert_tensor.grad.squeeze().detach().cpu().numpy()
    values = alert_tensor.squeeze().detach().cpu().numpy()
    
    attributions = gradients * values
    return attributions

def is_missing_value(value, threshold=MISSING_VALUE_INDICATOR):
    """Check if a feature value represents missing/NA data"""
    return abs(value - threshold) < 1e-6

def generate_xai_explanation(model, alert_features, feature_names, top_k=5, scaler=None, shap_explainer=None):
    """
    Generate XAI explanation using SHAP
    
    Args:
        model: LSTM model
        alert_features: numpy array of feature values (UNSCALED)
        feature_names: list of feature names
        top_k: number of top features
        scaler: MinMaxScaler
        shap_explainer: SHAP KernelExplainer instance
    
    Returns:
        Dictionary with XAI results
    """
    # Scale features
    if scaler is not None:
        alert_features_scaled = scaler.transform(alert_features.reshape(1, -1)).flatten()
    else:
        alert_features_scaled = alert_features
    
    # Get prediction
    alert_tensor = torch.tensor(alert_features_scaled, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        output = model(alert_tensor)
        probs = torch.softmax(output, dim=1)[0]
        pred_class = output.argmax(dim=1).item()
        confidence = probs[pred_class].item()
    
    print(f"  [DEBUG] Prediction: {pred_class}, Confidence: {confidence:.4f}")
    
    # Compute SHAP attributions
    attributions = compute_shap_attribution(shap_explainer, alert_features_scaled, n_samples=50)
    
    print(f"  [DEBUG] Attribution range: [{attributions.min():.4f}, {attributions.max():.4f}]")
    
    # Combine features with attributions
    feature_importance = []
    for name, attr, value in zip(feature_names, attributions, alert_features):
        feature_importance.append({
            'feature': name,
            'importance': float(attr),
            'value': float(value),
            'abs_importance': float(abs(attr)),
            'is_missing': is_missing_value(value)
        })
    
    # Sort by absolute importance
    feature_importance.sort(key=lambda x: x['abs_importance'], reverse=True)
    
    # Filter out missing values
    feature_importance_present = [f for f in feature_importance if not f['is_missing']]
    
    missing_count = len([f for f in feature_importance if f['is_missing']])
    present_count = len(feature_importance_present)
    
    print(f"  [DEBUG] Features: {present_count} present, {missing_count} missing")
    
    top_features = feature_importance_present[:top_k] if len(feature_importance_present) >= top_k else feature_importance[:top_k]
    
    return {
        'prediction': 'Important' if pred_class == 1 else 'Irrelevant',
        'confidence': confidence,
        'pred_class': pred_class,
        'top_features': top_features,
        'all_features': feature_importance,
        'num_missing_features': missing_count,
        'num_present_features': present_count
    }

# ==================== CAUSAL COMPONENT ====================
# [Keep all causal analysis functions from original code]
print("\n" + "="*70)
print("CAUSAL COMPONENT: Root Cause Analysis")
print("="*70)

def find_root_causes(graph, target_feature):
    """Find all root causes of target"""
    if target_feature not in graph:
        return []
    ancestors = nx.ancestors(graph, target_feature)
    return [node for node in ancestors if graph.in_degree(node) == 0]

def find_causal_path(graph, source, target):
    """Find shortest causal path"""
    if source not in graph or target not in graph:
        return None
    try:
        return nx.shortest_path(graph, source, target)
    except nx.NetworkXNoPath:
        return None

def get_direct_causes(graph, feature):
    """Get direct causes (parents)"""
    if feature not in graph:
        return []
    return list(graph.predecessors(feature))

def analyze_causal_chain(graph, target_feature, alert_data, all_feature_names):
    """Analyze causal chains leading to target"""
    if target_feature not in graph:
        return {
            'target': target_feature,
            'in_graph': False,
            'root_causes': [],
            'causal_paths': [],
            'direct_causes': [],
            'reason': 'Feature not in causal graph'
        }
    
    root_causes = find_root_causes(graph, target_feature)
    
    causal_paths = []
    for root in root_causes:
        path = find_causal_path(graph, root, target_feature)
        if path:
            path_with_values = [
                {'feature': f, 'value': alert_data.get(f, 'N/A')}
                for f in path
            ]
            causal_paths.append({
                'root': root,
                'path': path,
                'path_with_values': path_with_values,
                'length': len(path)
            })
    
    direct_causes = get_direct_causes(graph, target_feature)
    direct_causes_with_values = [
        {'feature': c, 'value': alert_data.get(c, 'N/A')}
        for c in direct_causes
    ]
    
    return {
        'target': target_feature,
        'in_graph': True,
        'root_causes': root_causes,
        'causal_paths': causal_paths,
        'direct_causes': direct_causes_with_values,
        'num_paths': len(causal_paths)
    }

# ==================== HYBRID EXPLAINER ====================
print("\n" + "="*70)
print("HYBRID EXPLAINER: Combining SHAP + Causal")
print("="*70)

class HybridExplainer:
    """Combines SHAP and Causal Analysis"""
    
    def __init__(self, model, causal_graph, feature_names, scaler=None, shap_explainer=None):
        self.model = model
        self.graph = causal_graph
        self.feature_names = feature_names
        self.scaler = scaler
        self.shap_explainer = shap_explainer
    
    def explain(self, alert_data, alert_id=None):
        """Generate hybrid explanation"""
        if isinstance(alert_data, dict):
            alert_features = np.array([alert_data.get(f, 0) for f in self.feature_names])
            alert_dict = alert_data
        else:
            alert_features = alert_data
            alert_dict = {f: v for f, v in zip(self.feature_names, alert_features)}
        
        # Get XAI explanation using SHAP
        xai_results = generate_xai_explanation(
            self.model,
            alert_features,
            self.feature_names,
            top_k=5,
            scaler=self.scaler,
            shap_explainer=self.shap_explainer
        )
        
        # Analyze causal chains
        causal_analyses = []
        for feat_info in xai_results['top_features']:
            feature_name = feat_info['feature']
            if feature_name in self.graph:
                causal_analysis = analyze_causal_chain(
                    self.graph, feature_name, alert_dict, self.feature_names
                )
            else:
                causal_analysis = {
                    'target': feature_name,
                    'in_graph': False,
                    'root_causes': [],
                    'causal_paths': [],
                    'direct_causes': [],
                    'reason': 'Feature not in causal graph'
                }
            causal_analyses.append(causal_analysis)
        
        # Label causal analysis
        label_causal = None
        if 'Label' in self.graph:
            label_causal = analyze_causal_chain(
                self.graph, 'Label', alert_dict, self.feature_names
            )
        
        # Generate recommendations
        recommendations = self._generate_recommendations(
            xai_results, causal_analyses, alert_dict
        )
        
        # Keep the rest of HybridExplainer and HybridExplanation classes
        # from original code...
        
        return {
            'alert_id': alert_id,
            'xai': xai_results,
            'causal': causal_analyses,
            'label_causal': label_causal,
            'recommendations': recommendations
        }
    
    def _generate_recommendations(self, xai_results, causal_analyses, alert_data):
        """Generate recommendations (keep from original)"""
        # [Keep full _generate_recommendations method from original code]
        pass

# ==================== DEMO ====================
print("\n" + "="*70)
print("GENERATING DEMO EXPLANATIONS WITH SHAP")
print("="*70)

explainer = HybridExplainer(
    model=model,
    causal_graph=causal_graph,
    feature_names=FEATURE_NAMES,
    scaler=scaler,
    shap_explainer=shap_explainer
)

if df is not None:
    print("\nGenerating explanations for sample alerts...")
    sample_indices = [0, 100, 500]
    
    for idx in sample_indices[:1]:  # Test with one sample first
        print(f"\n{'='*70}")
        print(f"EXAMPLE ALERT #{idx}")
        print(f"{'='*70}")
        
        alert_row = df.iloc[idx]
        alert_features = alert_row[FEATURE_NAMES].values
        
        explanation = explainer.explain(alert_features, alert_id=idx)
        
        print("\nPrediction:", explanation['xai']['prediction'])
        print("Confidence:", f"{explanation['xai']['confidence']:.1%}")
        print("\nTop Features (SHAP):")
        for i, feat in enumerate(explanation['xai']['top_features'], 1):
            print(f"  {i}. {feat['feature']}: {feat['importance']:.4f}")

print("\n" + "="*70)
print("STEP 4 COMPLETE - USING SHAP")
print("="*70)