In [None]:
# notebooks/agent_training.ipynb
<jupyter_output>
<empty_output>
<jupyter_text>
# üöÄ CyberGuard Agent Training Notebook## Complete Training Pipeline for Cybersecurity AI Agents**Author**: CyberGuard AI Research Team  **Version**: 1.0.0  **Last Updated**: 2024  **Purpose**: Train and optimize multi-agent cybersecurity AI system  **Target Model**: GQA Transformer with mHC coordination  **Dataset**: OWASP Web Security + Custom Threat Intelligence---## üìã Table of Contents1. [System Overview](1-system-overview)2. [Environment Setup](2-environment-setup)3. [Data Loading & Preprocessing](3-data-loading--preprocessing)4. [mHC Architecture Implementation](4-mhc-architecture-implementation)5. [GQA Transformer Implementation](5-gqa-transformer-implementation)6. [Multi-Agent Training Loop](6-multi-agent-training-loop)7. [mHC Coordination Training](7-mhc-coordination-training)8. [Adversarial Training](8-adversarial-training)9. [Evaluation & Metrics](9-evaluation--metrics)10. [Model Export & Deployment](10-model-export--deployment)11. [Hyperparameter Optimization](11-hyperparameter-optimization)12. [Visualization & Analysis](12-visualization--analysis)---## 1. System Overview**CyberGuard** is a multi-agent AI system for web security analysis that uses:1. **Manifold-Constrained Hyper-Connections (mHC)**: For stable multi-agent coordination2. **Grouped Query Attention (GQA)**: Memory-efficient transformer architecture3. **Flash Attention + RoPE**: Optimized attention with positional encoding4. **10 Specialized Agents**: Each focused on different security aspects### Why This Architecture?| Component | Purpose | Why It's Important ||-----------|---------|-------------------|| **mHC** | Prevents reasoning collapse | Ensures stable agent coordination || **GQA** | Reduces memory usage | Enables handling long web sessions || **Flash Attention** | Speeds up training | Makes real-time analysis possible || **RoPE** | Better position encoding | Improves sequence understanding || **Multi-Agent** | Specialized threat detection | Catches diverse attack patterns |```python Theoretical Benefits:mhc_benefits = [    "Prevents signal explosion in multi-agent systems",    "Eliminates dominant agent bias through doubly-stochastic normalization",    "Maintains identity-preserving mappings",    "Ensures bounded signal propagation",    "Enables convex state mixing"]gqa_benefits = [    "Reduces KV cache memory by 75% vs MHA",    "Maintains 95%+ of MHA performance",    "Enables longer sequence processing",    "Faster inference for real-time security"]```
<jupyter_code>
# Import all required libraries
# Explanation: We're importing the core libraries needed for training
import sys
import os
import json
import yaml
import pickle
import warnings
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any, Callable
from dataclasses import dataclass, field

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Add project root to Python path for imports
# This allows us to import modules from the src directory
project_root = Path('..').resolve()
sys.path.insert(0, str(project_root))

# Core data science libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder, StandardScaler

# PyTorch for deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.cuda.amp import GradScaler, autocast  # Mixed precision training

# Check if CUDA is available and set device
# CUDA enables GPU acceleration for faster training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üöÄ Using device: {device}")
print(f"üíª CUDA available: {torch.cuda.is_available()}")
print(f"üéØ GPU count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"üìä GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    print(f"üéÆ GPU Name: {torch.cuda.get_device_name(0)}")

# Set random seeds for reproducibility
# This ensures that our experiments are reproducible
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Import custom modules from CyberGuard
try:
    from src.core.mhc_architecture import ManifoldConstrainedHyperConnections
    from src.core.gqa_transformer import SecurityGQATransformer, FlashGQA, RotaryPositionalEmbedding
    from src.agents.base_agent import SecurityAgent
    from src.agents.agent_orchestrator import AgentOrchestrator
    print("‚úÖ Successfully imported CyberGuard modules")
except ImportError as e:
    print(f"‚ö†Ô∏è Could not import CyberGuard modules: {e}")
    print("üìù Creating mock implementations for training demonstration")
    
    # Create mock implementations for demonstration purposes
    class ManifoldConstrainedHyperConnections:
        """Mock implementation for demonstration"""
        pass
    
    class SecurityGQATransformer:
        """Mock implementation for demonstration"""
        pass

# Visualization settings
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

# Set display options for better readability
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)
pd.set_option('display.width', 1000)
pd.set_option('display.float_format', '{:.4f}'.format)

print("\n" + "="*80)
print("üéØ CYBERGUARD AGENT TRAINING ENVIRONMENT READY")
print("="*80)
<jupyter_output>
<empty_output>
<jupyter_text>
## 2. Environment Setup**Explanation**: In this section, we set up the complete training environment including configuration, directories, and logging. Proper environment setup is crucial for reproducible experiments.### Key Components:1. **Configuration Management**: Load and validate training parameters2. **Directory Structure**: Create organized folders for artifacts3. **Logging Setup**: Track experiments and results4. **Hardware Optimization**: Configure CUDA and mixed precision
<jupyter_code>
@dataclass
class TrainingConfig:
    """
    Dataclass for storing training configuration parameters.
    Dataclasses automatically generate __init__, __repr__, and other special methods.
    
    Parameters:
    -----------
    experiment_name : str
        Name of the experiment for tracking
    batch_size : int
        Number of samples per training batch
    learning_rate : float
        Initial learning rate for optimizer
    num_epochs : int
        Total number of training epochs
    warmup_steps : int
        Steps for linear learning rate warmup
    weight_decay : float
        L2 regularization strength
    gradient_clip : float
        Maximum gradient norm for clipping
    dropout_rate : float
        Dropout probability for regularization
    label_smoothing : float
        Label smoothing factor for classification
    early_stopping_patience : int
        Epochs to wait before early stopping
    checkpoint_frequency : int
        Save checkpoint every N epochs
    mixed_precision : bool
        Whether to use mixed precision training
    use_gqa : bool
        Whether to use Grouped Query Attention
    gqa_groups : int
        Number of groups for GQA (None = auto)
    mhc_temperature : float
        Temperature parameter for mHC coordination
    """
    
    # Experiment settings
    experiment_name: str = "cyberguard_v1"
    random_seed: int = 42
    
    # Model architecture
    d_model: int = 512              # Model dimension (embedding size)
    n_heads: int = 8                # Number of attention heads
    n_layers: int = 6               # Number of transformer layers
    d_ff: int = 2048                # Feed-forward dimension (4 * d_model)
    max_seq_len: int = 2048         # Maximum sequence length
    vocab_size: int = 30000         # Vocabulary size for tokenization
    
    # GQA settings
    use_gqa: bool = True            # Use Grouped Query Attention
    gqa_groups: int = 2             # Groups for GQA (n_heads // 4)
    use_flash_attention: bool = True  # Use Flash Attention optimization
    
    # Training hyperparameters
    batch_size: int = 32
    learning_rate: float = 3e-4
    num_epochs: int = 100
    warmup_steps: int = 2000
    weight_decay: float = 0.01
    gradient_clip: float = 1.0
    dropout_rate: float = 0.1
    label_smoothing: float = 0.1
    
    # mHC parameters
    mhc_temperature: float = 1.0
    mhc_sinkhorn_iterations: int = 50
    mhc_signal_bound: float = 1.0
    mhc_identity_preserve: float = 0.1
    
    # Training control
    early_stopping_patience: int = 10
    checkpoint_frequency: int = 5
    eval_frequency: int = 100       # Steps between evaluations
    log_frequency: int = 50         # Steps between logging
    
    # Optimization
    mixed_precision: bool = True    # Use mixed precision (FP16)
    gradient_checkpointing: bool = False  # Trade compute for memory
    
    # Data
    train_split: float = 0.7
    val_split: float = 0.15
    test_split: float = 0.15
    
    def __post_init__(self):
        """Validate configuration after initialization"""
        assert 0 < self.train_split < 1, "Train split must be between 0 and 1"
        assert 0 < self.val_split < 1, "Val split must be between 0 and 1"
        assert 0 < self.test_split < 1, "Test split must be between 0 and 1"
        assert self.d_model % self.n_heads == 0, "d_model must be divisible by n_heads"
        if self.use_gqa:
            assert self.n_heads % self.gqa_groups == 0, "n_heads must be divisible by gqa_groups"
        
        # Auto-calculate d_ff if not set
        if self.d_ff is None:
            self.d_ff = 4 * self.d_model
        
        # Create experiment directory
        self.experiment_dir = Path(f"../experiments/{self.experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
        self.experiment_dir.mkdir(parents=True, exist_ok=True)
        
        # Create subdirectories
        (self.experiment_dir / "checkpoints").mkdir(exist_ok=True)
        (self.experiment_dir / "logs").mkdir(exist_ok=True)
        (self.experiment_dir / "models").mkdir(exist_ok=True)
        (self.experiment_dir / "visualizations").mkdir(exist_ok=True)
    
    def save(self, filepath: str):
        """Save configuration to JSON file"""
        with open(filepath, 'w') as f:
            json.dump(self.__dict__, f, indent=2, default=str)
    
    @classmethod
    def load(cls, filepath: str):
        """Load configuration from JSON file"""
        with open(filepath, 'r') as f:
            data = json.load(f)
        return cls(**data)

# Initialize configuration
config = TrainingConfig()
print("üìã Training Configuration:")
print(json.dumps(config.__dict__, indent=2, default=str))

# Save configuration
config.save(config.experiment_dir / "config.json")
print(f"\nüíæ Configuration saved to: {config.experiment_dir / 'config.json'}")

# Setup logging
import logging
from logging.handlers import RotatingFileHandler

def setup_logging(log_dir: Path, log_level: str = "INFO"):
    """
    Setup comprehensive logging for training
    
    Args:
        log_dir: Directory to store log files
        log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
    """
    log_dir.mkdir(exist_ok=True)
    
    # Create formatters
    detailed_formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s'
    )
    simple_formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s'
    )
    
    # File handler (rotating logs)
    file_handler = RotatingFileHandler(
        log_dir / 'training.log',
        maxBytes=10*1024*1024,  # 10MB
        backupCount=5
    )
    file_handler.setLevel(getattr(logging, log_level))
    file_handler.setFormatter(detailed_formatter)
    
    # Console handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(simple_formatter)
    
    # Setup root logger
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)
    
    # Special logger for metrics
    metrics_logger = logging.getLogger('metrics')
    metrics_handler = logging.FileHandler(log_dir / 'metrics.log')
    metrics_handler.setFormatter(logging.Formatter('%(asctime)s,%(message)s'))
    metrics_logger.addHandler(metrics_handler)
    metrics_logger.setLevel(logging.INFO)
    metrics_logger.propagate = False
    
    return logger, metrics_logger

# Initialize logging
logger, metrics_logger = setup_logging(config.experiment_dir / "logs")
logger.info("üéØ Starting CyberGuard Agent Training")
logger.info(f"üìä Configuration: {config.experiment_name}")
logger.info(f"üíª Device: {device}")
logger.info(f"üéÆ CUDA Available: {torch.cuda.is_available()}")

# Create tensorboard writer for visualization
try:
    from torch.utils.tensorboard import SummaryWriter
    writer = SummaryWriter(log_dir=str(config.experiment_dir / "tensorboard"))
    logger.info("üìà TensorBoard initialized")
except ImportError:
    writer = None
    logger.warning("‚ö†Ô∏è TensorBoard not available, skipping visualization")

print("\n‚úÖ Environment setup complete!")
print(f"üìÅ Experiment directory: {config.experiment_dir}")
print(f"üìù Logs: {config.experiment_dir / 'logs'}")
print(f"üíæ Checkpoints: {config.experiment_dir / 'checkpoints'}")
<jupyter_output>
<empty_output>
<jupyter_text>
## 3. Data Loading & Preprocessing**Explanation**: This section handles loading, preprocessing, and preparing cybersecurity datasets for training. We work with multiple data sources including OWASP attacks, CVE databases, and web traffic logs.### Data Sources:1. **OWASP Web Security Dataset**: Labeled web attacks2. **CVE Database**: Known vulnerabilities and exploits3. **Web Traffic Logs**: Real HTTP request/response pairs4. **Threat Intelligence Feeds**: Current attack patterns
<jupyter_code>
class CyberSecurityDataset(Dataset):
    """
    PyTorch Dataset for cybersecurity training data.
    A Dataset class in PyTorch needs to implement __len__ and __getitem__ methods.
    
    This dataset handles:
    - Loading from multiple sources
    - Tokenization and encoding
    - Sequence padding and truncation
    - Label encoding and balancing
    """
    
    def __init__(self, 
                 data_paths: List[str], 
                 config: TrainingConfig,
                 split: str = 'train',
                 max_samples: Optional[int] = None):
        """
        Initialize the cybersecurity dataset.
        
        Args:
            data_paths: List of paths to data files
            config: Training configuration
            split: Data split ('train', 'val', 'test')
            max_samples: Maximum number of samples to load (for debugging)
        """
        super().__init__()
        self.config = config
        self.split = split
        self.max_samples = max_samples
        
        # Threat categories based on OWASP Top-10
        self.threat_categories = [
            'injection',          # SQLi, NoSQLi, OS command
            'broken_auth',        # Authentication bypass
            'sensitive_data',     # Data exposure
            'xxe',                # XML External Entities
            'broken_access',      # IDOR, privilege escalation
            'security_misconfig', # Security misconfiguration
            'xss',                # Cross-site scripting
            'insecure_deserial',  # Insecure deserialization
            'vulnerable_components', # Using vulnerable components
            'insufficient_logging', # Insufficient logging
            'benign',             # Normal traffic
            'suspicious',         # Unclassified suspicious
        ]
        
        # Initialize tokenizer
        self.tokenizer = self._initialize_tokenizer()
        
        # Load and preprocess data
        self.samples = self._load_data(data_paths)
        
        logger.info(f"üìä Loaded {len(self.samples)} samples for {split} split")
        logger.info(f"üéØ Threat distribution: {self._get_class_distribution()}")
    
    def _initialize_tokenizer(self):
        """
        Initialize a custom tokenizer for security data.
        In production, you might use BPE or SentencePiece.
        
        Returns:
            Simple tokenizer for demonstration
        """
        # Create vocabulary
        vocab = {
            '[PAD]': 0,
            '[UNK]': 1,
            '[CLS]': 2,
            '[SEP]': 3,
            '[MASK]': 4,
        }
        
        # Add common web security tokens
        web_tokens = [
            # HTTP methods
            'GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD', 'OPTIONS',
            # Headers
            'Content-Type', 'Authorization', 'Cookie', 'User-Agent',
            'X-Forwarded-For', 'X-CSRF-Token', 'X-XSS-Protection',
            # Protocols
            'HTTP/1.1', 'HTTPS', 'SSL', 'TLS',
            # Attack patterns
            '<script>', 'javascript:', 'onload=', 'onerror=',
            "' OR '1'='1", 'UNION SELECT', '; DROP', '--',
            '../../etc/passwd', '../', '..\\',
            # Common parameters
            'id=', 'user=', 'password=', 'token=', 'session=',
            'file=', 'cmd=', 'exec=', 'system=',
            # Special characters
            '&', '=', '?', '#', '@', '$', '%', '^', '*',
            '(', ')', '{', '}', '[', ']', '<', '>',
            "'", '"', '`', '\\', '/', '|'
        ]
        
        # Add tokens to vocabulary
        for idx, token in enumerate(web_tokens, start=len(vocab)):
            vocab[token] = idx
        
        # Create reverse mapping
        self.idx_to_token = {v: k for k, v in vocab.items()}
        self.vocab_size = len(vocab)
        
        return vocab
    
    def _load_data(self, data_paths: List[str]) -> List[Dict]:
        """
        Load data from multiple sources and preprocess.
        
        Args:
            data_paths: List of file paths
            
        Returns:
            List of processed samples
        """
        samples = []
        
        for data_path in data_paths:
            if not Path(data_path).exists():
                logger.warning(f"‚ö†Ô∏è Data file not found: {data_path}")
                continue
            
            try:
                # Load based on file extension
                if data_path.endswith('.json'):
                    with open(data_path, 'r') as f:
                        data = json.load(f)
                elif data_path.endswith('.csv'):
                    data = pd.read_csv(data_path).to_dict('records')
                elif data_path.endswith('.parquet'):
                    data = pd.read_parquet(data_path).to_dict('records')
                else:
                    logger.warning(f"‚ö†Ô∏è Unsupported file format: {data_path}")
                    continue
                
                # Process each sample
                for sample in data[:self.max_samples] if self.max_samples else data:
                    processed = self._preprocess_sample(sample)
                    if processed:
                        samples.append(processed)
                        
            except Exception as e:
                logger.error(f"‚ùå Error loading {data_path}: {e}")
                continue
        
        # Balance dataset if training
        if self.split == 'train':
            samples = self._balance_dataset(samples)
        
        return samples
    
    def _preprocess_sample(self, sample: Dict) -> Optional[Dict]:
        """
        Preprocess a single sample.
        
        Args:
            sample: Raw sample data
            
        Returns:
            Processed sample or None if invalid
        """
        try:
            # Extract features
            url = sample.get('url', '')
            method = sample.get('method', 'GET')
            headers = sample.get('headers', {})
            body = sample.get('body', '')
            params = sample.get('params', {})
            
            # Get threat label
            threat_type = sample.get('threat_type', 'benign')
            if threat_type not in self.threat_categories:
                threat_type = 'suspicious'
            
            # Convert to threat category index
            threat_label = self.threat_categories.index(threat_type)
            
            # Get severity (0.0 to 1.0)
            severity = float(sample.get('severity', 0.0))
            
            # Create text representation
            text_representation = self._create_text_representation(
                url, method, headers, body, params
            )
            
            # Tokenize
            token_ids = self._tokenize_text(text_representation)
            
            # Create attention mask
            attention_mask = [1] * len(token_ids)
            
            # Pad or truncate sequence
            if len(token_ids) > self.config.max_seq_len:
                token_ids = token_ids[:self.config.max_seq_len]
                attention_mask = attention_mask[:self.config.max_seq_len]
            else:
                padding_length = self.config.max_seq_len - len(token_ids)
                token_ids = token_ids + [self.tokenizer['[PAD]']] * padding_length
                attention_mask = attention_mask + [0] * padding_length
            
            return {
                'token_ids': torch.tensor(token_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
                'threat_label': torch.tensor(threat_label, dtype=torch.long),
                'severity': torch.tensor(severity, dtype=torch.float32),
                'original_text': text_representation[:200],  # For debugging
                'threat_type': threat_type,
            }
            
        except Exception as e:
            logger.debug(f"Error preprocessing sample: {e}")
            return None
    
    def _create_text_representation(self, url: str, method: str, 
                                   headers: Dict, body: str, params: Dict) -> str:
        """
        Create a text representation of the HTTP request for tokenization.
        
        Args:
            url: Request URL
            method: HTTP method
            headers: HTTP headers
            body: Request body
            params: Query parameters
            
        Returns:
            Text representation
        """
        parts = []
        
        # Method and URL
        parts.append(f"{method} {url}")
        
        # Headers
        for key, value in headers.items():
            parts.append(f"{key}: {value}")
        
        # Parameters
        if params:
            parts.append("PARAMS: " + "&".join([f"{k}={v}" for k, v in params.items()]))
        
        # Body (first 1000 chars)
        if body:
            body_str = str(body)[:1000]
            parts.append(f"BODY: {body_str}")
        
        return " [SEP] ".join(parts)
    
    def _tokenize_text(self, text: str) -> List[int]:
        """
        Simple tokenizer that splits on whitespace and special characters.
        In production, use a proper tokenizer like BPE.
        
        Args:
            text: Text to tokenize
            
        Returns:
            List of token IDs
        """
        tokens = []
        
        # Simple tokenization (split on whitespace and special chars)
        import re
        word_tokens = re.findall(r'\w+|[^\w\s]', text)
        
        for token in word_tokens:
            if token in self.tokenizer:
                tokens.append(self.tokenizer[token])
            else:
                # Try uppercase/lowercase variants
                if token.upper() in self.tokenizer:
                    tokens.append(self.tokenizer[token.upper()])
                elif token.lower() in self.tokenizer:
                    tokens.append(self.tokenizer[token.lower()])
                else:
                    # Try splitting further
                    for char in token:
                        if char in self.tokenizer:
                            tokens.append(self.tokenizer[char])
                        else:
                            tokens.append(self.tokenizer['[UNK]'])
        
        # Add CLS token at beginning
        tokens = [self.tokenizer['[CLS]']] + tokens
        
        return tokens
    
    def _balance_dataset(self, samples: List[Dict]) -> List[Dict]:
        """
        Balance dataset by oversampling minority classes.
        
        Args:
            samples: List of samples
            
        Returns:
            Balanced list of samples
        """
        # Count samples per class
        class_counts = {}
        for sample in samples:
            label = sample['threat_type']
            class_counts[label] = class_counts.get(label, 0) + 1
        
        # Find target count (median of class counts)
        target_count = int(np.median(list(class_counts.values())))
        
        # Oversample minority classes
        balanced_samples = []
        for label, count in class_counts.items():
            class_samples = [s for s in samples if s['threat_type'] == label]
            
            if count < target_count:
                # Oversample
                oversample_factor = target_count // count
                remainder = target_count % count
                
                for _ in range(oversample_factor):
                    balanced_samples.extend(class_samples)
                balanced_samples.extend(class_samples[:remainder])
            else:
                # Take random subset
                indices = np.random.choice(len(class_samples), target_count, replace=False)
                balanced_samples.extend([class_samples[i] for i in indices])
        
        # Shuffle
        np.random.shuffle(balanced_samples)
        
        logger.info(f"‚öñÔ∏è Balanced dataset: {len(samples)} -> {len(balanced_samples)} samples")
        return balanced_samples
    
    def _get_class_distribution(self) -> Dict[str, int]:
        """Get distribution of threat classes"""
        distribution = {}
        for sample in self.samples:
            label = sample['threat_type']
            distribution[label] = distribution.get(label, 0) + 1
        return distribution
    
    def __len__(self) -> int:
        """Return number of samples in dataset"""
        return len(self.samples)
    
    def __getitem__(self, idx: int) -> Dict:
        """
        Get a single sample by index.
        
        Args:
            idx: Sample index
            
        Returns:
            Sample dictionary
        """
        return self.samples[idx]

# Create sample data for demonstration
def create_sample_data(output_dir: Path, num_samples: int = 10000):
    """
    Create sample cybersecurity data for training demonstration.
    In production, you would use real datasets.
    
    Args:
        output_dir: Directory to save sample data
        num_samples: Number of samples to generate
    """
    output_dir.mkdir(exist_ok=True)
    
    # Threat types and their characteristics
    threat_types = {
        'injection': {
            'patterns': ["' OR '1'='1", 'UNION SELECT', '; DROP TABLE', '1=1'],
            'severity': 0.9,
            'methods': ['POST', 'GET']
        },
        'xss': {
            'patterns': ['<script>alert', 'javascript:', 'onload=', '<img src=x onerror='],
            'severity': 0.7,
            'methods': ['GET', 'POST']
        },
        'broken_auth': {
            'patterns': ['admin', 'password=', 'token=1234', 'session=insecure'],
            'severity': 0.8,
            'methods': ['POST', 'PUT']
        },
        'benign': {
            'patterns': ['home', 'about', 'contact', 'products'],
            'severity': 0.0,
            'methods': ['GET']
        }
    }
    
    # URLs for different applications
    base_urls = [
        'https://example.com',
        'https://api.example.com',
        'https://admin.example.com',
        'https://shop.example.com'
    ]
    
    # Generate samples
    samples = []
    for i in range(num_samples):
        # Choose threat type
        if i < num_samples * 0.7:  # 70% benign
            threat_type = 'benign'
        else:
            threat_type = np.random.choice(['injection', 'xss', 'broken_auth'])
        
        # Get threat characteristics
        threat_info = threat_types[threat_type]
        
        # Generate sample
        url = np.random.choice(base_urls) + '/' + np.random.choice(['login', 'api', 'admin', 'search'])
        method = np.random.choice(threat_info['methods'])
        
        # Add attack pattern for malicious samples
        params = {}
        if threat_type != 'benign' and np.random.random() > 0.5:
            pattern = np.random.choice(threat_info['patterns'])
            params = {'q': pattern, 'id': str(i)}
        
        # Create sample
        sample = {
            'id': i,
            'url': url,
            'method': method,
            'headers': {
                'User-Agent': f'Browser_{np.random.randint(1, 100)}',
                'Content-Type': 'application/json' if method == 'POST' else 'text/html'
            },
            'params': params,
            'body': json.dumps({'data': 'test'}) if method == 'POST' else '',
            'threat_type': threat_type,
            'severity': threat_info['severity'],
            'timestamp': datetime.now().isoformat()
        }
        
        samples.append(sample)
    
    # Save to file
    output_file = output_dir / 'cybersecurity_dataset.json'
    with open(output_file, 'w') as f:
        json.dumps(samples, f, indent=2)
    
    logger.info(f"üìÅ Created sample dataset: {output_file} ({len(samples)} samples)")
    return str(output_file)

# Create and load dataset
logger.info("üì• Creating/Loading dataset...")
data_dir = Path("../data")
data_dir.mkdir(exist_ok=True)

# Check if dataset exists, otherwise create sample data
dataset_paths = []
for file in data_dir.glob("*.json"):
    dataset_paths.append(str(file))

if not dataset_paths:
    logger.info("üìù No dataset found, creating sample data...")
    sample_file = create_sample_data(data_dir, num_samples=10000)
    dataset_paths = [sample_file]

# Split dataset paths for train/val/test
# In production, you would have separate files or use proper splitting
train_paths = dataset_paths[:int(len(dataset_paths) * config.train_split)]
val_paths = dataset_paths[int(len(dataset_paths) * config.train_split):
                          int(len(dataset_paths) * (config.train_split + config.val_split))]
test_paths = dataset_paths[int(len(dataset_paths) * (config.train_split + config.val_split)):]

# Create datasets
train_dataset = CyberSecurityDataset(train_paths, config, split='train')
val_dataset = CyberSecurityDataset(val_paths, config, split='val')
test_dataset = CyberSecurityDataset(test_paths, config, split='test')

# Create data loaders
# DataLoader handles batching, shuffling, and parallel loading
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False,
    drop_last=True  # Drop incomplete batches
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

logger.info(f"üìä Dataset statistics:")
logger.info(f"   Train: {len(train_dataset)} samples, {len(train_loader)} batches")
logger.info(f"   Validation: {len(val_dataset)} samples, {len(val_loader)} batches")
logger.info(f"   Test: {len(test_dataset)} samples, {len(test_loader)} batches")

# Visualize class distribution
def visualize_class_distribution(dataset: CyberSecurityDataset, title: str):
    """Visualize distribution of threat classes"""
    distribution = dataset._get_class_distribution()
    
    fig, ax = plt.subplots(figsize=(12, 6))
    bars = ax.bar(distribution.keys(), distribution.values())
    ax.set_title(f'{title} - Class Distribution', fontsize=14, fontweight='bold')
    ax.set_xlabel('Threat Type', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.tick_params(axis='x', rotation=45)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(config.experiment_dir / "visualizations" / f"{title.lower().replace(' ', '_')}_distribution.png", dpi=150)
    plt.show()

visualize_class_distribution(train_dataset, "Training Set")
visualize_class_distribution(val_dataset, "Validation Set")

print("\n‚úÖ Data loading complete!")
print(f"üìä Training samples: {len(train_dataset):,}")
print(f"üìà Validation samples: {len(val_dataset):,}")
print(f"üéØ Test samples: {len(test_dataset):,}")
<jupyter_output>
<empty_output>
<jupyter_text>
## 4. mHC Architecture Implementation**Explanation**: Manifold-Constrained Hyper-Connections (mHC) is a novel architecture for stable multi-agent coordination. It prevents common issues like signal explosion, dominant agent bias, and reasoning collapse.### Key mHC Principles:1. **Doubly-Stochastic Normalization**: Ensures balanced agent contributions2. **Convex State Mixing**: Stable combination of agent states3. **Identity-Preserving Mappings**: Maintains individual agent characteristics4. **Non-Expansive Updates**: Bounded signal propagation5. **Sinkhorn-Knopp Projection**: Mathematical optimization for stability
<jupyter_code>
class EnhancedManifoldConstrainedHyperConnections(nn.Module):
    """
    Enhanced Manifold-Constrained Hyper-Connections (mHC) for multi-agent stability.
    
    This implementation prevents:
    - Signal explosion in multi-agent systems
    - Dominant agent bias through doubly-stochastic normalization
    - Reasoning collapse with identity-preserving mappings
    - Unstable updates via non-expansive constraints
    
    Mathematical Foundation:
    ------------------------
    Let A be the agent interaction matrix, we want:
    1. A ‚â• 0 (non-negative)
    2. Œ£‚±º A·µ¢‚±º = 1 ‚àÄi (row stochastic)
    3. Œ£·µ¢ A·µ¢‚±º = 1 ‚àÄj (column stochastic)
    4. ||A(x - y)|| ‚â§ ||x - y|| (non-expansive)
    
    We achieve this through Sinkhorn-Knopp iteration.
    """
    
    def __init__(self, 
                 n_agents: int, 
                 state_dim: int, 
                 temperature: float = 1.0,
                 sinkhorn_iterations: int = 50,
                 epsilon: float = 1e-8):
        """
        Initialize mHC module.
        
        Args:
            n_agents: Number of agents in the system
            state_dim: Dimension of agent state vectors
            temperature: Softmax temperature for attention
            sinkhorn_iterations: Number of Sinkhorn iterations
            epsilon: Small value for numerical stability
        """
        super().__init__()
        self.n_agents = n_agents
        self.state_dim = state_dim
        self.temperature = temperature
        self.sinkhorn_iterations = sinkhorn_iterations
        self.epsilon = epsilon
        
        # Learnable parameters for agent interaction
        # These weights will be learned during training
        self.agent_weights = nn.Parameter(torch.randn(n_agents, state_dim))
        
        # Signal bound for non-expansive constraint
        # Prevents signal explosion in the network
        self.signal_bound = nn.Parameter(torch.tensor(1.0))
        
        # Identity preservation factor
        # Controls how much original agent identity is preserved
        self.identity_factor = nn.Parameter(torch.tensor(0.1))
        
        # Layer normalization for state stabilization
        self.layer_norm = nn.LayerNorm(state_dim)
        
        # Initialize weights properly
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize weights with proper scaling"""
        # Xavier initialization for agent weights
        nn.init.xavier_uniform_(self.agent_weights)
        
        # Initialize signal bound
        nn.init.constant_(self.signal_bound, 1.0)
        
        # Initialize identity factor
        nn.init.constant_(self.identity_factor, 0.1)
    
    def sinkhorn_knopp_projection(self, log_alpha: torch.Tensor) -> torch.Tensor:
        """
        Sinkhorn-Knopp algorithm for doubly-stochastic normalization.
        
        The Sinkhorn-Knopp algorithm iteratively normalizes rows and columns
        of a matrix to make it doubly stochastic (all rows and columns sum to 1).
        
        Args:
            log_alpha: Log-space attention matrix [batch_size, n_agents, n_agents]
            
        Returns:
            Doubly-stochastic attention matrix
        """
        batch_size = log_alpha.shape[0]
        
        # Start with log matrix
        log_alpha = log_alpha / self.temperature
        
        for _ in range(self.sinkhorn_iterations):
            # Row normalization (sum over columns = 1)
            # Subtract logsumexp along columns
            log_alpha = log_alpha - torch.logsumexp(
                log_alpha, 
                dim=2, 
                keepdim=True
            )
            
            # Column normalization (sum over rows = 1)
            # Subtract logsumexp along rows
            log_alpha = log_alpha - torch.logsumexp(
                log_alpha, 
                dim=1, 
                keepdim=True
            )
        
        # Convert back from log space
        alpha = torch.exp(log_alpha)
        
        return alpha
    
    def convex_state_mixing(self, 
                           agent_states: torch.Tensor,
                           attention_weights: torch.Tensor) -> torch.Tensor:
        """
        Convex mixing of agent states with manifold constraints.
        
        This implements:
        1. Convex combination using doubly-stochastic attention
        2. Identity-preserving residual connections
        3. Non-expansive signal bounding
        4. Layer normalization for stability
        
        Args:
            agent_states: Tensor of shape [batch_size, n_agents, state_dim]
            attention_weights: Attention matrix [batch_size, n_agents, n_agents]
            
        Returns:
            Mixed states [batch_size, state_dim]
        """
        batch_size = agent_states.shape[0]
        
        # Ensure attention weights are doubly stochastic
        log_attention = torch.log(attention_weights + self.epsilon)
        normalized_attention = self.sinkhorn_knopp_projection(log_attention)
        
        # Convex combination of agent states
        # Each agent's state is weighted by attention to all agents
        mixed_state = torch.einsum('bij,bjk->bik', 
                                  normalized_attention, 
                                  agent_states)
        
        # Average over agents to get single state vector
        mixed_state = mixed_state.mean(dim=1)  # [batch_size, state_dim]
        
        # Identity-preserving residual connection
        # Preserve some of the original average agent state
        original_mean = agent_states.mean(dim=1)
        identity_preserved = self.identity_factor * original_mean
        
        mixed_state = (1 - self.identity_factor) * mixed_state + identity_preserved
        
        # Apply non-expansive signal bounding
        # This prevents signal explosion in deep networks
        state_norm = torch.norm(mixed_state, dim=-1, keepdim=True)
        scaling_factor = torch.min(
            torch.ones_like(state_norm),
            self.signal_bound / (state_norm + self.epsilon)
        )
        mixed_state = mixed_state * scaling_factor
        
        # Layer normalization for stability
        mixed_state = self.layer_norm(mixed_state)
        
        return mixed_state
    
    def compute_attention(self, 
                         agent_states: torch.Tensor,
                         agent_confidences: torch.Tensor) -> torch.Tensor:
        """
        Compute attention weights between agents.
        
        Uses both:
        1. Content-based attention (based on state similarity)
        2. Confidence-based attention (based on agent confidence)
        
        Args:
            agent_states: [batch_size, n_agents, state_dim]
            agent_confidences: [batch_size, n_agents] (0 to 1)
            
        Returns:
            Attention weights [batch_size, n_agents, n_agents]
        """
        batch_size, n_agents, state_dim = agent_states.shape
        
        # Content-based attention (similarity between agent states)
        # Normalize states for cosine similarity
        states_norm = F.normalize(agent_states, p=2, dim=-1)
        
        # Compute pairwise cosine similarity
        content_similarity = torch.einsum('bid,bjd->bij', 
                                         states_norm, 
                                         states_norm)
        
        # Confidence-based attention
        # Agents with higher confidence get more attention
        confidence_matrix = agent_confidences.unsqueeze(2) * agent_confidences.unsqueeze(1)
        
        # Combine content and confidence attention
        # Learnable weight for each combination
        combined_attention = (content_similarity + confidence_matrix) / 2
        
        # Apply softmax
        attention_weights = F.softmax(combined_attention / self.temperature, dim=-1)
        
        return attention_weights
    
    def forward(self, 
                agent_states: torch.Tensor,
                agent_confidences: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass through mHC layer.
        
        Args:
            agent_states: [batch_size, n_agents, state_dim]
            agent_confidences: [batch_size, n_agents]
            
        Returns:
            coordinated_state: [batch_size, state_dim]
            attention_matrix: [batch_size, n_agents, n_agents]
        """
        # Compute attention between agents
        attention_matrix = self.compute_attention(agent_states, agent_confidences)
        
        # Apply convex state mixing with manifold constraints
        coordinated_state = self.convex_state_mixing(agent_states, attention_matrix)
        
        return coordinated_state, attention_matrix
    
    def get_stability_metrics(self, 
                             agent_states: torch.Tensor,
                             coordinated_state: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Compute stability metrics for debugging and monitoring.
        
        Metrics:
        1. Signal norm: Measures signal magnitude
        2. Identity preservation: How much original identity is preserved
        3. Attention entropy: Diversity of attention distribution
        4. State change: Magnitude of state transformation
        
        Args:
            agent_states: Input agent states
            coordinated_state: Output coordinated state
            
        Returns:
            Dictionary of stability metrics
        """
        batch_size = agent_states.shape[0]
        
        metrics = {}
        
        # 1. Signal norm (should be bounded)
        metrics['signal_norm'] = torch.norm(coordinated_state, dim=-1).mean()
        
        # 2. Identity preservation
        original_mean = agent_states.mean(dim=1)
        identity_preservation = F.cosine_similarity(
            coordinated_state, 
            original_mean, 
            dim=-1
        ).mean()
        metrics['identity_preservation'] = identity_preservation
        
        # 3. Attention entropy (higher = more diverse attention)
        attention = self.compute_attention(agent_states, 
                                          torch.ones(batch_size, self.n_agents).to(agent_states.device))
        attention_entropy = -torch.sum(attention * torch.log(attention + self.epsilon), dim=-1).mean()
        metrics['attention_entropy'] = attention_entropy
        
        # 4. State change magnitude
        state_change = torch.norm(coordinated_state - original_mean, dim=-1).mean()
        metrics['state_change'] = state_change
        
        return metrics

# Test mHC implementation
def test_mhc_implementation():
    """Test mHC implementation with synthetic data"""
    print("üß™ Testing mHC Implementation...")
    
    # Create synthetic agent states
    batch_size = 4
    n_agents = 5
    state_dim = 128
    
    # Generate random agent states and confidences
    agent_states = torch.randn(batch_size, n_agents, state_dim)
    agent_confidences = torch.rand(batch_size, n_agents)
    
    # Create mHC module
    mhc = EnhancedManifoldConstrainedHyperConnections(
        n_agents=n_agents,
        state_dim=state_dim,
        temperature=1.0,
        sinkhorn_iterations=10  # Fewer iterations for testing
    )
    
    # Test forward pass
    coordinated_state, attention_matrix = mhc(agent_states, agent_confidences)
    
    print(f"‚úÖ Input shape: {agent_states.shape}")
    print(f"‚úÖ Output shape: {coordinated_state.shape}")
    print(f"‚úÖ Attention matrix shape: {attention_matrix.shape}")
    
    # Test doubly-stochastic property
    print("\nüìä Testing doubly-stochastic property:")
    
    # Sum over columns should be 1 for each row
    row_sums = attention_matrix.sum(dim=-1)
    row_sum_error = torch.abs(row_sums - 1.0).mean().item()
    print(f"   Row sum error: {row_sum_error:.6f} (should be ~0)")
    
    # Sum over rows should be 1 for each column
    col_sums = attention_matrix.sum(dim=1)
    col_sum_error = torch.abs(col_sums - 1.0).mean().item()
    print(f"   Column sum error: {col_sum_error:.6f} (should be ~0)")
    
    # Test stability metrics
    metrics = mhc.get_stability_metrics(agent_states, coordinated_state)
    print("\nüìà Stability Metrics:")
    for name, value in metrics.items():
        print(f"   {name}: {value.item():.4f}")
    
    # Test signal bounding
    print("\nüéØ Testing signal bounding:")
    signal_norm = torch.norm(coordinated_state, dim=-1)
    print(f"   Signal norms: {signal_norm.tolist()}")
    print(f"   All ‚â§ signal_bound ({mhc.signal_bound.item():.2f}): {(signal_norm <= mhc.signal_bound).all()}")
    
    return mhc

# Run mHC tests
mhc_module = test_mhc_implementation()

print("\n‚úÖ mHC implementation complete!")
print("üìä Key features implemented:")
print("   ‚Ä¢ Doubly-stochastic normalization via Sinkhorn-Knopp")
print("   ‚Ä¢ Convex state mixing with identity preservation")
print("   ‚Ä¢ Non-expansive signal bounding")
print("   ‚Ä¢ Stability metrics for monitoring")
print("   ‚Ä¢ Learnable parameters for adaptive coordination")
<jupyter_output>
<empty_output>
<jupyter_text>
## 5. GQA Transformer Implementation**Explanation**: Grouped Query Attention (GQA) is an optimized attention mechanism that groups multiple query heads to share the same key/value heads. This reduces memory usage while maintaining performance.### GQA vs MHA vs MQA:1. **MHA (Multi-Head Attention)**: Each head has separate Q, K, V - High quality, high memory2. **MQA (Multi-Query Attention)**: All heads share same K, V - Low memory, lower quality3. **GQA (Grouped Query Attention)**: Groups of heads share K, V - Balanced approach### Key Innovations:1. **Flash Attention Integration**: Optimized GPU implementation2. **RoPE (Rotary Positional Embedding)**: Better position encoding3. **Mixed Precision Support**: Faster training with FP164. **KV Cache Optimization**: Efficient inference
<jupyter_code>
class EnhancedRotaryPositionalEmbedding(nn.Module):
    """
    Enhanced Rotary Positional Embedding (RoPE) for transformers.
    
    RoPE encodes position information by rotating query and key vectors
    using sinusoidal functions. This provides better position awareness
    than absolute or relative position embeddings.
    
    Mathematical Formulation:
    ------------------------
    For a position m and dimension i, RoPE rotates the vector by angle Œ∏·µ¢:
    f(q, m) = q ‚äô exp(i m Œ∏)
    where Œ∏·µ¢ = 10000^(-2i/d)
    
    This creates relative position awareness in the attention mechanism.
    """
    
    def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0):
        """
        Initialize RoPE module.
        
        Args:
            dim: Dimension of the embeddings (must be even)
            max_seq_len: Maximum sequence length to precompute
            base: Base for frequency calculation
        """
        super().__init__()
        assert dim % 2 == 0, "Dimension must be even for RoPE"
        
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        
        # Precompute inverse frequencies
        # Œ∏ = 10000^(-2i/d) for i = 0, 1, ..., d/2-1
        inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
        
        # Precompute position indices
        position = torch.arange(max_seq_len, dtype=torch.float)
        
        # Outer product to get sinusoidal arguments
        # shape: [max_seq_len, dim//2]
        sinusoid_inp = torch.einsum("i,j->ij", position, inv_freq)
        
        # Precompute sin and cos
        # Using torch.cat to interleave sin and cos
        sin = torch.sin(sinusoid_inp)
        cos = torch.cos(sinusoid_inp)
        
        # Register as buffers (not trainable parameters)
        self.register_buffer("sin", sin, persistent=False)
        self.register_buffer("cos", cos, persistent=False)
        
    def forward(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor:
        """
        Apply rotary positional embedding to input tensor.
        
        Args:
            x: Input tensor of shape [batch_size, num_heads, seq_len, head_dim]
            offset: Position offset for incremental generation
            
        Returns:
            Tensor with rotary embeddings applied
        """
        batch_size, num_heads, seq_len, head_dim = x.shape
        assert head_dim == self.dim, f"Head dimension {head_dim} != RoPE dimension {self.dim}"
        
        # Reshape to separate real and imaginary parts
        # x shape: [batch_size, num_heads, seq_len, head_dim//2, 2]
        x_reshaped = x.view(batch_size, num_heads, seq_len, head_dim // 2, 2)
        
        # Get sin and cos for current positions
        sin = self.sin[offset:offset + seq_len].view(1, 1, seq_len, head_dim // 2, 1)
        cos = self.cos[offset:offset + seq_len].view(1, 1, seq_len, head_dim // 2, 1)
        
        # Extract real and imaginary parts
        x_real = x_reshaped[..., 0]
        x_imag = x_reshaped[..., 1]
        
        # Apply rotation: [x_real', x_imag'] = [x_real*cos - x_imag*sin, x_real*sin + x_imag*cos]
        x_real_rotated = x_real * cos - x_imag * sin
        x_imag_rotated = x_real * sin + x_imag * cos
        
        # Stack back
        x_rotated = torch.stack([x_real_rotated, x_imag_rotated], dim=-1)
        
        # Reshape back to original shape
        return x_rotated.view(batch_size, num_heads, seq_len, head_dim)

class EnhancedFlashGQA(nn.Module):
    """
    Enhanced Grouped Query Attention with Flash Attention optimization.
    
    GQA groups multiple query heads to share the same key/value heads,
    reducing memory usage while maintaining performance.
    
    Architecture:
    -------------
    - Query heads: n_heads (e.g., 8)
    - Key/Value groups: n_groups (e.g., 2)
    - Each group services n_heads / n_groups query heads
    
    Memory Savings:
    ---------------
    KV cache memory: MHA = 2 * n_heads * d_k * seq_len
                    GQA = 2 * n_groups * d_k * seq_len
    Savings: 1 - n_groups/n_heads (e.g., 75% for 8‚Üí2)
    """
    
    def __init__(self, 
                 d_model: int, 
                 n_heads: int, 
                 n_groups: Optional[int] = None,
                 dropout: float = 0.1,
                 use_flash: bool = True,
                 causal: bool = True):
        """
        Initialize Enhanced Flash GQA.
        
        Args:
            d_model: Model dimension
            n_heads: Number of query heads
            n_groups: Number of key/value groups (default: n_heads // 4)
            dropout: Attention dropout probability
            use_flash: Whether to use Flash Attention
            causal: Whether to use causal masking
        """
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.dropout = dropout
        self.use_flash = use_flash
        self.causal = causal
        
        # Set default groups
        if n_groups is None:
            n_groups = max(1, n_heads // 4)
        self.n_groups = n_groups
        assert n_heads % n_groups == 0, "n_heads must be divisible by n_groups"
        
        # Rotary Positional Embedding
        self.rope = EnhancedRotaryPositionalEmbedding(self.d_k)
        
        # Linear projections
        # Query: full dimension (n_heads * d_k)
        self.W_q = nn.Linear(d_model, d_model)
        
        # Key/Value: grouped dimension (n_groups * d_k)
        self.W_k = nn.Linear(d_model, n_groups * self.d_k)
        self.W_v = nn.Linear(d_model, n_groups * self.d_k)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model)
        
        # Attention dropout
        self.dropout_layer = nn.Dropout(dropout)
        
        # Group mapping: which query heads belong to which group
        self.register_buffer('group_map', self._create_group_map(n_heads, n_groups))
        
        # KV cache for inference (initialized as None)
        self.kv_cache = None
        
        # Initialize weights
        self._initialize_weights()
        
        logger.info(f"üéØ Initialized GQA with {n_heads} heads, {n_groups} groups")
        logger.info(f"   Memory savings: {100 * (1 - n_groups/n_heads):.1f}% vs MHA")
    
    def _create_group_map(self, n_heads: int, n_groups: int) -> torch.Tensor:
        """
        Create mapping from head index to group index.
        
        Example: n_heads=8, n_groups=2 ‚Üí [0,0,0,0,1,1,1,1]
        """
        group_size = n_heads // n_groups
        mapping = []
        for group_idx in range(n_groups):
            mapping.extend([group_idx] * group_size)
        return torch.tensor(mapping, dtype=torch.long)
    
    def _initialize_weights(self):
        """Initialize weights with proper scaling"""
        # Xavier initialization for linear layers
        nn.init.xavier_uniform_(self.W_q.weight, gain=1/math.sqrt(2))
        nn.init.xavier_uniform_(self.W_k.weight, gain=1/math.sqrt(2))
        nn.init.xavier_uniform_(self.W_v.weight, gain=1/math.sqrt(2))
        nn.init.xavier_uniform_(self.W_o.weight, gain=1/math.sqrt(2))
        
        # Initialize biases to zero
        if self.W_q.bias is not None:
            nn.init.zeros_(self.W_q.bias)
        if self.W_k.bias is not None:
            nn.init.zeros_(self.W_k.bias)
        if self.W_v.bias is not None:
            nn.init.zeros_(self.W_v.bias)
        if self.W_o.bias is not None:
            nn.init.zeros_(self.W_o.bias)
    
    def _expand_kv(self, K: torch.Tensor, V: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Expand grouped keys/values to match query heads.
        
        Args:
            K: Keys [batch_size, n_groups, seq_len, d_k]
            V: Values [batch_size, n_groups, seq_len, d_k]
            
        Returns:
            Expanded K, V [batch_size, n_heads, seq_len, d_k]
        """
        # Use group_map to duplicate KV for each query head
        K_expanded = K[:, self.group_map]
        V_expanded = V[:, self.group_map]
        return K_expanded, V_expanded
    
    def _flash_attention(self, 
                        Q: torch.Tensor, 
                        K: torch.Tensor, 
                        V: torch.Tensor,
                        mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Compute attention using Flash Attention if available.
        
        Flash Attention reduces memory usage from O(n¬≤) to O(n) by
        computing attention in blocks without storing the full attention matrix.
        
        Args:
            Q: Queries [batch_size, n_heads, seq_len, d_k]
            K: Keys [batch_size, n_heads, seq_len, d_k]
            V: Values [batch_size, n_heads, seq_len, d_k]
            mask: Optional attention mask
            
        Returns:
            Attention output
        """
        try:
            # Try to use PyTorch 2.0's scaled_dot_product_attention
            # This uses Flash Attention internally if available
            with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
                output = F.scaled_dot_product_attention(
                    Q, K, V,
                    attn_mask=mask,
                    dropout_p=self.dropout if self.training else 0.0,
                    is_causal=self.causal and mask is None
                )
            return output
        except (RuntimeError, AttributeError):
            # Fall back to standard attention
            logger.warning("‚ö†Ô∏è Flash Attention not available, using standard attention")
            return self._standard_attention(Q, K, V, mask)
    
    def _standard_attention(self,
                          Q: torch.Tensor,
                          K: torch.Tensor,
                          V: torch.Tensor,
                          mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Standard attention computation (fallback).
        
        Args:
            Q, K, V: Query, Key, Value tensors
            mask: Optional attention mask
            
        Returns:
            Attention output
        """
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        
        # Apply causal mask if needed
        if self.causal and mask is None:
            seq_len = Q.size(-2)
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=Q.device, dtype=torch.bool),
                diagonal=1
            )
            scores = scores.masked_fill(causal_mask, -1e9)
        
        # Softmax
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout_layer(attn_weights)
        
        # Apply to values
        output = torch.matmul(attn_weights, V)
        
        return output
    
    def forward(self,
                query: torch.Tensor,
                key: Optional[torch.Tensor] = None,
                value: Optional[torch.Tensor] = None,
                mask: Optional[torch.Tensor] = None,
                use_cache: bool = False,
                past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Forward pass through GQA layer.
        
        Args:
            query: Query tensor [batch_size, seq_len, d_model]
            key: Optional key tensor (if None, uses query)
            value: Optional value tensor (if None, uses query)
            mask: Optional attention mask
            use_cache: Whether to use KV cache
            past_key_value: Previous KV cache
            
        Returns:
            output: Attention output [batch_size, seq_len, d_model]
            present_key_value: Updated KV cache
        """
        batch_size, seq_len, _ = query.shape
        
        # Use query as key/value if not provided (self-attention)
        if key is None:
            key = query
        if value is None:
            value = query
        
        # Project queries (per head)
        Q = self.W_q(query)
        Q = Q.view(batch_size, seq_len, self.n_heads, self.d_k)
        Q = Q.transpose(1, 2)  # [batch_size, n_heads, seq_len, d_k]
        
        # Project keys/values (grouped)
        K = self.W_k(key).view(batch_size, -1, self.n_groups, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.n_groups, self.d_k).transpose(1, 2)
        
        # Apply Rotary Positional Embedding
        Q = self.rope(Q)
        K = self.rope(K)
        
        # Handle KV cache for inference
        if use_cache:
            if past_key_value is not None:
                # Concatenate with previous cache
                past_K, past_V = past_key_value
                K = torch.cat([past_K, K], dim=2)
                V = torch.cat([past_V, V], dim=2)
            
            # Update present key value
            present_key_value = (K, V)
        else:
            present_key_value = None
        
        # Expand KV to match Q heads
        K_expanded, V_expanded = self._expand_kv(K, V)
        
        # Compute attention
        if self.use_flash and self.training:
            attn_output = self._flash_attention(Q, K_expanded, V_expanded, mask)
        else:
            attn_output = self._standard_attention(Q, K_expanded, V_expanded, mask)
        
        # Combine heads
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        
        # Final projection
        output = self.W_o(attn_output)
        
        return output, present_key_value
    
    def get_kv_cache_size(self, seq_len: int, dtype: torch.dtype = torch.float16) -> int:
        """
        Calculate KV cache size in bytes.
        
        Args:
            seq_len: Sequence length
            dtype: Data type
            
        Returns:
            Cache size in bytes
        """
        # Size per parameter
        if dtype == torch.float16:
            bytes_per_param = 2
        elif dtype == torch.float32:
            bytes_per_param = 4
        elif dtype == torch.bfloat16:
            bytes_per_param = 2
        else:
            bytes_per_param = 2  # Default
        
        # KV cache: 2 * n_groups * seq_len * d_k
        cache_size = 2 * self.n_groups * seq_len * self.d_k
        return cache_size * bytes_per_param
    
    def reset_cache(self):
        """Reset KV cache"""
        self.kv_cache = None

class SecurityGQATransformerLayer(nn.Module):
    """
    Complete transformer layer with GQA for security analysis.
    
    This layer combines:
    1. Grouped Query Attention (GQA)
    2. Feed-forward network
    3. Layer normalization
    4. Residual connections
    5. Dropout for regularization
    """
    
    def __init__(self, 
                 d_model: int, 
                 n_heads: int, 
                 n_groups: int,
                 d_ff: int,
                 dropout: float = 0.1,
                 activation: str = "gelu"):
        super().__init__()
        
        # GQA attention
        self.attention = EnhancedFlashGQA(
            d_model=d_model,
            n_heads=n_heads,
            n_groups=n_groups,
            dropout=dropout,
            use_flash=True,
            causal=False  # Non-causal for security analysis
        )
        
        # Layer normalization (Pre-LN architecture)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            self._get_activation(activation),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        
        # Dropout for residual connections
        self.dropout = nn.Dropout(dropout)
    
    def _get_activation(self, activation: str) -> nn.Module:
        """Get activation function"""
        if activation == "gelu":
            return nn.GELU()
        elif activation == "relu":
            return nn.ReLU()
        elif activation == "silu":
            return nn.SiLU()
        else:
            raise ValueError(f"Unknown activation: {activation}")
    
    def forward(self, 
                x: torch.Tensor,
                mask: Optional[torch.Tensor] = None,
                use_cache: bool = False,
                past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        """
        Forward pass through transformer layer.
        
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            mask: Optional attention mask
            use_cache: Whether to use KV cache
            past_key_value: Previous KV cache
            
        Returns:
            output: Layer output
            present_key_value: Updated KV cache
        """
        # Self-attention with residual connection (Pre-LN)
        residual = x
        x_norm = self.norm1(x)
        
        attn_output, present_key_value = self.attention(
            x_norm, x_norm, x_norm,
            mask=mask,
            use_cache=use_cache,
            past_key_value=past_key_value
        )
        
        attn_output = self.dropout(attn_output)
        x = residual + attn_output
        
        # Feed-forward with residual connection
        residual = x
        x_norm = self.norm2(x)
        ff_output = self.ffn(x_norm)
        x = residual + ff_output
        
        return x, present_key_value

class CyberGuardTransformer(nn.Module):
    """
    Complete CyberGuard transformer model for security analysis.
    
    This model uses:
    1. GQA for efficient attention
    2. Multiple transformer layers
    3. Threat classification head
    4. Severity regression head
    5. Feature extraction for agent coordination
    """
    
    def __init__(self, 
                 vocab_size: int,
                 config: TrainingConfig):
        super().__init__()
        
        self.config = config
        self.vocab_size = vocab_size
        
        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, config.d_model)
        
        # Position embedding (learnable, could also use RoPE directly)
        self.position_embedding = nn.Embedding(config.max_seq_len, config.d_model)
        
        # Transformer layers
        self.layers = nn.ModuleList([
            SecurityGQATransformerLayer(
                d_model=config.d_model,
                n_heads=config.n_heads,
                n_groups=config.gqa_groups,
                d_ff=config.d_ff,
                dropout=config.dropout_rate,
                activation="gelu"
            )
            for _ in range(config.n_layers)
        ])
        
        # Final layer normalization
        self.final_norm = nn.LayerNorm(config.d_model)
        
        # Threat classification head (OWASP categories + benign)
        self.threat_classifier = nn.Sequential(
            nn.Linear(config.d_model, config.d_model * 2),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.d_model * 2, len(train_dataset.threat_categories))
        )
        
        # Severity regression head (0.0 to 1.0)
        self.severity_regressor = nn.Sequential(
            nn.Linear(config.d_model, config.d_model),
            nn.GELU(),
            nn.Linear(config.d_model, 1),
            nn.Sigmoid()
        )
        
        # Feature extraction for agent coordination
        self.feature_extractor = nn.Sequential(
            nn.Linear(config.d_model, config.d_model),
            nn.Tanh(),
            nn.Dropout(config.dropout_rate)
        )
        
        # Initialize weights
        self._initialize_weights()
        
        logger.info(f"ü§ñ Initialized CyberGuard Transformer with {config.n_layers} layers")
        logger.info(f"   Total parameters: {self.count_parameters():,}")
        logger.info(f"   GQA groups: {config.gqa_groups} (saves {100*(1-config.gqa_groups/config.n_heads):.1f}% KV memory)")
    
    def _initialize_weights(self):
        """Initialize all weights properly"""
        # Embeddings
        nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.02)
        nn.init.normal_(self.position_embedding.weight, mean=0.0, std=0.02)
        
        # Linear layers
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def count_parameters(self) -> int:
        """Count total trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: Optional[torch.Tensor] = None,
                use_cache: bool = False,
                past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
        """
        Forward pass through the complete model.
        
        Args:
            input_ids: Token IDs [batch_size, seq_len]
            attention_mask: Attention mask [batch_size, seq_len]
            use_cache: Whether to use KV cache
            past_key_values: List of previous KV caches for each layer
            
        Returns:
            Dictionary with model outputs
        """
        batch_size, seq_len = input_ids.shape
        
        # Create position indices
        positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        
        # Get embeddings
        token_embeds = self.token_embedding(input_ids)
        position_embeds = self.position_embedding(positions)
        x = token_embeds + position_embeds
        
        # Create attention mask if not provided
        if attention_mask is None:
            attention_mask = torch.ones(batch_size, seq_len, device=input_ids.device)
        
        # Expand mask for attention heads
        attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, L]
        
        # Process through transformer layers
        present_key_values = [] if use_cache else None
        
        for i, layer in enumerate(self.layers):
            past_key_value = past_key_values[i] if past_key_values is not None else None
            
            x, present_key_value = layer(
                x,
                mask=attention_mask,
                use_cache=use_cache,
                past_key_value=past_key_value
            )
            
            if use_cache:
                present_key_values.append(present_key_value)
        
        # Final normalization
        x = self.final_norm(x)
        
        # Pool features (use CLS token or mean pooling)
        # Here we use mean pooling across sequence
        pooled_features = x.mean(dim=1)
        
        # Extract coordination features
        coordination_features = self.feature_extractor(pooled_features)
        
        # Threat classification
        threat_logits = self.threat_classifier(pooled_features)
        
        # Severity regression
        severity_score = self.severity_regressor(pooled_features).squeeze(-1)
        
        return {
            'threat_logits': threat_logits,
            'severity_score': severity_score,
            'coordination_features': coordination_features,
            'hidden_states': x,
            'pooled_features': pooled_features,
            'present_key_values': present_key_values if use_cache else None
        }
    
    def get_attention_maps(self, 
                          input_ids: torch.Tensor,
                          layer_idx: int = -1) -> torch.Tensor:
        """
        Extract attention maps for visualization.
        
        Args:
            input_ids: Input token IDs
            layer_idx: Layer index to extract attention from (-1 = last layer)
            
        Returns:
            Attention maps [batch_size, n_heads, seq_len, seq_len]
        """
        self.eval()
        with torch.no_grad():
            # We need to modify the attention layer to return attention weights
            # For simplicity, we'll compute attention manually
            
            batch_size, seq_len = input_ids.shape
            
            # Get embeddings
            positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
            token_embeds = self.token_embedding(input_ids)
            position_embeds = self.position_embedding(positions)
            x = token_embeds + position_embeds
            
            # Process up to selected layer
            for i, layer in enumerate(self.layers[:layer_idx+1]):
                # Get layer components
                attn_layer = layer.attention
                norm = layer.norm1
                
                # Apply layer norm
                x_norm = norm(x)
                
                # Project queries, keys, values
                Q = attn_layer.W_q(x_norm).view(batch_size, seq_len, attn_layer.n_heads, attn_layer.d_k)
                K = attn_layer.W_k(x_norm).view(batch_size, seq_len, attn_layer.n_groups, attn_layer.d_k)
                V = attn_layer.W_v(x_norm).view(batch_size, seq_len, attn_layer.n_groups, attn_layer.d_k)
                
                # Apply RoPE
                Q = attn_layer.rope(Q.transpose(1, 2)).transpose(1, 2)
                K = attn_layer.rope(K.transpose(1, 2)).transpose(1, 2)
                
                # Expand KV
                K_expanded, V_expanded = attn_layer._expand_kv(
                    K.transpose(1, 2), 
                    V.transpose(1, 2)
                )
                
                # Compute attention scores
                Q = Q.transpose(1, 2)  # [B, H, L, D]
                scores = torch.matmul(Q, K_expanded.transpose(-2, -1)) / math.sqrt(attn_layer.d_k)
                
                # Apply softmax to get attention weights
                attn_weights = F.softmax(scores, dim=-1)
                
                # Apply attention to values
                attn_output = torch.matmul(attn_weights, V_expanded)
                
                # Reshape and project
                attn_output = attn_output.transpose(1, 2).contiguous()
                attn_output = attn_output.view(batch_size, seq_len, attn_layer.d_model)
                output = attn_layer.W_o(attn_output)
                
                # Residual connection
                x = x + layer.dropout(output)
                
                # FFN (skipped for attention visualization)
                residual = x
                x_norm = layer.norm2(x)
                ff_output = layer.ffn(x_norm)
                x = residual + ff_output
                
                if i == layer_idx:
                    return attn_weights
        
        return None

# Test GQA implementation
def test_gqa_implementation():
    """Test GQA implementation with synthetic data"""
    print("üß™ Testing GQA Implementation...")
    
    # Create synthetic data
    batch_size = 2
    seq_len = 128
    d_model = 512
    vocab_size = 10000
    
    # Create random input
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
    
    # Create model
    model = CyberGuardTransformer(vocab_size, config)
    model = model.to(device)
    input_ids = input_ids.to(device)
    
    # Test forward pass
    with torch.no_grad():
        outputs = model(input_ids)
    
    print(f"‚úÖ Input shape: {input_ids.shape}")
    print(f"‚úÖ Threat logits shape: {outputs['threat_logits'].shape}")
    print(f"‚úÖ Severity scores shape: {outputs['severity_score'].shape}")
    print(f"‚úÖ Coordination features shape: {outputs['coordination_features'].shape}")
    
    # Test attention maps
    attention_maps = model.get_attention_maps(input_ids[:1], layer_idx=0)
    if attention_maps is not None:
        print(f"‚úÖ Attention maps shape: {attention_maps.shape}")
        
        # Visualize attention for first sample, first head
        fig, ax = plt.subplots(figsize=(10, 8))
        im = ax.imshow(attention_maps[0, 0].cpu().numpy(), cmap='viridis')
        ax.set_title('Attention Map (First Head)', fontsize=14)
        ax.set_xlabel('Key Position', fontsize=12)
        ax.set_ylabel('Query Position', fontsize=12)
        plt.colorbar(im, ax=ax)
        plt.tight_layout()
        plt.savefig(config.experiment_dir / "visualizations" / "attention_map.png", dpi=150)
        plt.show()
    
    # Test memory usage
    print("\nüíæ Memory Usage Analysis:")
    gqa_layer = model.layers[0].attention
    seq_lengths = [128, 256, 512, 1024, 2048]
    
    print(f"{'Sequence Length':<15} {'GQA Cache (MB)':<15} {'MHA Cache (MB)':<15} {'Savings':<10}")
    print("-" * 55)
    
    for seq_len in seq_lengths:
        gqa_memory = gqa_layer.get_kv_cache_size(seq_len) / 1e6
        
        # Calculate MHA memory (if all heads had separate KV)
        mha_memory = (2 * config.n_heads * seq_len * gqa_layer.d_k * 2) / 1e6
        
        savings = 100 * (1 - gqa_memory / mha_memory)
        
        print(f"{seq_len:<15} {gqa_memory:<15.2f} {mha_memory:<15.2f} {savings:<10.1f}%")
    
    return model

# Run GQA tests
cyberguard_model = test_gqa_implementation()

print("\n‚úÖ GQA Transformer implementation complete!")
print("üéØ Key features implemented:")
print("   ‚Ä¢ Grouped Query Attention with configurable groups")
print("   ‚Ä¢ Rotary Positional Embedding (RoPE)")
print("   ‚Ä¢ Flash Attention optimization")
print("   ‚Ä¢ KV cache for efficient inference")
print("   ‚Ä¢ Threat classification and severity regression")
print("   ‚Ä¢ Feature extraction for agent coordination")
print(f"   ‚Ä¢ {cyberguard_model.count_parameters():,} total parameters")
<jupyter_output>
<empty_output>
<jupyter_text>
## 6. Multi-Agent Training Loop**Explanation**: This section implements the complete training loop for the multi-agent cybersecurity system. We train both individual agents and their coordination through mHC.### Training Strategy:1. **Individual Agent Pre-training**: Train each agent on its specialty2. **Joint Fine-tuning**: Train agents together with mHC coordination3. **Adversarial Training**: Expose agents to attack simulations4. **Curriculum Learning**: Start easy, increase difficulty
<jupyter_code>
class CyberGuardTrainer:
    """
    Complete trainer for CyberGuard multi-agent system.
    
    This trainer handles:
    1. Individual agent training
    2. Multi-agent coordination training with mHC
    3. Adversarial training for robustness
    4. Curriculum learning for progressive difficulty
    5. Mixed precision training for efficiency
    6. Comprehensive logging and monitoring
    """
    
    def __init__(self, 
                 model: nn.Module,
                 mhc: EnhancedManifoldConstrainedHyperConnections,
                 train_loader: DataLoader,
                 val_loader: DataLoader,
                 config: TrainingConfig,
                 device: torch.device):
        """
        Initialize the trainer.
        
        Args:
            model: CyberGuard transformer model
            mhc: Manifold-Constrained Hyper-Connections module
            train_loader: Training data loader
            val_loader: Validation data loader
            config: Training configuration
            device: Training device (CPU/GPU)
        """
        self.model = model
        self.mhc = mhc
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        self.device = device
        
        # Move models to device
        self.model = self.model.to(device)
        self.mhc = self.mhc.to(device)
        
        # Setup optimizers
        self.optimizer = self._create_optimizer()
        self.scheduler = self._create_scheduler()
        
        # Setup loss functions
        self.criterion = self._create_criterion()
        
        # Setup mixed precision training
        self.scaler = GradScaler() if config.mixed_precision else None
        
        # Training state
        self.global_step = 0
        self.current_epoch = 0
        self.best_val_loss = float('inf')
        self.best_model_state = None
        
        # Metrics tracking
        self.metrics = {
            'train': {
                'loss': [],
                'accuracy': [],
                'threat_f1': [],
                'severity_mae': [],
                'mhc_stability': []
            },
            'val': {
                'loss': [],
                'accuracy': [],
                'threat_f1': [],
                'severity_mae': [],
                'mhc_stability': []
            }
        }
        
        # Early stopping
        self.early_stopping_counter = 0
        
        # Create checkpoint directory
        self.checkpoint_dir = config.experiment_dir / "checkpoints"
        self.checkpoint_dir.mkdir(exist_ok=True)
        
        logger.info("üéØ Initialized CyberGuard Trainer")
        logger.info(f"   Optimizer: AdamW with LR={config.learning_rate}")
        logger.info(f"   Mixed Precision: {config.mixed_precision}")
        logger.info(f"   Early Stopping Patience: {config.early_stopping_patience}")
    
    def _create_optimizer(self) -> torch.optim.Optimizer:
        """
        Create optimizer with weight decay and parameter grouping.
        
        Returns:
            Configured optimizer
        """
        # Separate parameters for different learning rates
        no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
        
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": self.config.weight_decay,
                "lr": self.config.learning_rate
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
                "lr": self.config.learning_rate
            },
            {
                "params": self.mhc.parameters(),
                "weight_decay": self.config.weight_decay * 0.5,  # Less decay for mHC
                "lr": self.config.learning_rate * 0.5  # Lower LR for coordination
            }
        ]
        
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.config.learning_rate,
            betas=(0.9, 0.95),
            eps=1e-8,
            weight_decay=self.config.weight_decay
        )
        
        return optimizer
    
    def _create_scheduler(self) -> torch.optim.lr_scheduler._LRScheduler:
        """
        Create learning rate scheduler with warmup.
        
        Returns:
            Configured scheduler
        """
        # Calculate total training steps
        num_training_steps = len(self.train_loader) * self.config.num_epochs
        
        # Create scheduler with warmup
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer=self.optimizer,
            max_lr=self.config.learning_rate,
            total_steps=num_training_steps,
            pct_start=self.config.warmup_steps / num_training_steps,
            anneal_strategy='cos',
            cycle_momentum=False,
            div_factor=25.0,  # Initial LR = max_lr / 25
            final_div_factor=10000.0  # Final LR = max_lr / 10000
        )
        
        return scheduler
    
    def _create_criterion(self) -> Dict[str, Callable]:
        """
        Create loss functions for multi-task learning.
        
        Returns:
            Dictionary of loss functions
        """
        criterion = {}
        
        # Threat classification loss (with label smoothing)
        if self.config.label_smoothing > 0:
            criterion['classification'] = nn.CrossEntropyLoss(
                label_smoothing=self.config.label_smoothing
            )
        else:
            criterion['classification'] = nn.CrossEntropyLoss()
        
        # Severity regression loss (MSE for regression)
        criterion['severity'] = nn.MSELoss()
        
        # mHC stability loss (encourage stable coordination)
        criterion['mhc_stability'] = lambda metrics: self._compute_mhc_stability_loss(metrics)
        
        return criterion
    
    def _compute_mhc_stability_loss(self, metrics: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Compute loss to encourage mHC stability.
        
        We want:
        1. High identity preservation
        2. Moderate attention entropy (not too low/high)
        3. Bounded signal norms
        
        Args:
            metrics: mHC stability metrics
            
        Returns:
            Stability loss
        """
        loss = 0.0
        
        # Encourage identity preservation (closer to 1 is better)
        identity_loss = 1.0 - metrics.get('identity_preservation', torch.tensor(0.0))
        loss += identity_loss * 0.5
        
        # Encourage moderate attention entropy (target ~log(n_agents))
        target_entropy = math.log(self.mhc.n_agents)
        entropy = metrics.get('attention_entropy', torch.tensor(0.0))
        entropy_loss = torch.abs(entropy - target_entropy) / target_entropy
        loss += entropy_loss * 0.3
        
        # Penalize large state changes
        state_change = metrics.get('state_change', torch.tensor(0.0))
        loss += state_change * 0.2
        
        return loss
    
    def _compute_losses(self, 
                       outputs: Dict[str, torch.Tensor],
                       targets: Dict[str, torch.Tensor],
                       mhc_metrics: Optional[Dict[str, torch.Tensor]] = None) -> Dict[str, torch.Tensor]:
        """
        Compute all losses for multi-task learning.
        
        Args:
            outputs: Model outputs
            targets: Ground truth targets
            mhc_metrics: mHC stability metrics
            
        Returns:
            Dictionary of losses
        """
        losses = {}
        
        # Classification loss
        threat_logits = outputs['threat_logits']
        threat_labels = targets['threat_label']
        losses['classification'] = self.criterion['classification'](threat_logits, threat_labels)
        
        # Severity regression loss
        severity_pred = outputs['severity_score']
        severity_true = targets['severity']
        losses['severity'] = self.criterion['severity'](severity_pred, severity_true)
        
        # mHC stability loss (if available)
        if mhc_metrics is not None:
            losses['mhc_stability'] = self.criterion['mhc_stability'](mhc_metrics)
        else:
            losses['mhc_stability'] = torch.tensor(0.0, device=self.device)
        
        # Total loss (weighted sum)
        losses['total'] = (
            losses['classification'] * 0.6 +
            losses['severity'] * 0.3 +
            losses['mhc_stability'] * 0.1
        )
        
        return losses
    
    def _compute_metrics(self,
                        outputs: Dict[str, torch.Tensor],
                        targets: Dict[str, torch.Tensor]) -> Dict[str, float]:
        """
        Compute evaluation metrics.
        
        Args:
            outputs: Model outputs
            targets: Ground truth targets
            
        Returns:
            Dictionary of metrics
        """
        metrics = {}
        
        # Classification accuracy
        threat_pred = outputs['threat_logits'].argmax(dim=-1)
        threat_true = targets['threat_label']
        accuracy = (threat_pred == threat_true).float().mean().item()
        metrics['accuracy'] = accuracy
        
        # Severity MAE
        severity_pred = outputs['severity_score']
        severity_true = targets['severity']
        mae = torch.abs(severity_pred - severity_true).mean().item()
        metrics['severity_mae'] = mae
        
        # Threat F1 score (macro average)
        from sklearn.metrics import f1_score
        try:
            f1 = f1_score(
                threat_true.cpu().numpy(),
                threat_pred.cpu().numpy(),
                average='macro',
                zero_division=0
            )
            metrics['threat_f1'] = f1
        except:
            metrics['threat_f1'] = 0.0
        
        return metrics
    
    def train_epoch(self) -> Dict[str, float]:
        """
        Train for one epoch.
        
        Returns:
            Dictionary of training metrics
        """
        self.model.train()
        self.mhc.train()
        
        epoch_losses = []
        epoch_metrics = {
            'accuracy': [],
            'severity_mae': [],
            'threat_f1': [],
            'mhc_stability': []
        }
        
        # Progress bar
        from tqdm.auto import tqdm
        pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}")
        
        for batch_idx, batch in enumerate(pbar):
            # Move batch to device
            input_ids = batch['token_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            threat_labels = batch['threat_label'].to(self.device)
            severity = batch['severity'].to(self.device)
            
            # Prepare targets
            targets = {
                'threat_label': threat_labels,
                'severity': severity
            }
            
            # Forward pass with mixed precision
            with autocast(enabled=self.config.mixed_precision):
                # Get model outputs
                outputs = self.model(input_ids, attention_mask)
                
                # Simulate multi-agent coordination
                # In real system, each agent would be a separate model
                # Here we simulate by creating multiple views of the features
                batch_size = outputs['coordination_features'].shape[0]
                n_agents = self.mhc.n_agents
                state_dim = self.mhc.state_dim
                
                # Create synthetic agent states (in real system, these would come from different agents)
                agent_states = outputs['coordination_features'].unsqueeze(1)
                agent_states = agent_states.expand(-1, n_agents, -1)
                
                # Add some noise to differentiate agents
                noise = torch.randn_like(agent_states) * 0.1
                agent_states = agent_states + noise
                
                # Agent confidences (simulated)
                agent_confidences = torch.rand(batch_size, n_agents, device=self.device)
                
                # Apply mHC coordination
                coordinated_state, attention_matrix = self.mhc(agent_states, agent_confidences)
                
                # Get mHC stability metrics
                mhc_metrics = self.mhc.get_stability_metrics(agent_states, coordinated_state)
                
                # Compute losses
                losses = self._compute_losses(outputs, targets, mhc_metrics)
            
            # Backward pass
            self.optimizer.zero_grad()
            
            if self.config.mixed_precision and self.scaler is not None:
                # Mixed precision backward
                self.scaler.scale(losses['total']).backward()
                
                # Gradient clipping
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(
                    list(self.model.parameters()) + list(self.mhc.parameters()),
                    self.config.gradient_clip
                )
                
                # Optimizer step
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                # Standard backward
                losses['total'].backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(
                    list(self.model.parameters()) + list(self.mhc.parameters()),
                    self.config.gradient_clip
                )
                
                # Optimizer step
                self.optimizer.step()
            
            # Scheduler step
            if self.scheduler is not None:
                self.scheduler.step()
            
            # Compute metrics
            metrics = self._compute_metrics(outputs, targets)
            metrics['mhc_stability'] = mhc_metrics['identity_preservation'].item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': losses['total'].item(),
                'acc': metrics['accuracy'],
                'lr': self.optimizer.param_groups[0]['lr']
            })
            
            # Accumulate metrics
            epoch_losses.append(losses['total'].item())
            for key in epoch_metrics:
                epoch_metrics[key].append(metrics[key])
            
            # Log to tensorboard
            if writer is not None and self.global_step % self.config.log_frequency == 0:
                writer.add_scalar('Train/Loss', losses['total'].item(), self.global_step)
                writer.add_scalar('Train/Accuracy', metrics['accuracy'], self.global_step)
                writer.add_scalar('Train/Learning_Rate', self.optimizer.param_groups[0]['lr'], self.global_step)
                writer.add_scalar('Train/mHC_Stability', metrics['mhc_stability'], self.global_step)
                
                # Add mHC attention visualization
                if batch_idx == 0:  # First batch of epoch
                    writer.add_image(
                        'mHC/Attention_Matrix',
                        attention_matrix[0].unsqueeze(0),  # Add channel dimension
                        self.current_epoch
                    )
            
            # Log to metrics logger
            metrics_logger.info(
                f"train,{self.global_step},{losses['total'].item():.4f},"
                f"{metrics['accuracy']:.4f},{metrics['threat_f1']:.4f},"
                f"{metrics['severity_mae']:.4f},{metrics['mhc_stability']:.4f}"
            )
            
            self.global_step += 1
        
        # Compute epoch averages
        avg_loss = np.mean(epoch_losses)
        avg_metrics = {k: np.mean(v) for k, v in epoch_metrics.items()}
        
        return {'loss': avg_loss, **avg_metrics}
    
    def validate(self) -> Dict[str, float]:
        """
        Validate model on validation set.
        
        Returns:
            Dictionary of validation metrics
        """
        self.model.eval()
        self.mhc.eval()
        
        val_losses = []
        val_metrics = {
            'accuracy': [],
            'severity_mae': [],
            'threat_f1': [],
            'mhc_stability': []
        }
        
        all_predictions = []
        all_targets = []
        
        with torch.no_grad():
            for batch in tqdm(self.val_loader, desc="Validation"):
                # Move batch to device
                input_ids = batch['token_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                threat_labels = batch['threat_label'].to(self.device)
                severity = batch['severity'].to(self.device)
                
                # Prepare targets
                targets = {
                    'threat_label': threat_labels,
                    'severity': severity
                }
                
                # Forward pass
                outputs = self.model(input_ids, attention_mask)
                
                # Simulate mHC coordination
                batch_size = outputs['coordination_features'].shape[0]
                n_agents = self.mhc.n_agents
                
                agent_states = outputs['coordination_features'].unsqueeze(1)
                agent_states = agent_states.expand(-1, n_agents, -1)
                
                agent_confidences = torch.rand(batch_size, n_agents, device=self.device)
                coordinated_state, _ = self.mhc(agent_states, agent_confidences)
                
                # Get mHC metrics
                mhc_metrics = self.mhc.get_stability_metrics(agent_states, coordinated_state)
                
                # Compute losses
                losses = self._compute_losses(outputs, targets, mhc_metrics)
                
                # Compute metrics
                metrics = self._compute_metrics(outputs, targets)
                metrics['mhc_stability'] = mhc_metrics['identity_preservation'].item()
                
                # Accumulate
                val_losses.append(losses['total'].item())
                for key in val_metrics:
                    val_metrics[key].append(metrics[key])
                
                # Collect predictions for confusion matrix
                all_predictions.extend(outputs['threat_logits'].argmax(dim=-1).cpu().numpy())
                all_targets.extend(threat_labels.cpu().numpy())
        
        # Compute averages
        avg_loss = np.mean(val_losses)
        avg_metrics = {k: np.mean(v) for k, v in val_metrics.items()}
        
        # Generate confusion matrix
        self._plot_confusion_matrix(all_targets, all_predictions)
        
        # Log to tensorboard
        if writer is not None:
            writer.add_scalar('Val/Loss', avg_loss, self.current_epoch)
            writer.add_scalar('Val/Accuracy', avg_metrics['accuracy'], self.current_epoch)
            writer.add_scalar('Val/Threat_F1', avg_metrics['threat_f1'], self.current_epoch)
        
        # Log to metrics logger
        metrics_logger.info(
            f"val,{self.current_epoch},{avg_loss:.4f},"
            f"{avg_metrics['accuracy']:.4f},{avg_metrics['threat_f1']:.4f},"
            f"{avg_metrics['severity_mae']:.4f},{avg_metrics['mhc_stability']:.4f}"
        )
        
        return {'loss': avg_loss, **avg_metrics}
    
    def _plot_confusion_matrix(self, y_true: List, y_pred: List):
        """
        Plot and save confusion matrix.
        
        Args:
            y_true: True labels
            y_pred: Predicted labels
        """
        from sklearn.metrics import confusion_matrix
        import seaborn as sns
        
        # Get class names
        class_names = train_dataset.threat_categories
        
        # Compute confusion matrix
        cm = confusion_matrix(y_true, y_pred, labels=range(len(class_names)))
        
        # Normalize
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        # Plot
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        
        # Raw counts
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=class_names, yticklabels=class_names, ax=axes[0])
        axes[0].set_title('Confusion Matrix (Counts)', fontsize=14, fontweight='bold')
        axes[0].set_xlabel('Predicted', fontsize=12)
        axes[0].set_ylabel('True', fontsize=12)
        axes[0].tick_params(axis='x', rotation=45)
        axes[0].tick_params(axis='y', rotation=0)
        
        # Normalized
        sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names, ax=axes[1])
        axes[1].set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
        axes[1].set_xlabel('Predicted', fontsize=12)
        axes[1].set_ylabel('True', fontsize=12)
        axes[1].tick_params(axis='x', rotation=45)
        axes[1].tick_params(axis='y', rotation=0)
        
        plt.tight_layout()
        plt.savefig(
            config.experiment_dir / "visualizations" / f"confusion_matrix_epoch_{self.current_epoch}.png",
            dpi=150,
            bbox_inches='tight'
        )
        plt.close(fig)
    
    def save_checkpoint(self, is_best: bool = False):
        """
        Save model checkpoint.
        
        Args:
            is_best: Whether this is the best model so far
        """
        checkpoint = {
            'epoch': self.current_epoch,
            'global_step': self.global_step,
            'model_state_dict': self.model.state_dict(),
            'mhc_state_dict': self.mhc.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'scaler_state_dict': self.scaler.state_dict() if self.scaler else None,
            'best_val_loss': self.best_val_loss,
            'metrics': self.metrics,
            'config': self.config.__dict__
        }
        
        # Save regular checkpoint
        checkpoint_path = self.checkpoint_dir / f"checkpoint_epoch_{self.current_epoch}.pt"
        torch.save(checkpoint, checkpoint_path)
        logger.info(f"üíæ Saved checkpoint: {checkpoint_path}")
        
        # Save best model
        if is_best:
            best_path = self.checkpoint_dir / "best_model.pt"
            torch.save(checkpoint, best_path)
            self.best_model_state = checkpoint.copy()
            logger.info(f"üèÜ Saved best model: {best_path}")
    
    def load_checkpoint(self, checkpoint_path: Path):
        """
        Load model checkpoint.
        
        Args:
            checkpoint_path: Path to checkpoint file
        """
        if not checkpoint_path.exists():
            logger.warning(f"‚ö†Ô∏è Checkpoint not found: {checkpoint_path}")
            return
        
        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        
        # Load state dicts
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.mhc.load_state_dict(checkpoint['mhc_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if self.scheduler and checkpoint['scheduler_state_dict']:
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        if self.scaler and checkpoint['scaler_state_dict']:
            self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
        
        # Load training state
        self.current_epoch = checkpoint['epoch']
        self.global_step = checkpoint['global_step']
        self.best_val_loss = checkpoint['best_val_loss']
        self.metrics = checkpoint.get('metrics', self.metrics)
        
        logger.info(f"üìÇ Loaded checkpoint from epoch {self.current_epoch}")
    
    def train(self):
        """
        Main training loop.
        """
        logger.info("üöÄ Starting training...")
        logger.info(f"üìà Total epochs: {self.config.num_epochs}")
        logger.info(f"üìä Training samples: {len(self.train_loader.dataset)}")
        logger.info(f"üéØ Validation samples: {len(self.val_loader.dataset)}")
        
        # Training loop
        for epoch in range(self.current_epoch, self.config.num_epochs):
            self.current_epoch = epoch
            
            logger.info(f"\n{'='*80}")
            logger.info(f"üìÖ Epoch {epoch + 1}/{self.config.num_epochs}")
            logger.info(f"{'='*80}")
            
            # Train for one epoch
            train_results = self.train_epoch()
            self.metrics['train']['loss'].append(train_results['loss'])
            self.metrics['train']['accuracy'].append(train_results['accuracy'])
            self.metrics['train']['threat_f1'].append(train_results['threat_f1'])
            self.metrics['train']['severity_mae'].append(train_results['severity_mae'])
            self.metrics['train']['mhc_stability'].append(train_results['mhc_stability'])
            
            # Validate
            val_results = self.validate()
            self.metrics['val']['loss'].append(val_results['loss'])
            self.metrics['val']['accuracy'].append(val_results['accuracy'])
            self.metrics['val']['threat_f1'].append(val_results['threat_f1'])
            self.metrics['val']['severity_mae'].append(val_results['severity_mae'])
            self.metrics['val']['mhc_stability'].append(val_results['mhc_stability'])
            
            # Print epoch results
            logger.info(f"üìä Train Loss: {train_results['loss']:.4f}, "
                       f"Acc: {train_results['accuracy']:.4f}, "
                       f"F1: {train_results['threat_f1']:.4f}")
            logger.info(f"üéØ Val Loss: {val_results['loss']:.4f}, "
                       f"Acc: {val_results['accuracy']:.4f}, "
                       f"F1: {val_results['threat_f1']:.4f}")
            
            # Check for improvement
            is_best = val_results['loss'] < self.best_val_loss
            
            if is_best:
                self.best_val_loss = val_results['loss']
                self.early_stopping_counter = 0
                logger.info(f"üèÜ New best model! Val loss: {val_results['loss']:.4f}")
            else:
                self.early_stopping_counter += 1
                logger.info(f"‚è≥ No improvement for {self.early_stopping_counter} epoch(s)")
            
            # Save checkpoint
            if (epoch + 1) % self.config.checkpoint_frequency == 0 or is_best:
                self.save_checkpoint(is_best=is_best)
            
            # Check early stopping
            if self.early_stopping_counter >= self.config.early_stopping_patience:
                logger.info(f"üõë Early stopping triggered after {epoch + 1} epochs")
                break
        
        # Save final model
        logger.info("üíæ Saving final model...")
        self.save_checkpoint(is_best=False)
        
        # Save final metrics
        metrics_path = config.experiment_dir / "training_metrics.json"
        with open(metrics_path, 'w') as f:
            json.dump(self.metrics, f, indent=2, default=lambda x: float(x) if torch.is_tensor(x) else x)
        
        logger.info(f"üìä Training metrics saved to: {metrics_path}")
        logger.info("‚úÖ Training complete!")

# Initialize trainer
logger.info("üéØ Initializing CyberGuard Trainer...")

# Create mHC module (simulating 5 security agents)
n_agents = 5  # Simulating 5 different security agents
state_dim = config.d_model  # Same as model dimension for coordination

mhc_module = EnhancedManifoldConstrainedHyperConnections(
    n_agents=n_agents,
    state_dim=state_dim,
    temperature=config.mhc_temperature,
    sinkhorn_iterations=config.mhc_sinkhorn_iterations
)

# Initialize trainer
trainer = CyberGuardTrainer(
    model=cyberguard_model,
    mhc=mhc_module,
    train_loader=train_loader,
    val_loader=val_loader,
    config=config,
    device=device
)

# Check if we should resume from checkpoint
checkpoint_to_load = None
if checkpoint_to_load:
    logger.info(f"üìÇ Loading checkpoint: {checkpoint_to_load}")
    trainer.load_checkpoint(Path(checkpoint_to_load))

print("\n‚úÖ Multi-agent trainer initialized!")
print("üéØ Training configuration:")
print(f"   ‚Ä¢ Number of agents: {n_agents}")
print(f"   ‚Ä¢ State dimension: {state_dim}")
print(f"   ‚Ä¢ Training epochs: {config.num_epochs}")
print(f"   ‚Ä¢ Batch size: {config.batch_size}")
print(f"   ‚Ä¢ Learning rate: {config.learning_rate}")
print(f"   ‚Ä¢ Mixed precision: {config.mixed_precision}")
print(f"   ‚Ä¢ Early stopping patience: {config.early_stopping_patience}")
<jupyter_output>
<empty_output>
<jupyter_text>
## 7. mHC Coordination Training**Explanation**: This section focuses specifically on training the mHC coordination mechanism. We want to ensure that agents collaborate effectively without any single agent dominating the decision-making.### Training Objectives:1. **Balanced Contribution**: All agents should contribute meaningfully2. **Stable Coordination**: Output shouldn't oscillate wildly3. **Identity Preservation**: Agents should retain their specialties4. **Conflict Resolution**: Disagreements should be resolved rationally
<jupyter_code>
def analyze_mhc_coordination(trainer: CyberGuardTrainer, 
                            test_batch: Dict[str, torch.Tensor]):
    """
    Analyze mHC coordination behavior.
    
    Args:
        trainer: CyberGuard trainer
        test_batch: Test batch for analysis
    """
    print("üîç Analyzing mHC Coordination...")
    
    trainer.model.eval()
    trainer.mhc.eval()
    
    with torch.no_grad():
        # Move batch to device
        input_ids = test_batch['token_ids'].to(device)
        attention_mask = test_batch['attention_mask'].to(device)
        
        # Get model outputs
        outputs = trainer.model(input_ids, attention_mask)
        
        # Create synthetic agent states with different "personalities"
        batch_size = outputs['coordination_features'].shape[0]
        n_agents = trainer.mhc.n_agents
        
        # Create agents with different biases
        agent_states = []
        for i in range(n_agents):
            # Each agent gets the base features plus a unique bias
            bias = torch.ones_like(outputs['coordination_features']) * (i * 0.1)
            agent_state = outputs['coordination_features'] + bias
            agent_states.append(agent_state)
        
        agent_states = torch.stack(agent_states, dim=1)  # [B, N, D]
        
        # Create agent confidences (simulating different confidence levels)
        agent_confidences = torch.rand(batch_size, n_agents, device=device)
        
        # Apply mHC coordination
        coordinated_state, attention_matrix = trainer.mhc(agent_states, agent_confidences)
        
        # Get stability metrics
        mhc_metrics = trainer.mhc.get_stability_metrics(agent_states, coordinated_state)
    
    print("\nüìä Coordination Analysis:")
    print(f"   Number of agents: {n_agents}")
    print(f"   State dimension: {trainer.mhc.state_dim}")
    
    print("\nüéØ Stability Metrics:")
    for name, value in mhc_metrics.items():
        print(f"   {name}: {value.item():.4f}")
    
    print("\nü§ñ Agent Contributions (Attention Matrix):")
    # Print attention matrix for first sample
    attn_sample = attention_matrix[0].cpu().numpy()
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Attention matrix heatmap
    im1 = axes[0].imshow(attn_sample, cmap='viridis', vmin=0, vmax=1)
    axes[0].set_title('Agent Attention Matrix', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Key Agent', fontsize=12)
    axes[0].set_ylabel('Query Agent', fontsize=12)
    axes[0].set_xticks(range(n_agents))
    axes[0].set_yticks(range(n_agents))
    plt.colorbar(im1, ax=axes[0])
    
    # Agent contribution distribution
    agent_contributions = attn_sample.mean(axis=0)  # Average attention received
    bars = axes[1].bar(range(n_agents), agent_contributions)
    axes[1].set_title('Agent Contribution Distribution', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Agent Index', fontsize=12)
    axes[1].set_ylabel('Average Attention', fontsize=12)
    axes[1].set_xticks(range(n_agents))
    axes[1].axhline(y=1/n_agents, color='r', linestyle='--', label='Ideal (Equal)')
    axes[1].legend()
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig(config.experiment_dir / "visualizations" / "mhc_coordination_analysis.png", dpi=150)
    plt.show()
    
    # Check doubly-stochastic property
    print("\nüìê Doubly-Stochastic Property Check:")
    row_sums = attn_sample.sum(axis=1)
    col_sums = attn_sample.sum(axis=0)
    
    print(f"   Row sums: {row_sums}")
    print(f"   Column sums: {col_sums}")
    print(f"   Max row deviation: {np.abs(row_sums - 1).max():.6f}")
    print(f"   Max column deviation: {np.abs(col_sums - 1).max():.6f}")
    
    # Analyze attention entropy
    print("\nüé≤ Attention Entropy Analysis:")
    for i in range(n_agents):
        entropy = -np.sum(attn_sample[i] * np.log(attn_sample[i] + 1e-8))
        print(f"   Agent {i} attention entropy: {entropy:.4f}")
    
    # Check for dominant agents
    print("\nüëë Dominant Agent Analysis:")
    avg_attention_per_agent = attn_sample.mean(axis=0)
    dominant_threshold = 2.0 / n_agents  # More than twice the ideal
    
    for i in range(n_agents):
        attention = avg_attention_per_agent[i]
        if attention > dominant_threshold:
            print(f"   ‚ö†Ô∏è  Agent {i} might be dominant: {attention:.3f} > {dominant_threshold:.3f}")
        else:
            print(f"   ‚úÖ Agent {i} contribution balanced: {attention:.3f}")
    
    return {
        'attention_matrix': attention_matrix,
        'mhc_metrics': mhc_metrics,
        'agent_contributions': avg_attention_per_agent
    }

# Get a test batch for analysis
test_batch = next(iter(val_loader))

# Analyze mHC coordination
coordination_results = analyze_mhc_coordination(trainer, test_batch)

print("\n‚úÖ mHC coordination analysis complete!")
print("üìä Key insights:")
print("   ‚Ä¢ Doubly-stochastic normalization ensures balanced contributions")
print("   ‚Ä¢ Attention entropy indicates diversity of attention patterns")
print("   ‚Ä¢ Identity preservation maintains agent specialties")
print("   ‚Ä¢ Signal bounding prevents information explosion")
<jupyter_output>
<empty_output>
<jupyter_text>
## 8. Adversarial Training**Explanation**: Adversarial training makes the model more robust by exposing it to malicious inputs during training. This is crucial for cybersecurity applications where attackers actively try to evade detection.### Adversarial Techniques:1. **FGSM (Fast Gradient Sign Method)**: Small perturbations in gradient direction2. **PGD (Projected Gradient Descent)**: Iterative FGSM with constraints3. **Text Adversarial Attacks**: Word substitutions, insertions, deletions4. **Obfuscation Attacks**: Encoding, padding, fragmentation
<jupyter_code>
class AdversarialTrainer:
    """
    Adversarial training module for cybersecurity robustness.
    
    This module generates adversarial examples to make the model
    more robust against evasion attacks.
    """
    
    def __init__(self, 
                 model: nn.Module,
                 epsilon: float = 0.1,
                 alpha: float = 0.01,
                 num_iterations: int = 7,
                 attack_type: str = 'pgd'):
        """
        Initialize adversarial trainer.
        
        Args:
            model: Model to train adversarially
            epsilon: Maximum perturbation magnitude
            alpha: Step size for attacks
            num_iterations: Number of PGD iterations
            attack_type: Type of attack ('fgsm', 'pgd', 'text')
        """
        self.model = model
        self.epsilon = epsilon
        self.alpha = alpha
        self.num_iterations = num_iterations
        self.attack_type = attack_type
        
        # Text attack parameters
        self.substitution_rate = 0.1
        self.insertion_rate = 0.05
        self.deletion_rate = 0.05
        
        # Obfuscation patterns (common in web attacks)
        self.obfuscation_patterns = [
            ('<script>', '<scr\u0131pt>'),  # Unicode homoglyph
            ('alert', 'al\u0065rt'),         # Unicode insertion
            ('javascript', 'javascr\u0131pt'),
            ('onload', 'on\u0131oad'),
            ('union', 'un\u0131on'),
            ('select', 'sel\u0065ct'),
            ('from', 'fr\u006fm'),
        ]
    
    def fgsm_attack(self, 
                   embeddings: torch.Tensor,
                   labels: torch.Tensor,
                   criterion: nn.Module) -> torch.Tensor:
        """
        Fast Gradient Sign Method (FGSM) attack.
        
        Args:
            embeddings: Input embeddings
            labels: True labels
            criterion: Loss function
            
        Returns:
            Adversarial embeddings
        """
        # Ensure gradients are enabled
        embeddings.requires_grad = True
        
        # Forward pass
        outputs = self.model(embeddings)
        loss = criterion(outputs['threat_logits'], labels)
        
        # Backward pass to get gradients
        self.model.zero_grad()
        loss.backward()
        
        # Get gradient sign
        gradient_sign = embeddings.grad.sign()
        
        # Create adversarial example
        adversarial_embeddings = embeddings + self.epsilon * gradient_sign
        
        # Project back to epsilon ball
        delta = adversarial_embeddings - embeddings
        delta = torch.clamp(delta, -self.epsilon, self.epsilon)
        adversarial_embeddings = embeddings + delta
        
        return adversarial_embeddings.detach()
    
    def pgd_attack(self,
                  embeddings: torch.Tensor,
                  labels: torch.Tensor,
                  criterion: nn.Module) -> torch.Tensor:
        """
        Projected Gradient Descent (PGD) attack.
        
        Args:
            embeddings: Input embeddings
            labels: True labels
            criterion: Loss function
            
        Returns:
            Adversarial embeddings
        """
        # Start from random perturbation
        delta = torch.rand_like(embeddings) * 2 * self.epsilon - self.epsilon
        adversarial_embeddings = embeddings + delta
        
        for _ in range(self.num_iterations):
            adversarial_embeddings.requires_grad = True
            
            # Forward pass
            outputs = self.model(adversarial_embeddings)
            loss = criterion(outputs['threat_logits'], labels)
            
            # Backward pass
            self.model.zero_grad()
            loss.backward()
            
            # Update perturbation
            with torch.no_grad():
                gradient = adversarial_embeddings.grad
                delta = delta + self.alpha * gradient.sign()
                
                # Project back to epsilon ball
                delta = torch.clamp(delta, -self.epsilon, self.epsilon)
                adversarial_embeddings = embeddings + delta
        
        return adversarial_embeddings.detach()
    
    def text_adversarial_attack(self, 
                               input_ids: torch.Tensor,
                               tokenizer: Dict) -> torch.Tensor:
        """
        Text-level adversarial attack for token sequences.
        
        Args:
            input_ids: Input token IDs
            tokenizer: Tokenizer for vocabulary
            
        Returns:
            Adversarial token IDs
        """
        adversarial_ids = input_ids.clone()
        batch_size, seq_len = input_ids.shape
        
        # Convert to list for manipulation
        ids_list = input_ids.cpu().tolist()
        adv_list = []
        
        for sample in ids_list:
            adv_sample = sample.copy()
            
            # Random substitutions
            if np.random.random() < self.substitution_rate:
                # Choose random position (skip special tokens)
                valid_positions = [i for i, tid in enumerate(sample) 
                                 if tid not in [0, 1, 2, 3, 4]]  # Skip special tokens
                if valid_positions:
                    pos = np.random.choice(valid_positions)
                    # Replace with random token
                    adv_sample[pos] = np.random.randint(5, len(tokenizer))
            
            # Random insertions
            if np.random.random() < self.insertion_rate:
                # Choose random position
                pos = np.random.randint(0, len(sample))
                # Insert random token
                adv_sample.insert(pos, np.random.randint(5, len(tokenizer)))
                # Truncate if too long
                if len(adv_sample) > seq_len:
                    adv_sample = adv_sample[:seq_len]
            
            # Random deletions
            if np.random.random() < self.deletion_rate:
                valid_positions = [i for i, tid in enumerate(sample) 
                                 if tid not in [0, 1, 2, 3, 4]]
                if valid_positions:
                    pos = np.random.choice(valid_positions)
                    del adv_sample[pos]
                    # Pad if too short
                    if len(adv_sample) < seq_len:
                        adv_sample.append(0)  # Pad token
            
            adv_list.append(adv_sample)
        
        # Convert back to tensor
        adversarial_ids = torch.tensor(adv_list, device=input_ids.device, dtype=torch.long)
        
        return adversarial_ids
    
    def obfuscation_attack(self, 
                          text: str,
                          is_web_attack: bool = True) -> str:
        """
        Apply obfuscation patterns common in web attacks.
        
        Args:
            text: Input text
            is_web_attack: Whether this is a web attack pattern
            
        Returns:
            Obfuscated text
        """
        if not is_web_attack:
            return text
        
        obfuscated_text = text
        
        # Apply obfuscation patterns
        for pattern, replacement in self.obfuscation_patterns:
            if np.random.random() < 0.3:  # 30% chance to apply each pattern
                obfuscated_text = obfuscated_text.replace(pattern, replacement)
        
        # Add random whitespace
        if np.random.random() < 0.2:
            # Insert random whitespace
            chars = list(obfuscated_text)
            for _ in range(np.random.randint(1, 3)):
                pos = np.random.randint(0, len(chars))
                chars.insert(pos, ' ')
            obfuscated_text = ''.join(chars)
        
        # Add URL encoding
        if np.random.random() < 0.2:
            # Randomly encode some characters
            import urllib.parse
            chars_to_encode = np.random.choice(list(obfuscated_text), 
                                              size=min(3, len(obfuscated_text)), 
                                              replace=False)
            for char in chars_to_encode:
                encoded = urllib.parse.quote(char)
                obfuscated_text = obfuscated_text.replace(char, encoded)
        
        return obfuscated_text
    
    def adversarial_training_step(self,
                                 trainer: CyberGuardTrainer,
                                 batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Perform adversarial training step.
        
        Args:
            trainer: CyberGuard trainer
            batch: Training batch
            
        Returns:
            Dictionary with losses and metrics
        """
        # Extract data
        input_ids = batch['token_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        threat_labels = batch['threat_label'].to(device)
        severity = batch['severity'].to(device)
        
        # Get embeddings from model
        with torch.no_grad():
            embeddings = trainer.model.token_embedding(input_ids)
        
        # Create adversarial examples
        if self.attack_type == 'fgsm':
            adversarial_embeddings = self.fgsm_attack(
                embeddings, threat_labels, trainer.criterion['classification']
            )
        elif self.attack_type == 'pgd':
            adversarial_embeddings = self.pgd_attack(
                embeddings, threat_labels, trainer.criterion['classification']
            )
        elif self.attack_type == 'text':
            # Text-level attack
            adversarial_ids = self.text_adversarial_attack(input_ids, train_dataset.tokenizer)
            adversarial_embeddings = trainer.model.token_embedding(adversarial_ids)
        else:
            raise ValueError(f"Unknown attack type: {self.attack_type}")
        
        # Train on adversarial examples
        trainer.model.train()
        trainer.optimizer.zero_grad()
        
        # Forward pass with adversarial examples
        outputs = trainer.model._forward_from_embeddings(adversarial_embeddings, attention_mask)
        
        # Compute losses
        targets = {'threat_label': threat_labels, 'severity': severity}
        losses = trainer._compute_losses(outputs, targets)
        
        # Backward pass
        losses['total'].backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(trainer.model.parameters(), trainer.config.gradient_clip)
        
        # Optimizer step
        trainer.optimizer.step()
        
        # Compute metrics
        metrics = trainer._compute_metrics(outputs, targets)
        
        return {
            'loss': losses['total'].item(),
            'accuracy': metrics['accuracy'],
            'threat_f1': metrics['threat_f1']
        }
    
    def evaluate_robustness(self,
                          trainer: CyberGuardTrainer,
                          test_loader: DataLoader,
                          num_batches: int = 10) -> Dict[str, float]:
        """
        Evaluate model robustness against adversarial attacks.
        
        Args:
            trainer: CyberGuard trainer
            test_loader: Test data loader
            num_batches: Number of batches to evaluate
            
        Returns:
            Dictionary with robustness metrics
        """
        trainer.model.eval()
        
        clean_accuracies = []
        adversarial_accuracies = []
        
        batch_count = 0
        
        for batch in test_loader:
            if batch_count >= num_batches:
                break
            
            # Extract data
            input_ids = batch['token_ids'].to(device)
            threat_labels = batch['threat_label'].to(device)
            
            # Clean accuracy
            with torch.no_grad():
                clean_outputs = trainer.model(input_ids)
                clean_pred = clean_outputs['threat_logits'].argmax(dim=-1)
                clean_acc = (clean_pred == threat_labels).float().mean().item()
                clean_accuracies.append(clean_acc)
            
            # Adversarial accuracy (FGSM attack)
            with torch.enable_grad():
                embeddings = trainer.model.token_embedding(input_ids)
                adversarial_embeddings = self.fgsm_attack(
                    embeddings, threat_labels, trainer.criterion['classification']
                )
                
                with torch.no_grad():
                    adv_outputs = trainer.model._forward_from_embeddings(adversarial_embeddings)
                    adv_pred = adv_outputs['threat_logits'].argmax(dim=-1)
                    adv_acc = (adv_pred == threat_labels).float().mean().item()
                    adversarial_accuracies.append(adv_acc)
            
            batch_count += 1
        
        # Compute statistics
        clean_accuracy = np.mean(clean_accuracies)
        adversarial_accuracy = np.mean(adversarial_accuracies)
        robustness_gap = clean_accuracy - adversarial_accuracy
        
        return {
            'clean_accuracy': clean_accuracy,
            'adversarial_accuracy': adversarial_accuracy,
            'robustness_gap': robustness_gap,
            'robustness_score': adversarial_accuracy / clean_accuracy if clean_accuracy > 0 else 0
        }

# Initialize adversarial trainer
adversarial_trainer = AdversarialTrainer(
    model=cyberguard_model,
    epsilon=0.1,
    alpha=0.01,
    num_iterations=7,
    attack_type='pgd'  # Can be 'fgsm', 'pgd', or 'text'
)

print("üîê Adversarial Training Module Initialized")
print(f"üéØ Attack type: {adversarial_trainer.attack_type}")
print(f"üìè Epsilon (perturbation bound): {adversarial_trainer.epsilon}")
print(f"üìà Alpha (step size): {adversarial_trainer.alpha}")
print(f"üîÑ Iterations: {adversarial_trainer.num_iterations}")

# Test adversarial attack
def test_adversarial_attack():
    """Test adversarial attack generation"""
    print("\nüß™ Testing Adversarial Attacks...")
    
    # Get a test batch
    test_batch = next(iter(val_loader))
    input_ids = test_batch['token_ids'][:2].to(device)  # First 2 samples
    threat_labels = test_batch['threat_label'][:2].to(device)
    
    # Get embeddings
    with torch.no_grad():
        embeddings = cyberguard_model.token_embedding(input_ids)
    
    print(f"‚úÖ Original embeddings shape: {embeddings.shape}")
    
    # Test FGSM attack
    print("\n‚ö° Testing FGSM attack...")
    adversarial_embeddings = adversarial_trainer.fgsm_attack(
        embeddings, threat_labels, nn.CrossEntropyLoss()
    )
    
    perturbation = (adversarial_embeddings - embeddings).abs().max().item()
    print(f"   Max perturbation: {perturbation:.6f}")
    print(f"   Epsilon bound: {adversarial_trainer.epsilon}")
    print(f"   Within bounds: {perturbation <= adversarial_trainer.epsilon}")
    
    # Test PGD attack
    print("\nüîÑ Testing PGD attack...")
    adversarial_embeddings_pgd = adversarial_trainer.pgd_attack(
        embeddings, threat_labels, nn.CrossEntropyLoss()
    )
    
    perturbation_pgd = (adversarial_embeddings_pgd - embeddings).abs().max().item()
    print(f"   Max perturbation: {perturbation_pgd:.6f}")
    print(f"   Epsilon bound: {adversarial_trainer.epsilon}")
    print(f"   Within bounds: {perturbation_pgd <= adversarial_trainer.epsilon}")
    
    # Test text adversarial attack
    print("\nüìù Testing text adversarial attack...")
    adversarial_ids = adversarial_trainer.text_adversarial_attack(
        input_ids, train_dataset.tokenizer
    )
    
    # Count changed tokens
    changed_tokens = (adversarial_ids != input_ids).sum().item()
    total_tokens = input_ids.numel()
    change_rate = changed_tokens / total_tokens
    print(f"   Changed tokens: {changed_tokens}/{total_tokens} ({change_rate:.1%})")
    
    # Test obfuscation
    print("\nüé≠ Testing obfuscation attack...")
    test_text = "<script>alert('XSS')</script> UNION SELECT * FROM users"
    obfuscated_text = adversarial_trainer.obfuscation_attack(test_text, is_web_attack=True)
    print(f"   Original: {test_text}")
    print(f"   Obfuscated: {obfuscated_text}")
    
    # Visualize perturbations
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Original vs FGSM
    diff_fgsm = (adversarial_embeddings - embeddings).abs().mean(dim=-1).cpu().numpy()
    axes[0].imshow(diff_fgsm, cmap='hot', aspect='auto')
    axes[0].set_title('FGSM Perturbations', fontsize=12)
    axes[0].set_xlabel('Sequence Position', fontsize=10)
    axes[0].set_ylabel('Sample', fontsize=10)
    
    # Original vs PGD
    diff_pgd = (adversarial_embeddings_pgd - embeddings).abs().mean(dim=-1).cpu().numpy()
    axes[1].imshow(diff_pgd, cmap='hot', aspect='auto')
    axes[1].set_title('PGD Perturbations', fontsize=12)
    axes[1].set_xlabel('Sequence Position', fontsize=10)
    
    # Token changes
    token_changes = (adversarial_ids != input_ids).cpu().numpy()
    axes[2].imshow(token_changes, cmap='binary', aspect='auto')
    axes[2].set_title('Token Changes (Text Attack)', fontsize=12)
    axes[2].set_xlabel('Sequence Position', fontsize=10)
    
    plt.tight_layout()
    plt.savefig(config.experiment_dir / "visualizations" / "adversarial_attacks.png", dpi=150)
    plt.show()
    
    return True

# Run adversarial attack tests
test_adversarial_attack()

print("\n‚úÖ Adversarial training module ready!")
print("üîê This will help make CyberGuard robust against:")
print("   ‚Ä¢ Gradient-based evasion attacks (FGSM, PGD)")
print("   ‚Ä¢ Text manipulation attacks")
print("   ‚Ä¢ Obfuscation and encoding attacks")
print("   ‚Ä¢ Adversarial examples in the wild")
<jupyter_output>
<empty_output>
<jupyter_text>
## 9. Evaluation & Metrics**Explanation**: Comprehensive evaluation is crucial for cybersecurity systems. We need to measure not just accuracy, but also robustness, fairness, and operational metrics.### Evaluation Dimensions:1. **Accuracy Metrics**: Precision, Recall, F1, AUC-ROC2. **Robustness Metrics**: Adversarial accuracy, perturbation sensitivity3. **Fairness Metrics**: Equal opportunity, demographic parity4. **Operational Metrics**: Inference speed, memory usage, scalability
<jupyter_code>
class ComprehensiveEvaluator:
    """
    Comprehensive evaluator for CyberGuard system.
    
    Evaluates:
    1. Accuracy and classification performance
    2. Robustness against adversarial attacks
    3. Fairness across different threat types
    4. Operational efficiency (speed, memory)
    5. mHC coordination stability
    """
    
    def __init__(self,
                 model: nn.Module,
                 mhc: EnhancedManifoldConstrainedHyperConnections,
                 test_loader: DataLoader,
                 device: torch.device,
                 threat_categories: List[str]):
        """
        Initialize evaluator.
        
        Args:
            model: CyberGuard model
            mhc: mHC coordination module
            test_loader: Test data loader
            device: Evaluation device
            threat_categories: List of threat category names
        """
        self.model = model
        self.mhc = mhc
        self.test_loader = test_loader
        self.device = device
        self.threat_categories = threat_categories
        
        # Move models to device
        self.model = self.model.to(device)
        self.mhc = self.mhc.to(device)
        
        # Evaluation results storage
        self.results = {}
        
        logger.info("üìä Initialized Comprehensive Evaluator")
    
    def evaluate_classification(self) -> Dict[str, Any]:
        """
        Evaluate classification performance.
        
        Returns:
            Dictionary with classification metrics
        """
        print("üìà Evaluating Classification Performance...")
        
        self.model.eval()
        
        all_predictions = []
        all_labels = []
        all_probabilities = []
        all_severity_pred = []
        all_severity_true = []
        
        with torch.no_grad():
            for batch in tqdm(self.test_loader, desc="Classification Evaluation"):
                # Move to device
                input_ids = batch['token_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                threat_labels = batch['threat_label'].to(self.device)
                severity = batch['severity'].to(self.device)
                
                # Forward pass
                outputs = self.model(input_ids, attention_mask)
                
                # Get predictions
                threat_pred = outputs['threat_logits'].argmax(dim=-1)
                threat_probs = F.softmax(outputs['threat_logits'], dim=-1)
                
                # Collect results
                all_predictions.extend(threat_pred.cpu().numpy())
                all_labels.extend(threat_labels.cpu().numpy())
                all_probabilities.extend(threat_probs.cpu().numpy())
                all_severity_pred.extend(outputs['severity_score'].cpu().numpy())
                all_severity_true.extend(severity.cpu().numpy())
        
        # Convert to numpy arrays
        all_predictions = np.array(all_predictions)
        all_labels = np.array(all_labels)
        all_probabilities = np.array(all_probabilities)
        all_severity_pred = np.array(all_severity_pred)
        all_severity_true = np.array(all_severity_true)
        
        # Compute metrics
        from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                                   roc_auc_score, confusion_matrix, classification_report)
        
        # Basic accuracy
        accuracy = accuracy_score(all_labels, all_predictions)
        
        # Per-class metrics
        precision, recall, f1, support = precision_recall_fscore_support(
            all_labels, all_predictions, average=None, zero_division=0
        )
        
        # Macro and weighted averages
        precision_macro = precision.mean()
        recall_macro = recall.mean()
        f1_macro = f1.mean()
        
        precision_weighted = np.average(precision, weights=support)
        recall_weighted = np.average(recall, weights=support)
        f1_weighted = np.average(f1, weights=support)
        
        # ROC-AUC (if binary or one-vs-rest)
        try:
            if len(self.threat_categories) == 2:
                auc = roc_auc_score(all_labels, all_probabilities[:, 1])
            else:
                # One-vs-rest AUC
                auc = roc_auc_score(all_labels, all_probabilities, multi_class='ovr', average='macro')
        except:
            auc = 0.0
        
        # Severity regression metrics
        from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
        severity_mae = mean_absolute_error(all_severity_true, all_severity_pred)
        severity_mse = mean_squared_error(all_severity_true, all_severity_pred)
        severity_r2 = r2_score(all_severity_true, all_severity_pred)
        
        # Create classification report
        class_report = classification_report(
            all_labels, all_predictions,
            target_names=self.threat_categories,
            output_dict=True
        )
        
        # Create confusion matrix
        cm = confusion_matrix(all_labels, all_predictions)
        
        # Store results
        self.results['classification'] = {
            'accuracy': accuracy,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'f1_macro': f1_macro,
            'precision_weighted': precision_weighted,
            'recall_weighted': recall_weighted,
            'f1_weighted': f1_weighted,
            'auc_roc': auc,
            'severity_mae': severity_mae,
            'severity_mse': severity_mse,
            'severity_r2': severity_r2,
            'class_report': class_report,
            'confusion_matrix': cm.tolist()
        }
        
        # Print summary
        print("\nüìä Classification Results:")
        print(f"   Accuracy: {accuracy:.4f}")
        print(f"   F1 Macro: {f1_macro:.4f}")
        print(f"   F1 Weighted: {f1_weighted:.4f}")
        print(f"   AUC-ROC: {auc:.4f}")
        print(f"   Severity MAE: {severity_mae:.4f}")
        print(f"   Severity R¬≤: {severity_r2:.4f}")
        
        # Plot confusion matrix
        self._plot_confusion_matrix_heatmap(cm, "Classification Confusion Matrix")
        
        # Plot precision-recall by class
        self._plot_precision_recall_by_class(precision, recall, f1, support)
        
        return self.results['classification']
    
    def evaluate_robustness(self, 
                          adversarial_trainer: AdversarialTrainer,
                          num_batches: int = 20) -> Dict[str, Any]:
        """
        Evaluate model robustness against adversarial attacks.
        
        Args:
            adversarial_trainer: Adversarial trainer for attack generation
            num_batches: Number of batches to evaluate
            
        Returns:
            Dictionary with robustness metrics
        """
        print("üõ°Ô∏è Evaluating Robustness...")
        
        self.model.eval()
        
        clean_accuracies = []
        fgsm_accuracies = []
        pgd_accuracies = []
        text_attack_accuracies = []
        
        batch_count = 0
        
        for batch in self.test_loader:
            if batch_count >= num_batches:
                break
            
            # Extract data
            input_ids = batch['token_ids'].to(self.device)
            threat_labels = batch['threat_label'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            
            # Clean accuracy
            with torch.no_grad():
                outputs = self.model(input_ids, attention_mask)
                pred = outputs['threat_logits'].argmax(dim=-1)
                clean_acc = (pred == threat_labels).float().mean().item()
                clean_accuracies.append(clean_acc)
            
            # FGSM attack
            with torch.enable_grad():
                embeddings = self.model.token_embedding(input_ids)
                
                # FGSM
                adv_embeddings_fgsm = adversarial_trainer.fgsm_attack(
                    embeddings, threat_labels, nn.CrossEntropyLoss()
                )
                
                with torch.no_grad():
                    outputs_fgsm = self.model._forward_from_embeddings(adv_embeddings_fgsm, attention_mask)
                    pred_fgsm = outputs_fgsm['threat_logits'].argmax(dim=-1)
                    fgsm_acc = (pred_fgsm == threat_labels).float().mean().item()
                    fgsm_accuracies.append(fgsm_acc)
            
            # PGD attack
            with torch.enable_grad():
                adv_embeddings_pgd = adversarial_trainer.pgd_attack(
                    embeddings, threat_labels, nn.CrossEntropyLoss()
                )
                
                with torch.no_grad():
                    outputs_pgd = self.model._forward_from_embeddings(adv_embeddings_pgd, attention_mask)
                    pred_pgd = outputs_pgd['threat_logits'].argmax(dim=-1)
                    pgd_acc = (pred_pgd == threat_labels).float().mean().item()
                    pgd_accuracies.append(pgd_acc)
            
            # Text attack
            with torch.no_grad():
                adv_ids_text = adversarial_trainer.text_adversarial_attack(
                    input_ids, train_dataset.tokenizer
                )
                outputs_text = self.model(adv_ids_text, attention_mask)
                pred_text = outputs_text['threat_logits'].argmax(dim=-1)
                text_acc = (pred_text == threat_labels).float().mean().item()
                text_attack_accuracies.append(text_acc)
            
            batch_count += 1
        
        # Compute statistics
        clean_acc = np.mean(clean_accuracies)
        fgsm_acc = np.mean(fgsm_accuracies)
        pgd_acc = np.mean(pgd_accuracies)
        text_acc = np.mean(text_attack_accuracies)
        
        # Robustness gaps
        fgsm_gap = clean_acc - fgsm_acc
        pgd_gap = clean_acc - pgd_acc
        text_gap = clean_acc - text_acc
        
        # Robustness scores (higher is better)
        fgsm_robustness = fgsm_acc / clean_acc if clean_acc > 0 else 0
        pgd_robustness = pgd_acc / clean_acc if clean_acc > 0 else 0
        text_robustness = text_acc / clean_acc if clean_acc > 0 else 0
        
        # Store results
        self.results['robustness'] = {
            'clean_accuracy': clean_acc,
            'fgsm_accuracy': fgsm_acc,
            'pgd_accuracy': pgd_acc,
            'text_attack_accuracy': text_acc,
            'fgsm_gap': fgsm_gap,
            'pgd_gap': pgd_gap,
            'text_gap': text_gap,
            'fgsm_robustness': fgsm_robustness,
            'pgd_robustness': pgd_robustness,
            'text_robustness': text_robustness,
            'avg_robustness': (fgsm_robustness + pgd_robustness + text_robustness) / 3
        }
        
        # Print summary
        print("\nüõ°Ô∏è Robustness Results:")
        print(f"   Clean Accuracy: {clean_acc:.4f}")
        print(f"   FGSM Accuracy: {fgsm_acc:.4f} (Gap: {fgsm_gap:.4f}, Robustness: {fgsm_robustness:.3f})")
        print(f"   PGD Accuracy: {pgd_acc:.4f} (Gap: {pgd_gap:.4f}, Robustness: {pgd_robustness:.3f})")
        print(f"   Text Attack Accuracy: {text_acc:.4f} (Gap: {text_gap:.4f}, Robustness: {text_robustness:.3f})")
        print(f"   Average Robustness: {self.results['robustness']['avg_robustness']:.3f}")
        
        # Plot robustness comparison
        self._plot_robustness_comparison(
            clean_acc, fgsm_acc, pgd_acc, text_acc,
            ['Clean', 'FGSM', 'PGD', 'Text Attack']
        )
        
        return self.results['robustness']
    
    def evaluate_mhc_coordination(self, num_samples: int = 100) -> Dict[str, Any]:
        """
        Evaluate mHC coordination stability and effectiveness.
        
        Args:
            num_samples: Number of samples to evaluate
            
        Returns:
            Dictionary with mHC evaluation metrics
        """
        print("üîÑ Evaluating mHC Coordination...")
        
        self.model.eval()
        self.mhc.eval()
        
        stability_metrics = {
            'identity_preservation': [],
            'attention_entropy': [],
            'signal_norm': [],
            'state_change': []
        }
        
        attention_entropies = []
        agent_contributions = []
        
        sample_count = 0
        
        with torch.no_grad():
            for batch in self.test_loader:
                if sample_count >= num_samples:
                    break
                
                # Get a batch
                input_ids = batch['token_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                
                # Get model outputs
                outputs = self.model(input_ids, attention_mask)
                
                # Create synthetic agent states
                batch_size = outputs['coordination_features'].shape[0]
                n_agents = self.mhc.n_agents
                
                agent_states = outputs['coordination_features'].unsqueeze(1)
                agent_states = agent_states.expand(-1, n_agents, -1)
                
                # Add noise for differentiation
                noise = torch.randn_like(agent_states) * 0.1
                agent_states = agent_states + noise
                
                # Agent confidences
                agent_confidences = torch.rand(batch_size, n_agents, device=self.device)
                
                # Apply mHC
                coordinated_state, attention_matrix = self.mhc(agent_states, agent_confidences)
                
                # Get stability metrics
                metrics = self.mhc.get_stability_metrics(agent_states, coordinated_state)
                
                # Accumulate metrics
                for key in stability_metrics:
                    stability_metrics[key].append(metrics[key].item())
                
                # Analyze attention matrix
                attn = attention_matrix[0].cpu().numpy()
                
                # Attention entropy per agent
                for i in range(n_agents):
                    entropy = -np.sum(attn[i] * np.log(attn[i] + 1e-8))
                    attention_entropies.append(entropy)
                
                # Agent contributions
                contributions = attn.mean(axis=0)
                agent_contributions.extend(contributions)
                
                sample_count += batch_size
        
        # Compute statistics
        avg_metrics = {k: np.mean(v) for k, v in stability_metrics.items()}
        std_metrics = {k: np.std(v) for k, v in stability_metrics.items()}
        
        # Attention analysis
        avg_attention_entropy = np.mean(attention_entropies)
        std_attention_entropy = np.std(attention_entropies)
        
        # Agent contribution fairness
        agent_contributions_array = np.array(agent_contributions).reshape(-1, self.mhc.n_agents)
        avg_contributions = agent_contributions_array.mean(axis=0)
        contribution_std = agent_contributions_array.std(axis=0)
        
        # Fairness metric: Gini coefficient of agent contributions
        sorted_contributions = np.sort(avg_contributions)
        n = len(sorted_contributions)
        gini_coefficient = np.sum((2 * np.arange(1, n+1) - n - 1) * sorted_contributions) / (n * np.sum(sorted_contributions))
        
        # Store results
        self.results['mhc_coordination'] = {
            'stability_metrics': avg_metrics,
            'stability_std': std_metrics,
            'attention_entropy': avg_attention_entropy,
            'attention_entropy_std': std_attention_entropy,
            'agent_contributions': avg_contributions.tolist(),
            'contribution_std': contribution_std.tolist(),
            'fairness_gini': gini_coefficient,
            'fairness_score': 1 - gini_coefficient  # Higher is better
        }
        
        # Print summary
        print("\nüîÑ mHC Coordination Results:")
        print(f"   Identity Preservation: {avg_metrics['identity_preservation']:.4f}")
        print(f"   Attention Entropy: {avg_attention_entropy:.4f} (ideal: {math.log(self.mhc.n_agents):.4f})")
        print(f"   Signal Norm: {avg_metrics['signal_norm']:.4f}")
        print(f"   State Change: {avg_metrics['state_change']:.4f}")
        print(f"   Fairness Score: {self.results['mhc_coordination']['fairness_score']:.4f}")
        print(f"   Gini Coefficient: {gini_coefficient:.4f} (0 = perfect equality)")
        
        # Plot mHC analysis
        self._plot_mhc_analysis(avg_metrics, avg_contributions, contribution_std)
        
        return self.results['mhc_coordination']
    
    def evaluate_operational_efficiency(self,
                                      num_samples: int = 100,
                                      sequence_lengths: List[int] = [128, 256, 512, 1024]) -> Dict[str, Any]:
        """
        Evaluate operational efficiency (speed, memory).
        
        Args:
            num_samples: Number of samples for timing
            sequence_lengths: Different sequence lengths to test
            
        Returns:
            Dictionary with operational metrics
        """
        print("‚ö° Evaluating Operational Efficiency...")
        
        self.model.eval()
        
        efficiency_metrics = {}
        
        # Warmup
        dummy_input = torch.randint(0, 1000, (1, 128), device=self.device)
        with torch.no_grad():
            _ = self.model(dummy_input)
        
        # Test different sequence lengths
        inference_times = {}
        memory_usages = {}
        
        for seq_len in sequence_lengths:
            print(f"   Testing seq_len={seq_len}...")
            
            # Create test input
            batch_size = min(4, num_samples)  # Small batch for memory testing
            input_ids = torch.randint(0, self.model.vocab_size, (batch_size, seq_len), device=self.device)
            
            # Measure inference time
            import time
            times = []
            
            for _ in range(10):  # Multiple runs for stable measurement
                start_time = time.perf_counter()
                with torch.no_grad():
                    _ = self.model(input_ids)
                torch.cuda.synchronize() if torch.cuda.is_available() else None
                end_time = time.perf_counter()
                times.append(end_time - start_time)
            
            avg_time = np.mean(times)
            std_time = np.std(times)
            
            # Calculate tokens per second
            tokens_per_second = (batch_size * seq_len) / avg_time
            
            # Estimate memory usage
            if torch.cuda.is_available():
                torch.cuda.reset_peak_memory_stats()
                with torch.no_grad():
                    _ = self.model(input_ids)
                memory_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
                torch.cuda.reset_peak_memory_stats()
            else:
                # Rough estimate for CPU
                # Model parameters + activations
                param_memory = sum(p.numel() * 4 for p in self.model.parameters()) / 1024 / 1024  # MB
                activation_memory = (batch_size * seq_len * self.model.config.d_model * 4) / 1024 / 1024  # MB
                memory_mb = param_memory + activation_memory
            
            inference_times[seq_len] = {
                'avg_time_ms': avg_time * 1000,
                'std_time_ms': std_time * 1000,
                'tokens_per_second': tokens_per_second
            }
            
            memory_usages[seq_len] = memory_mb
        
        # Store results
        self.results['operational_efficiency'] = {
            'inference_times': inference_times,
            'memory_usages': memory_usages,
            'sequence_lengths': sequence_lengths
        }
        
        # Print summary
        print("\n‚ö° Operational Efficiency Results:")
        for seq_len in sequence_lengths:
            metrics = inference_times[seq_len]
            memory = memory_usages[seq_len]
            print(f"   Seq Len {seq_len}:")
            print(f"     Time: {metrics['avg_time_ms']:.2f} ¬± {metrics['std_time_ms']:.2f} ms")
            print(f"     Speed: {metrics['tokens_per_second']:.0f} tokens/sec")
            print(f"     Memory: {memory:.2f} MB")
        
        # Plot efficiency analysis
        self._plot_efficiency_analysis(inference_times, memory_usages)
        
        return self.results['operational_efficiency']
    
    def _plot_confusion_matrix_heatmap(self, cm: np.ndarray, title: str):
        """Plot confusion matrix as heatmap"""
        fig, ax = plt.subplots(figsize=(10, 8))
        
        # Normalize
        cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        im = ax.imshow(cm_normalized, cmap='Blues')
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.set_xlabel('Predicted', fontsize=12)
        ax.set_ylabel('True', fontsize=12)
        
        # Set ticks
        ax.set_xticks(range(len(self.threat_categories)))
        ax.set_yticks(range(len(self.threat_categories)))
        ax.set_xticklabels(self.threat_categories, rotation=45, ha='right')
        ax.set_yticklabels(self.threat_categories)
        
        # Add text annotations
        for i in range(len(self.threat_categories)):
            for j in range(len(self.threat_categories)):
                text = ax.text(j, i, f'{cm[i, j]}\n({cm_normalized[i, j]:.2f})',
                             ha="center", va="center",
                             color="white" if cm_normalized[i, j] > 0.5 else "black",
                             fontsize=9)
        
        plt.colorbar(im, ax=ax)
        plt.tight_layout()
        plt.savefig(config.experiment_dir / "visualizations" / "confusion_matrix_evaluation.png", dpi=150)
        plt.show()
    
    def _plot_precision_recall_by_class(self, 
                                       precision: np.ndarray,
                                       recall: np.ndarray,
                                       f1: np.ndarray,
                                       support: np.ndarray):
        """Plot precision, recall, and F1 by class"""
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        x = np.arange(len(self.threat_categories))
        width = 0.25
        
        # Precision
        axes[0].bar(x - width, precision, width, label='Precision', color='skyblue')
        axes[0].set_xlabel('Threat Category', fontsize=11)
        axes[0].set_ylabel('Precision', fontsize=11)
        axes[0].set_title('Precision by Class', fontsize=12, fontweight='bold')
        axes[0].set_xticks(x)
        axes[0].set_xticklabels(self.threat_categories, rotation=45, ha='right')
        axes[0].legend()
        
        # Recall
        axes[1].bar(x, recall, width, label='Recall', color='lightgreen')
        axes[1].set_xlabel('Threat Category', fontsize=11)
        axes[1].set_ylabel('Recall', fontsize=11)
        axes[1].set_title('Recall by Class', fontsize=12, fontweight='bold')
        axes[1].set_xticks(x)
        axes[1].set_xticklabels(self.threat_categories, rotation=45, ha='right')
        axes[1].legend()
        
        # F1 Score
        axes[2].bar(x + width, f1, width, label='F1 Score', color='salmon')
        axes[2].set_xlabel('Threat Category', fontsize=11)
        axes[2].set_ylabel('F1 Score', fontsize=11)
        axes[2].set_title('F1 Score by Class', fontsize=12, fontweight='bold')
        axes[2].set_xticks(x)
        axes[2].set_xticklabels(self.threat_categories, rotation=45, ha='right')
        axes[2].legend()
        
        plt.tight_layout()
        plt.savefig(config.experiment_dir / "visualizations" / "precision_recall_f1_by_class.png", dpi=150)
        plt.show()
    
    def _plot_robustness_comparison(self, *accuracies, attack_names):
        """Plot robustness comparison across different attacks"""
        fig, ax = plt.subplots(figsize=(10, 6))
        
        x = np.arange(len(attack_names))
        bars = ax.bar(x, accuracies, color=['green', 'orange', 'red', 'purple'])
        
        ax.set_xlabel('Attack Type', fontsize=12)
        ax.set_ylabel('Accuracy', fontsize=12)
        ax.set_title('Robustness Comparison', fontsize=14, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(attack_names, rotation=0)
        ax.set_ylim([0, 1])
        ax.grid(axis='y', alpha=0.3)
        
        # Add value labels on bars
        for bar, acc in zip(bars, accuracies):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{acc:.3f}', ha='center', va='bottom', fontweight='bold')
        
        plt.tight_layout()
        plt.savefig(config.experiment_dir / "visualizations" / "robustness_comparison.png", dpi=150)
        plt.show()
    
    def _plot_mhc_analysis(self, stability_metrics: Dict, 
                          agent_contributions: np.ndarray,
                          contribution_std: np.ndarray):
        """Plot mHC coordination analysis"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Stability metrics
        metric_names = list(stability_metrics.keys())
        metric_values = list(stability_metrics.values())
        
        bars1 = axes[0, 0].bar(metric_names, metric_values, color='lightblue')
        axes[0, 0].set_title('mHC Stability Metrics', fontsize=12, fontweight='bold')
        axes[0, 0].set_ylabel('Value', fontsize=11)
        axes[0, 0].tick_params(axis='x', rotation=45)
        
        # Add value labels
        for bar, val in zip(bars1, metric_values):
            height = bar.get_height()
            axes[0, 0].text(bar.get_x() + bar.get_width()/2., height,
                           f'{val:.3f}', ha='center', va='bottom')
        
        # Agent contributions
        n_agents = len(agent_contributions)
        x = np.arange(n_agents)
        
        axes[0, 1].bar(x, agent_contributions, yerr=contribution_std,
                      color='lightgreen', capsize=5)
        axes[0, 1].set_title('Agent Contributions', fontsize=12, fontweight='bold')
        axes[0, 1].set_xlabel('Agent Index', fontsize=11)
        axes[0, 1].set_ylabel('Average Attention', fontsize=11)
        axes[0, 1].axhline(y=1/n_agents, color='r', linestyle='--', label='Ideal')
        axes[0, 1].legend()
        
        # Attention entropy distribution
        if 'attention_entropy_samples' in self.results.get('mhc_coordination', {}):
            entropy_samples = self.results['mhc_coordination']['attention_entropy_samples']
            axes[1, 0].hist(entropy_samples, bins=20, color='salmon', alpha=0.7)
            axes[1, 0].axvline(x=math.log(n_agents), color='r', linestyle='--', 
                              label=f'Ideal: {math.log(n_agents):.2f}')
            axes[1, 0].set_title('Attention Entropy Distribution', fontsize=12, fontweight='bold')
            axes[1, 0].set_xlabel('Entropy', fontsize=11)
            axes[1, 0].set_ylabel('Frequency', fontsize=11)
            axes[1, 0].legend()
        
        # Signal norm over time
        if 'signal_norm_samples' in self.results.get('mhc_coordination', {}):
            signal_norms = self.results['mhc_coordination']['signal_norm_samples']
            axes[1, 1].plot(signal_norms, color='purple', linewidth=2)
            axes[1, 1].axhline(y=1.0, color='r', linestyle='--', label='Bound')
            axes[1, 1].set_title('Signal Norm Over Samples', fontsize=12, fontweight='bold')
            axes[1, 1].set_xlabel('Sample Index', fontsize=11)
            axes[1, 1].set_ylabel('Signal Norm', fontsize=11)
            axes[1, 1].legend()
        
        plt.tight_layout()
        plt.savefig(config.experiment_dir / "visualizations" / "mhc_analysis.png", dpi=150)
        plt.show()
    
    def _plot_efficiency_analysis(self, inference_times: Dict, memory_usages: Dict):
        """Plot operational efficiency analysis"""
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Extract data
        seq_lengths = list(inference_times.keys())
        avg_times = [inference_times[s]['avg_time_ms'] for s in seq_lengths]
        tokens_per_sec = [inference_times[s]['tokens_per_second'] for s in seq_lengths]
        memory = [memory_usages[s] for s in seq_lengths]
        
        # Inference time vs sequence length
        axes[0].plot(seq_lengths, avg_times, marker='o', linewidth=2, color='blue')
        axes[0].set_xlabel('Sequence Length', fontsize=11)
        axes[0].set_ylabel('Inference Time (ms)', fontsize=11)
        axes[0].set_title('Inference Time vs Sequence Length', fontsize=12, fontweight='bold')
        axes[0].grid(True, alpha=0.3)
        
        # Tokens per second
        axes[1].plot(seq_lengths, tokens_per_sec, marker='s', linewidth=2, color='green')
        axes[1].set_xlabel('Sequence Length', fontsize=11)
        axes[1].set_ylabel('Tokens/Second', fontsize=11)
        axes[1].set_title('Throughput vs Sequence Length', fontsize=12, fontweight='bold')
        axes[1].grid(True, alpha=0.3)
        
        # Memory usage
        axes[2].plot(seq_lengths, memory, marker='^', linewidth=2, color='red')
        axes[2].set_xlabel('Sequence Length', fontsize=11)
        axes[2].set_ylabel('Memory Usage (MB)', fontsize=11)
        axes[2].set_title('Memory Usage vs Sequence Length', fontsize=12, fontweight='bold')
        axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(config.experiment_dir / "visualizations" / "operational_efficiency.png", dpi=150)
        plt.show()
    
    def generate_comprehensive_report(self) -> Dict[str, Any]:
        """
        Generate comprehensive evaluation report.
        
        Returns:
            Complete evaluation report
        """
        print("\n" + "="*80)
        print("üìã GENERATING COMPREHENSIVE EVALUATION REPORT")
        print("="*80)
        
        # Run all evaluations
        if 'classification' not in self.results:
            self.evaluate_classification()
        
        if 'robustness' not in self.results:
            self.evaluate_robustness(adversarial_trainer)
        
        if 'mhc_coordination' not in self.results:
            self.evaluate_mhc_coordination()
        
        if 'operational_efficiency' not in self.results:
            self.evaluate_operational_efficiency()
        
        # Calculate overall score
        overall_score = self._calculate_overall_score()
        
        # Create report
        report = {
            'metadata': {
                'evaluation_date': datetime.now().isoformat(),
                'model_name': 'CyberGuard Transformer',
                'evaluator_version': '1.0.0',
                'device': str(self.device)
            },
            'overall_score': overall_score,
            'detailed_results': self.results,
            'recommendations': self._generate_recommendations()
        }
        
        # Save report
        report_path = config.experiment_dir / "evaluation_report.json"
        with open(report_path, 'w') as f:
            json.dump(report, f, indent=2, default=lambda x: float(x) if isinstance(x, (np.float32, np.float64)) else x)
        
        print(f"\nüíæ Evaluation report saved to: {report_path}")
        
        # Print summary
        print("\n" + "="*80)
        print("üéØ EVALUATION SUMMARY")
        print("="*80)
        
        classification = self.results['classification']
        robustness = self.results['robustness']
        mhc = self.results['mhc_coordination']
        
        print(f"\nüìä Classification Performance:")
        print(f"   Accuracy: {classification['accuracy']:.4f}")
        print(f"   F1 Macro: {classification['f1_macro']:.4f}")
        print(f"   AUC-ROC: {classification['auc_roc']:.4f}")
        
        print(f"\nüõ°Ô∏è Robustness:")
        print(f"   Clean Accuracy: {robustness['clean_accuracy']:.4f}")
        print(f"   Avg Robustness: {robustness['avg_robustness']:.3f}")
        print(f"   PGD Gap: {robustness['pgd_gap']:.4f}")
        
        print(f"\nüîÑ mHC Coordination:")
        print(f"   Identity Preservation: {mhc['stability_metrics']['identity_preservation']:.4f}")
        print(f"   Fairness Score: {mhc['fairness_score']:.4f}")
        print(f"   Attention Entropy: {mhc['attention_entropy']:.4f}")
        
        print(f"\n‚ö° Operational Efficiency:")
        efficiency = self.results['operational_efficiency']
        seq_len = efficiency['sequence_lengths'][-1]  # Longest sequence tested
        speed = efficiency['inference_times'][seq_len]['tokens_per_second']
        memory = efficiency['memory_usages'][seq_len]
        print(f"   Speed (seq_len={seq_len}): {speed:.0f} tokens/sec")
        print(f"   Memory Usage: {memory:.2f} MB")
        
        print(f"\nüèÜ Overall Score: {overall_score['total']:.2f}/100")
        print(f"   Breakdown: Classification={overall_score['classification']:.1f}, "
              f"Robustness={overall_score['robustness']:.1f}, "
              f"Coordination={overall_score['coordination']:.1f}, "
              f"Efficiency={overall_score['efficiency']:.1f}")
        
        return report
    
    def _calculate_overall_score(self) -> Dict[str, float]:
        """Calculate overall evaluation score"""
        weights = {
            'classification': 0.35,
            'robustness': 0.30,
            'coordination': 0.20,
            'efficiency': 0.15
        }
        
        scores = {}
        
        # Classification score (based on accuracy and F1)
        classification = self.results['classification']
        scores['classification'] = (
            classification['accuracy'] * 0.5 +
            classification['f1_macro'] * 0.5
        ) * 100
        
        # Robustness score
        robustness = self.results['robustness']
        scores['robustness'] = robustness['avg_robustness'] * 100
        
        # Coordination score
        mhc = self.results['mhc_coordination']
        scores['coordination'] = (
            mhc['stability_metrics']['identity_preservation'] * 0.4 +
            mhc['fairness_score'] * 0.4 +
            min(1.0, mhc['attention_entropy'] / math.log(self.mhc.n_agents)) * 0.2
        ) * 100
        
        # Efficiency score (inverse of time and memory)
        efficiency = self.results['operational_efficiency']
        seq_len = efficiency['sequence_lengths'][-1]
        speed = efficiency['inference_times'][seq_len]['tokens_per_second']
        memory = efficiency['memory_usages'][seq_len]
        
        # Normalize speed (target: 1000 tokens/sec)
        speed_score = min(1.0, speed / 1000)
        
        # Normalize memory (target: 500 MB for longest sequence)
        memory_score = max(0, 1.0 - memory / 500)
        
        scores['efficiency'] = (speed_score * 0.6 + memory_score * 0.4) * 100
        
        # Total score
        scores['total'] = sum(scores[category] * weight 
                            for category, weight in weights.items())
        
        return scores
    
    def _generate_recommendations(self) -> List[str]:
        """Generate improvement recommendations based on evaluation"""
        recommendations = []
        
        classification = self.results['classification']
        robustness = self.results['robustness']
        mhc = self.results['mhc_coordination']
        
        # Classification recommendations
        if classification['f1_macro'] < 0.8:
            recommendations.append(
                "Improve classification performance by collecting more diverse training data"
            )
        
        if classification['severity_mae'] > 0.2:
            recommendations.append(
                "Improve severity regression by adding more precise severity labels"
            )
        
        # Robustness recommendations
        if robustness['avg_robustness'] < 0.7:
            recommendations.append(
                "Increase adversarial training to improve robustness against attacks"
            )
        
        if robustness['pgd_gap'] > 0.3:
            recommendations.append(
                "Implement defensive distillation or adversarial training with PGD"
            )
        
        # Coordination recommendations
        if mhc['fairness_score'] < 0.8:
            recommendations.append(
                "Adjust mHC parameters to ensure more balanced agent contributions"
            )
        
        if mhc['stability_metrics']['identity_preservation'] < 0.8:
            recommendations.append(
                "Increase identity preservation factor in mHC to maintain agent specialties"
            )
        
        # Add general recommendations
        recommendations.extend([
            "Regularly update threat intelligence feeds",
            "Implement continuous monitoring of false positives/negatives",
            "Conduct periodic red team exercises",
            "Maintain model versioning and A/B testing framework"
        ])
        
        return recommendations

# Initialize evaluator
evaluator = ComprehensiveEvaluator(
    model=cyberguard_model,
    mhc=mhc_module,
    test_loader=test_loader,
    device=device,
    threat_categories=train_dataset.threat_categories
)

print("‚úÖ Comprehensive Evaluator Initialized")
print("üìä Will evaluate:")
print("   ‚Ä¢ Classification performance (accuracy, F1, AUC-ROC)")
print("   ‚Ä¢ Robustness against adversarial attacks")
print("   ‚Ä¢ mHC coordination stability and fairness")
print("   ‚Ä¢ Operational efficiency (speed, memory)")
print("   ‚Ä¢ Generate comprehensive report with recommendations")

# Run comprehensive evaluation
comprehensive_report = evaluator.generate_comprehensive_report()

print("\n‚úÖ Comprehensive evaluation complete!")
print("üéØ Key findings saved to evaluation report")
print("üìù Recommendations generated for improvement")
print("üìà Visualizations saved to experiment directory")
<jupyter_output>
<empty_output>
<jupyter_text>
## 10. Model Export & Deployment**Explanation**: After training, we need to export the model for deployment. This includes model optimization, format conversion, and deployment pipeline setup.### Deployment Pipeline:1. **Model Optimization**: Quantization, pruning, distillation2. **Format Conversion**: ONNX, TorchScript, TensorRT3. **API Development**: REST API, gRPC, WebSocket4. **Monitoring**: Performance, drift detection, security
<jupyter_code>
class ModelExporter:
    """
    Model exporter for CyberGuard deployment.
    
    Handles:
    1. Model optimization (quantization, pruning)
    2. Format conversion (ONNX, TorchScript)
    3. Deployment package creation
    4. Inference optimization
    """
    
    def __init__(self,
                 model: nn.Module,
                 mhc: EnhancedManifoldConstrainedHyperConnections,
                 config: TrainingConfig,
                 tokenizer: Dict,
                 device: torch.device):
        """
        Initialize model exporter.
        
        Args:
            model: Trained CyberGuard model
            mhc: Trained mHC module
            config: Training configuration
            tokenizer: Tokenizer dictionary
            device: Export device
        """
        self.model = model
        self.mhc = mhc
        self.config = config
        self.tokenizer = tokenizer
        self.device = device
        
        # Export directory
        self.export_dir = config.experiment_dir / "deployment"
        self.export_dir.mkdir(exist_ok=True)
        
        # Subdirectories
        (self.export_dir / "optimized").mkdir(exist_ok=True)
        (self.export_dir / "onnx").mkdir(exist_ok=True)
        (self.export_dir / "torchscript").mkdir(exist_ok=True)
        (self.export_dir / "api").mkdir(exist_ok=True)
        
        logger.info("üì¶ Initialized Model Exporter")
        logger.info(f"   Export directory: {self.export_dir}")
    
    def optimize_model(self, 
                      quantization: bool = True,
                      pruning: bool = False,
                      distillation: bool = False) -> nn.Module:
        """
        Optimize model for deployment.
        
        Args:
            quantization: Apply quantization (FP16/INT8)
            pruning: Apply pruning to reduce model size
            distillation: Apply knowledge distillation
            
        Returns:
            Optimized model
        """
        print("‚ö° Optimizing model for deployment...")
        
        optimized_model = self.model
        
        # 1. Quantization
        if quantization:
            print("   Applying quantization...")
            
            # Dynamic quantization for linear layers
            if hasattr(torch.quantization, 'quantize_dynamic'):
                try:
                    # Quantize linear layers to int8
                    quantized_model = torch.quantization.quantize_dynamic(
                        optimized_model,
                        {torch.nn.Linear},
                        dtype=torch.qint8
                    )
                    
                    # Test quantization
                    test_input = torch.randint(0, 100, (1, 128), device='cpu')
                    with torch.no_grad():
                        _ = quantized_model(test_input)
                    
                    optimized_model = quantized_model
                    print("   ‚úÖ Dynamic quantization successful")
                    
                except Exception as e:
                    print(f"   ‚ö†Ô∏è Dynamic quantization failed: {e}")
            
            # Mixed precision (FP16)
            try:
                # Convert to half precision
                optimized_model = optimized_model.half()
                print("   ‚úÖ Mixed precision (FP16) conversion successful")
            except Exception as e:
                print(f"   ‚ö†Ô∏è Mixed precision conversion failed: {e}")
        
        # 2. Pruning (if requested)
        if pruning:
            print("   Applying pruning...")
            try:
                # Global magnitude pruning
                parameters_to_prune = []
                for name, module in optimized_model.named_modules():
                    if isinstance(module, torch.nn.Linear):
                        parameters_to_prune.append((module, 'weight'))
                
                # Prune 20% of weights globally
                torch.nn.utils.prune.global_unstructured(
                    parameters_to_prune,
                    pruning_method=torch.nn.utils.prune.L1Unstructured,
                    amount=0.2
                )
                
                # Remove pruning reparameterization
                for module, _ in parameters_to_prune:
                    torch.nn.utils.prune.remove(module, 'weight')
                
                print("   ‚úÖ Pruning successful (20% of weights removed)")
            except Exception as e:
                print(f"   ‚ö†Ô∏è Pruning failed: {e}")
        
        # 3. Knowledge distillation (if requested)
        if distillation:
            print("   Applying knowledge distillation...")
            # This would require a teacher model
            # For now, just log that it's not implemented
            print("   ‚ö†Ô∏è Knowledge distillation requires teacher model (skipping)")
        
        # Calculate optimized model size
        optimized_size = self._calculate_model_size(optimized_model)
        original_size = self._calculate_model_size(self.model)
        
        print(f"\nüìä Optimization Results:")
        print(f"   Original size: {original_size:.2f} MB")
        print(f"   Optimized size: {optimized_size:.2f} MB")
        print(f"   Compression ratio: {original_size/optimized_size:.2f}x")
        
        # Save optimized model
        optimized_path = self.export_dir / "optimized" / "cyberguard_optimized.pt"
        torch.save({
            'model_state_dict': optimized_model.state_dict(),
            'config': self.config.__dict__,
            'tokenizer': self.tokenizer,
            'optimization_info': {
                'quantization': quantization,
                'pruning': pruning,
                'distillation': distillation,
                'original_size_mb': original_size,
                'optimized_size_mb': optimized_size
            }
        }, optimized_path)
        
        print(f"   üíæ Optimized model saved to: {optimized_path}")
        
        return optimized_model
    
    def export_to_onnx(self, 
                      optimized_model: nn.Module,
                      sample_input: torch.Tensor,
                      opset_version: int = 14):
        """
        Export model to ONNX format.
        
        Args:
            optimized_model: Optimized model
            sample_input: Sample input for tracing
            opset_version: ONNX opset version
            
        Returns:
            Path to exported ONNX model
        """
        print("\nüì§ Exporting to ONNX format...")
        
        onnx_path = self.export_dir / "onnx" / "cyberguard.onnx"
        
        try:
            # Set model to evaluation mode
            optimized_model.eval()
            
            # Export to ONNX
            torch.onnx.export(
                optimized_model,
                sample_input,
                onnx_path,
                export_params=True,
                opset_version=opset_version,
                do_constant_folding=True,
                input_names=['input_ids'],
                output_names=['threat_logits', 'severity_score', 'coordination_features'],
                dynamic_axes={
                    'input_ids': {0: 'batch_size', 1: 'sequence_length'},
                    'threat_logits': {0: 'batch_size'},
                    'severity_score': {0: 'batch_size'},
                    'coordination_features': {0: 'batch_size'}
                },
                verbose=False
            )
            
            print(f"   ‚úÖ ONNX export successful: {onnx_path}")
            
            # Verify ONNX model
            import onnx
            onnx_model = onnx.load(onnx_path)
            onnx.checker.check_model(onnx_model)
            
            # Print model info
            print(f"   üìä ONNX Model Info:")
            print(f"     - Input: {onnx_model.graph.input[0].name}")
            print(f"     - Outputs: {[output.name for output in onnx_model.graph.output]}")
            print(f"     - Opset: {onnx_model.opset_import[0].version}")
            
            # Test inference with ONNX Runtime
            self._test_onnx_inference(onnx_path, sample_input)
            
        except Exception as e:
            print(f"   ‚ùå ONNX export failed: {e}")
            return None
        
        return onnx_path
    
    def export_to_torchscript(self,
                             optimized_model: nn.Module,
                             sample_input: torch.Tensor) -> str:
        """
        Export model to TorchScript format.
        
        Args:
            optimized_model: Optimized model
            sample_input: Sample input for tracing
            
        Returns:
            Path to exported TorchScript model
        """
        print("\nüì§ Exporting to TorchScript format...")
        
        torchscript_path = self.export_dir / "torchscript" / "cyberguard.pt"
        
        try:
            # Set model to evaluation mode
            optimized_model.eval()
            
            # Trace the model
            with torch.no_grad():
                traced_model = torch.jit.trace(optimized_model, sample_input)
            
            # Save traced model
            traced_model.save(torchscript_path)
            
            print(f"   ‚úÖ TorchScript export successful: {torchscript_path}")
            
            # Test inference with TorchScript
            self._test_torchscript_inference(torchscript_path, sample_input)
            
        except Exception as e:
            print(f"   ‚ùå TorchScript export failed: {e}")
            return None
        
        return torchscript_path
    
    def create_deployment_package(self,
                                onnx_path: Optional[str] = None,
                                torchscript_path: Optional[str] = None):
        """
        Create complete deployment package.
        
        Args:
            onnx_path: Path to ONNX model (optional)
            torchscript_path: Path to TorchScript model (optional)
        """
        print("\nüì¶ Creating deployment package...")
        
        package_dir = self.export_dir / "deployment_package"
        package_dir.mkdir(exist_ok=True)
        
        # Copy models
        models_dir = package_dir / "models"
        models_dir.mkdir(exist_ok=True)
        
        if onnx_path and Path(onnx_path).exists():
            shutil.copy(onnx_path, models_dir / "cyberguard.onnx")
        
        if torchscript_path and Path(torchscript_path).exists():
            shutil.copy(torchscript_path, models_dir / "cyberguard.pt")
        
        # Copy optimized model
        optimized_path = self.export_dir / "optimized" / "cyberguard_optimized.pt"
        if optimized_path.exists():
            shutil.copy(optimized_path, models_dir / "cyberguard_optimized.pt")
        
        # Save tokenizer
        tokenizer_path = models_dir / "tokenizer.json"
        with open(tokenizer_path, 'w') as f:
            json.dump(self.tokenizer, f)
        
        # Save configuration
        config_path = package_dir / "config.yaml"
        with open(config_path, 'w') as f:
            yaml.dump(self.config.__dict__, f)
        
        # Create API server
        self._create_api_server(package_dir)
        
        # Create Dockerfile
        self._create_dockerfile(package_dir)
        
        # Create requirements.txt
        self._create_requirements(package_dir)
        
        # Create deployment scripts
        self._create_deployment_scripts(package_dir)
        
        # Create README
        self._create_readme(package_dir)
        
        # Create test scripts
        self._create_test_scripts(package_dir)
        
        print(f"\nüéâ Deployment package created: {package_dir}")
        print("\nüìÅ Package structure:")
        for root, dirs, files in os.walk(package_dir):
            level = root.replace(str(package_dir), '').count(os.sep)
            indent = ' ' * 2 * level
            print(f'{indent}{os.path.basename(root)}/')
            subindent = ' ' * 2 * (level + 1)
            for file in files:
                print(f'{subindent}{file}')
        
        # Create archive
        import tarfile
        archive_path = self.export_dir / "cyberguard_deployment.tar.gz"
        with tarfile.open(archive_path, "w:gz") as tar:
            tar.add(package_dir, arcname="cyberguard_deployment")
        
        print(f"\nüì¶ Archive created: {archive_path}")
        print(f"   Size: {archive_path.stat().st_size / 1024 / 1024:.2f} MB")
        
        return package_dir
    
    def _calculate_model_size(self, model: nn.Module) -> float:
        """Calculate model size in MB"""
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        
        buffer_size = 0
        for buffer in model.buffers():
            buffer_size += buffer.nelement() * buffer.element_size()
        
        size_mb = (param_size + buffer_size) / 1024 / 1024
        return size_mb
    
    def _test_onnx_inference(self, onnx_path: str, sample_input: torch.Tensor):
        """Test ONNX model inference"""
        try:
            import onnxruntime as ort
            
            # Create ONNX Runtime session
            ort_session = ort.InferenceSession(onnx_path)
            
            # Prepare input
            input_name = ort_session.get_inputs()[0].name
            input_data = {input_name: sample_input.cpu().numpy()}
            
            # Run inference
            outputs = ort_session.run(None, input_data)
            
            print(f"   ‚úÖ ONNX inference test successful")
            print(f"   üìä Output shapes: {[o.shape for o in outputs]}")
            
        except ImportError:
            print("   ‚ö†Ô∏è ONNX Runtime not installed, skipping inference test")
        except Exception as e:
            print(f"   ‚ö†Ô∏è ONNX inference test failed: {e}")
    
    def _test_torchscript_inference(self, torchscript_path: str, sample_input: torch.Tensor):
        """Test TorchScript model inference"""
        try:
            # Load traced model
            traced_model = torch.jit.load(torchscript_path)
            traced_model.eval()
            
            # Run inference
            with torch.no_grad():
                outputs = traced_model(sample_input)
            
            print(f"   ‚úÖ TorchScript inference test successful")
            
            if isinstance(outputs, dict):
                print(f"   üìä Output keys: {list(outputs.keys())}")
            elif isinstance(outputs, torch.Tensor):
                print(f"   üìä Output shape: {outputs.shape}")
            
        except Exception as e:
            print(f"   ‚ö†Ô∏è TorchScript inference test failed: {e}")
    
    def _create_api_server(self, package_dir: Path):
        """Create FastAPI server for deployment"""
        api_code = '''"""
CyberGuard REST API Server
"""
import json
import yaml
from pathlib import Path
from typing import Dict, List, Optional

import torch
import numpy as np
from fastapi import FastAPI, HTTPException, Depends, Security
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel, Field
import uvicorn

# Security
security = HTTPBearer()

class CyberGuardAPI:
    """CyberGuard REST API Server"""
    
    def __init__(self, model_path: Path, config_path: Path, tokenizer_path: Path):
        self.model_path = model_path
        self.config_path = config_path
        self.tokenizer_path = tokenizer_path
        
        # Load configuration
        with open(config_path, 'r') as f:
            self.config = yaml.safe_load(f)
        
        # Load tokenizer
        with open(tokenizer_path, 'r') as f:
            self.tokenizer = json.load(f)
        
        # Load model
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = self._load_model()
        
        # Initialize API key (in production, use proper secrets management)
        self.api_keys = ["cyberguard-secret-key-2024"]
        
        print(f"üöÄ CyberGuard API initialized on {self.device}")
    
    def _load_model(self):
        """Load the appropriate model based on available formats"""
        model_file = None
        
        # Check for TorchScript
        ts_path = self.model_path.parent / "cyberguard.pt"
        if ts_path.exists():
            model = torch.jit.load(ts_path, map_location=self.device)
            print(f"üì¶ Loaded TorchScript model from {ts_path}")
            return model
        
        # Check for ONNX (would require ONNX Runtime)
        # Check for PyTorch model
        pt_path = self.model_path.parent / "cyberguard_optimized.pt"
        if pt_path.exists():
            checkpoint = torch.load(pt_path, map_location=self.device)
            # Reconstruct model from state dict
            # This would require the model class definition
            print(f"üì¶ Loaded PyTorch model from {pt_path}")
            # Implementation depends on model architecture
            raise NotImplementedError("Model loading needs architecture definition")
        
        raise FileNotFoundError("No model files found")
    
    def _tokenize(self, text: str, max_length: int = 512) -> torch.Tensor:
        """Tokenize text"""
        # Simple tokenization (in production, use proper tokenizer)
        tokens = []
        for word in text.split()[:max_length]:
            token_id = self.tokenizer.get(word, self.tokenizer.get('[UNK]', 1))
            tokens.append(token_id)
        
        # Pad if necessary
        if len(tokens) < max_length:
            tokens += [self.tokenizer.get('[PAD]', 0)] * (max_length - len(tokens))
        else:
            tokens = tokens[:max_length]
        
        return torch.tensor([tokens], dtype=torch.long, device=self.device)
    
    def analyze(self, text: str) -> Dict:
        """Analyze text for security threats"""
        # Tokenize
        input_ids = self._tokenize(text)
        
        # Inference
        with torch.no_grad():
            outputs = self.model(input_ids)
        
        # Process outputs
        if isinstance(outputs, dict):
            threat_logits = outputs.get('threat_logits', None)
            severity_score = outputs.get('severity_score', None)
            features = outputs.get('coordination_features', None)
        elif isinstance(outputs, (list, tuple)):
            threat_logits = outputs[0] if len(outputs) > 0 else None
            severity_score = outputs[1] if len(outputs) > 1 else None
            features = outputs[2] if len(outputs) > 2 else None
        else:
            threat_logits = outputs
            severity_score = None
            features = None
        
        # Convert to Python types
        result = {
            'text': text,
            'analysis': {}
        }
        
        if threat_logits is not None:
            probs = torch.softmax(threat_logits, dim=-1)
            threat_idx = torch.argmax(probs).item()
            confidence = probs[0, threat_idx].item()
            
            # Threat categories (should be loaded from config)
            threat_categories = [
                'benign', 'injection', 'xss', 'broken_auth', 
                'sensitive_data', 'xxe', 'broken_access',
                'security_misconfig', 'insecure_deserial',
                'vulnerable_components', 'insufficient_logging'
            ]
            
            result['analysis']['threat_detection'] = {
                'category': threat_categories[threat_idx] if threat_idx < len(threat_categories) else 'unknown',
                'confidence': confidence,
                'is_malicious': threat_idx != 0,  # Index 0 is 'benign'
                'all_probabilities': probs[0].cpu().numpy().tolist()
            }
        
        if severity_score is not None:
            result['analysis']['severity'] = {
                'score': severity_score[0].item() if isinstance(severity_score, torch.Tensor) else severity_score,
                'level': self._get_severity_level(severity_score)
            }
        
        return result
    
    def _get_severity_level(self, score: float) -> str:
        """Convert severity score to level"""
        if score < 0.3:
            return 'LOW'
        elif score < 0.6:
            return 'MEDIUM'
        elif score < 0.8:
            return 'HIGH'
        else:
            return 'CRITICAL'
    
    def verify_api_key(self, credentials: HTTPAuthorizationCredentials = Depends(security)):
        """Verify API key"""
        if credentials.credentials not in self.api_keys:
            raise HTTPException(
                status_code=401,
                detail="Invalid API key",
                headers={"WWW-Authenticate": "Bearer"},
            )
        return credentials.credentials

# Pydantic models
class AnalysisRequest(BaseModel):
    text: str = Field(..., description="Text to analyze for security threats")
    detailed: bool = Field(False, description="Return detailed analysis")

class AnalysisResponse(BaseModel):
    success: bool
    analysis: Dict
    timestamp: str
    model_version: str = "1.0.0"

# Create FastAPI app
app = FastAPI(
    title="CyberGuard Security Analysis API",
    description="AI-powered web security threat detection",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Initialize API (would be done differently in production)
api = None

@app.on_event("startup")
async def startup_event():
    """Initialize API on startup"""
    global api
    try:
        model_path = Path("models/cyberguard.pt")
        config_path = Path("config.yaml")
        tokenizer_path = Path("models/tokenizer.json")
        
        api = CyberGuardAPI(model_path, config_path, tokenizer_path)
    except Exception as e:
        print(f"‚ùå Failed to initialize API: {e}")
        raise

@app.get("/health", tags=["health"])
async def health_check():
    """Health check endpoint"""
    return {"status": "healthy", "timestamp": datetime.now().isoformat()}

@app.post("/analyze", response_model=AnalysisResponse, tags=["analysis"])
async def analyze_security(
    request: AnalysisRequest,
    api_key: str = Depends(api.verify_api_key) if api else None
):
    """Analyze text for security threats"""
    if api is None:
        raise HTTPException(status_code=503, detail="API not initialized")
    
    try:
        result = api.analyze(request.text)
        
        return AnalysisResponse(
            success=True,
            analysis=result,
            timestamp=datetime.now().isoformat(),
            model_version="1.0.0"
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/stats", tags=["monitoring"])
async def get_stats(api_key: str = Depends(api.verify_api_key) if api else None):
    """Get API statistics"""
    # In production, track actual statistics
    return {
        "requests_processed": 0,
        "threats_detected": 0,
        "avg_response_time": 0.0,
        "model_version": "1.0.0"
    }

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
'''
