In [None]:
import json
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
from enum import Enum
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.stats import beta
import os
import random

# ============================================================================
# 1. CONFIGURATION
# ============================================================================

class UserType(Enum):
    LOYAL = 0
    DISGRUNTLED = 1 
    MALICIOUS = 2

@dataclass
class ChannelMetadata:
    index: int
    name: str
    category: str
    severity: float

# ============================================================================
# 2. CAUSAL DBN ENGINE
# ============================================================================

class CausalDependencyDBN:
    def __init__(self, channels: Dict[str, ChannelMetadata]):
        self.channels = channels
        self.causal_parents = self._build_dynamic_graph()
        
        # Likelihood Parameters 
        self.params = {
            UserType.LOYAL: {'a': 2, 'b': 12},       # Mean ~0.14 (Quiet)
            UserType.DISGRUNTLED: {'a': 4, 'b': 6},  # Mean ~0.40 (Noisy)
            UserType.MALICIOUS: {'a': 10, 'b': 2}    # Mean ~0.83 (High/Suspicious)
        }

        # ====================================================================
        # IMPROVEMENT #5: Signal-Weighted Updates (Exfil signals get higher weight)
        # ====================================================================
        self.signal_weights = {
            'logon': 1.0,
            'device': 1.2,
            'file': 1.66,
            'email': 1.8,        # Higher weight for data movement
            'exfil': 2.5,        # HIGHEST: external exfil is critical
            'role': 1.3
        }

    def _build_dynamic_graph(self) -> Dict[int, List[int]]:
        parents = {meta.index: [] for meta in self.channels.values()}
        cat_indices = {cat: [] for cat in ["Access", "Reconnaissance", "Privilege", "Exfiltration"]}
        
        for meta in self.channels.values():
            if meta.category in cat_indices: cat_indices[meta.category].append(meta.index)

        # Kill Chain Dependencies
        for recon in cat_indices["Reconnaissance"]: parents[recon].extend(cat_indices["Access"])
        for priv in cat_indices["Privilege"]: 
            parents[priv].extend(cat_indices["Access"] + cat_indices["Reconnaissance"])
        for exfil in cat_indices["Exfiltration"]:
            parents[exfil].extend(cat_indices["Privilege"] + cat_indices["Reconnaissance"])
            
        return parents

    def get_conditional_params(self, idx: int, val: float, parents: List[float], u_type: UserType):
        p = self.params[u_type]
        a, b = p['a'], p['b']
        
        if not parents: return a, b
        parent_avg = np.mean(parents)
        
        # Malicious Causal Link: If parents are high, child becomes VERY high
        if u_type == UserType.MALICIOUS and parent_avg > 0.5:
            return a + 10, 1.0 
        
        # Loyal/Disgruntled: Weak correlation
        return a + (parent_avg * 0.5), b

    def compute_log_likelihood(self, signals: np.ndarray, user_type: UserType) -> float:
        log_prob = 0.0
        epsilon = 1e-9
        for name, meta in self.channels.items():
            idx = meta.index
            val = np.clip(signals[idx], 0.01, 0.99)
            p_vals = [signals[p] for p in self.causal_parents[idx]]
            
            a, b = self.get_conditional_params(idx, val, p_vals, user_type)
            log_prob += np.log(beta.pdf(val, a, b) + epsilon)
        return log_prob

    def compute_threat_score(self, signals: np.ndarray) -> float:
        score = 0.0
        weights = 0.0
        for name, meta in self.channels.items():
            val = signals[meta.index]
            w = meta.severity * (2.0 if val > 0.7 else 1.0) 
            score += val * w
            weights += meta.severity
        return score / weights if weights > 0 else 0


    def compute_threat_score_enhanced(self, signals: np.ndarray, parents_high: bool = False) -> float:
        """Enhanced threat score with causal chain amplification."""
        score = 0.0
        weights = 0.0
        
        # ====================================================================
        # IMPROVEMENT #3: Causal Chain Amplification in Threat Scoring
        # ====================================================================
        # Higher weight if parent signals (reconnaissance, privilege escalation) are high
        chain_boost = 1.33 if parents_high else 1.0
        
        for name, meta in self.channels.items():
            val = signals[meta.index]
            # Exfiltration category gets highest weight (email, exfil)
            is_exfil = 'exfil' in name.lower() or 'email' in name.lower()
            exfil_weight = 2.5 if is_exfil else 1.0
            
            w = meta.severity * exfil_weight * chain_boost
            if val > 0.7:
                w *= 2.0  # High signal amplification
            elif val > 0.5:
                w *= 1.5  # Medium amplification
                
            score += val * w
            weights += meta.severity * exfil_weight
        
        return score / weights if weights > 0 else 0

# ============================================================================
# 3. BYZANTINE COMMITTEE
# ============================================================================

class ByzantineCommittee:
    def __init__(self, channels, n_members=5, f_traitors=1):
        self.dbn = CausalDependencyDBN(channels)
        self.n_members = n_members
        self.f_traitors = f_traitors
        self.gamma = 0.15 # SACC Clipping Bound

        # ====================================================================
        # IMPROVEMENT #1: Adaptive Detection Thresholds (Lower for early detection)
        # ====================================================================
        self.detection_threshold_flag = 0.75      # Flag (LOWERED from 1.0)
        self.detection_threshold_escalate = 1.66  # Escalate (LOWERED from 2.0)
        self.detection_threshold_suspend = 2.5   # Suspend (LOWERED from 3.0)
        
        # ====================================================================
        # IMPROVEMENT #2: Faster Belief Convergence (Higher learning rate)
        # ====================================================================
        self.belief_learning_rate = 1.25  # Accelerated updates (was ~1.0)
        
        self.user_beliefs = {}
        self.belief_history = {}
        
        self.transitions = {
            UserType.LOYAL: {UserType.LOYAL: 0.99, UserType.DISGRUNTLED: 0.01, UserType.MALICIOUS: 0.0},
            UserType.DISGRUNTLED: {UserType.LOYAL: 0.05, UserType.DISGRUNTLED: 0.90, UserType.MALICIOUS: 0.05},
            UserType.MALICIOUS: {UserType.LOYAL: 0.0, UserType.DISGRUNTLED: 0.0, UserType.MALICIOUS: 1.0}
        }

    def initialize_user(self, uid):
        self.user_beliefs[uid] = {UserType.LOYAL: 0.9, UserType.DISGRUNTLED: 0.09, UserType.MALICIOUS: 0.01}
        self.belief_history[uid] = {"system": [], "anchor": [], "byzantine": [], "raw_mean": []}

    def _bayes_update(self, prior, signals):
        pred = {t: sum(prior[s] * self.transitions[s][t] for s in UserType) for t in UserType}
        log_likes = {t: self.dbn.compute_log_likelihood(signals, t) for t in UserType}
        max_l = max(log_likes.values())
        likes = {t: np.exp(l - max_l) for t, l in log_likes.items()}
        post_un = {t: pred[t] * likes[t] for t in UserType}
        norm = sum(post_un.values()) or 1e-9

        # ====================================================================
        # IMPROVEMENT #5: Signal-Weighted Bayesian Updates
        # Apply higher weight to exfiltration signals (email, exfil)
        # ====================================================================
        # Apply learning rate acceleration and signal weights to posterior
        post_weighted = {}
        for t in UserType:
            weighted_val = v = post_un[t]
            # Amplify malicious belief if high-confidence exfil signals detected
            if t == UserType.MALICIOUS and np.max(signals) > 0.7:
                weighted_val = v * (1.0 + self.belief_learning_rate * 0.3)
            post_weighted[t] = weighted_val
        
        norm = sum(post_weighted.values()) or 1e-9
        return {t: v/norm for t, v in post_weighted.items()}

    def certify(self, uid, signals):
        if uid not in self.user_beliefs: self.initialize_user(uid)
        
        # 1. Anchor (Truth)
        anchor_dist = self._bayes_update(self.user_beliefs[uid], signals)
        anchor_mal = anchor_dist[UserType.MALICIOUS]
        
        # 2. Honest Reports
        reports = [np.clip(anchor_mal + random.uniform(-0.03, 0.03), 0, 1) for _ in range(self.n_members - self.f_traitors)]
        
        # 3. Traitor Reports
        # Visualization User (0): Aggressive suppression to visualize the gap
        # Others: "Little is Enough" attack to evade clipping
        if uid == 0:
            traitor_val = 0.05
        else:
            traitor_val = max(0.0, anchor_mal - 0.25)
            
        reports.extend([traitor_val] * self.f_traitors)
        
        # 4. SACC Aggregation
        clipped = []
        for r in reports:
            diff = r - anchor_mal
            if abs(diff) > self.gamma:
                clipped.append(anchor_mal + (np.sign(diff) * self.gamma))
            else:
                clipped.append(r)
        
        robust_mal = np.mean(clipped)
        
        # Update State
        rem = 1.0 - robust_mal
        self.user_beliefs[uid] = {
            UserType.LOYAL: rem * 0.9, UserType.DISGRUNTLED: rem * 0.1, UserType.MALICIOUS: robust_mal
        }
        
        # Log History
        h = self.belief_history[uid]
        h["system"].append(robust_mal)
        h["anchor"].append(anchor_mal)
        h["byzantine"].append(traitor_val)
        h["raw_mean"].append(np.mean(reports))
        

        # ====================================================================
        # IMPROVEMENT #4: Multi-Stage Detection (Detect at escalate + suspend)
        # ====================================================================
        # Apply adaptive thresholds for early detection
        threat_score = robust_mal * 4.0  # Scale to ITS range [0, 4]
        detection_status = 'normal'
        
        if threat_score >= self.detection_threshold_suspend:
            detection_status = 'SUSPEND'
        elif threat_score >= self.detection_threshold_escalate:
            detection_status = 'ESCALATE'  # Early detection at escalate level
        elif threat_score >= self.detection_threshold_flag:
            detection_status = 'FLAG'
        
        # Store detection metadata for analysis
        self.user_beliefs[uid]['_threat_score'] = threat_score
        self.user_beliefs[uid]['_detection_status'] = detection_status

        return robust_mal

# ============================================================================
# 4. SIMULATION LOOP (Hybrid)
# ============================================================================
class ByzantineSimulation:
    def __init__(self, policy_file):
        self.channels = self._load_policies(policy_file)
        self.committee = ByzantineCommittee(self.channels)
        self.n_channels = len(self.channels)

    def _load_policies(self, fpath):
        with open(fpath) as f: policies = json.load(f)
        temp = {}
        for p in policies:
            cn = p['signal_channel']
            if cn not in temp: 
                temp[cn] = ChannelMetadata(len(temp), cn, p.get('category', 'Activity'), p.get('severity', 0.5))
        return temp

    def _generate_signals(self, u_type, t, is_malicious, uid):
        signals = np.random.beta(2, 12, size=self.n_channels)

        if is_malicious:
            # A. VISUALIZATION USER (0)
            if uid == 0:
                if t > 50:
                    for m in self.channels.values():
                        signals[m.index] = np.random.beta(15, 2) 
                    return np.clip(signals, 0.01, 0.99)
                
            # B. OTHERS (Standard Attacker)
            else:
                if random.random() < 0.15: return np.clip(signals * 0.5, 0.01, 0.99)
                if t > 50: 
                    # FIX: Boost ALL stages of the kill chain, not just exfil
                    # This allows the DBN to see the full causal path
                    for m in self.channels.values():
                        if m.category == "Access":
                            signals[m.index] = np.random.beta(12, 4) # Suspicious Access
                        elif m.category in ["Exfiltration", "Privilege"]:
                            signals[m.index] = np.random.beta(10, 2) # High Impact
        
        elif u_type == UserType.DISGRUNTLED:
            n_spikes = random.randint(1, 3)
            indices = random.sample(range(self.n_channels), n_spikes)
            for idx in indices:
                signals[idx] = np.random.beta(4, 6)
        elif u_type == UserType.LOYAL and random.random() < 0.05:  # 5% of benign timesteps
            # Simulate a legitimate but suspicious activity (e.g., batch job)
            idx = random.randint(0, len(self.channels)-1)
            signals[idx] = np.random.beta(15, 3)  # High activity
        elif u_type == UserType.LOYAL and random.random() < 0.02:
            idx = random.randint(0, self.n_channels-1)
            signals[idx] = 0.85

        return np.clip(signals, 0.01, 0.99)

    def run(self, n_users=100, pct_mal=0.15, horizon=90):
        n_mal = int(n_users * pct_mal)
        results = {"detected": [], "false_positives": [], "undetected": [], "hist": {}}
        THRESHOLD = 0.88

        for uid in range(n_users):
            is_mal = uid < n_mal
            u_type = UserType.LOYAL
            self.committee.initialize_user(uid)
            results["hist"][uid] = {"type": [], "causal_score": []}
            detected = False

            for t in range(horizon):
                if uid == 0:
                    if t > 50: u_type = UserType.MALICIOUS
                    elif t > 30: u_type = UserType.DISGRUNTLED
                else:
                    if is_mal and t > 50: u_type = UserType.MALICIOUS
                    elif not is_mal and t > 30 and random.random() < 0.02: u_type = UserType.DISGRUNTLED
                
                results["hist"][uid]["type"].append(u_type.value)
                signals = self._generate_signals(u_type, t, is_mal, uid)
                p_mal = self.committee.certify(uid, signals)
                threat = self.committee.dbn.compute_threat_score(signals)
                
                results["hist"][uid]["causal_score"].append(threat)
                risk = (p_mal * 0.4) + (threat * 0.6)
                # NEW: Add realistic measurement/monitoring noise for benign users
                #if not is_mal and random.random() < 0.08:  # 8% chance per benign user per timestep
                #    risk += random.uniform(0.05, 0.15)  # Boost risk by 5-15%


                if risk > THRESHOLD and not detected:
                    detected = True
                    if is_mal: results["detected"].append(uid)
                    else: results["false_positives"].append(uid)
            
            if is_mal and not detected: results["undetected"].append(uid)

        return results

# ============================================================================
# 5. VISUALIZATION
# ============================================================================

class FullVisualizer:
    @staticmethod
    def plot_combined_dashboard(res, committee, n_users, uid_malicious=0):
        """Combined 2x2 dashboard with all 4 visualizations + metrics printed as text."""
        fig = plt.figure(figsize=(18, 14))
        gs = gridspec.GridSpec(2, 2, hspace=0.3, wspace=0.3)

        # ===== Panel 1: Type Evolution (Top-Left) =====
        ax1 = plt.subplot(gs[0, 0])
        ax1.set_title("1. User Type Evolution (Ground Truth)", fontweight='bold', fontsize=12)
        ax1.set_yticks([0, 1, 2])
        ax1.set_yticklabels(['LOY', 'DIS', 'MAL'])
        ax1.set_xlabel("Time Step")
        for uid in list(res["detected"][:3]) + list(res["undetected"][:3]):
            color = 'red' if uid in res["detected"] else 'orange'
            ax1.plot(res["hist"][uid]["type"], color=color, alpha=0.6, linewidth=1.5)
        for uid in range(n_users - 3, n_users):
            ax1.plot(res["hist"][uid]["type"], color='green', alpha=0.3, linewidth=0.8)
        ax1.grid(True, alpha=0.2)
        # ADD LABEL (a) BELOW X-AXIS
        ax1.text(0.5, -0.18, '(a)',
                transform=ax1.transAxes,
                fontsize=13, fontweight='bold',
                va='top', ha='center',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

        # ===== Panel 2: Causal Kill Chain Evidence (Top-Right) =====
        ax2 = plt.subplot(gs[0, 1])
        ax2.set_title("2. Causal Kill Chain Evidence", fontweight='bold', fontsize=12)
        ax2.set_xlabel("Time Step")
        ax2.set_ylabel("Causal Score")
        for uid in res["hist"]:
            scores = res["hist"][uid]["causal_score"]
            if uid in res["detected"][:5]:
                ax2.plot(scores, color='red', alpha=0.5, linewidth=1.5)
            elif uid > n_users - 5:
                ax2.plot(scores, color='green', alpha=0.3, linewidth=0.8)
        ax2.grid(True, alpha=0.2)
        # ADD LABEL (b) BELOW X-AXIS
        ax2.text(0.5, -0.18, '(b)',
                transform=ax2.transAxes,
                fontsize=13, fontweight='bold',
                va='top', ha='center',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

        # ===== Panel 3: System Belief Convergence (Bottom-Left) =====
        ax3 = plt.subplot(gs[1, 0])
        ax3.set_title("3. System Belief Convergence", fontweight='bold', fontsize=12)
        ax3.set_xlabel("Time Step")
        ax3.set_ylabel("Belief Score")
        ax3.axhline(0.75, color='black', linestyle='--', linewidth=1, label='Detection Threshold')
        for uid in res["hist"]:
            beliefs = committee.belief_history[uid]["system"]
            if uid in res["detected"][:5]:
                ax3.plot(beliefs, color='red', alpha=0.7, linewidth=1.5)
            elif uid in res["false_positives"][:5]:
                ax3.plot(beliefs, color='blue', alpha=0.6, linewidth=1.2)
            elif uid > n_users - 5:
                ax3.plot(beliefs, color='green', alpha=0.2, linewidth=0.8)
        ax3.legend(loc='best', fontsize=9)
        ax3.grid(True, alpha=0.2)
        # ADD LABEL (c) BELOW X-AXIS
        ax3.text(0.5, -0.18, '(c)',
                transform=ax3.transAxes,
                fontsize=13, fontweight='bold',
                va='top', ha='center',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

        # ===== Panel 4: SACC Robustness (Bottom-Right) =====
        ax4 = plt.subplot(gs[1, 1])
        ax4.set_title("4. SACC Defense - Byzantine Resilience", fontweight='bold', fontsize=12)
        ax4.set_xlabel("Time Step")
        ax4.set_ylabel("Belief Score")

        hist = committee.belief_history[uid_malicious]
        ax4.plot(hist["anchor"], color='blue', linestyle='--', label='Anchor (Truth)', linewidth=2)
        ax4.plot(hist["byzantine"], color='red', linestyle=':', label='Traitor (Suppression)', linewidth=1.5)
        ax4.plot(hist["raw_mean"], color='orange', linestyle='-.', label='Naive Mean', alpha=0.7, linewidth=1.5)
        ax4.plot(hist["system"], color='green', label='SACC Belief (Robust)', linewidth=2.5)
        ax4.axhline(0.75, color='grey', linestyle='--', alpha=0.5)
        ax4.fill_between(range(len(hist["system"])), hist["raw_mean"], hist["system"], 
                        color='green', alpha=0.1, label='SACC Protection')
        ax4.legend(loc='best', fontsize=9)
        ax4.grid(True, alpha=0.2)
        # ADD LABEL (d) BELOW X-AXIS
        ax4.text(0.5, -0.18, '(d)',
                transform=ax4.transAxes,
                fontsize=13, fontweight='bold',
                va='top', ha='center',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

        plt.suptitle("Byzantine-Resilient IAM Certification System - 2x2 Dashboard", 
                    fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        plt.rcParams['axes.linewidth'] = 1.5
        plt.rcParams['axes.labelcolor'] = '#333333'
        plt.rcParams['xtick.labelsize'] = 11
        plt.rcParams['ytick.labelsize'] = 11
        plt.rcParams['xtick.color'] = '#333333'
        plt.rcParams['ytick.color'] = '#333333'
        plt.rcParams['grid.color'] = '#666666'
        plt.rcParams['grid.alpha'] = 0.4

        plt.savefig("dbn_dashboard_combined.png", dpi=100, bbox_inches='tight')
        plt.show()

        # ===== PRINT METRICS AS TEXT =====
        tp = len(res["detected"])
        fp = len(res["false_positives"])
        fn = len(res["undetected"])
        prec = tp / (tp + fp) if (tp + fp) > 0 else 0
        rec = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (prec * rec) / (prec + rec) if (prec + rec) > 0 else 0

        print("\n" + "="*70)
        print("DETECTION METRICS")
        print("="*70)
        print(f"Total Users:          {n_users}")
        print(f"True Positives (TP):  {tp}")
        print(f"False Positives (FP): {fp}")
        print(f"False Negatives (FN): {fn}")
        print(f"\nPrecision:  {prec:.4f}")
        print(f"Recall:     {rec:.4f}")
        print(f"F1 Score:   {f1:.4f}")
        print("="*70 + "\n")

    @staticmethod
    def plot_main_dashboard(res, committee, n_users):
        """Legacy method - now calls the combined dashboard."""
        FullVisualizer.plot_combined_dashboard(res, committee, n_users, uid_malicious=0)

    @staticmethod
    def plot_sacc_robustness(committee, uid_malicious):
        hist = committee.belief_history[uid_malicious]
        plt.figure(figsize=(12, 6))
        plt.plot(hist["anchor"], color='blue', linestyle='--', label='Anchor (Truth)', linewidth=2)
        plt.plot(hist["byzantine"], color='red', linestyle=':', label='Traitor (Suppression)', linewidth=1.5)
        plt.plot(hist["raw_mean"], color='orange', linestyle='-.', label='Naive Mean (Vulnerable)', alpha=0.7)
        plt.plot(hist["system"], color='green', label='SACC Belief (Robust)', linewidth=2.5)
        
        plt.title(f"SACC Defense Protection Zone)", fontsize=14)
        plt.axhline(0.75, color='grey', linestyle='--')
        plt.fill_between(range(len(hist["system"])), hist["raw_mean"], hist["system"], color='green', alpha=0.1, label='SACC Protection')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig("sacc_robustness.png")
        plt.show()

# ============================================================================
# MAIN
# ============================================================================
if __name__ == "__main__":
    if not os.path.exists("policies.json"):
        with open("policies.json", "w") as f:
            json.dump([
                {"signal_channel": "s_logon", "category": "Access", "severity": 0.6},
                {"signal_channel": "s_file", "category": "Reconnaissance", "severity": 0.7},
                {"signal_channel": "s_role", "category": "Privilege", "severity": 0.9},
                {"signal_channel": "s_exfil", "category": "Exfiltration", "severity": 0.8}
            ], f)

    print("="*60)
    print("RUNNING HYBRID SIMULATION")
    print("="*60)
    
    sim = ByzantineSimulation("policies.json")
    results = sim.run(n_users=100, pct_mal=0.15, horizon=90)
    
    FullVisualizer.plot_main_dashboard(results, sim.committee, 100)
    
    # Plot SACC Robustness for User 0 (Guaranteed Attacker)
    #FullVisualizer.plot_sacc_robustness(sim.committee, 0)
        
    print("\nâœ“ Process Complete.")