In [None]:
"""
CyberGuard REST API Server for production deployment.
This server provides a secure, scalable API for security threat analysis.
"""
# Documentation string: Describes the purpose of this module.
# This API serves as the deployment interface for the trained CyberGuard model.

# ============================================================================
# SECTION 1: IMPORTS AND DEPENDENCIES
# ============================================================================

# Standard library imports (Python built-in modules)
import json  # For parsing and generating JSON data
import yaml  # For parsing YAML configuration files
import logging  # For logging application events and errors
from pathlib import Path  # For object-oriented filesystem path manipulation
from typing import Dict, List, Optional, Any  # Type hints for better code documentation
from datetime import datetime  # For handling timestamps and dates
from functools import lru_cache  # For caching function results to improve performance

# Core machine learning dependencies
import torch  # PyTorch: Main deep learning framework for model inference
import numpy as np  # NumPy: Numerical computing library for array operations

# FastAPI framework imports for building the web server
from fastapi import FastAPI, HTTPException, Depends, Security, Request
# FastAPI: Modern web framework for building APIs
# HTTPException: For raising HTTP errors with specific status codes
# Depends: Dependency injection system for request handling
# Security: For implementing security mechanisms
# Request: For accessing request metadata

from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
# HTTPBearer: Security scheme for Bearer token authentication
# HTTPAuthorizationCredentials: Type for storing authentication credentials

from fastapi.middleware.cors import CORSMiddleware
# CORSMiddleware: Middleware for handling Cross-Origin Resource Sharing

from fastapi.middleware.trustedhost import TrustedHostMiddleware
# TrustedHostMiddleware: Security middleware to validate host headers

from pydantic import BaseModel, Field, validator
# Pydantic: Data validation and settings management
# BaseModel: Base class for creating data models with validation
# Field: For defining field metadata and constraints
# validator: Decorator for custom validation functions

from slowapi import Limiter, _rate_limit_exceeded_handler
# SlowAPI: Rate limiting extension for FastAPI
# Limiter: Main rate limiting class
# _rate_limit_exceeded_handler: Default handler for rate limit violations

from slowapi.util import get_remote_address
# get_remote_address: Function to extract client IP address for rate limiting

from slowapi.errors import RateLimitExceeded
# RateLimitExceeded: Exception raised when rate limit is exceeded

# Conditional imports for optional dependencies with fallback handling
try:
    import onnxruntime as ort  # ONNX Runtime for model inference optimization
    ONNX_AVAILABLE = True  # Flag indicating ONNX Runtime is available
except ImportError:
    ONNX_AVAILABLE = False  # ONNX Runtime not available
    ort = None  # Set to None to avoid NameError

try:
    import tensorrt  # NVIDIA TensorRT for GPU acceleration (optional)
    TENSORRT_AVAILABLE = True  # Flag indicating TensorRT is available
except ImportError:
    TENSORRT_AVAILABLE = False  # TensorRT not available
    tensorrt = None  # Set to None to avoid NameError

# ============================================================================
# SECTION 2: LOGGING CONFIGURATION
# ============================================================================

# Configure logging for production use
logging.basicConfig(
    level=logging.INFO,  # Set minimum log level to INFO (captures INFO, WARNING, ERROR, CRITICAL)
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    # Log format: timestamp - logger name - log level - message
    handlers=[
        logging.FileHandler('cyberguard_api.log'),  # Write logs to file
        logging.StreamHandler()  # Also output logs to console (stdout)
    ]
)

# Create logger instance for this module
logger = logging.getLogger(__name__)
# __name__ gives the module's name ('__main__' if run directly, otherwise module path)
# This allows hierarchical logging configuration

# ============================================================================
# SECTION 3: SECURITY SETUP
# ============================================================================

# Create HTTP Bearer authentication security scheme
security = HTTPBearer(auto_error=False)
# auto_error=False: Don't automatically raise error if credentials missing
# This allows endpoints to handle missing credentials gracefully

# Initialize rate limiter using client IP address as the key
limiter = Limiter(key_func=get_remote_address)
# key_func: Function that returns a unique identifier for rate limiting
# get_remote_address: Extracts client IP from request

# ============================================================================
# SECTION 4: DATA MODELS (Pydantic Schemas)
# ============================================================================

class AnalysisRequest(BaseModel):
    """
    Pydantic model for validating analysis request payload.
    Ensures incoming data meets expected format and constraints.
    """
    text: str = Field(
        ...,  # Ellipsis indicates this field is required (cannot be None)
        description="Text to analyze for security threats. This can be HTTP requests, logs, or any security-relevant text.",
        min_length=1,  # Minimum string length validation
        max_length=10000  # Maximum string length to prevent overly large requests
    )
    # text field: The input text to be analyzed for security threats
    
    detailed: bool = Field(
        False,  # Default value if not provided
        description="If True, returns detailed analysis including probabilities for all threat categories."
    )
    # detailed field: Flag controlling the verbosity of the response
    
    @validator('text')  # Decorator: This function validates the 'text' field
    def validate_text_not_empty(cls, v):  # cls: class reference, v: value to validate
        """Ensure text is not just whitespace"""
        # v.strip(): Remove leading/trailing whitespace
        # not v: Check if string is empty
        # not v.strip(): Check if string contains only whitespace
        if not v or not v.strip():
            raise ValueError('Text cannot be empty or whitespace only')  # Raise validation error
        return v.strip()  # Return stripped version of the text


class AnalysisResponse(BaseModel):
    """
    Pydantic model for standardizing API response format.
    Ensures consistent response structure for clients.
    """
    success: bool = Field(..., description="Whether the analysis was successful")
    # success field: Boolean indicating if the request was processed successfully
    
    analysis: Dict[str, Any] = Field(..., description="Analysis results")
    # analysis field: Dictionary containing the actual analysis results
    # Dict[str, Any]: Keys are strings, values can be any type
    
    timestamp: str = Field(..., description="ISO format timestamp of analysis")
    # timestamp field: String containing when the analysis was performed
    
    model_version: str = Field(..., description="Version of the model used")
    # model_version field: String indicating which model version was used
    
    request_id: Optional[str] = Field(None, description="Unique ID for tracking this request")
    # request_id field: Optional unique identifier for request tracking
    # Optional[str]: Can be string or None
    # None: Default value if not provided


class ThreatCategory(BaseModel):
    """
    Model representing a threat category with metadata.
    Used to structure threat category information.
    """
    id: int = Field(..., description="Category ID")  # Unique numerical identifier
    name: str = Field(..., description="Category name")  # Human-readable name
    description: str = Field(..., description="Category description")  # Detailed description
    severity_weight: float = Field(..., description="Weight for severity calculation")
    # severity_weight: Numerical weight used in severity calculations (0.0 to 1.0)

# ============================================================================
# SECTION 5: MAIN API HANDLER CLASS
# ============================================================================

class CyberGuardAPI:
    """
    Main API handler class that manages model loading, inference, and request processing.
    This class follows the singleton pattern for efficient resource management.
    Singleton pattern ensures only one instance exists, preventing redundant model loading.
    """
    
    _instance = None  # Class variable to store the single instance (singleton pattern)
    
    def __new__(cls, *args, **kwargs):
        """
        Override __new__ to implement singleton pattern.
        Called before __init__ during object creation.
        """
        if cls._instance is None:  # Check if instance doesn't exist
            # Create new instance by calling parent class's __new__ method
            cls._instance = super(CyberGuardAPI, cls).__new__(cls)
        return cls._instance  # Always return the same instance
    
    def __init__(self, model_dir: Path, config_path: Path, tokenizer_path: Path):
        """
        Initialize the API with model, configuration, and tokenizer.
        Only executes once due to singleton pattern.
        
        Args:
            model_dir: Directory containing model files (TorchScript, ONNX, PyTorch)
            config_path: Path to configuration YAML file with model settings
            tokenizer_path: Path to tokenizer JSON file for text processing
        """
        # Skip reinitialization if already initialized (singleton safeguard)
        if hasattr(self, 'initialized') and self.initialized:
            return  # Exit early if already initialized
            
        # Store paths as instance variables
        self.model_dir = model_dir  # Directory where model files are stored
        self.config_path = config_path  # Path to configuration file
        self.tokenizer_path = tokenizer_path  # Path to tokenizer vocabulary file
        
        # Load and validate configuration from YAML file
        self.config = self._load_and_validate_config()
        # self.config: Dictionary containing all configuration settings
        
        # Load tokenizer and create reverse lookup dictionary
        self.tokenizer, self.reverse_tokenizer = self._load_tokenizer()
        # self.tokenizer: Dictionary mapping tokens to IDs
        # self.reverse_tokenizer: Dictionary mapping IDs to tokens (for debugging)
        
        # Determine the best available computation device (GPU/CPU)
        self.device = self._get_device()
        # self.device: torch.device object (e.g., 'cuda', 'cpu', 'mps')
        
        # Load threat categories from configuration
        self.threat_categories = self._load_threat_categories()
        # self.threat_categories: List of ThreatCategory objects
        
        # Load the trained model in the best available format
        self.model = self._load_model()
        # self.model: Loaded model ready for inference
        
        # Warm up model to avoid cold start latency (especially important for GPU)
        self._warmup_model()
        
        # Load API keys for authentication
        self.api_keys = self._load_api_keys()
        # self.api_keys: Set of valid API keys
        
        # Initialize counters for statistics
        self.request_count = 0  # Total number of requests processed
        self.threat_detection_count = 0  # Number of malicious requests detected
        
        # Mark as initialized to prevent reinitialization
        self.initialized = True
        
        # Log successful initialization
        logger.info(f"CyberGuard API initialized successfully on device: {self.device}")
        logger.info(f"Model type: {type(self.model).__name__}")  # Log model class name
        logger.info(f"Threat categories loaded: {len(self.threat_categories)}")  # Log category count
    
    def _load_and_validate_config(self) -> Dict:
        """
        Load and validate configuration file.
        
        Returns:
            Validated configuration dictionary
            
        Raises:
            FileNotFoundError: If config file doesn't exist
            yaml.YAMLError: If config file is invalid YAML
            ValueError: If required config fields are missing
        """
        # Check if configuration file exists
        if not self.config_path.exists():
            error_msg = f"Configuration file not found: {self.config_path}"
            logger.error(error_msg)  # Log error
            raise FileNotFoundError(error_msg)  # Raise exception
        
        try:
            # Open and parse YAML file
            with open(self.config_path, 'r') as f:  # 'r' mode for reading
                config = yaml.safe_load(f)  # Parse YAML safely (no code execution)
        except yaml.YAMLError as e:  # Catch YAML parsing errors
            error_msg = f"Invalid YAML in config file: {e}"
            logger.error(error_msg)
            raise yaml.YAMLError(error_msg)  # Re-raise with context
        
        # Define required configuration fields
        required_fields = ['model_name', 'model_version', 'max_seq_length', 'threat_categories']
        
        # Validate each required field exists in config
        for field in required_fields:
            if field not in config:  # Check if field is missing
                error_msg = f"Missing required config field: {field}"
                logger.error(error_msg)
                raise ValueError(error_msg)  # Raise validation error
        
        logger.info(f"Configuration loaded from: {self.config_path}")
        return config  # Return validated configuration
    
    def _load_tokenizer(self) -> tuple:
        """
        Load tokenizer from JSON file and create reverse mapping.
        
        Returns:
            Tuple of (token_to_id dictionary, id_to_token dictionary)
            
        Raises:
            FileNotFoundError: If tokenizer file doesn't exist
            json.JSONDecodeError: If tokenizer file is invalid JSON
        """
        # Check if tokenizer file exists
        if not self.tokenizer_path.exists():
            error_msg = f"Tokenizer file not found: {self.tokenizer_path}"
            logger.error(error_msg)
            raise FileNotFoundError(error_msg)
        
        try:
            # Open and parse JSON file
            with open(self.tokenizer_path, 'r') as f:
                tokenizer = json.load(f)  # Parse JSON to Python dictionary
        except json.JSONDecodeError as e:  # Catch JSON parsing errors
            error_msg = f"Invalid JSON in tokenizer file: {e}"
            logger.error(error_msg)
            raise json.JSONDecodeError(error_msg)
        
        # Create reverse mapping: ID -> Token (for debugging and interpretation)
        reverse_tokenizer = {v: k for k, v in tokenizer.items()}
        # Dictionary comprehension: Create new dict with swapped key-value pairs
        
        logger.info(f"Tokenizer loaded with {len(tokenizer)} tokens")
        return tokenizer, reverse_tokenizer  # Return both mappings
    
    def _get_device(self) -> torch.device:
        """
        Determine the best available device for computation.
        Prioritizes CUDA GPU, then MPS (Apple Silicon), then CPU.
        
        Returns:
            torch.device object specifying where computations should run
        """
        # Check if CUDA (NVIDIA GPU) is available
        if torch.cuda.is_available():
            device = torch.device('cuda')  # Create CUDA device
            # Get GPU information for logging
            gpu_name = torch.cuda.get_device_name(0)  # Name of first GPU
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9  # Memory in GB
            logger.info(f"Using CUDA GPU: {gpu_name} ({gpu_memory:.2f} GB)")
        
        # Check if MPS (Apple Silicon GPU) is available (macOS only)
        elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            device = torch.device('mps')  # Create MPS device (Metal Performance Shaders)
            logger.info("Using Apple MPS (Metal Performance Shaders)")
        
        # Fallback to CPU if no GPU available
        else:
            device = torch.device('cpu')  # Create CPU device
            logger.info("Using CPU (no GPU available)")
        
        return device  # Return the selected device
    
    def _load_threat_categories(self) -> List[ThreatCategory]:
        """
        Load threat categories from configuration.
        
        Returns:
            List of ThreatCategory objects
        """
        categories = []  # Initialize empty list to store categories
        
        # Get threat categories list from config, default to empty list if not found
        category_configs = self.config.get('threat_categories', [])
        # .get() method: Returns default value if key doesn't exist
        
        # Enumerate through category configurations
        # enumerate() returns (index, item) pairs
        for i, cat_config in enumerate(category_configs):
            # Create ThreatCategory object from config
            category = ThreatCategory(
                id=i,  # Use enumeration index as ID
                name=cat_config.get('name', f'category_{i}'),  # Get name or use default
                description=cat_config.get('description', ''),  # Get description or empty string
                severity_weight=cat_config.get('severity_weight', 0.5)  # Get weight or default 0.5
            )
            categories.append(category)  # Add to list
        
        return categories  # Return complete list
    
    def _load_model(self):
        """
        Load the trained model in the best available format.
        Attempts loading in this order: TorchScript -> ONNX -> PyTorch
        
        Returns:
            Loaded model ready for inference
            
        Raises:
            RuntimeError: If no model file can be loaded
        """
        # List of model formats to try, in order of preference
        # Each tuple: (format_name, file_path)
        model_files = [
            ('TorchScript', self.model_dir / 'cyberguard.pt'),  # TorchScript format
            ('ONNX', self.model_dir / 'cyberguard.onnx'),  # ONNX format
            ('PyTorch', self.model_dir / 'cyberguard_optimized.pt'),  # Optimized PyTorch
            ('PyTorch State Dict', self.model_dir / 'model_state_dict.pt')  # State dict only
        ]
        
        # Variables to store loaded model and format
        loaded_model = None  # Will hold the successfully loaded model
        loaded_format = None  # Will hold the format name
        
        # Try each model format
        for format_name, model_path in model_files:
            if model_path.exists():  # Check if file exists
                try:
                    # --- TORCHSCRIPT LOADING ---
                    if format_name == 'TorchScript':
                        # Load TorchScript model
                        loaded_model = torch.jit.load(model_path, map_location=self.device)
                        # torch.jit.load: Loads TorchScript model
                        # map_location: Specifies where to load the model (GPU/CPU)
                        
                        loaded_model.eval()  # Set model to evaluation mode
                        # eval() mode: Disables dropout, batch norm uses running statistics
                        
                        loaded_format = 'TorchScript'  # Record format
                        logger.info(f"Successfully loaded TorchScript model from {model_path}")
                        break  # Exit loop since we found a working model
                    
                    # --- ONNX LOADING ---
                    elif format_name == 'ONNX' and ONNX_AVAILABLE:
                        # Create ONNX Runtime session options
                        session_options = ort.SessionOptions()
                        session_options.intra_op_num_threads = 1  # Threads for single operation
                        session_options.inter_op_num_threads = 1  # Threads between operations
                        
                        # Set execution providers in order of preference
                        providers = ['CPUExecutionProvider']  # Default to CPU
                        if 'CUDAExecutionProvider' in ort.get_available_providers():
                            # If CUDA provider available, use it first
                            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
                        
                        # Create ONNX Runtime inference session
                        loaded_model = ort.InferenceSession(
                            str(model_path),  # Convert Path to string
                            sess_options=session_options,  # Session configuration
                            providers=providers  # Execution providers
                        )
                        loaded_format = 'ONNX'
                        logger.info(f"Successfully loaded ONNX model from {model_path}")
                        break
                    
                    # --- PYTORCH LOADING ---
                    elif format_name.startswith('PyTorch'):
                        try:
                            # Load PyTorch checkpoint
                            checkpoint = torch.load(model_path, map_location=self.device)
                            # torch.load: Loads PyTorch saved objects
                            
                            if 'model_state_dict' in checkpoint:
                                # Checkpoint contains only state dict (weights)
                                # Need model class definition to reconstruct
                                logger.warning("PyTorch model loading requires model class definition")
                                continue  # Skip to next format
                            else:
                                # Assume it's a complete saved model
                                loaded_model = checkpoint  # Use directly
                                loaded_model.eval()  # Set to evaluation mode
                                loaded_format = 'PyTorch'
                                logger.info(f"Successfully loaded PyTorch model from {model_path}")
                                break
                                
                        except Exception as e:
                            logger.warning(f"Failed to load {format_name}: {e}")
                            continue  # Try next format
                            
                except Exception as e:
                    # Catch any error during loading
                    logger.warning(f"Failed to load {format_name} model: {e}")
                    continue  # Try next format
        
        # Check if any model was successfully loaded
        if loaded_model is None:
            # Create error message listing all formats tried
            error_msg = "No model file could be loaded. Checked: " + ", ".join([f[0] for f in model_files])
            logger.error(error_msg)
            raise RuntimeError(error_msg)  # Raise error if no model found
        
        # Store format information
        self.model_format = loaded_format
        return loaded_model  # Return the loaded model
    
    def _warmup_model(self):
        """
        Warm up the model with dummy inference to avoid cold start latency.
        This is especially important for GPU models to initialize CUDA context.
        """
        logger.info("Warming up model...")
        
        try:
            # Get maximum sequence length from config
            max_seq_len = self.config.get('max_seq_length', 512)  # Default to 512
            
            # Create dummy input tensor for warmup
            dummy_input = torch.randint(
                low=10,  # Start from 10 to avoid special token IDs (0-4)
                high=100,  # Random IDs up to 99
                size=(1, min(32, max_seq_len)),  # Shape: (batch_size=1, sequence_length)
                # Use smaller sequence for warmup (max 32 tokens)
                device=self.device  # Place on correct device (GPU/CPU)
            )
            
            # Perform warmup inference
            with torch.no_grad():  # Disable gradient computation (inference only)
                if self.model_format == 'ONNX':
                    # ONNX Runtime warmup
                    input_name = self.model.get_inputs()[0].name  # Get input tensor name
                    self.model.run(None, {input_name: dummy_input.cpu().numpy()})
                    # .run(): Execute ONNX model
                    # None: Don't specify output names (get all outputs)
                    # Convert tensor to numpy for ONNX Runtime
                else:
                    # PyTorch/TorchScript warmup
                    _ = self.model(dummy_input)  # Run inference, ignore output
            
            logger.info("Model warmup completed successfully")
            
        except Exception as e:
            # Non-critical error: Log warning but don't crash
            logger.warning(f"Model warmup failed (non-critical): {e}")
    
    def _load_api_keys(self) -> set:
        """
        Load API keys from environment or file.
        In production, use proper secrets management like HashiCorp Vault, AWS Secrets Manager, etc.
        
        Returns:
            Set of valid API keys (sets provide O(1) lookup time)
        """
        # Import os module for environment variable access
        import os
        
        # Get API keys from environment variable
        api_keys_env = os.getenv('CYBERGUARD_API_KEYS', '')
        # os.getenv(): Get environment variable value
        # Second parameter: Default value if variable not set
        
        if api_keys_env:  # Check if environment variable is not empty
            # Split comma-separated keys and convert to set
            keys = set(api_keys_env.split(','))
            # set(): Creates unordered collection of unique elements
            # .split(','): Splits string at commas
            
            logger.info(f"Loaded {len(keys)} API key(s) from environment")
        else:
            # Fallback to single default key (for development only)
            keys = {'cyberguard-dev-key-2024'}  # Set with single default key
            logger.warning("Using default development API key. Not suitable for production.")
        
        return keys  # Return set of API keys
    
    @lru_cache(maxsize=1000)  # Decorator: Least Recently Used cache
    # maxsize=1000: Cache up to 1000 unique calls
    # Automatically evicts least recently used entries when full
    def _tokenize_cached(self, text: str, max_length: int) -> torch.Tensor:
        """
        Tokenize text with caching for repeated similar inputs.
        LRU cache improves performance for common requests.
        
        Args:
            text: Input text to tokenize
            max_length: Maximum sequence length
            
        Returns:
            Tensor of token IDs with shape (1, max_length)
        """
        # Convert to lowercase for consistency
        # Lowercasing helps with token matching but preserves case for certain patterns
        text_lower = text.lower()
        
        # Initialize empty token list
        tokens = []
        
        # Split text into words (basic tokenization)
        # In production, use proper tokenizer like BPE, WordPiece, or SentencePiece
        words = text_lower.split()  # .split(): Splits on whitespace
        
        # Process each word up to max_length
        for word in words[:max_length]:
            # Check for common attack patterns
            for pattern in ['<script>', 'javascript:', 'union', 'select', 'from', 'where']:
                if pattern in word:  # Check if pattern exists in word
                    # Get token ID for pattern, fallback to UNK token if not found
                    pattern_id = self.tokenizer.get(pattern, self.tokenizer.get('[UNK]', 1))
                    tokens.append(pattern_id)  # Add pattern token
                    
                    # Remove pattern from word to tokenize remaining parts
                    word = word.replace(pattern, '')
            
            # Add remaining word if not empty after pattern removal
            if word:  # Check if word is not empty string
                # Get token ID, fallback to UNK token if not in vocabulary
                token_id = self.tokenizer.get(word, self.tokenizer.get('[UNK]', 1))
                tokens.append(token_id)
        
        # Add CLS (classification) token at beginning
        # CLS token is commonly used in transformer models for classification tasks
        tokens = [self.tokenizer.get('[CLS]', 2)] + tokens
        # + tokens: Concatenate lists
        
        # --- SEQUENCE LENGTH MANAGEMENT ---
        # Truncate if sequence exceeds maximum length
        if len(tokens) > max_length:
            tokens = tokens[:max_length]  # Keep first max_length tokens
        
        # Pad if sequence is shorter than maximum length
        elif len(tokens) < max_length:
            # Create padding tokens (PAD token ID)
            padding = [self.tokenizer.get('[PAD]', 0)] * (max_length - len(tokens))
            # [value] * n: Creates list with n copies of value
            
            tokens = tokens + padding  # Append padding tokens
        
        # Convert to PyTorch tensor
        return torch.tensor([tokens], dtype=torch.long, device=self.device)
        # [tokens]: Create batch dimension (batch_size=1)
        # dtype=torch.long: Integer data type for token IDs
        # device=self.device: Place tensor on correct device (GPU/CPU)
    
    def _get_severity_level(self, score: float) -> str:
        """
        Convert numerical severity score to categorical level.
        
        Args:
            score: Severity score between 0 and 1
            
        Returns:
            Severity level string (INFO, LOW, MEDIUM, HIGH, CRITICAL)
        """
        # Define severity thresholds
        if score < 0.2:
            return 'INFO'  # Informational, not a threat
        elif score < 0.4:
            return 'LOW'  # Low severity threat
        elif score < 0.6:
            return 'MEDIUM'  # Medium severity threat
        elif score < 0.8:
            return 'HIGH'  # High severity threat
        else:
            return 'CRITICAL'  # Critical threat requiring immediate attention
    
    def analyze(self, text: str, detailed: bool = False, request_id: str = None) -> Dict[str, Any]:
        """
        Main analysis method that processes text and returns threat analysis.
        
        Args:
            text: Input text to analyze
            detailed: Whether to include detailed probabilities
            request_id: Optional request ID for tracking
            
        Returns:
            Dictionary containing analysis results
            
        Raises:
            RuntimeError: If model inference fails
            ValueError: If input text is invalid
        """
        # Increment request counter for statistics
        self.request_count += 1
        
        # Validate input
        if not text or not isinstance(text, str):
            raise ValueError("Input text must be a non-empty string")
        
        # Get maximum sequence length from configuration
        max_seq_length = self.config.get('max_seq_length', 512)  # Default to 512
        
        try:
            # --- TOKENIZATION ---
            # Convert text to tensor using cached tokenization
            input_tensor = self._tokenize_cached(text, max_seq_length)
            # Shape: (1, max_seq_length) where 1 is batch size
            
            # --- INFERENCE BASED ON MODEL FORMAT ---
            if self.model_format == 'ONNX':
                # ONNX Runtime inference path
                input_name = self.model.get_inputs()[0].name  # Get input tensor name
                # Create input feed dictionary for ONNX Runtime
                input_feed = {input_name: input_tensor.cpu().numpy()}
                # .cpu(): Move tensor to CPU (ONNX expects numpy arrays)
                # .numpy(): Convert to numpy array
                
                # Run inference
                outputs = self.model.run(None, input_feed)
                # None: Get all outputs
                # Returns list of numpy arrays
                
                # Parse outputs (order depends on model export)
                if len(outputs) >= 2:  # At least 2 outputs expected
                    threat_logits = torch.from_numpy(outputs[0])  # Convert back to tensor
                    severity_score = torch.from_numpy(outputs[1])
                else:  # Only one output
                    threat_logits = torch.from_numpy(outputs[0])
                    severity_score = None
                    
            else:
                # PyTorch/TorchScript inference path
                with torch.no_grad():  # Disable gradient computation
                    outputs = self.model(input_tensor)  # Run model
                
                # Handle different output formats from different model types
                if isinstance(outputs, dict):
                    # Model returns dictionary
                    threat_logits = outputs.get('threat_logits')  # Get with .get() (returns None if not found)
                    severity_score = outputs.get('severity_score')
                elif isinstance(outputs, (list, tuple)):
                    # Model returns list or tuple
                    threat_logits = outputs[0] if len(outputs) > 0 else None
                    severity_score = outputs[1] if len(outputs) > 1 else None
                else:
                    # Model returns single tensor
                    threat_logits = outputs
                    severity_score = None
            
            # --- PROCESS THREAT CLASSIFICATION OUTPUT ---
            if threat_logits is not None:
                # Apply softmax to convert logits to probabilities
                probs = torch.softmax(threat_logits, dim=-1)
                # softmax: Converts logits to probability distribution (sums to 1)
                # dim=-1: Apply along last dimension (class dimension)
                
                # Get predicted class (index with highest probability)
                threat_idx = torch.argmax(probs).item()
                # torch.argmax(): Returns index of maximum value
                # .item(): Convert single-element tensor to Python scalar
                
                # Get confidence score (probability of predicted class)
                confidence = probs[0, threat_idx].item()
                # probs[0, threat_idx]: Access probability for batch 0, predicted class
                
                # Get threat category information
                if threat_idx < len(self.threat_categories):
                    threat_category = self.threat_categories[threat_idx]  # Get category object
                    category_name = threat_category.name  # Get category name
                else:
                    category_name = 'unknown'  # Fallback if index out of range
                    threat_category = None
                
                # Check if malicious (assuming first category is benign)
                is_malicious = threat_idx != 0
                # Index 0 is assumed to be 'benign' category
                
                # Update threat detection counter
                if is_malicious:
                    self.threat_detection_count += 1
                
                # Prepare threat detection result dictionary
                threat_result = {
                    'category': category_name,  # Threat category name
                    'category_id': threat_idx,  # Numerical category ID
                    'confidence': confidence,  # Model confidence (0 to 1)
                    'is_malicious': is_malicious  # Boolean indicating threat
                }
                
                # Add detailed probabilities if requested
                if detailed and threat_category is not None:
                    # Create dictionary mapping category names to probabilities
                    threat_result['probabilities'] = {
                        cat.name: probs[0, i].item()  # Probability for each category
                        for i, cat in enumerate(self.threat_categories)  # Loop through all categories
                    }
                    # Dictionary comprehension: Creates dict from loop
                    
            else:  # No threat logits returned
                threat_result = {
                    'category': 'unknown',
                    'category_id': -1,
                    'confidence': 0.0,
                    'is_malicious': False
                }
            
            # --- PROCESS SEVERITY SCORE ---
            if severity_score is not None:
                if isinstance(severity_score, torch.Tensor):
                    severity_value = severity_score[0].item()  # Extract scalar value
                else:
                    severity_value = float(severity_score)  # Convert to float
                
                severity_result = {
                    'score': severity_value,  # Numerical score (0 to 1)
                    'level': self._get_severity_level(severity_value)  # Categorical level
                }
            else:
                # Default severity if model doesn't provide it
                severity_result = {
                    'score': 0.8 if threat_result['is_malicious'] else 0.1,
                    'level': 'HIGH' if threat_result['is_malicious'] else 'INFO'
                }
            
            # --- COMPILE FINAL RESULT ---
            result = {
                'text_preview': text[:100] + ('...' if len(text) > 100 else ''),
                # Show first 100 characters with ellipsis if truncated
                
                'text_length': len(text),  # Original text length
                'threat_detection': threat_result,  # Threat analysis
                'severity': severity_result,  # Severity assessment
                'processing_time_ms': 0,  # Placeholder (would be calculated with timers)
                'model_format': self.model_format  # Model format used
            }
            
            # Log analysis completion
            logger.info(f"Analysis completed: threat={threat_result['category']}, "
                       f"malicious={threat_result['is_malicious']}, "
                       f"confidence={threat_result['confidence']:.3f}")
            
            return result  # Return analysis results
            
        except Exception as e:
            # Catch any error during analysis
            error_msg = f"Analysis failed: {str(e)}"
            logger.error(error_msg, exc_info=True)  # Log with exception info
            # exc_info=True: Include exception traceback in log
            raise RuntimeError(error_msg)  # Re-raise as RuntimeError
    
    def verify_api_key(self, credentials: Optional[HTTPAuthorizationCredentials] = None) -> bool:
        """
        Verify API key from request credentials.
        
        Args:
            credentials: HTTP authorization credentials
            
        Returns:
            True if API key is valid, False otherwise
        """
        if credentials is None:
            logger.warning("No API credentials provided")
            return False
        
        # Check if credential is in valid API keys set
        if credentials.credentials in self.api_keys:
            return True
        else:
            # Log partial key for security (don't log full key)
            logger.warning(f"Invalid API key attempt: {credentials.credentials[:8]}...")
            return False
    
    def get_statistics(self) -> Dict[str, Any]:
        """
        Get API usage statistics.
        
        Returns:
            Dictionary with statistics
        """
        return {
            'total_requests': self.request_count,
            'threats_detected': self.threat_detection_count,
            'threat_detection_rate': (
                self.threat_detection_count / self.request_count 
                if self.request_count > 0 else 0
            ),  # Calculate rate, handle division by zero
            'model_format': self.model_format,
            'device': str(self.device),
            'uptime': 'N/A'  # Placeholder - would track actual uptime
        }

# ============================================================================
# SECTION 6: FASTAPI APPLICATION SETUP
# ============================================================================

# Create FastAPI application with production settings
app = FastAPI(
    title="CyberGuard Security Analysis API",  # API title (appears in docs)
    description="Enterprise-grade AI-powered web security threat detection system",  # API description
    version="2.0.0",  # API version
    docs_url="/docs",  # URL for Swagger documentation
    redoc_url="/redoc",  # URL for ReDoc documentation
    openapi_url="/openapi.json"  # URL for OpenAPI schema
)

# Add rate limiting to app state
app.state.limiter = limiter  # Attach limiter to app state
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Register exception handler for rate limit exceeded errors

# Add CORS (Cross-Origin Resource Sharing) middleware
app.add_middleware(
    CORSMiddleware,  # CORS middleware class
    allow_origins=["*"],  # Allow all origins (restrict in production)
    allow_credentials=True,  # Allow cookies and authentication
    allow_methods=["GET", "POST"],  # Allowed HTTP methods
    allow_headers=["*"],  # Allow all headers
)

# Add trusted host middleware for security
app.add_middleware(
    TrustedHostMiddleware,  # Trusted host middleware
    allowed_hosts=["*"]  # Allow all hosts (restrict in production)
)

# Global API instance (will be initialized on first use)
api_instance = None

# ============================================================================
# SECTION 7: DEPENDENCY FUNCTIONS
# ============================================================================

def get_api_instance() -> CyberGuardAPI:
    """
    Dependency function to get the API instance.
    Ensures proper initialization and provides singleton access.
    
    Returns:
        Initialized CyberGuardAPI instance
        
    Raises:
        HTTPException: If API fails to initialize
    """
    global api_instance  # Declare we're modifying the global variable
    
    if api_instance is None:  # Check if not initialized
        try:
            # Import os here to avoid import at module level if only used in this function
            import os
            
            # Get paths from environment variables with defaults
            model_dir = Path(os.getenv('MODEL_DIR', './models'))
            # os.getenv(): Get environment variable, default to './models'
            
            config_path = Path(os.getenv('CONFIG_PATH', './config.yaml'))
            tokenizer_path = Path(os.getenv('TOKENIZER_PATH', './models/tokenizer.json'))
            
            # Initialize API instance
            api_instance = CyberGuardAPI(model_dir, config_path, tokenizer_path)
        except Exception as e:
            # Log critical error and raise HTTP exception
            logger.critical(f"Failed to initialize API: {e}")
            raise HTTPException(
                status_code=503,  # Service Unavailable status code
                detail="Service temporarily unavailable. Failed to initialize model."
            )
    
    return api_instance  # Return the initialized instance


async def verify_api_key_dependency(
    request: Request,  # FastAPI request object (automatically injected)
    credentials: Optional[HTTPAuthorizationCredentials] = Security(security)
    # Security(security): Use HTTP Bearer security scheme
    # Optional: Credentials might be None if not provided
) -> str:
    """
    FastAPI dependency for API key verification.
    This function will be called by FastAPI for endpoints that require authentication.
    
    Args:
        request: FastAPI request object
        credentials: HTTP authorization credentials
        
    Returns:
        API key if valid
        
    Raises:
        HTTPException: If API key is invalid or missing
    """
    # Get API instance
    api = get_api_instance()
    
    # Verify API key
    if not api.verify_api_key(credentials):
        raise HTTPException(
            status_code=401,  # Unauthorized status code
            detail="Invalid or missing API key",
            headers={"WWW-Authenticate": "Bearer"},  # Required for 401 responses
        )
    
    # Return valid API key for use in endpoint if needed
    return credentials.credentials

# ============================================================================
# SECTION 8: FASTAPI EVENT HANDLERS
# ============================================================================

@app.on_event("startup")
async def startup_event():
    """
    Event handler called when FastAPI application starts.
    Used for initialization tasks.
    """
    logger.info("Starting CyberGuard API server...")
    
    # Pre-initialize API instance during startup
    # This helps catch initialization errors early
    try:
        _ = get_api_instance()  # Initialize API, ignore return value
        logger.info("API initialization completed during startup")
    except Exception as e:
        # Log error but don't crash - allows health checks to report status
        logger.error(f"Startup initialization failed: {e}")
        # Don't raise here to allow health endpoint to report status

# ============================================================================
# SECTION 9: API ENDPOINTS
# ============================================================================

@app.get("/health", tags=["monitoring"])  # GET endpoint, categorized under "monitoring"
@limiter.limit("10/minute")  # Rate limit: 10 requests per minute per IP
async def health_check(request: Request):
    """
    Health check endpoint for load balancers and monitoring.
    Used by Kubernetes, Docker, load balancers to check service health.
    
    Returns:
        Health status and basic system information
    """
    try:
        # Get API instance
        api = get_api_instance()
        
        # Basic system checks
        health_status = {
            "status": "healthy",
            "timestamp": datetime.utcnow().isoformat() + "Z",  # ISO format with Z for UTC
            "model_loaded": api.model is not None,  # Check if model loaded
            "model_format": api.model_format,  # Model format (TorchScript, ONNX, etc.)
            "device": str(api.device),  # Computation device
            "threat_categories": len(api.threat_categories),  # Number of categories
            "tokenizer_size": len(api.tokenizer)  # Vocabulary size
        }
        
        return health_status  # FastAPI automatically converts dict to JSON
        
    except Exception as e:
        # Log error and return unhealthy status
        logger.error(f"Health check failed: {e}")
        raise HTTPException(
            status_code=503,  # Service Unavailable
            detail=f"Service unhealthy: {str(e)}"
        )


@app.post("/analyze", response_model=AnalysisResponse, tags=["analysis"])
# POST endpoint, returns AnalysisResponse model, categorized under "analysis"
@limiter.limit("100/minute")  # Rate limit: 100 requests per minute per IP
async def analyze_security(
    request: Request,  # FastAPI request object
    analysis_request: AnalysisRequest,  # Validated request body
    api_key: str = Depends(verify_api_key_dependency)  # API key verification dependency
):
    """
    Analyze text for security threats.
    Main endpoint for security analysis.
    
    Args:
        request: FastAPI request object
        analysis_request: Validated analysis request
        api_key: Verified API key (from dependency)
        
    Returns:
        Analysis results with standardized response format
    """
    # Generate unique request ID for tracking
    import uuid  # Import here to avoid unnecessary import at module level
    request_id = str(uuid.uuid4())[:8]  # First 8 characters of UUID
    # UUID: Universally Unique Identifier
    # [:8]: Slice to get first 8 characters (sufficient for tracking)
    
    logger.info(f"Analysis request {request_id}: length={len(analysis_request.text)}")
    
    try:
        # Get API instance
        api = get_api_instance()
        
        # Start timing for performance measurement
        import time  # Import here for local use
        start_time = time.time()  # Current time in seconds since epoch
        
        # Perform analysis
        analysis_result = api.analyze(
            text=analysis_request.text,
            detailed=analysis_request.detailed,
            request_id=request_id
        )
        
        # Calculate processing time
        processing_time_ms = (time.time() - start_time) * 1000  # Convert to milliseconds
        analysis_result['processing_time_ms'] = processing_time_ms  # Add to result
        
        # Prepare response using Pydantic model
        response = AnalysisResponse(
            success=True,
            analysis=analysis_result,
            timestamp=datetime.utcnow().isoformat() + "Z",  # Current UTC time
            model_version=api.config.get('model_version', '2.0.0'),  # From config
            request_id=request_id  # Include request ID for tracking
        )
        
        logger.info(f"Analysis request {request_id} completed in {processing_time_ms:.2f}ms")
        
        return response  # FastAPI automatically serializes to JSON
        
    except ValueError as e:
        # Client error (invalid input)
        logger.warning(f"Analysis request {request_id} failed validation: {e}")
        raise HTTPException(
            status_code=400,  # Bad Request
            detail=str(e)  # Include validation error message
        )
    except RuntimeError as e:
        # Server error during analysis
        logger.error(f"Analysis request {request_id} failed: {e}")
        raise HTTPException(
            status_code=500,  # Internal Server Error
            detail="Analysis failed due to server error"  # Generic error message for security
        )
    except Exception as e:
        # Unexpected error
        logger.error(f"Unexpected error in analysis request {request_id}: {e}", exc_info=True)
        raise HTTPException(
            status_code=500,  # Internal Server Error
            detail="Internal server error"  # Generic message for security
        )


@app.get("/statistics", tags=["monitoring"])
async def get_statistics(api_key: str = Depends(verify_api_key_dependency)):
    """
    Get API usage statistics.
    Requires valid API key.
    
    Returns:
        API usage statistics
    """
    try:
        api = get_api_instance()
        stats = api.get_statistics()  # Get statistics from API instance
        
        return {
            "success": True,
            "statistics": stats,
            "timestamp": datetime.utcnow().isoformat() + "Z"
        }
        
    except Exception as e:
        logger.error(f"Failed to get statistics: {e}")
        raise HTTPException(
            status_code=500,
            detail="Failed to retrieve statistics"
        )


@app.get("/categories", tags=["information"])
async def get_threat_categories():
    """
    Get information about available threat categories.
    No API key required for this endpoint (public information).
    
    Returns:
        List of threat categories with descriptions
    """
    try:
        api = get_api_instance()
        
        # Convert ThreatCategory objects to dictionaries
        categories = [
            {
                "id": cat.id,
                "name": cat.name,
                "description": cat.description,
                "severity_weight": cat.severity_weight
            }
            for cat in api.threat_categories  # List comprehension
        ]
        
        return {
            "success": True,
            "categories": categories,
            "count": len(categories)
        }
        
    except Exception as e:
        logger.error(f"Failed to get threat categories: {e}")
        raise HTTPException(
            status_code=500,
            detail="Failed to retrieve threat categories"
        )

# ============================================================================
# SECTION 10: APPLICATION ENTRY POINT
# ============================================================================

if __name__ == "__main__":
    """
    Entry point for running the API server directly.
    In production, use uvicorn or gunicorn instead.
    This block only executes when script is run directly (not imported).
    """
    import uvicorn  # ASGI server for running FastAPI
    import os  # For environment variable access
    
    # Load environment variables from .env file if present
    from dotenv import load_dotenv
    load_dotenv()  # Loads variables from .env file into os.environ
    
    # Configuration from environment variables with defaults
    host = os.getenv("HOST", "0.0.0.0")  # Host to bind to (0.0.0.0 = all interfaces)
    port = int(os.getenv("PORT", "8000"))  # Port to listen on
    workers = int(os.getenv("WORKERS", "1"))  # Number of worker processes
    
    # Log startup information
    logger.info(f"Starting CyberGuard API server on {host}:{port}")
    logger.info(f"Workers: {workers}")
    logger.info(f"Environment: {os.getenv('ENVIRONMENT', 'development')}")
    
    # Run ASGI server
    uvicorn.run(
        "cyberguard_api:app",  # Module path to FastAPI app (this file: app variable)
        host=host,
        port=port,
        workers=workers,  # Number of worker processes
        reload=False,  # Disable auto-reload (set to True for development)
        log_level="info"  # Logging level for uvicorn
    )