In [None]:
# Cell 1: Markdown - Notebook Overview
"""
# Manifold-Constrained Hyper-Connections (mHC) Experiments

## Overview
This notebook explores the Manifold-Constrained Hyper-Connections (mHC) architecture used in CyberGuard for stable multi-agent coordination.

## Why mHC?
Traditional multi-agent systems suffer from:
1. **Signal Explosion**: Unbounded information flow between agents causing instability
2. **Dominant Agent Bias**: One agent overwhelming others' contributions leading to unfair coordination
3. **Reasoning Collapse**: Agents losing individual reasoning capabilities due to excessive mixing

mHC solves these through mathematical constraints:
- Doubly-stochastic normalization (Sinkhorn-Knopp projection) ensures equal contribution
- Convex state mixing with bounded propagation prevents signal explosion
- Identity-preserving mappings maintain agent individuality
- Non-expansive updates guarantee stability
"""

# Cell 2: Setup and Imports
import torch  # PyTorch for tensor operations and neural networks
import torch.nn as nn  # Neural network module for building models
import torch.nn.functional as F  # Functional operations like softmax, normalization
import numpy as np  # Numerical computing for array operations
import matplotlib.pyplot as plt  # Plotting and visualization
from typing import List, Tuple, Dict, Optional  # Type hints for better code documentation
import math  # Mathematical functions and constants
from tqdm import tqdm  # Progress bars for long-running operations
import seaborn as sns  # Statistical data visualization (not used but imported)
import pandas as pd  # Data manipulation and analysis
import time  # Time measurement for performance tracking
import os  # Operating system interface for file operations
import json  # JSON serialization for saving results
import pickle  # Python object serialization for saving models

# Set random seeds for reproducibility across runs
torch.manual_seed(42)  # PyTorch random seed
np.random.seed(42)  # NumPy random seed

# Enable GPU acceleration if available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")  # Inform user about computation device

# Cell 3: mHC Core Implementation
class ManifoldConstrainedHyperConnections:
    """
    Main implementation of Manifold-Constrained Hyper-Connections (mHC).
    This class provides mathematical guarantees for stable multi-agent coordination
    by enforcing constraints on information flow between agents.
    """
    
    def __init__(self, n_agents: int, state_dim: int, temperature: float = 1.0):
        """
        Constructor initializes mHC with configuration parameters.
        
        Parameters:
        n_agents: Number of coordinating agents in the system
        state_dim: Dimension of each agent's state vector representation
        temperature: Controls attention distribution sharpness (higher = more uniform)
        """
        # Store configuration parameters
        self.n_agents = n_agents  # Number of agents in the coordination system
        self.state_dim = state_dim  # Dimensionality of agent state vectors
        self.temperature = temperature  # Attention temperature parameter
        
        # Sinkhorn-Knopp algorithm parameters for doubly-stochastic normalization
        self.sinkhorn_iterations = 50  # Maximum iterations for convergence
        self.epsilon = 1e-8  # Small constant to prevent numerical instability (division by zero)
        
        # Bounded propagation parameters for stability guarantees
        self.signal_bound = 1.0  # Maximum allowed L2 norm of mixed states (β parameter)
        self.identity_preserve_factor = 0.1  # λ parameter: weight for preserving agent identity
        
        # Metrics tracking for analysis and debugging
        self.metrics = {
            'signal_norms': [],  # Track state norms before/after bounding
            'attention_entropy': [],  # Measure attention distribution uniformity
            'coordination_efficiency': []  # Ratio of useful information after coordination
        }
        
    def sinkhorn_knopp_projection(self, log_alpha: torch.Tensor) -> torch.Tensor:
        """
        Implements Sinkhorn-Knopp algorithm to convert any non-negative matrix
        into a doubly-stochastic matrix (rows and columns sum to 1).
        
        Mathematical basis: Iterative row and column normalization that converges
        to a matrix where each agent gives and receives equal total attention.
        
        Parameters:
        log_alpha: Input matrix in log-space for numerical stability
        
        Returns:
        Doubly-stochastic matrix after projection
        """
        # Validate input dimensions
        if log_alpha.dim() != 2:  # Must be 2D matrix
            raise ValueError("log_alpha must be 2D matrix")
        if log_alpha.shape[0] != log_alpha.shape[1] or log_alpha.shape[0] != self.n_agents:
            raise ValueError(f"Expected shape [{self.n_agents}, {self.n_agents}], got {log_alpha.shape}")
        
        # Perform iterative normalization
        for iteration in range(self.sinkhorn_iterations):
            # Row normalization: ensure each agent's outgoing influence sums to 1
            # This prevents any single agent from dominating the coordination
            log_alpha = log_alpha - torch.logsumexp(
                log_alpha, 
                dim=1,  # Sum across columns (agents receiving attention)
                keepdim=True  # Maintain dimension for broadcasting
            )
            
            # Column normalization: ensure each agent receives equal total attention
            # This prevents any agent from being ignored in coordination
            log_alpha = log_alpha - torch.logsumexp(
                log_alpha, 
                dim=0,  # Sum across rows (agents giving attention)
                keepdim=True
            )
            
            # Early convergence optimization: check if matrix is already doubly-stochastic
            if iteration > 10:  # Wait for initial convergence
                row_sums = torch.exp(log_alpha).sum(dim=1)  # Actual row sums
                col_sums = torch.exp(log_alpha).sum(dim=0)  # Actual column sums
                # Check if both row and column sums are approximately 1
                row_converged = torch.allclose(row_sums, torch.ones_like(row_sums), rtol=1e-4)
                col_converged = torch.allclose(col_sums, torch.ones_like(col_sums), rtol=1e-4)
                if row_converged and col_converged:  # Early exit if converged
                    break
        
        # Convert from log-space back to probability space
        doubly_stochastic_matrix = torch.exp(log_alpha)
        
        return doubly_stochastic_matrix
    
    def convex_state_mixing(self, 
                           agent_states: List[torch.Tensor], 
                           attention_weights: torch.Tensor) -> torch.Tensor:
        """
        Core mHC algorithm: Mix agent states with mathematical constraints.
        
        The mixing process has four key steps:
        1. Doubly-stochastic normalization of attention weights
        2. Convex combination of agent states
        3. Identity preservation to maintain agent individuality
        4. Signal bounding to prevent information explosion
        
        Parameters:
        agent_states: List of each agent's state tensor [batch_size, state_dim]
        attention_weights: Raw attention matrix [batch_size, n_agents, n_agents]
        
        Returns:
        Bounded, mixed state tensor after applying all mHC constraints
        """
        # Input validation
        if not agent_states:
            raise ValueError("agent_states list cannot be empty")
        
        # Extract batch size from first agent's state
        batch_size = agent_states[0].shape[0]
        
        # Stack all agent states into single tensor for efficient computation
        # Result shape: [batch_size, n_agents, state_dim]
        stacked_states = torch.stack(agent_states, dim=1)
        
        # Ensure attention weights have proper dimensions
        if attention_weights.dim() == 2:  # If missing batch dimension
            attention_weights = attention_weights.unsqueeze(0)  # Add batch dimension
        
        # Transpose if attention matrix has wrong orientation
        if attention_weights.shape[1] != self.n_agents:
            if attention_weights.shape[2] == self.n_agents:
                attention_weights = attention_weights.transpose(1, 2)
        
        # Step 1: Apply Sinkhorn-Knopp projection to get doubly-stochastic attention
        log_attention = torch.log(attention_weights + self.epsilon)  # Log-space for stability
        
        # Handle batch processing (Sinkhorn works on 2D matrices)
        if batch_size == 1:
            # Single batch: squeeze, process, then unsqueeze
            normalized_attention = self.sinkhorn_knopp_projection(
                log_attention.squeeze(0)
            ).unsqueeze(0)
        else:
            # Multiple batches: process each separately
            normalized_attention_list = []
            for b in range(batch_size):
                norm_att = self.sinkhorn_knopp_projection(log_attention[b])
                normalized_attention_list.append(norm_att)
            normalized_attention = torch.stack(normalized_attention_list, dim=0)
        
        # Store normalized attention for metrics and debugging
        self.last_normalized_attention = normalized_attention
        
        # Step 2: Convex combination of agent states using normalized attention
        # Einstein summation: bij = batch i agents j agents, bjd = batch j agents dimension
        # Result: mixed_state[b,d] = Σ_i Σ_j normalized_attention[b,i,j] * stacked_states[b,j,d]
        mixed_state = torch.einsum('bij,bjd->bd', normalized_attention, stacked_states)
        
        # Step 3: Identity preservation - blend mixed state with original identities
        identity_states = stacked_states.mean(dim=1)  # Compute mean state across agents
        # Convex combination: mixed_state * (1-λ) + identity_states * λ
        # λ = identity_preserve_factor controls how much original identity is preserved
        mixed_state = (mixed_state * (1 - self.identity_preserve_factor) + 
                      identity_states * self.identity_preserve_factor)
        
        # Step 4: Signal bounding - prevent state norms from exploding
        mixed_state_norm = torch.norm(mixed_state, dim=-1, keepdim=True)  # L2 norm per batch
        # Scaling factor: min(1, bound / norm) ensures norm ≤ bound
        scaling = torch.minimum(
            torch.ones_like(mixed_state_norm),  # Upper bound of 1 (no scaling if norm ≤ bound)
            self.signal_bound / (mixed_state_norm + self.epsilon)  # Scale down if norm > bound
        )
        bounded_state = mixed_state * scaling  # Apply scaling
        
        # Track metrics for analysis
        self._track_metrics(normalized_attention, mixed_state_norm, bounded_state)
        
        return bounded_state
    
    def _track_metrics(self, attention: torch.Tensor, 
                      pre_bound_norm: torch.Tensor,
                      bounded_state: torch.Tensor):
        """
        Internal method to track performance metrics during coordination.
        
        Metrics tracked:
        1. Signal norms before and after bounding
        2. Attention entropy (uniformity measure)
        3. Coordination efficiency (information preservation)
        """
        # 1. Track signal norm changes
        post_bound_norm = torch.norm(bounded_state, dim=-1).mean().item()
        self.metrics['signal_norms'].append({
            'pre_bound': pre_bound_norm.mean().item(),  # Norm before bounding
            'post_bound': post_bound_norm  # Norm after bounding
        })
        
        # 2. Compute attention entropy (measure of fairness)
        attention_flat = attention.flatten()  # Flatten matrix for entropy calculation
        entropy = -torch.sum(attention_flat * torch.log(attention_flat + self.epsilon)).item()
        self.metrics['attention_entropy'].append(entropy)  # Higher entropy = more uniform attention
        
        # 3. Coordination efficiency: ratio of post-bound to pre-bound norm
        efficiency = post_bound_norm / (pre_bound_norm.mean().item() + self.epsilon)
        self.metrics['coordination_efficiency'].append(efficiency)  # Closer to 1 = more efficient
    
    def residual_coordination(self, 
                            agent_outputs: List[Dict], 
                            agent_confidences: torch.Tensor) -> Dict:
        """
        Higher-level coordination that preserves individual agent reasoning.
        
        Instead of overwriting agent decisions, this method:
        1. Extracts reasoning states from each agent
        2. Applies mHC mixing for coordination
        3. Uses coordinated state to adjust (not replace) decisions
        4. Aggregates decisions with fairness constraints
        
        Parameters:
        agent_outputs: List of agent decisions with reasoning states
        agent_confidences: Confidence scores for each agent [batch_size, n_agents]
        
        Returns:
        Dictionary with coordinated decisions and analysis information
        """
        # Extract reasoning states from agent outputs
        reasoning_states = []
        for i, output in enumerate(agent_outputs):
            # Get reasoning state or create zero state if not provided
            state = output.get('reasoning_state', 
                             torch.zeros((1, self.state_dim), device=agent_confidences.device))
            # Ensure proper dimensions
            if state.dim() == 1:
                state = state.unsqueeze(0)  # Add batch dimension
            elif state.dim() == 2 and state.shape[0] != 1:
                state = state[0:1]  # Take first element if batch exists
            reasoning_states.append(state)
        
        # Create attention matrix from agent confidences
        batch_size = agent_confidences.shape[0]
        # Pairwise attention: conf_i * conf_j gives higher weight to confident pairs
        attention_logits = torch.einsum('bi,bj->bij', agent_confidences, agent_confidences)
        attention_logits = attention_logits / self.temperature  # Apply temperature scaling
        
        # Apply mHC state mixing to get coordinated reasoning state
        coordinated_state = self.convex_state_mixing(reasoning_states, attention_logits)
        
        # Process individual agent decisions with confidence weighting
        decisions = []
        for i, output in enumerate(agent_outputs):
            agent_decision = output['decision']
            agent_weight = agent_confidences[:, i:i+1]  # Extract confidence for this agent
            
            # Create constrained decision weighted by confidence
            constrained_decision = {
                'threat_level': agent_decision['threat_level'] * agent_weight,
                'confidence': agent_decision['confidence'] * agent_weight,
                'evidence': agent_decision.get('evidence', []),
                'agent_id': output.get('agent_id', f'agent_{i}')
            }
            decisions.append(constrained_decision)
        
        # Aggregate decisions using normalized confidence weights
        threat_levels = torch.stack([d['threat_level'] for d in decisions], dim=1)
        confidences = torch.stack([d['confidence'] for d in decisions], dim=1)
        normalized_weights = F.softmax(agent_confidences, dim=-1)  # Normalize to probability distribution
        
        # Weighted sum aggregation
        final_threat = torch.sum(threat_levels * normalized_weights.unsqueeze(-1), dim=1)
        final_confidence = torch.sum(confidences * normalized_weights.unsqueeze(-1), dim=1)
        
        # Aggregate evidence from all agents with confidence weighting
        all_evidence = []
        for i, output in enumerate(agent_outputs):
            evidence = output['decision'].get('evidence', [])
            # Weight evidence by agent confidence
            agent_weight = normalized_weights[0, i].item() if batch_size == 1 else normalized_weights[:, i].mean().item()
            for ev in evidence:
                ev['source_confidence'] = agent_weight  # Annotate with source confidence
                all_evidence.append(ev)
        
        # Sort evidence by confidence and keep top results
        all_evidence.sort(key=lambda x: x.get('source_confidence', 0), reverse=True)
        top_evidence = all_evidence[:10]  # Limit to top 10 pieces for stability
        
        # Return comprehensive coordination results
        return {
            'final_decision': {
                'threat_level': final_threat,
                'confidence': final_confidence,
                'evidence': top_evidence
            },
            'coordinated_state': coordinated_state,
            'agent_contributions': normalized_weights.tolist(),
            'attention_matrix': attention_logits.squeeze().tolist() if batch_size == 1 else None
        }

# Cell 4: Visualization and Analysis Functions
def visualize_mhc_components(mhc: ManifoldConstrainedHyperConnections,
                           agent_states: List[torch.Tensor],
                           attention_weights: torch.Tensor):
    """
    Create comprehensive visualization of mHC coordination process.
    
    Generates 4 subplots showing:
    1. Original agent states in 2D space
    2. Attention matrices before/after Sinkhorn normalization
    3. State mixing process and magnitude changes
    4. Signal bounding effect on state norms
    """
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))  # 2x2 grid of plots
    
    # Convert tensors to numpy for visualization
    states_np = [s.detach().cpu().numpy() for s in agent_states]
    
    # Subplot 1: Original agent states visualization
    ax1 = axes[0, 0]
    colors = plt.cm.Set1(np.linspace(0, 1, len(states_np)))  # Distinct colors for each agent
    
    for i, state in enumerate(states_np):
        # Handle batch dimension and ensure 2D coordinates
        if state.ndim == 2:  # Has batch dimension
            state = state[0]  # Take first batch element
        if len(state) < 2:  # Pad if state dimension < 2
            state_padded = np.zeros(2)
            state_padded[:len(state)] = state[:2]
            state = state_padded
        
        # Plot agent state as scatter point
        ax1.scatter(state[0], state[1], color=colors[i], 
                   s=100, label=f'Agent {i+1}', alpha=0.7)
        ax1.annotate(f'A{i+1}', (state[0], state[1]), 
                    xytext=(5, 5), textcoords='offset points')  # Label agents
    
    ax1.set_xlabel('Dimension 1')
    ax1.set_ylabel('Dimension 2')
    ax1.set_title('Original Agent States (First 2 Dimensions)')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    ax1.axhline(y=0, color='gray', linestyle='--', alpha=0.5)  # Zero lines for reference
    ax1.axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    
    # Subplot 2: Attention matrix visualization
    ax2 = axes[0, 1]
    # Get original attention matrix
    attention_original = attention_weights.detach().cpu().numpy()
    if attention_original.ndim == 3:
        attention_original = attention_original[0]  # Take first batch
    
    # Apply Sinkhorn normalization for comparison
    log_attention = torch.log(attention_weights + 1e-8)
    attention_sinkhorn = mhc.sinkhorn_knopp_projection(log_attention)
    attention_sinkhorn_np = attention_sinkhorn.detach().cpu().numpy()
    if attention_sinkhorn_np.ndim == 3:
        attention_sinkhorn_np = attention_sinkhorn_np[0]
    
    # Combine original and normalized matrices side by side
    combined_attention = np.hstack([attention_original, attention_sinkhorn_np])
    
    # Create heatmap visualization
    im = ax2.imshow(combined_attention, cmap='viridis', aspect='auto')
    plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)  # Add color bar
    
    # Add dividing line and labels
    n = mhc.n_agents
    ax2.axvline(x=n-0.5, color='white', linewidth=2)  # Vertical divider
    ax2.set_xticks([n//2 - 0.5, n + n//2 - 0.5])  # Position labels
    ax2.set_xticklabels(['Original', 'Sinkhorn'])  # Matrix type labels
    ax2.set_yticks(range(n))
    ax2.set_yticklabels([f'A{i+1}' for i in range(n)])  # Agent labels
    ax2.set_title('Attention Matrices: Original vs Doubly-Stochastic')
    ax2.set_xlabel('Matrix Type')
    ax2.set_ylabel('Agent')
    
    # Subplot 3: State mixing process visualization
    ax3 = axes[1, 0]
    # Get mixed state from mHC
    mixed_state = mhc.convex_state_mixing(agent_states, attention_weights)
    mixed_np = mixed_state.detach().cpu().numpy()
    if mixed_np.ndim == 2:
        mixed_np = mixed_np[0]
    
    # Calculate L2 norms for visualization
    x_positions = np.arange(len(states_np) + 1)
    state_magnitudes = [np.linalg.norm(s) for s in states_np]  # Individual state norms
    mixed_magnitude = np.linalg.norm(mixed_np)  # Mixed state norm
    
    # Bar plot showing individual vs mixed state magnitudes
    ax3.bar(x_positions[:-1], state_magnitudes, alpha=0.6, 
           label='Individual States')
    ax3.bar(x_positions[-1], mixed_magnitude, alpha=0.8, 
           color='red', label='Mixed State')
    
    ax3.set_xlabel('State')
    ax3.set_ylabel('Magnitude (L2 Norm)')
    ax3.set_title('State Mixing: Individual → Coordinated')
    ax3.set_xticks(x_positions)
    ax3.set_xticklabels([f'A{i+1}' for i in range(len(states_np))] + ['Mixed'])
    ax3.legend()
    ax3.grid(True, alpha=0.3, axis='y')  # Horizontal grid only
    
    # Subplot 4: Signal bounding effect visualization
    ax4 = axes[1, 1]
    # Calculate unbounded mixing for comparison (simple average without constraints)
    stacked_states = torch.stack(agent_states)
    if stacked_states.dim() == 3:
        stacked_states = stacked_states.squeeze(1)  # Remove batch if batch_size=1
    unbounded_mixed = stacked_states.mean(dim=0)
    
    # Get norms for comparison
    if unbounded_mixed.dim() == 1:
        unbounded_norm = torch.norm(unbounded_mixed).item()
    else:
        unbounded_norm = torch.norm(unbounded_mixed[0]).item()
    
    if mixed_state.dim() == 1:
        bounded_norm = torch.norm(mixed_state).item()
    else:
        bounded_norm = torch.norm(mixed_state[0]).item()
    
    # Bar plot comparing unbounded vs bounded mixing
    norms = [unbounded_norm, bounded_norm]
    labels = ['Unbounded', 'Bounded']
    colors_bar = ['orange', 'green']
    
    bars = ax4.bar(labels, norms, color=colors_bar, alpha=0.7)
    ax4.axhline(y=mhc.signal_bound, color='red', linestyle='--', 
               label=f'Bound = {mhc.signal_bound}')  # Signal bound reference line
    
    # Add value labels on bars
    for bar, norm in zip(bars, norms):
        height = bar.get_height()
        ax4.text(bar.get_x() + bar.get_width()/2., height,
                f'{norm:.3f}', ha='center', va='bottom')
    
    ax4.set_xlabel('Mixing Type')
    ax4.set_ylabel('State Norm')
    ax4.set_title('Signal Bounding Effect on State Norms')
    ax4.legend()
    ax4.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()  # Adjust spacing between subplots
    return fig

def analyze_mhc_stability(mhc: ManifoldConstrainedHyperConnections,
                         n_iterations: int = 100,
                         noise_level: float = 0.1):
    """
    Analyze stability of mHC coordination over multiple iterations.
    
    Tests how mHC maintains stability under changing conditions by tracking:
    1. State norm boundedness over time
    2. Attention entropy consistency
    3. Coordination efficiency stability
    
    Parameters:
    n_iterations: Number of coordination steps to simulate
    noise_level: Amount of random noise added each iteration to simulate changing inputs
    """
    # Initialize random agent states
    batch_size = 1  # Single batch for simplicity
    state_dim = mhc.state_dim
    
    # Create initial agent states with some structure
    agent_states = [
        torch.randn(batch_size, state_dim, device=device) * 0.5 + 1.0
        for _ in range(mhc.n_agents)
    ]
    
    # Initialize random confidence scores normalized to sum to 1
    confidences = torch.rand(batch_size, mhc.n_agents, device=device)
    confidences = F.softmax(confidences, dim=-1)
    
    # History tracking for analysis
    history = {
        'state_norms': [],  # Track state norms over iterations
        'attention_entropy': [],  # Track attention uniformity
        'efficiency': [],  # Track coordination efficiency
        'coordinated_state': []  # Store coordinated states for trajectory analysis
    }
    
    # Main simulation loop
    for iteration in range(n_iterations):
        # Add noise to simulate changing environment
        if iteration > 0:
            noise = torch.randn_like(agent_states[0]) * noise_level
            agent_states = [s + noise for s in agent_states]
        
        # Create attention matrix from current confidences
        attention = torch.einsum('bi,bj->bij', confidences, confidences)
        
        # Apply mHC coordination
        mixed_state = mhc.convex_state_mixing(agent_states, attention)
        
        # Extract metrics for analysis
        state_norm = torch.norm(mixed_state).item()
        history['state_norms'].append(state_norm)
        
        # Store mHC metrics if available
        if mhc.metrics['attention_entropy']:
            history['attention_entropy'].append(mhc.metrics['attention_entropy'][-1])
        if mhc.metrics['coordination_efficiency']:
            history['efficiency'].append(mhc.metrics['coordination_efficiency'][-1])
        
        # Store coordinated state for trajectory analysis
        history['coordinated_state'].append(mixed_state.detach().cpu().numpy())
    
    # Create stability analysis visualization
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    # Plot 1: State norm stability over time
    ax1 = axes[0, 0]
    ax1.plot(history['state_norms'], linewidth=2)
    ax1.axhline(y=mhc.signal_bound, color='red', linestyle='--', 
               label=f'Bound ({mhc.signal_bound})')  # Signal bound reference
    ax1.set_xlabel('Iteration')
    ax1.set_ylabel('State Norm')
    ax1.set_title('State Norm Stability Over Time')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Attention entropy evolution
    ax2 = axes[0, 1]
    if history['attention_entropy']:
        ax2.plot(history['attention_entropy'], linewidth=2, color='green')
        ax2.set_xlabel('Iteration')
        ax2.set_ylabel('Entropy')
        ax2.set_title('Attention Uniformity (Higher = More Fair)')
        ax2.grid(True, alpha=0.3)
    
    # Plot 3: Coordination efficiency over time
    ax3 = axes[1, 0]
    if history['efficiency']:
        ax3.plot(history['efficiency'], linewidth=2, color='purple')
        ax3.set_xlabel('Iteration')
        ax3.set_ylabel('Efficiency')
        ax3.set_title('Coordination Efficiency (Closer to 1 = Better)')
        ax3.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)  # Ideal efficiency
        ax3.grid(True, alpha=0.3)
    
    # Plot 4: State trajectory in 2D space
    ax4 = axes[1, 1]
    if history['coordinated_state']:
        states_array = np.array(history['coordinated_state'])
        if states_array.ndim == 3:  # Handle batch dimension
            states_array = states_array[:, 0, :]  # Take first batch
        
        # Plot 2D trajectory if state dimension >= 2
        if states_array.shape[1] >= 2:
            # Scatter plot colored by iteration
            ax4.scatter(states_array[:, 0], states_array[:, 1], 
                       c=range(len(states_array)), cmap='viridis', 
                       alpha=0.6, s=50)
            ax4.plot(states_array[:, 0], states_array[:, 1], 
                    alpha=0.3, color='gray')  # Connect points with line
            
            # Mark start and end points
            ax4.scatter(states_array[0, 0], states_array[0, 1], 
                       color='green', s=100, label='Start', marker='o')
            ax4.scatter(states_array[-1, 0], states_array[-1, 1], 
                       color='red', s=100, label='End', marker='s')
            
            ax4.set_xlabel('Dimension 1')
            ax4.set_ylabel('Dimension 2')
            ax4.set_title('Coordinated State Trajectory')
            ax4.legend()
            ax4.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Calculate quantitative stability statistics
    stats = {}
    if history['state_norms']:
        norms = np.array(history['state_norms'])
        stats['norm_mean'] = norms.mean()
        stats['norm_std'] = norms.std()
        stats['norm_violations'] = np.sum(norms > mhc.signal_bound * 1.1)  # Count violations (10% tolerance)
    
    if history['attention_entropy']:
        entropy = np.array(history['attention_entropy'])
        stats['entropy_mean'] = entropy.mean()
        stats['entropy_std'] = entropy.std()
    
    if history['efficiency']:
        efficiency = np.array(history['efficiency'])
        stats['efficiency_mean'] = efficiency.mean()
        stats['efficiency_std'] = efficiency.std()
    
    return fig, stats

# Cell 5: Comparative Analysis Functions
def compare_coordination_strategies(n_agents: int = 5, 
                                  state_dim: int = 64,
                                  n_trials: int = 50):
    """
    Compare mHC against naive coordination strategies.
    
    Evaluates four coordination approaches:
    1. mHC: Our proposed manifold-constrained method
    2. Simple averaging: Equal weighting of all agents
    3. Weighted averaging: Weight by confidence scores
    4. Max confidence: Follow most confident agent
    
    Metrics compared:
    - State norm stability (lower variation = better)
    - Computation time (lower = better)
    - Fairness (higher entropy = more equal contributions)
    """
    # Initialize mHC for comparison
    mhc = ManifoldConstrainedHyperConnections(n_agents, state_dim)
    
    # Results storage for each strategy
    results = {
        'mhc': {'norms': [], 'times': [], 'fairness': []},
        'simple_avg': {'norms': [], 'times': [], 'fairness': []},
        'weighted_avg': {'norms': [], 'times': [], 'fairness': []},
        'max_conf': {'norms': [], 'times': [], 'fairness': []}
    }
    
    # Run multiple trials for statistical significance
    for trial in tqdm(range(n_trials), desc="Running trials"):
        # Generate random agent states and confidences for this trial
        agent_states = [
            torch.randn(1, state_dim, device=device) * 2.0 - 1.0  # Uniform distribution in [-1, 1]
            for _ in range(n_agents)
        ]
        
        confidences = torch.rand(1, n_agents, device=device)
        confidences = F.softmax(confidences, dim=-1)  # Normalize to probability distribution
        
        # Create attention matrix from confidences
        attention = torch.einsum('bi,bj->bij', confidences, confidences)
        
        # Strategy 1: mHC coordination
        start_time = time.time()
        mhc_state = mhc.convex_state_mixing(agent_states, attention)
        mhc_time = time.time() - start_time
        
        mhc_norm = torch.norm(mhc_state).item()
        # Fairness: use attention entropy from mHC metrics
        mhc_fairness = mhc.metrics['attention_entropy'][-1] if mhc.metrics['attention_entropy'] else 0.0
        
        results['mhc']['norms'].append(mhc_norm)
        results['mhc']['times'].append(mhc_time)
        results['mhc']['fairness'].append(mhc_fairness)
        
        # Strategy 2: Simple averaging (equal weights)
        start_time = time.time()
        simple_avg = torch.stack(agent_states).mean(dim=0)
        simple_time = time.time() - start_time
        
        simple_norm = torch.norm(simple_avg).item()
        simple_fairness = math.log(n_agents)  # Maximum entropy (perfect fairness)
        
        results['simple_avg']['norms'].append(simple_norm)
        results['simple_avg']['times'].append(simple_time)
        results['simple_avg']['fairness'].append(simple_fairness)
        
        # Strategy 3: Weighted averaging by confidence
        start_time = time.time()
        stacked_states = torch.stack(agent_states, dim=1)  # [1, n_agents, dim]
        weights = confidences.unsqueeze(-1)  # [1, n_agents, 1]
        weighted_avg = torch.sum(stacked_states * weights, dim=1)
        weighted_time = time.time() - start_time
        
        weighted_norm = torch.norm(weighted_avg).item()
        # Fairness: entropy of confidence distribution
        weighted_fairness = -torch.sum(confidences[0] * torch.log(confidences[0] + 1e-8)).item()
        
        results['weighted_avg']['norms'].append(weighted_norm)
        results['weighted_avg']['times'].append(weighted_time)
        results['weighted_avg']['fairness'].append(weighted_fairness)
        
        # Strategy 4: Max confidence (follow most confident agent)
        start_time = time.time()
        max_idx = torch.argmax(confidences, dim=-1).item()
        max_conf_state = agent_states[max_idx]
        max_conf_time = time.time() - start_time
        
        max_conf_norm = torch.norm(max_conf_state).item()
        max_conf_fairness = 0.0  # Zero fairness (only one agent contributes)
        
        results['max_conf']['norms'].append(max_conf_norm)
        results['max_conf']['times'].append(max_conf_time)
        results['max_conf']['fairness'].append(max_conf_fairness)
    
    # Create comparison visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    strategies = list(results.keys())
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']  # Colorblind-friendly palette
    
    # Plot 1: Norm distribution comparison (box plot)
    ax1 = axes[0, 0]
    norm_data = [results[s]['norms'] for s in strategies]
    
    bp = ax1.boxplot(norm_data, patch_artist=True)
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    ax1.set_xticklabels([s.replace('_', ' ').title() for s in strategies])
    ax1.set_ylabel('State Norm')
    ax1.set_title('State Norm Distribution (Lower Variation = More Stable)')
    ax1.grid(True, alpha=0.3, axis='y')
    ax1.axhline(y=mhc.signal_bound, color='red', linestyle='--', 
               label=f'mHC Bound ({mhc.signal_bound})')
    ax1.legend()
    
    # Plot 2: Computation time comparison
    ax2 = axes[0, 1]
    time_means = [np.mean(results[s]['times']) * 1000 for s in strategies]  # Convert to ms
    time_stds = [np.std(results[s]['times']) * 1000 for s in strategies]
    
    x_pos = np.arange(len(strategies))
    bars = ax2.bar(x_pos, time_means, yerr=time_stds, 
                  color=colors, alpha=0.7, capsize=5)
    
    ax2.set_xticks(x_pos)
    ax2.set_xticklabels([s.replace('_', ' ').title() for s in strategies])
    ax2.set_ylabel('Time (ms)')
    ax2.set_title('Computation Time Comparison (Lower = Faster)')
    ax2.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for bar, mean in zip(bars, time_means):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                f'{mean:.3f}ms', ha='center', va='bottom')
    
    # Plot 3: Fairness comparison (violin plot)
    ax3 = axes[1, 0]
    fairness_data = [results[s]['fairness'] for s in strategies]
    
    vp = ax3.violinplot(fairness_data, showmeans=True, showmedians=True)
    
    # Customize violin plot colors
    for pc, color in zip(vp['bodies'], colors):
        pc.set_facecolor(color)
        pc.set_alpha(0.7)
    
    vp['cmeans'].set_color('black')  # Mean indicator color
    vp['cmedians'].set_color('red')  # Median indicator color
    
    ax3.set_xticks(range(1, len(strategies) + 1))
    ax3.set_xticklabels([s.replace('_', ' ').title() for s in strategies])
    ax3.set_ylabel('Fairness (Higher = More Equal)')
    ax3.set_title('Agent Contribution Fairness Comparison')
    ax3.grid(True, alpha=0.3, axis='y')
    
    # Plot 4: Norm vs Fairness trade-off scatter plot
    ax4 = axes[1, 1]
    for i, strategy in enumerate(strategies):
        norms = results[strategy]['norms']
        fairness = results[strategy]['fairness']
        ax4.scatter(norms, fairness, color=colors[i], alpha=0.6,
                   label=strategy.replace('_', ' ').title(), s=50)
    
    ax4.set_xlabel('State Norm')
    ax4.set_ylabel('Fairness')
    ax4.set_title('Norm vs Fairness Trade-off Analysis')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # Add ideal region markers
    ideal_norm = mhc.signal_bound
    ideal_fairness = math.log(n_agents)
    ax4.axvline(x=ideal_norm, color='green', linestyle='--', alpha=0.5, label='Ideal Norm')
    ax4.axhline(y=ideal_fairness, color='blue', linestyle='--', alpha=0.5, label='Ideal Fairness')
    ax4.scatter([ideal_norm], [ideal_fairness], color='black', s=100, 
               marker='*', label='Ideal Point')
    
    plt.tight_layout()
    
    # Calculate summary statistics for quantitative comparison
    summary_stats = {}
    for strategy in strategies:
        summary_stats[strategy] = {
            'norm_mean': np.mean(results[strategy]['norms']),
            'norm_std': np.std(results[strategy]['norms']),
            'norm_bound_violation': np.mean(
                np.array(results[strategy]['norms']) > mhc.signal_bound * 1.1
            ),
            'time_mean_ms': np.mean(results[strategy]['times']) * 1000,
            'fairness_mean': np.mean(results[strategy]['fairness']),
            'fairness_std': np.std(results[strategy]['fairness'])
        }
    
    return fig, summary_stats

# Cell 6: Security-Specific Experiments
def security_threat_coordination_experiment():
    """
    Simulate real-world security threat coordination scenarios.
    
    Tests mHC with specialized security agents having:
    - Different expertise areas (XSS, SQLi, CSRF, etc.)
    - Varying confidence levels based on threat type
    - Potential conflicting threat assessments
    
    Scenarios include:
    1. Specialized attacks (XSS, SQLi)
    2. False positives
    3. Mixed threats
    4. Conflicting assessments
    """
    # Define security agent types with their expertise
    agent_types = [
        {'name': 'XSS_Detector', 'expertise': 'xss', 'base_confidence': 0.9},
        {'name': 'SQLi_Detector', 'expertise': 'sqli', 'base_confidence': 0.8},
        {'name': 'CSRF_Detector', 'expertise': 'csrf', 'base_confidence': 0.7},
        {'name': 'Behavior_Analyzer', 'expertise': 'behavior', 'base_confidence': 0.6},
        {'name': 'Payload_Scanner', 'expertise': 'malware', 'base_confidence': 0.85}
    ]
    
    n_agents = len(agent_types)
    state_dim = 128  # Reasonable state dimension for security features
    
    # Initialize mHC for coordination
    mhc = ManifoldConstrainedHyperConnections(n_agents, state_dim)
    
    # Define test scenarios with different threat characteristics
    scenarios = [
        {
            'name': 'XSS Attack',
            'threat_type': 'xss',
            'agent_detections': [0.95, 0.3, 0.2, 0.4, 0.1],  # XSS expert highly confident
            'threat_level': 0.9
        },
        {
            'name': 'SQL Injection',
            'threat_type': 'sqli',
            'agent_detections': [0.2, 0.92, 0.1, 0.3, 0.05],  # SQLi expert highly confident
            'threat_level': 0.85
        },
        {
            'name': 'False Positive',
            'threat_type': 'benign',
            'agent_detections': [0.1, 0.15, 0.08, 0.05, 0.12],  # All agents uncertain
            'threat_level': 0.1
        },
        {
            'name': 'Mixed Threat',
            'threat_type': 'mixed',
            'agent_detections': [0.7, 0.8, 0.6, 0.9, 0.75],  # All agents somewhat confident
            'threat_level': 0.75
        },
        {
            'name': 'Conflicting Assessment',
            'threat_type': 'conflict',
            'agent_detections': [0.9, 0.1, 0.85, 0.2, 0.15],  # Strong disagreement between agents
            'threat_level': 0.5
        }
    ]
    
    results = []  # Store results for each scenario
    
    # Run each scenario
    for scenario in scenarios:
        print(f"\n{'='*60}")
        print(f"Scenario: {scenario['name']}")
        print(f"Threat Type: {scenario['threat_type']}")
        print(f"{'='*60}")
        
        # Create agent states based on expertise and detection confidence
        agent_states = []
        agent_outputs = []
        
        for i, agent in enumerate(agent_types):
            # Create base state with expertise encoding
            base_state = torch.zeros(1, state_dim, device=device)
            
            # Encode expertise in specific dimensions (each agent gets 10 dimensions)
            expertise_idx = i * 10
            base_state[0, expertise_idx:expertise_idx+10] = 1.0
            
            # Add noise based on detection confidence (higher confidence = less noise)
            detection_conf = scenario['agent_detections'][i]
            noise = torch.randn_like(base_state) * (1 - detection_conf) * 0.5
            agent_state = base_state + noise
            
            # Scale state by detection confidence
            agent_state = agent_state * detection_conf
            agent_states.append(agent_state)
            
            # Create agent output dictionary
            agent_outputs.append({
                'agent_id': agent['name'],
                'expertise': agent['expertise'],
                'decision': {
                    'threat_level': torch.tensor([[detection_conf]], device=device),
                    'confidence': torch.tensor([[detection_conf * agent['base_confidence']]], 
                                              device=device),
                    'evidence': [
                        f"{agent['name']} detected {scenario['threat_type']} with confidence {detection_conf:.2f}"
                    ]
                },
                'reasoning_state': agent_state
            })
        
        # Create confidence tensor and normalize
        confidences = torch.tensor([scenario['agent_detections']], device=device)
        confidences = F.softmax(confidences, dim=-1)
        
        # Perform mHC coordination
        coordinated_result = mhc.residual_coordination(agent_outputs, confidences)
        
        # Extract and analyze results
        final_threat = coordinated_result['final_decision']['threat_level'].item()
        final_confidence = coordinated_result['final_decision']['confidence'].item()
        agent_contributions = coordinated_result['agent_contributions'][0]
        
        # Calculate expert alignment for specialized threats
        if scenario['threat_type'] in ['xss', 'sqli']:
            expert_idx = 0 if scenario['threat_type'] == 'xss' else 1
            expert_weight = agent_contributions[expert_idx]
            alignment = expert_weight / max(agent_contributions)
        else:
            alignment = 1.0  # Not applicable for non-specialized threats
        
        # Calculate accuracy compared to ground truth
        accuracy = 1.0 - abs(final_threat - scenario['threat_level'])
        
        # Store scenario results
        results.append({
            'scenario': scenario['name'],
            'threat_type': scenario['threat_type'],
            'ground_truth': scenario['threat_level'],
            'final_threat': final_threat,
            'final_confidence': final_confidence,
            'accuracy': accuracy,
            'expert_alignment': alignment,
            'agent_contributions': agent_contributions
        })
        
        # Print detailed results for this scenario
        print(f"Ground Truth Threat Level: {scenario['threat_level']:.2f}")
        print(f"mHC Coordinated Threat: {final_threat:.2f}")
        print(f"mHC Confidence: {final_confidence:.2f}")
        print(f"Decision Accuracy: {accuracy:.2%}")
        
        if scenario['threat_type'] in ['xss', 'sqli']:
            print(f"Expert Alignment: {alignment:.2%}")
        
        print("\nAgent Contributions:")
        for i, agent in enumerate(agent_types):
            print(f"  {agent['name']}: {agent_contributions[i]:.3f}")
    
    # Convert results to DataFrame for easier analysis
    results_df = pd.DataFrame(results)
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot 1: Threat level comparison
    ax1 = axes[0, 0]
    x_pos = np.arange(len(results_df))
    width = 0.35  # Bar width
    
    bars1 = ax1.bar(x_pos - width/2, results_df['ground_truth'], 
                   width, label='Ground Truth', alpha=0.7)
    bars2 = ax1.bar(x_pos + width/2, results_df['final_threat'], 
                   width, label='mHC Coordinated', alpha=0.7)
    
    ax1.set_xlabel('Scenario')
    ax1.set_ylabel('Threat Level')
    ax1.set_title('Threat Level: Ground Truth vs mHC Coordination')
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels(results_df['scenario'], rotation=45, ha='right')
    ax1.legend()
    ax1.grid(True, alpha=0.3, axis='y')
    
    # Add error indicators between bars
    for i, (gt, mhc_val) in enumerate(zip(results_df['ground_truth'], results_df['final_threat'])):
        diff = abs(gt - mhc_val)
        ax1.plot([i - width/2, i + width/2], [gt, mhc_val], 
                'k-', alpha=0.5, linewidth=1)  # Connecting line
        ax1.text(i, max(gt, mhc_val) + 0.05, f'{diff:.3f}', 
                ha='center', va='bottom', fontsize=8)  # Difference value
    
    # Plot 2: Accuracy by scenario
    ax2 = axes[0, 1]
    colors = plt.cm.Set1(np.linspace(0, 1, len(results_df)))
    
    bars = ax2.bar(range(len(results_df)), results_df['accuracy'], color=colors)
    ax2.set_xlabel('Scenario')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Decision Accuracy by Scenario')
    ax2.set_xticks(range(len(results_df)))
    ax2.set_xticklabels(results_df['scenario'], rotation=45, ha='right')
    ax2.set_ylim([0, 1.1])
    ax2.axhline(y=1.0, color='green', linestyle='--', alpha=0.5, label='Perfect Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3, axis='y')
    
    # Add accuracy values on bars
    for bar, acc in zip(bars, results_df['accuracy']):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height,
                f'{acc:.2%}', ha='center', va='bottom')
    
    # Plot 3: Agent contribution heatmap
    ax3 = axes[1, 0]
    contrib_matrix = np.array([r['agent_contributions'] for r in results])
    
    im = ax3.imshow(contrib_matrix.T, cmap='YlOrRd', aspect='auto', 
                   interpolation='nearest')
    plt.colorbar(im, ax=ax3, label='Contribution Weight')
    
    ax3.set_xlabel('Scenario')
    ax3.set_ylabel('Agent')
    ax3.set_title('Agent Contribution Patterns Across Scenarios')
    ax3.set_xticks(range(len(results_df)))
    ax3.set_xticklabels(results_df['scenario'], rotation=45, ha='right')
    ax3.set_yticks(range(len(agent_types)))
    ax3.set_yticklabels([a['name'] for a in agent_types])
    
    # Add contribution values to heatmap
    for i in range(contrib_matrix.shape[0]):
        for j in range(contrib_matrix.shape[1]):
            ax3.text(i, j, f'{contrib_matrix[i, j]:.2f}', 
                    ha='center', va='center', 
                    color='black' if contrib_matrix[i, j] > 0.3 else 'white',  # Dynamic text color
                    fontsize=8)
    
    # Plot 4: Expert alignment analysis
    ax4 = axes[1, 1]
    # Filter scenarios with specialized threats
    expert_scenarios = results_df[results_df['threat_type'].isin(['xss', 'sqli'])]
    
    if not expert_scenarios.empty:
        x_pos_exp = np.arange(len(expert_scenarios))
        
        # Extract expert weights and maximum weights
        expert_weights = []
        max_weights = []
        
        for _, row in expert_scenarios.iterrows():
            scenario_idx = results_df[results_df['scenario'] == row['scenario']].index[0]
            contributions = results[scenario_idx]['agent_contributions']
            
            # Identify expert based on threat type
            if row['threat_type'] == 'xss':
                expert_idx = 0  # XSS detector
            else:  # sqli
                expert_idx = 1  # SQLi detector
            
            expert_weights.append(contributions[expert_idx])
            max_weights.append(max(contributions))
        
        # Plot expert weight vs max weight
        ax4.bar(x_pos_exp - 0.2, expert_weights, 0.4, 
               label='Expert Weight', alpha=0.7)
        ax4.bar(x_pos_exp + 0.2, max_weights, 0.4, 
               label='Max Weight', alpha=0.7)
        
        ax4.set_xlabel('Scenario')
        ax4.set_ylabel('Contribution Weight')
        ax4.set_title('Expert vs Max Contribution (Specialized Threats)')
        ax4.set_xticks(x_pos_exp)
        ax4.set_xticklabels(expert_scenarios['scenario'])
        ax4.legend()
        ax4.grid(True, alpha=0.3, axis='y')
        
        # Add ratio indicators
        for i, (exp, max_w) in enumerate(zip(expert_weights, max_weights)):
            ratio = exp / max_w if max_w > 0 else 0
            ax4.text(i, max(exp, max_w) + 0.05, f'{ratio:.2f}', 
                    ha='center', va='bottom', fontsize=9)
    else:
        ax4.text(0.5, 0.5, 'No specialized threat scenarios\nin this experiment',
                ha='center', va='center', transform=ax4.transAxes)
        ax4.set_title('Expert Alignment Analysis')
    
    plt.tight_layout()
    
    # Calculate overall experiment statistics
    overall_stats = {
        'mean_accuracy': results_df['accuracy'].mean(),
        'std_accuracy': results_df['accuracy'].std(),
        'mean_threat_error': (results_df['ground_truth'] - results_df['final_threat']).abs().mean(),
        'scenarios_with_high_accuracy': (results_df['accuracy'] > 0.9).sum(),
        'total_scenarios': len(results_df)
    }
    
    # Print summary statistics
    print(f"\n{'='*60}")
    print("OVERALL EXPERIMENT STATISTICS")
    print(f"{'='*60}")
    print(f"Mean Accuracy: {overall_stats['mean_accuracy']:.2%}")
    print(f"Accuracy Std Dev: {overall_stats['std_accuracy']:.3f}")
    print(f"Mean Threat Level Error: {overall_stats['mean_threat_error']:.3f}")
    print(f"Scenarios with >90% Accuracy: {overall_stats['scenarios_with_high_accuracy']}/{overall_stats['total_scenarios']}")
    
    return fig, results_df, overall_stats

# Cell 7: Parameter Tuning Experiment
def parameter_tuning_experiment():
    """
    Systematic exploration of mHC parameter space.
    
    Tests four key parameters:
    1. Identity preservation factor (λ): Controls individuality vs mixing
    2. Signal bound (β): Maximum allowed state norm
    3. Temperature (τ): Attention distribution sharpness
    4. Sinkhorn iterations: Balance of accuracy vs computation time
    
    Uses one-at-a-time experimental design to understand each parameter's effect.
    """
    # Define parameter ranges to test
    param_ranges = {
        'identity_factor': [0.0, 0.1, 0.2, 0.3, 0.4, 0.5],  # λ values
        'signal_bound': [0.5, 0.8, 1.0, 1.2, 1.5, 2.0],  # β values
        'temperature': [0.1, 0.5, 1.0, 2.0, 5.0, 10.0],  # τ values
        'sinkhorn_iterations': [10, 20, 50, 100, 200]  # Iteration counts
    }
    
    n_agents = 5
    state_dim = 64
    n_trials = 20  # Trials per parameter value for statistical significance
    
    results = []  # Store results for all parameter combinations
    
    # Parameter 1: Identity preservation factor experiment
    print("Running parameter tuning experiments...")
    print("\n1. Testing identity preservation factor...")
    for identity_factor in tqdm(param_ranges['identity_factor']):
        trial_results = []
        for trial in range(n_trials):
            # Create test scenario
            agent_states = [
                torch.randn(1, state_dim, device=device) * 1.5
                for _ in range(n_agents)
            ]
            confidences = torch.rand(1, n_agents, device=device)
            confidences = F.softmax(confidences, dim=-1)
            
            # Create mHC with current parameter
            mhc = ManifoldConstrainedHyperConnections(n_agents, state_dim)
            mhc.identity_preserve_factor = identity_factor
            
            # Test coordination
            attention = torch.einsum('bi,bj->bij', confidences, confidences)
            mixed_state = mhc.convex_state_mixing(agent_states, attention)
            
            # Calculate metrics
            mixed_norm = torch.norm(mixed_state).item()
            
            # Individuality preservation: cosine similarity with original states
            individual_states = torch.stack(agent_states).squeeze(1)
            similarities = F.cosine_similarity(mixed_state, individual_states, dim=-1)
            individuality = similarities.mean().item()
            
            # Stability: whether norm respects bound
            stability = 1.0 if mixed_norm <= mhc.signal_bound else 0.0
            
            trial_results.append({
                'identity_factor': identity_factor,
                'mixed_norm': mixed_norm,
                'individuality': individuality,
                'stability': stability
            })
        
        # Aggregate trial results
        avg_results = {
            'parameter': 'identity_factor',
            'value': identity_factor,
            'avg_norm': np.mean([r['mixed_norm'] for r in trial_results]),
            'avg_individuality': np.mean([r['individuality'] for r in trial_results]),
            'stability_rate': np.mean([r['stability'] for r in trial_results])
        }
        results.append(avg_results)
    
    # Parameter 2: Signal bound experiment
    print("\n2. Testing signal bound...")
    for signal_bound in tqdm(param_ranges['signal_bound']):
        trial_results = []
        for trial in range(n_trials):
            agent_states = [
                torch.randn(1, state_dim, device=device) * 2.0
                for _ in range(n_agents)
            ]
            confidences = torch.rand(1, n_agents, device=device)
            confidences = F.softmax(confidences, dim=-1)
            
            mhc = ManifoldConstrainedHyperConnections(n_agents, state_dim)
            mhc.signal_bound = signal_bound
            
            attention = torch.einsum('bi,bj->bij', confidences, confidences)
            mixed_state = mhc.convex_state_mixing(agent_states, attention)
            
            mixed_norm = torch.norm(mixed_state).item()
            
            # Efficiency: how close to bound (optimal is near but not exceeding bound)
            efficiency = mixed_norm / signal_bound if signal_bound > 0 else 0
            utilization = min(1.0, mixed_norm / signal_bound) if signal_bound > 0 else 0
            
            trial_results.append({
                'signal_bound': signal_bound,
                'mixed_norm': mixed_norm,
                'efficiency': efficiency,
                'utilization': utilization
            })
        
        avg_results = {
            'parameter': 'signal_bound',
            'value': signal_bound,
            'avg_norm': np.mean([r['mixed_norm'] for r in trial_results]),
            'avg_efficiency': np.mean([r['efficiency'] for r in trial_results]),
            'avg_utilization': np.mean([r['utilization'] for r in trial_results])
        }
        results.append(avg_results)
    
    # Parameter 3: Temperature experiment
    print("\n3. Testing temperature...")
    for temperature in tqdm(param_ranges['temperature']):
        trial_results = []
        for trial in range(n_trials):
            agent_states = [
                torch.randn(1, state_dim, device=device)
                for _ in range(n_agents)
            ]
            confidences = torch.rand(1, n_agents, device=device)
            confidences = F.softmax(confidences, dim=-1)
            
            mhc = ManifoldConstrainedHyperConnections(n_agents, state_dim)
            mhc.temperature = temperature
            
            attention = torch.einsum('bi,bj->bij', confidences, confidences)
            mixed_state = mhc.convex_state_mixing(agent_states, attention)
            
            # Attention entropy (uniformity)
            attention_entropy = mhc.metrics['attention_entropy'][-1] if mhc.metrics['attention_entropy'] else 0.0
            # State variance (decision sharpness)
            mixed_var = mixed_state.var().item()
            
            trial_results.append({
                'temperature': temperature,
                'attention_entropy': attention_entropy,
                'mixed_variance': mixed_var
            })
        
        avg_results = {
            'parameter': 'temperature',
            'value': temperature,
            'avg_entropy': np.mean([r['attention_entropy'] for r in trial_results]),
            'avg_variance': np.mean([r['mixed_variance'] for r in trial_results])
        }
        results.append(avg_results)
    
    # Parameter 4: Sinkhorn iterations experiment
    print("\n4. Testing Sinkhorn iterations...")
    for sinkhorn_iter in tqdm(param_ranges['sinkhorn_iterations']):
        trial_results = []
        for trial in range(n_trials):
            agent_states = [
                torch.randn(1, state_dim, device=device)
                for _ in range(n_agents)
            ]
            
            # Create non-doubly-stochastic attention matrix
            attention = torch.rand(1, n_agents, n_agents, device=device)
            attention = attention / attention.sum(dim=-1, keepdim=True)  # Row-stochastic only
            
            mhc = ManifoldConstrainedHyperConnections(n_agents, state_dim)
            mhc.sinkhorn_iterations = sinkhorn_iter
            
            # Measure computation time
            start_time = time.time()
            mixed_state = mhc.convex_state_mixing(agent_states, attention)
            computation_time = time.time() - start_time
            
            # Check convergence error
            convergence_error = 0.0
            if hasattr(mhc, 'last_normalized_attention'):
                norm_att = mhc.last_normalized_attention
                row_sums = norm_att.sum(dim=-1)
                col_sums = norm_att.sum(dim=-2)
                row_error = torch.abs(row_sums - 1.0).mean().item()
                col_error = torch.abs(col_sums - 1.0).mean().item()
                convergence_error = (row_error + col_error) / 2
            
            trial_results.append({
                'sinkhorn_iterations': sinkhorn_iter,
                'computation_time': computation_time,
                'convergence_error': convergence_error
            })
        
        avg_results = {
            'parameter': 'sinkhorn_iterations',
            'value': sinkhorn_iter,
            'avg_time': np.mean([r['computation_time'] for r in trial_results]),
            'avg_error': np.mean([r['convergence_error'] for r in trial_results])
        }
        results.append(avg_results)
    
    # Convert results to DataFrame
    results_df = pd.DataFrame(results)
    
    # Create parameter tuning visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # Plot 1: Identity factor analysis
    ax1 = axes[0, 0]
    idf_data = results_df[results_df['parameter'] == 'identity_factor']
    ax1.plot(idf_data['value'], idf_data['avg_individuality'], 
            'o-', linewidth=2, label='Individuality', markersize=8)
    ax1.plot(idf_data['value'], idf_data['stability_rate'], 
            's-', linewidth=2, label='Stability Rate', markersize=8)
    ax1.set_xlabel('Identity Preservation Factor (λ)')
    ax1.set_ylabel('Metric Value')
    ax1.set_title('Identity Factor: Individuality vs Stability Trade-off')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Find optimal identity factor (balance point)
    combined_score = (np.array(idf_data['avg_individuality']) + 
                     np.array(idf_data['stability_rate'])) / 2
    optimal_idx = np.argmax(combined_score)
    optimal_value = idf_data.iloc[optimal_idx]['value']
    ax1.axvline(x=optimal_value, color='red', linestyle='--', alpha=0.7,
               label=f'Optimal λ = {optimal_value:.2f}')
    ax1.legend()
    
    # Plot 2: Signal bound analysis
    ax2 = axes[0, 1]
    bound_data = results_df[results_df['parameter'] == 'signal_bound']
    ax2_norm = ax2  # Left y-axis for norm
    ax2_eff = ax2.twinx()  # Right y-axis for efficiency
    
    line1 = ax2_norm.plot(bound_data['value'], bound_data['avg_norm'], 
                         'bo-', linewidth=2, label='Average Norm', markersize=8)
    ax2_norm.set_xlabel('Signal Bound (β)')
    ax2_norm.set_ylabel('Average State Norm', color='blue')
    ax2_norm.tick_params(axis='y', labelcolor='blue')
    
    line2 = ax2_eff.plot(bound_data['value'], bound_data['avg_efficiency'], 
                        'rs-', linewidth=2, label='Efficiency', markersize=8)
    ax2_eff.set_ylabel('Efficiency (Norm/Bound)', color='red')
    ax2_eff.tick_params(axis='y', labelcolor='red')
    
    # Add utilization as shaded area
    ax2_norm.fill_between(bound_data['value'], 0, bound_data['avg_utilization'],
                         alpha=0.2, color='green', label='Bound Utilization')
    
    ax2_norm.set_title('Signal Bound: Norm vs Efficiency Trade-off')
    lines = line1 + line2
    labels = [l.get_label() for l in lines]
    ax2_norm.legend(lines, labels, loc='upper left')
    ax2_norm.grid(True, alpha=0.3)
    
    # Plot 3: Temperature analysis
    ax3 = axes[1, 0]
    temp_data = results_df[results_df['parameter'] == 'temperature']
    ax3_ent = ax3  # Left y-axis for entropy
    ax3_var = ax3.twinx()  # Right y-axis for variance
    
    line1 = ax3_ent.plot(temp_data['value'], temp_data['avg_entropy'], 
                        'go-', linewidth=2, label='Attention Entropy', markersize=8)
    ax3_ent.set_xlabel('Temperature (τ)')
    ax3_ent.set_ylabel('Attention Entropy', color='green')
    ax3_ent.tick_params(axis='y', labelcolor='green')
    
    line2 = ax3_var.plot(temp_data['value'], temp_data['avg_variance'], 
                        'md-', linewidth=2, label='State Variance', markersize=8)
    ax3_var.set_ylabel('State Variance', color='magenta')
    ax3_var.tick_params(axis='y', labelcolor='magenta')
    
    ax3_ent.set_title('Temperature: Entropy vs Variance Trade-off')
    
    # Find temperature transition point
    entropy_diff = np.diff(temp_data['avg_entropy'])
    variance_diff = np.diff(temp_data['avg_variance'])
    entropy_change = np.abs(entropy_diff)
    variance_change = np.abs(variance_diff)
    combined_change = entropy_change + variance_change
    if len(combined_change) > 0:
        max_change_idx = np.argmax(combined_change)
        if max_change_idx < len(temp_data['value']) - 1:
            optimal_temp = temp_data.iloc[max_change_idx]['value']
            ax3_ent.axvline(x=optimal_temp, color='red', linestyle='--', alpha=0.7,
                           label=f'Transition τ = {optimal_temp:.1f}')
    
    lines = line1 + line2
    labels = [l.get_label() for l in lines]
    ax3_ent.legend(lines, labels, loc='upper left')
    ax3_ent.grid(True, alpha=0.3)
    
    # Plot 4: Sinkhorn iterations analysis
    ax4 = axes[1, 1]
    sinkhorn_data = results_df[results_df['parameter'] == 'sinkhorn_iterations']
    ax4_time = ax4  # Left y-axis for time
    ax4_error = ax4.twinx()  # Right y-axis for error
    
    line1 = ax4_time.plot(sinkhorn_data['value'], sinkhorn_data['avg_time'] * 1000, 
                         'co-', linewidth=2, label='Computation Time', markersize=8)
    ax4_time.set_xlabel('Sinkhorn Iterations')
    ax4_time.set_ylabel('Time (ms)', color='cyan')
    ax4_time.tick_params(axis='y', labelcolor='cyan')
    ax4_time.set_xscale('log')  # Log scale for iterations
    
    line2 = ax4_error.plot(sinkhorn_data['value'], sinkhorn_data['avg_error'], 
                          'yo-', linewidth=2, label='Convergence Error', markersize=8)
    ax4_error.set_ylabel('Convergence Error', color='orange')
    ax4_error.tick_params(axis='y', labelcolor='orange')
    
    ax4_time.set_title('Sinkhorn Iterations: Time vs Accuracy Trade-off')
    
    # Find optimal iterations (knee point)
    errors = np.array(sinkhorn_data['avg_error'])
    times = np.array(sinkhorn_data['avg_time'] * 1000)
    errors_norm = (errors - errors.min()) / (errors.max() - errors.min() + 1e-8)
    times_norm = (times - times.min()) / (times.max() - times.min() + 1e-8)
    distances = np.sqrt(errors_norm**2 + times_norm**2)  # Distance to ideal (0,0)
    optimal_idx = np.argmin(distances)
    optimal_iter = sinkhorn_data.iloc[optimal_idx]['value']
    ax4_time.axvline(x=optimal_iter, color='red', linestyle='--', alpha=0.7,
                    label=f'Optimal = {optimal_iter} iters')
    
    lines = line1 + line2
    labels = [l.get_label() for l in lines]
    ax4_time.legend(lines, labels, loc='upper right')
    ax4_time.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Provide parameter recommendations
    print(f"\n{'='*60}")
    print("PARAMETER TUNING RECOMMENDATIONS")
    print(f"{'='*60}")
    
    print(f"Identity Preservation Factor (λ):")
    print(f"  Recommended: {optimal_value:.2f}")
    print(f"  Reasoning: Balances individuality ({idf_data.iloc[optimal_idx]['avg_individuality']:.3f}) "
          f"with stability ({idf_data.iloc[optimal_idx]['stability_rate']:.3f})")
    
    # Signal bound recommendation
    efficiency_threshold = 0.8
    viable_bounds = bound_data[bound_data['avg_efficiency'] > efficiency_threshold]
    if not viable_bounds.empty:
        rec_bound = viable_bounds.iloc[0]['value']
        print(f"\nSignal Bound (β):")
        print(f"  Recommended: {rec_bound:.2f}")
        print(f"  Reasoning: Provides {viable_bounds.iloc[0]['avg_efficiency']:.3f} efficiency "
              f"with {viable_bounds.iloc[0]['avg_norm']:.3f} average norm")
    
    # Temperature recommendation
    print(f"\nTemperature (τ):")
    if 'optimal_temp' in locals():
        print(f"  Recommended: {optimal_temp:.1f}")
    else:
        print(f"  Recommended: 1.0")
    print(f"  Reasoning: Balances attention uniformity with decision diversity")
    
    # Sinkhorn iterations recommendation
    print(f"\nSinkhorn Iterations:")
    print(f"  Recommended: {optimal_iter}")
    print(f"  Reasoning: Achieves {sinkhorn_data.iloc[optimal_idx]['avg_error']:.6f} error "
          f"in {sinkhorn_data.iloc[optimal_idx]['avg_time']*1000:.3f} ms")
    
    return fig, results_df

# Cell 8: Adaptive mHC Variant
class AdaptiveMHC(ManifoldConstrainedHyperConnections):
    """
    Adaptive extension of mHC that learns optimal parameters during coordination.
    
    Key adaptive capabilities:
    1. Learns identity preservation factor based on agent diversity
    2. Adapts signal bound based on estimated threat severity
    3. Adjusts temperature based on agent agreement levels
    4. Maintains history for online learning and adaptation
    
    This enables mHC to automatically adjust to changing threat environments
    and coordination requirements without manual parameter tuning.
    """
    
    def __init__(self, n_agents: int, state_dim: int, learning_rate: float = 0.01):
        """
        Initialize adaptive mHC with learnable parameters.
        
        Parameters:
        n_agents: Number of coordinating agents
        state_dim: State vector dimensionality
        learning_rate: Gradient descent learning rate for parameter adaptation
        """
        super().__init__(n_agents, state_dim)
        
        # Define learnable parameters with PyTorch Parameter wrapper
        self.identity_factor = nn.Parameter(torch.tensor(0.1))  # Learnable λ
        self.signal_bound_param = nn.Parameter(torch.tensor(1.0))  # Learnable β
        self.temperature_param = nn.Parameter(torch.tensor(1.0))  # Learnable τ
        
        # Adaptive learning components
        self.learning_rate = learning_rate
        self.optimizer = torch.optim.Adam([self.identity_factor, 
                                          self.signal_bound_param, 
                                          self.temperature_param], 
                                         lr=learning_rate)  # Adam optimizer for adaptation
        
        # History tracking for adaptation analysis
        self.coordination_history = []
        self.max_history = 1000  # Maximum history length for memory management
    
    def compute_adaptation_metrics(self, agent_states: List[torch.Tensor], 
                                 attention: torch.Tensor) -> Dict:
        """
        Compute environmental metrics for parameter adaptation.
        
        Analyzes current coordination context to determine:
        1. Agent diversity (state variance)
        2. Agent agreement (attention consensus)
        3. Threat severity (state magnitudes)
        
        These metrics guide parameter adaptation decisions.
        """
        # Stack states for batch processing
        stacked_states = torch.stack(agent_states, dim=1)  # [B, N, D]
        
        # 1. Agent diversity: variance across agents
        state_variance = stacked_states.var(dim=1).mean().item()  # Higher = more diverse
        
        # 2. Agent agreement: attention distribution uniformity
        attention_flat = attention.flatten()
        attention_entropy = -torch.sum(
            attention_flat * torch.log(attention_flat + 1e-8)
        ).item() / (self.n_agents * self.n_agents)  # Normalized entropy
        
        # 3. Threat severity: estimated from state magnitudes
        state_magnitudes = torch.norm(stacked_states, dim=-1)  # [B, N]
        avg_magnitude = state_magnitudes.mean().item()  # Average threat level
        max_magnitude = state_magnitudes.max().item()  # Maximum threat level
        
        return {
            'state_variance': state_variance,
            'attention_entropy': attention_entropy,
            'avg_magnitude': avg_magnitude,
            'max_magnitude': max_magnitude
        }
    
    def adapt_parameters(self, metrics: Dict, mixed_state: torch.Tensor):
        """
        Adapt parameters based on current coordination performance.
        
        Adaptation rules:
        1. High diversity → higher identity preservation
        2. High threat → tighter signal bound
        3. Low agreement → higher temperature (smoother decisions)
        """
        # Rule 1: Adapt identity factor based on diversity
        # More diverse agents need more identity preservation
        target_identity = min(0.3, max(0.05, metrics['state_variance'] * 2))
        identity_loss = F.mse_loss(self.identity_factor, 
                                  torch.tensor(target_identity, device=self.identity_factor.device))
        
        # Rule 2: Adapt signal bound based on threat severity
        # Higher threat requires more conservative (tighter) bounds
        target_bound = min(2.0, max(0.5, 1.0 / (metrics['avg_magnitude'] + 0.5)))
        bound_loss = F.mse_loss(self.signal_bound_param, 
                               torch.tensor(target_bound, device=self.signal_bound_param.device))
        
        # Rule 3: Adapt temperature based on agreement
        # Low agreement (high entropy) needs higher temperature for smoother decisions
        target_temp = min(5.0, max(0.5, metrics['attention_entropy'] * 10))
        temp_loss = F.mse_loss(self.temperature_param, 
                              torch.tensor(target_temp, device=self.temperature_param.device))
        
        # Combine losses and update parameters
        total_loss = identity_loss + bound_loss + temp_loss
        self.optimizer.zero_grad()  # Clear previous gradients
        total_loss.backward()  # Compute gradients
        self.optimizer.step()  # Update parameters
        
        # Clamp parameters to valid ranges
        self.identity_factor.data.clamp_(0.0, 0.5)
        self.signal_bound_param.data.clamp_(0.1, 3.0)
        self.temperature_param.data.clamp_(0.1, 10.0)
        
        # Update base class parameters
        self.identity_preserve_factor = self.identity_factor.item()
        self.signal_bound = self.signal_bound_param.item()
        self.temperature = self.temperature_param.item()
        
        return {
            'total_loss': total_loss.item(),
            'identity_factor': self.identity_factor.item(),
            'signal_bound': self.signal_bound_param.item(),
            'temperature': self.temperature_param.item(),
            'identity_loss': identity_loss.item(),
            'bound_loss': bound_loss.item(),
            'temp_loss': temp_loss.item()
        }
    
    def convex_state_mixing(self, agent_states: List[torch.Tensor], 
                          attention_weights: torch.Tensor) -> torch.Tensor:
        """
        Override parent method to include parameter adaptation.
        
        Enhanced coordination flow:
        1. Update parameters from learnable values
        2. Compute adaptation metrics
        3. Perform standard mHC mixing
        4. Adapt parameters based on results
        5. Store history for analysis
        """
        # Sync parameters from learnable values
        self.identity_preserve_factor = self.identity_factor.item()
        self.signal_bound = self.signal_bound_param.item()
        self.temperature = self.temperature_param.item()
        
        # Compute adaptation metrics from current context
        metrics = self.compute_adaptation_metrics(agent_states, attention_weights)
        
        # Perform standard mHC mixing
        mixed_state = super().convex_state_mixing(agent_states, attention_weights)
        
        # Adapt parameters based on coordination results
        adaptation_results = self.adapt_parameters(metrics, mixed_state)
        
        # Store coordination history
        self.coordination_history.append({
            'metrics': metrics,
            'adaptation': adaptation_results,
            'mixed_state_norm': torch.norm(mixed_state).item()
        })
        
        # Manage history length
        if len(self.coordination_history) > self.max_history:
            self.coordination_history = self.coordination_history[-self.max_history:]
        
        return mixed_state

# Cell 9: GQA Integration
class SimpleGQA(nn.Module):
    """
    Simplified Grouped Query Attention implementation for testing mHC integration.
    
    GQA reduces computation by sharing keys and values across groups of attention heads,
    providing efficiency benefits while maintaining representational capacity.
    """
    
    def __init__(self, d_model: int, n_heads: int, n_groups: int = None):
        super().__init__()
        self.d_model = d_model  # Model dimension
        self.n_heads = n_heads  # Number of attention heads
        self.n_groups = n_groups if n_groups is not None else n_heads // 2
        
        # Validate configuration
        assert n_heads % self.n_groups == 0, "n_heads must be divisible by n_groups"
        
        self.head_dim = d_model // n_heads  # Dimension per head
        self.scale = self.head_dim ** -0.5  # Scaling factor for attention
        
        # Linear projections for queries, keys, and values
        self.q_proj = nn.Linear(d_model, d_model)  # Full dimension for queries
        self.k_proj = nn.Linear(d_model, d_model // (n_heads // self.n_groups))  # Reduced for keys
        self.v_proj = nn.Linear(d_model, d_model // (n_heads // self.n_groups))  # Reduced for values
        self.out_proj = nn.Linear(d_model, d_model)  # Output projection
        
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
        batch_size = query.shape[0]
        
        # Project inputs
        Q = self.q_proj(query)  # [B, seq_len, d_model]
        K = self.k_proj(key)    # [B, seq_len, d_model//group_ratio]
        V = self.v_proj(value)  # [B, seq_len, d_model//group_ratio]
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.n_groups, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_groups, self.head_dim).transpose(1, 2)
        
        # Repeat K and V for each head in group (key sharing)
        expand_ratio = self.n_heads // self.n_groups
        K = K.repeat_interleave(expand_ratio, dim=1)
        V = V.repeat_interleave(expand_ratio, dim=1)
        
        # Scaled dot-product attention
        attn_weights = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)
        
        # Reshape back to original format
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        
        # Final linear projection
        output = self.out_proj(attn_output)
        
        return output

class MHCGQAIntegration(nn.Module):
    """
    Integrated architecture combining mHC with Grouped Query Attention.
    
    This provides a complete solution for efficient multi-agent coordination:
    1. GQA for efficient intra-agent reasoning
    2. mHC for stable inter-agent coordination
    3. Adaptive parameter tuning for dynamic environments
    """
    
    def __init__(self, n_agents: int, d_model: int, n_heads: int, 
                 n_groups: int = None, use_mhc: bool = True):
        super().__init__()
        
        self.n_agents = n_agents
        self.d_model = d_model
        self.use_mhc = use_mhc
        
        # GQA for efficient intra-agent attention
        self.gqa_attention = SimpleGQA(d_model, n_heads, n_groups)
        
        # Adaptive mHC for inter-agent coordination
        if use_mhc:
            self.mhc_coordination = AdaptiveMHC(n_agents, d_model)
        
        # Agent feature encoder
        self.agent_encoder = nn.Sequential(
            nn.Linear(d_model, d_model * 2),  # Expand dimension
            nn.GELU(),  # GELU activation for smooth gradients
            nn.Linear(d_model * 2, d_model),  # Compress back
            nn.LayerNorm(d_model)  # Normalize
        )
        
        # Coordination state decoder
        self.coordination_decoder = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.GELU(),
            nn.Dropout(0.1),  # Regularization
            nn.Linear(d_model * 2, d_model),
            nn.LayerNorm(d_model)
        )
        
        # Output heads for security tasks
        self.threat_classifier = nn.Linear(d_model, 10)  # 10 threat types
        self.severity_regressor = nn.Linear(d_model, 1)  # Continuous severity
        self.confidence_estimator = nn.Linear(d_model, 1)  # Decision confidence
    
    def forward(self, agent_inputs: List[torch.Tensor], 
                agent_confidences: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Complete processing pipeline for security threat coordination.
        
        Steps:
        1. Encode each agent's input features
        2. Apply GQA self-attention for intra-agent reasoning
        3. Coordinate across agents using mHC
        4. Decode coordinated representation
        5. Generate security assessments
        """
        batch_size = agent_inputs[0].shape[0]
        
        # Step 1: Encode agent inputs
        encoded_agents = []
        for agent_input in agent_inputs:
            encoded = self.agent_encoder(agent_input)  # [B, D]
            encoded_agents.append(encoded)
        
        # Step 2: Intra-agent GQA attention
        attended_agents = []
        for encoded in encoded_agents:
            # Self-attention for feature refinement
            attended = self.gqa_attention(
                encoded.unsqueeze(1),  # Add sequence dimension
                encoded.unsqueeze(1),
                encoded.unsqueeze(1)
            ).squeeze(1)  # Remove sequence dimension
            attended_agents.append(attended)
        
        # Step 3: Inter-agent mHC coordination
        if self.use_mhc:
            attention = torch.einsum('bi,bj->bij', agent_confidences, agent_confidences)
            coordinated = self.mhc_coordination.convex_state_mixing(
                attended_agents, attention
            )
        else:
            # Fallback: simple averaging
            coordinated = torch.stack(attended_agents, dim=1).mean(dim=1)
        
        # Step 4: Decode coordinated state
        decoded = self.coordination_decoder(coordinated)
        
        # Step 5: Generate security assessments
        threat_logits = self.threat_classifier(decoded)  # Threat classification
        severity = torch.sigmoid(self.severity_regressor(decoded))  # Severity score [0,1]
        confidence = torch.sigmoid(self.confidence_estimator(decoded))  # Confidence score [0,1]
        
        # Extract mHC metrics for analysis
        mhc_metrics = {}
        if self.use_mhc and hasattr(self.mhc_coordination, 'coordination_history'):
            if self.mhc_coordination.coordination_history:
                latest = self.mhc_coordination.coordination_history[-1]
                mhc_metrics = {
                    'adaptation_loss': latest['adaptation']['total_loss'],
                    'identity_factor': latest['adaptation']['identity_factor'],
                    'signal_bound': latest['adaptation']['signal_bound'],
                    'temperature': latest['adaptation']['temperature']
                }
        
        return {
            'threat_logits': threat_logits,
            'severity': severity,
            'confidence': confidence,
            'coordinated_state': coordinated,
            'agent_states': attended_agents,
            'mhc_metrics': mhc_metrics
        }
    
    def train_step(self, agent_inputs: List[torch.Tensor], 
                  agent_confidences: torch.Tensor,
                  targets: Dict[str, torch.Tensor],
                  optimizer: torch.optim.Optimizer) -> Dict[str, float]:
        """
        Complete training step with multi-task loss calculation.
        
        Combines losses for:
        1. Threat classification accuracy
        2. Severity regression precision
        3. Confidence calibration
        4. mHC adaptation regularization
        """
        # Forward pass
        outputs = self(agent_inputs, agent_confidences)
        
        # Loss 1: Threat classification (cross-entropy)
        threat_loss = F.cross_entropy(
            outputs['threat_logits'], 
            targets['threat_labels'].long()
        )
        
        # Loss 2: Severity regression (MSE)
        severity_loss = F.mse_loss(
            outputs['severity'], 
            targets['severity_labels']
        )
        
        # Loss 3: Confidence calibration
        predicted_classes = torch.argmax(outputs['threat_logits'], dim=-1)
        correct = (predicted_classes == targets['threat_labels']).float()
        confidence_loss = F.mse_loss(
            outputs['confidence'].squeeze(),
            correct  # Confidence should match accuracy
        )
        
        # Loss 4: mHC adaptation regularization
        mhc_reg = 0.0
        if self.use_mhc and outputs['mhc_metrics']:
            # Regularize parameters to prevent extreme values
            identity_reg = torch.abs(outputs['mhc_metrics']['identity_factor'] - 0.2)
            bound_reg = torch.abs(outputs['mhc_metrics']['signal_bound'] - 1.0)
            temp_reg = torch.abs(outputs['mhc_metrics']['temperature'] - 1.0)
            mhc_reg = (identity_reg + bound_reg + temp_reg) / 3
        
        # Total loss with weighting
        total_loss = (threat_loss + severity_loss + 
                     confidence_loss * 0.5 + mhc_reg * 0.1)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        
        # Return comprehensive training metrics
        return {
            'total_loss': total_loss.item(),
            'threat_loss': threat_loss.item(),
            'severity_loss': severity_loss.item(),
            'confidence_loss': confidence_loss.item(),
            'mhc_reg': mhc_reg.item() if isinstance(mhc_reg, torch.Tensor) else mhc_reg,
            'threat_accuracy': (predicted_classes == targets['threat_labels']).float().mean().item(),
            'severity_mae': torch.abs(outputs['severity'] - targets['severity_labels']).mean().item(),
            'confidence_calibration': torch.abs(outputs['confidence'].squeeze() - correct).mean().item()
        }