In [None]:
"""
Step 4 Demo: Updated Hybrid Explanation System
================================================
Compatible with UNSW-NB15 dataset and latest Step 1-3 implementations
"""

import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import pickle
import joblib
import warnings
from sklearn.preprocessing import LabelEncoder
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("="*70)
print("Step 4 Demo: Hybrid Explanations - UNSW-NB15 Dataset")
print("="*70)

# ============================================================================
# SECTION 1: Configuration (Updated for UNSW-NB15)
# ============================================================================

print("\n[1/8] Loading Configuration...")

# File paths
LSTM_MODEL = '../step1_lstm_xai/best_lstm.pt'
SCALER = '../step1_lstm_xai/scaler.joblib'
CAUSAL_GRAPH = '../step2_causal_discovery/causal_graph.gpickle'
DATA_FILE = '../UNSW_NB15_training-set.csv'

# UNSW-NB15 Feature names (42 features after dropping id, attack_cat)
FEATURE_NAMES = [
    'dur', 'proto', 'service', 'state', 'spkts', 'dpkts',
    'sbytes', 'dbytes', 'rate', 'sttl', 'dttl', 'sload', 'dload',
    'sloss', 'dloss', 'sinpkt', 'dinpkt', 'sjit', 'djit',
    'swin', 'stcpb', 'dtcpb', 'dwin', 'tcprtt', 'synack', 'ackdat',
    'smean', 'dmean', 'trans_depth', 'response_body_len',
    'ct_srv_src', 'ct_state_ttl', 'ct_dst_ltm', 'ct_src_dport_ltm',
    'ct_dst_sport_ltm', 'ct_dst_src_ltm', 'is_ftp_login',
    'ct_ftp_cmd', 'ct_flw_http_mthd', 'ct_src_ltm', 'ct_srv_dst',
    'is_sm_ips_ports'
]

# SOC Analyst features used in causal discovery (from Step 2)
CAUSAL_FEATURES = [
    'proto', 'sttl', 'state', 'dtcpb', 'is_sm_ips_ports',
    'dttl', 'stcpb', 'service', 'dwin', 'swin'
]

# Check files exist
print("\nVerifying required files:")
files_status = {}
for name, file in [
    ('LSTM Model', LSTM_MODEL),
    ('Scaler', SCALER),
    ('Causal Graph', CAUSAL_GRAPH),
    ('Data', DATA_FILE)
]:
    exists = Path(file).exists()
    status = "‚úì" if exists else "‚úó"
    print(f"  {status} {name}: {file}")
    files_status[name] = exists

if not all(files_status.values()):
    print("\n‚ö†Ô∏è  WARNING: Some files are missing!")
    print("Please ensure Steps 1-3 have been completed.")

# ============================================================================
# SECTION 2: Load LSTM Model
# ============================================================================

print("\n[2/8] Loading LSTM Model...")

# LSTM Architecture (must match Step 1)
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size=128, num_layers=2, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 2)
        )
    
    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        return self.fc(out)

device = torch.device('cpu')
model = LSTMClassifier(input_size=len(FEATURE_NAMES)).to(device)

if files_status.get('LSTM Model'):
    model.load_state_dict(torch.load(LSTM_MODEL, map_location=device))
    model.eval()
    print("‚úì Loaded LSTM model")
else:
    print("‚ö†Ô∏è  Using untrained model")

# Load scaler
scaler = None
if files_status.get('Scaler'):
    scaler = joblib.load(SCALER)
    print(f"‚úì Loaded scaler ({scaler.n_features_in_} features)")

# ============================================================================
# SECTION 3: Load Causal Graph
# ============================================================================

print("\n[3/8] Loading Causal Graph...")

causal_graph = nx.DiGraph()
if files_status.get('Causal Graph'):
    causal_graph = pickle.load(open(CAUSAL_GRAPH, 'rb'))
    print(f"‚úì Loaded causal graph:")
    print(f"  Nodes: {causal_graph.number_of_nodes()}")
    print(f"  Edges: {causal_graph.number_of_edges()}")
    
    # Show root causes
    root_causes = [n for n in causal_graph.nodes() if causal_graph.in_degree(n) == 0]
    if root_causes:
        print(f"  Root causes: {', '.join(root_causes[:5])}")
else:
    print("‚ö†Ô∏è  Causal graph not found")

# ============================================================================
# SECTION 4: Load and Preprocess Data
# ============================================================================

print("\n[4/8] Loading Dataset...")

df = None
if files_status.get('Data'):
    df = pd.read_csv(DATA_FILE)
    print(f"‚úì Loaded {len(df):,} records")
    
    # Drop non-feature columns
    drop_cols = ['id', 'attack_cat']
    for col in drop_cols:
        if col in df.columns:
            df = df.drop(columns=[col])
    
    # Create Label column if needed
    if 'label' in df.columns and 'Label' not in df.columns:
        df['Label'] = df['label']
    
    # Encode categorical features
    print("\nEncoding categorical features...")
    categorical_cols = ['proto', 'service', 'state']
    
    label_encoders = {}
    for col in categorical_cols:
        if col in df.columns and df[col].dtype == 'object':
            le = LabelEncoder()
            df[col] = df[col].fillna('unknown')
            df[col] = le.fit_transform(df[col].astype(str))
            label_encoders[col] = le
            print(f"  ‚úì {col}: {df[col].nunique()} categories")
    
    # Check label distribution
    if 'Label' in df.columns:
        print(f"\nLabel Distribution:")
        label_counts = df['Label'].value_counts()
        for label, count in label_counts.items():
            pct = count / len(df) * 100
            label_name = "Attack" if label == 1 else "Normal"
            print(f"  {label_name} ({label}): {count:,} ({pct:.1f}%)")
    
    # Verify all features present
    available_features = [f for f in FEATURE_NAMES if f in df.columns]
    print(f"\n‚úì Available features: {len(available_features)}/{len(FEATURE_NAMES)}")

# ============================================================================
# SECTION 5: Example Alert Analysis
# ============================================================================

print("\n[5/8] Analyzing Example Alerts...")

if df is not None and len(df) > 0:
    # Select diverse samples
    print("\nSelecting diverse alert samples...")
    
    attack_samples = df[df['Label'] == 1].sample(n=min(3, len(df[df['Label'] == 1])), random_state=42)
    normal_samples = df[df['Label'] == 0].sample(n=min(2, len(df[df['Label'] == 0])), random_state=42)
    
    print(f"  ‚úì Selected {len(attack_samples)} attack + {len(normal_samples)} normal alerts")
    
    # Example 1: First Attack Alert
    print("\n" + "="*70)
    print("EXAMPLE 1: Attack Alert Analysis")
    print("="*70)
    
    alert = attack_samples.iloc[0]
    print(f"\nAlert ID: {alert.name}")
    print(f"True Label: Attack")
    
    # Show key features
    key_features = ['proto', 'sttl', 'state', 'is_sm_ips_ports', 'spkts', 'dpkts']
    print("\nKey Feature Values:")
    for feat in key_features:
        if feat in alert.index:
            value = alert[feat]
            print(f"  {feat:20s}: {value}")
    
    # Predict using LSTM
    if scaler is not None:
        alert_features = alert[FEATURE_NAMES].values.reshape(1, -1)
        alert_scaled = scaler.transform(alert_features)
        alert_tensor = torch.tensor(alert_scaled, dtype=torch.float32).unsqueeze(1).to(device)
        
        with torch.no_grad():
            output = model(alert_tensor)
            probs = torch.softmax(output, dim=1)[0]
            pred_class = output.argmax(dim=1).item()
            confidence = probs[pred_class].item()
        
        pred_label = "Attack" if pred_class == 1 else "Normal"
        print(f"\nLSTM Prediction:")
        print(f"  Predicted: {pred_label} (confidence: {confidence:.2%})")
        print(f"  Normal prob: {probs[0]:.2%}")
        print(f"  Attack prob: {probs[1]:.2%}")

# ============================================================================
# SECTION 6: Feature Importance (XAI)</
# ============================================================================

print("\n[6/8] Computing Feature Importance...")

def compute_feature_importance_simple(model, alert_tensor, feature_names):
    """Simple gradient-based attribution"""
    alert_tensor = alert_tensor.clone().requires_grad_(True)
    
    output = model(alert_tensor)
    pred_class = output.argmax(dim=1).item()
    
    model.zero_grad()
    output[0, pred_class].backward()
    
    gradients = alert_tensor.grad.squeeze().detach().cpu().numpy()
    values = alert_tensor.squeeze().detach().cpu().numpy()
    
    # Attribution = gradient * input
    attributions = gradients * values
    
    # Create feature importance dict
    importance = []
    for name, attr, val in zip(feature_names, attributions, values):
        importance.append({
            'feature': name,
            'importance': float(attr),
            'value': float(val),
            'abs_importance': float(abs(attr))
        })
    
    importance.sort(key=lambda x: x['abs_importance'], reverse=True)
    return importance

if df is not None and scaler is not None:
    print("\nTop 5 Important Features (Attack Alert):")
    importance = compute_feature_importance_simple(model, alert_tensor, FEATURE_NAMES)
    
    for i, feat in enumerate(importance[:5], 1):
        print(f"  {i}. {feat['feature']:20s} importance: {feat['importance']:+.4f}")

# ============================================================================
# SECTION 7: Causal Analysis
# ============================================================================

print("\n[7/8] Performing Causal Analysis...")

if causal_graph.number_of_nodes() > 0:
    print("\nCausal Relationships for Top Features:")
    
    for feat_info in importance[:5]:
        feat_name = feat_info['feature']
        
        if feat_name in causal_graph:
            # Find ancestors (root causes)
            ancestors = list(nx.ancestors(causal_graph, feat_name))
            root_causes = [n for n in ancestors if causal_graph.in_degree(n) == 0]
            
            # Find direct causes
            direct_causes = list(causal_graph.predecessors(feat_name))
            
            print(f"\n  {feat_name}:")
            if root_causes:
                print(f"    Root causes: {', '.join(root_causes)}")
            if direct_causes:
                print(f"    Direct causes: {', '.join(direct_causes)}")
            
            # Find shortest path to label if it exists
            if 'label' in causal_graph or 'Label' in causal_graph:
                target = 'label' if 'label' in causal_graph else 'Label'
                try:
                    path = nx.shortest_path(causal_graph, feat_name, target)
                    print(f"    Path to outcome: {' ‚Üí '.join(path)}")
                except nx.NetworkXNoPath:
                    print(f"    No direct path to outcome")
        else:
            print(f"\n  {feat_name}: Not in causal graph")

# ============================================================================
# SECTION 8: Visualizations
# ============================================================================

print("\n[8/8] Creating Visualizations...")

if df is not None and len(attack_samples) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Hybrid Explanation System Demo - UNSW-NB15', 
                 fontsize=16, fontweight='bold')
    
    # Plot 1: Feature Importance
    ax = axes[0, 0]
    top_features = importance[:10]
    features = [f['feature'] for f in top_features]
    importances = [f['importance'] for f in top_features]
    colors = ['red' if imp > 0 else 'blue' for imp in importances]
    
    ax.barh(features, importances, color=colors, alpha=0.7)
    ax.set_xlabel('Importance Score')
    ax.set_title('XAI: Feature Importance (Top 10)')
    ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
    ax.grid(axis='x', alpha=0.3)
    
    # Plot 2: Feature Value Distribution (Attack vs Normal)
    ax = axes[0, 1]
    if 'sttl' in df.columns:
        attack_sttl = df[df['Label'] == 1]['sttl']
        normal_sttl = df[df['Label'] == 0]['sttl']
        
        ax.hist(normal_sttl, bins=30, alpha=0.5, label='Normal', color='blue')
        ax.hist(attack_sttl, bins=30, alpha=0.5, label='Attack', color='red')
        ax.set_xlabel('sttl (Source TTL)')
        ax.set_ylabel('Frequency')
        ax.set_title('Distribution: sttl (Strong Attack Indicator)')
        ax.legend()
        ax.grid(alpha=0.3)
    
    # Plot 3: Causal Graph (if available)
    ax = axes[1, 0]
    if causal_graph.number_of_nodes() > 0:
        # Draw subset of causal graph
        subgraph_nodes = list(causal_graph.nodes())[:15]  # Limit to 15 nodes
        subgraph = causal_graph.subgraph(subgraph_nodes)
        
        pos = nx.spring_layout(subgraph, k=2, iterations=50, seed=42)
        nx.draw_networkx_nodes(subgraph, pos, node_color='lightblue', 
                              node_size=800, alpha=0.9, ax=ax)
        nx.draw_networkx_labels(subgraph, pos, font_size=7, ax=ax)
        nx.draw_networkx_edges(subgraph, pos, edge_color='gray', 
                              arrows=True, arrowsize=15, ax=ax)
        ax.set_title('Causal Graph (Subset)')
        ax.axis('off')
    else:
        ax.text(0.5, 0.5, 'Causal graph\nnot available', 
               ha='center', va='center', fontsize=12)
        ax.axis('off')
    
    # Plot 4: Prediction Confidence
    ax = axes[1, 1]
    if scaler is not None:
        # Get predictions for multiple samples
        sample_indices = list(range(min(100, len(df))))
        predictions = []
        
        for idx in sample_indices:
            sample = df.iloc[idx]
            features = sample[FEATURE_NAMES].values.reshape(1, -1)
            scaled = scaler.transform(features)
            tensor = torch.tensor(scaled, dtype=torch.float32).unsqueeze(1).to(device)
            
            with torch.no_grad():
                output = model(tensor)
                probs = torch.softmax(output, dim=1)[0]
                predictions.append({
                    'true': sample['Label'],
                    'pred_prob': probs[1].item()
                })
        
        pred_df = pd.DataFrame(predictions)
        
        # Plot confidence distributions
        attack_conf = pred_df[pred_df['true'] == 1]['pred_prob']
        normal_conf = pred_df[pred_df['true'] == 0]['pred_prob']
        
        ax.hist(normal_conf, bins=20, alpha=0.5, label='Normal', color='blue')
        ax.hist(attack_conf, bins=20, alpha=0.5, label='Attack', color='red')
        ax.set_xlabel('Attack Probability')
        ax.set_ylabel('Frequency')
        ax.set_title('Model Confidence Distribution')
        ax.legend()
        ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('step4_demo_visualization.png', dpi=300, bbox_inches='tight')
    print("\n‚úì Saved visualization: step4_demo_visualization.png")
    plt.show()

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*70)
print("SUMMARY: Step 4 Demo Complete")
print("="*70)

print("\n‚úÖ What We Demonstrated:")
print("  ‚Ä¢ LSTM-based alert classification (Step 1)")
print("  ‚Ä¢ Feature importance via gradient-based XAI")
print("  ‚Ä¢ Causal graph analysis (Step 2)")
print("  ‚Ä¢ Integration of XAI + Causal for hybrid explanations")

print("\nüîç Key Findings:")
if df is not None:
    print(f"  ‚Ä¢ Dataset: {len(df):,} records")
    if 'Label' in df.columns:
        attack_pct = (df['Label'] == 1).mean() * 100
        print(f"  ‚Ä¢ Attack rate: {attack_pct:.1f}%")
    
    if 'sttl' in importance[0]['feature']:
        print(f"  ‚Ä¢ Top feature: {importance[0]['feature']} (TTL-based indicator)")

print("\nüìä Generated Files:")
print("  ‚Ä¢ step4_demo_visualization.png - Comprehensive visualization")

print("\nüéØ Next Steps:")
print("  ‚Ä¢ Run full Step 4 hybrid explainer for detailed explanations")
print("  ‚Ä¢ Proceed to Step 5 for quantitative evaluation")
print("  ‚Ä¢ Compare hybrid vs XAI-only approaches")

print("\nüí° Key Advantages of Hybrid Approach:")
print("  ‚Ä¢ XAI tells us WHAT: Which features are important")
print("  ‚Ä¢ Causal tells us WHY/HOW: Root causes and causal chains")
print("  ‚Ä¢ Combined: Actionable recommendations for SOC analysts")

print("\n" + "="*70)
print("Demo Complete!")
print("="*70)