In [None]:
# %% [markdown]
"""
# Inference Demo and Performance Analysis
## Humanoid Vision System - Real-time Inference

This notebook demonstrates:
1. Model loading and preparation
2. Real-time inference pipeline
3. Performance benchmarking
4. Visualization and analysis
5. Deployment testing
"""

# %% [markdown]
"""
## 1. Setup and Imports
"""

# %%
import sys
import os
sys.path.append('../src')

# Core libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2
import json
import time
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision

# Inference modules
from src.inference.engine import InferenceEngine
from src.inference.preprocessing import ImagePreprocessor
from src.inference.postprocessing import DetectionPostprocessor
from src.inference.visualizer import DetectionVisualizer
from src.inference.robot_interface import RobotCommunicationInterface

# Model
from src.models.hybrid_vision import HybridVisionSystem
from src.utils.logging import setup_logger

# Visualization
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets

# Performance monitoring
import psutil
import GPUtil
from memory_profiler import memory_usage

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

# Configuration
config = {
    'model': {
        'path': '../models/checkpoints/model_epoch_050_loss_1.234.pt',  # Example path
        'num_classes': 80,
        'image_size': (416, 416),
        'confidence_threshold': 0.5,
        'iou_threshold': 0.5
    },
    'inference': {
        'batch_size': 1,
        'use_amp': True,
        'warmup_runs': 10,
        'benchmark_runs': 100,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    },
    'visualization': {
        'show_confidence': True,
        'show_class_names': True,
        'color_scheme': 'categorical'
    }
}

# Create directories
os.makedirs('../results/inference', exist_ok=True)
os.makedirs('../results/visualizations', exist_ok=True)

# Initialize logger
logger = setup_logger('inference_demo')

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# %% [markdown]
"""
## 2. Model Loading and Preparation
"""

# %%
class ModelLoader:
    """Load and prepare model for inference."""
    
    def __init__(self, config):
        self.config = config
        self.device = torch.device(config['inference']['device'])
        
    def load_model(self, model_path=None):
        """Load model from checkpoint."""
        if model_path is None:
            model_path = self.config['model']['path']
        
        print(f"Loading model from: {model_path}")
        
        if not os.path.exists(model_path):
            print(f"Model file not found: {model_path}")
            print("Creating a new model for demonstration...")
            return self.create_demo_model()
        
        try:
            # Load checkpoint
            checkpoint = torch.load(model_path, map_location=self.device)
            
            # Create model
            model = HybridVisionSystem(
                config=self.config['model'],
                num_classes=self.config['model']['num_classes'],
                use_vit=True,
                use_rag=False
            ).to(self.device)
            
            # Load weights
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint)
            
            # Set to evaluation mode
            model.eval()
            
            print("✅ Model loaded successfully!")
            
            # Print model info
            self.print_model_info(model, checkpoint)
            
            return model
            
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Creating a new model for demonstration...")
            return self.create_demo_model()
    
    def create_demo_model(self):
        """Create a new model for demonstration."""
        print("Creating new model for demonstration...")
        
        model = HybridVisionSystem(
            config=self.config['model'],
            num_classes=self.config['model']['num_classes'],
            use_vit=True,
            use_rag=False
        ).to(self.device)
        
        # Initialize with random weights
        model.eval()
        
        print("✅ Demo model created successfully!")
        
        return model
    
    def print_model_info(self, model, checkpoint=None):
        """Print model information."""
        print("\n" + "="*50)
        print("MODEL INFORMATION")
        print("="*50)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        
        # Checkpoint info
        if checkpoint is not None:
            if 'epoch' in checkpoint:
                print(f"Trained for: {checkpoint['epoch']} epochs")
            if 'loss' in checkpoint:
                print(f"Checkpoint loss: {checkpoint['loss']:.4f}")
        
        # Device info
        print(f"Device: {self.device}")
        print(f"Using mixed precision: {self.config['inference']['use_amp']}")
        
        # Model components
        print("\nModel Components:")
        for name, module in model.named_children():
            if hasattr(module, 'parameters'):
                params = sum(p.numel() for p in module.parameters())
                print(f"  {name}: {params:,} parameters")
        
        print("="*50)

# %%
# Load model
model_loader = ModelLoader(config)
model = model_loader.load_model()

# %% [markdown]
"""
## 3. Inference Engine Setup
"""

# %%
class InferenceEngineDemo:
    """Demonstrate inference engine capabilities."""
    
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.device = torch.device(config['inference']['device'])
        
        # Create inference engine
        self.engine = InferenceEngine(
            model=model,
            image_size=config['model']['image_size'],
            confidence_threshold=config['model']['confidence_threshold'],
            iou_threshold=config['model']['iou_threshold'],
            device=self.device
        )
        
        # Create preprocessor
        self.preprocessor = ImagePreprocessor(
            image_size=config['model']['image_size'],
            normalize=True
        )
        
        # Create postprocessor
        self.postprocessor = DetectionPostprocessor(
            num_classes=config['model']['num_classes'],
            confidence_threshold=config['model']['confidence_threshold'],
            iou_threshold=config['model']['iou_threshold']
        )
        
        # Create visualizer
        self.visualizer = DetectionVisualizer(
            class_names=self.get_coco_class_names(),
            color_scheme=config['visualization']['color_scheme']
        )
        
    def get_coco_class_names(self):
        """Get COCO dataset class names."""
        return [
            'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
            'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench',
            'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
            'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
            'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
            'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
            'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
            'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
            'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
            'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
            'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
            'toothbrush'
        ]
    
    def load_sample_images(self, num_images=5):
        """Load sample images for demonstration."""
        print(f"\nLoading {num_images} sample images...")
        
        sample_images = []
        
        # Try to load from sample directory
        sample_dir = Path('../data/samples')
        if sample_dir.exists():
            image_files = list(sample_dir.glob('*.jpg')) + list(sample_dir.glob('*.png'))
            image_files = image_files[:num_images]
            
            for img_file in image_files:
                try:
                    image = Image.open(img_file)
                    image = np.array(image)
                    sample_images.append((str(img_file), image))
                except Exception as e:
                    print(f"Error loading {img_file}: {e}")
        
        # If no sample images found, create synthetic ones
        if not sample_images:
            print("No sample images found. Creating synthetic images...")
            
            for i in range(num_images):
                # Create synthetic image with random shapes
                image = np.random.randint(0, 255, (416, 416, 3), dtype=np.uint8)
                
                # Add some random rectangles to simulate objects
                for _ in range(np.random.randint(3, 8)):
                    x1 = np.random.randint(50, 366)
                    y1 = np.random.randint(50, 366)
                    x2 = x1 + np.random.randint(30, 100)
                    y2 = y1 + np.random.randint(30, 100)
                    color = np.random.randint(0, 255, 3)
                    cv2.rectangle(image, (x1, y1), (x2, y2), color.tolist(), -1)
                
                sample_images.append((f'synthetic_{i}.jpg', image))
        
        print(f"Loaded {len(sample_images)} images")
        
        return sample_images
    
    def run_single_inference(self, image_path_or_array):
        """Run inference on a single image."""
        print(f"\nRunning single image inference...")
        
        # Load and preprocess image
        if isinstance(image_path_or_array, str):
            print(f"Loading image from: {image_path_or_array}")
            image = Image.open(image_path_or_array)
            image_np = np.array(image)
        else:
            image_np = image_path_or_array
        
        # Display original image
        print(f"Original image shape: {image_np.shape}")
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))
        axes[0].imshow(image_np)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        # Run inference
        start_time = time.time()
        
        with torch.no_grad():
            # Preprocess
            input_tensor = self.preprocessor(image_np)
            
            # Inference
            if self.config['inference']['use_amp']:
                with torch.cuda.amp.autocast():
                    outputs = self.engine.infer(input_tensor)
            else:
                outputs = self.engine.infer(input_tensor)
        
        inference_time = (time.time() - start_time) * 1000  # Convert to ms
        
        # Postprocess
        detections = self.postprocessor(outputs)
        
        print(f"Inference time: {inference_time:.2f} ms")
        print(f"Detected objects: {len(detections['boxes'])}")
        
        # Visualize results
        visualized = self.visualizer.visualize_detections(
            image_np, 
            detections,
            show_confidence=self.config['visualization']['show_confidence'],
            show_class_names=self.config['visualization']['show_class_names']
        )
        
        axes[1].imshow(visualized)
        axes[1].set_title(f'Detections ({len(detections["boxes"])} objects)')
        axes[1].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Print detection details
        if len(detections['boxes']) > 0:
            print("\nDetection Details:")
            print("-" * 80)
            print(f"{'Class':<20} {'Confidence':<12} {'BBox (x1,y1,x2,y2)':<30}")
            print("-" * 80)
            
            for i in range(min(5, len(detections['boxes']))):  # Show top 5
                class_name = detections['class_names'][i]
                confidence = detections['scores'][i]
                bbox = detections['boxes'][i]
                
                print(f"{class_name:<20} {confidence:<12.4f} {str(bbox):<30}")
        
        return {
            'image': image_np,
            'detections': detections,
            'inference_time': inference_time
        }
    
    def run_batch_inference(self, images, batch_size=4):
        """Run inference on a batch of images."""
        print(f"\nRunning batch inference (batch_size={batch_size})...")
        
        if len(images) < batch_size:
            print(f"Warning: Only {len(images)} images available, using batch_size={len(images)}")
            batch_size = len(images)
        
        # Preprocess images
        batch_tensors = []
        original_images = []
        
        for img_name, img_array in images[:batch_size]:
            input_tensor = self.preprocessor(img_array)
            batch_tensors.append(input_tensor)
            original_images.append((img_name, img_array))
        
        # Stack batch
        batch_tensor = torch.cat(batch_tensors, dim=0)
        
        print(f"Batch tensor shape: {batch_tensor.shape}")
        
        # Run inference
        start_time = time.time()
        
        with torch.no_grad():
            if self.config['inference']['use_amp']:
                with torch.cuda.amp.autocast():
                    outputs = self.engine.infer_batch(batch_tensor)
            else:
                outputs = self.engine.infer_batch(batch_tensor)
        
        batch_time = (time.time() - start_time) * 1000  # Convert to ms
        
        print(f"Batch inference time: {batch_time:.2f} ms")
        print(f"Time per image: {batch_time/batch_size:.2f} ms")
        
        # Process results for each image
        all_results = []
        
        for i in range(batch_size):
            img_name, img_array = original_images[i]
            
            # Extract outputs for this image
            if isinstance(outputs, dict):
                # Handle dictionary outputs
                image_outputs = {k: v[i] if v is not None else None 
                               for k, v in outputs.items()}
            else:
                # Handle tensor outputs
                image_outputs = outputs[i]
            
            # Postprocess
            detections = self.postprocessor(image_outputs)
            
            all_results.append({
                'image_name': img_name,
                'image': img_array,
                'detections': detections
            })
            
            print(f"\nImage {i+1}: {img_name}")
            print(f"  Detected objects: {len(detections['boxes'])}")
        
        # Visualize batch results
        self.visualize_batch_results(all_results, batch_size)
        
        return all_results, batch_time
    
    def visualize_batch_results(self, results, batch_size):
        """Visualize batch inference results."""
        fig, axes = plt.subplots(2, min(4, batch_size), figsize=(16, 8))
        
        if batch_size == 1:
            axes = axes.reshape(2, 1)
        
        for i in range(min(4, batch_size)):
            result = results[i]
            img_array = result['image']
            detections = result['detections']
            
            # Original image
            axes[0, i].imshow(img_array)
            axes[0, i].set_title(f'Original {i+1}')
            axes[0, i].axis('off')
            
            # Detected image
            visualized = self.visualizer.visualize_detections(
                img_array,
                detections,
                show_confidence=True,
                show_class_names=True
            )
            
            axes[1, i].imshow(visualized)
            axes[1, i].set_title(f'Detected: {len(detections["boxes"])} objects')
            axes[1, i].axis('off')
        
        plt.suptitle(f'Batch Inference Results (Batch Size: {batch_size})', fontsize=16)
        plt.tight_layout()
        plt.show()

# %%
# Create inference engine
inference_demo = InferenceEngineDemo(model, config)

# Load sample images
sample_images = inference_demo.load_sample_images(num_images=5)

# Run single inference on first image
if sample_images:
    first_image_path, first_image_array = sample_images[0]
    result = inference_demo.run_single_inference(first_image_array)

# %% [markdown]
"""
## 4. Performance Benchmarking
"""

# %%
class PerformanceBenchmark:
    """Benchmark inference performance."""
    
    def __init__(self, inference_engine, config):
        self.engine = inference_engine
        self.config = config
        self.device = torch.device(config['inference']['device'])
        
    def benchmark_latency(self, num_runs=100, warmup_runs=10, batch_size=1):
        """Benchmark inference latency."""
        print(f"\n{'='*60}")
        print(f"LATENCY BENCHMARK")
        print(f"{'='*60}")
        print(f"Runs: {num_runs} (warmup: {warmup_runs})")
        print(f"Batch size: {batch_size}")
        print(f"Device: {self.device}")
        print(f"Mixed precision: {self.config['inference']['use_amp']}")
        
        # Create test input
        H, W = self.config['model']['image_size']
        test_input = torch.randn(batch_size, 3, H, W).to(self.device)
        
        latencies = []
        
        # Warmup runs
        print(f"\nWarmup runs ({warmup_runs})...")
        for i in range(warmup_runs):
            with torch.no_grad():
                if self.config['inference']['use_amp']:
                    with torch.cuda.amp.autocast():
                        _ = self.engine.model(test_input)
                else:
                    _ = self.engine.model(test_input)
            
            if (i + 1) % 5 == 0:
                print(f"  Warmup run {i + 1}/{warmup_runs}")
        
        # Benchmark runs
        print(f"\nBenchmark runs ({num_runs})...")
        for i in range(num_runs):
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            start_time = time.perf_counter()
            
            with torch.no_grad():
                if self.config['inference']['use_amp']:
                    with torch.cuda.amp.autocast():
                        _ = self.engine.model(test_input)
                else:
                    _ = self.engine.model(test_input)
            
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            end_time = time.perf_counter()
            
            latency = (end_time - start_time) * 1000  # Convert to ms
            latencies.append(latency)
            
            if (i + 1) % 20 == 0:
                print(f"  Run {i + 1}/{num_runs}: {latency:.2f} ms")
        
        # Calculate statistics
        latencies = np.array(latencies)
        
        print(f"\n{'='*60}")
        print("RESULTS:")
        print(f"{'='*60}")
        print(f"Mean latency: {latencies.mean():.2f} ms")
        print(f"Std latency: {latencies.std():.2f} ms")
        print(f"Min latency: {latencies.min():.2f} ms")
        print(f"Max latency: {latencies.max():.2f} ms")
        print(f"Median latency: {np.median(latencies):.2f} ms")
        print(f"Throughput: {1000/latencies.mean():.1f} FPS")
        
        # Percentiles
        percentiles = [50, 90, 95, 99]
        print(f"\nLatency Percentiles:")
        for p in percentiles:
            value = np.percentile(latencies, p)
            print(f"  P{p}: {value:.2f} ms")
        
        # Visualize results
        self.visualize_latency_benchmark(latencies, batch_size)
        
        return latencies
    
    def visualize_latency_benchmark(self, latencies, batch_size):
        """Visualize latency benchmark results."""
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
        # Histogram
        axes[0, 0].hist(latencies, bins=30, alpha=0.7, edgecolor='black')
        axes[0, 0].axvline(np.mean(latencies), color='red', linestyle='--',
                          label=f'Mean: {np.mean(latencies):.2f} ms')
        axes[0, 0].axvline(np.median(latencies), color='green', linestyle='--',
                          label=f'Median: {np.median(latencies):.2f} ms')
        axes[0, 0].set_xlabel('Latency (ms)')
        axes[0, 0].set_ylabel('Frequency')
        axes[0, 0].set_title('Latency Distribution')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # Time series
        axes[0, 1].plot(latencies, marker='o', linestyle='-', alpha=0.6)
        axes[0, 1].axhline(np.mean(latencies), color='red', linestyle='--',
                          label=f'Mean: {np.mean(latencies):.2f} ms')
        axes[0, 1].set_xlabel('Run')
        axes[0, 1].set_ylabel('Latency (ms)')
        axes[0, 1].set_title('Latency Over Time')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # CDF
        sorted_latencies = np.sort(latencies)
        cdf = np.arange(1, len(sorted_latencies) + 1) / len(sorted_latencies)
        axes[1, 0].plot(sorted_latencies, cdf, linewidth=2)
        axes[1, 0].set_xlabel('Latency (ms)')
        axes[1, 0].set_ylabel('CDF')
        axes[1, 0].set_title('Cumulative Distribution Function')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Add percentile lines
        percentiles = [50, 90, 95, 99]
        colors = ['red', 'orange', 'green', 'blue']
        for p, color in zip(percentiles, colors):
            value = np.percentile(latencies, p)
            axes[1, 0].axvline(value, color=color, linestyle='--', alpha=0.7,
                              label=f'P{p}: {value:.1f} ms')
        axes[1, 0].legend()
        
        # Box plot
        axes[1, 1].boxplot(latencies, vert=True, patch_artist=True)
        axes[1, 1].set_ylabel('Latency (ms)')
        axes[1, 1].set_title('Box Plot')
        axes[1, 1].grid(True, alpha=0.3)
        
        # Add statistics text
        stats_text = f"""
        Batch Size: {batch_size}
        Mean: {np.mean(latencies):.2f} ms
        Std: {np.std(latencies):.2f} ms
        Min: {np.min(latencies):.2f} ms
        Max: {np.max(latencies):.2f} ms
        FPS: {1000/np.mean(latencies):.1f}
        """
        axes[1, 1].text(0.95, 0.05, stats_text,
                       transform=axes[1, 1].transAxes,
                       verticalalignment='bottom',
                       horizontalalignment='right',
                       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.suptitle(f'Inference Latency Benchmark (Batch Size: {batch_size})', fontsize=16)
        plt.tight_layout()
        plt.show()
        
        # Create interactive visualization
        self.create_interactive_benchmark_viz(latencies, batch_size)
    
    def create_interactive_benchmark_viz(self, latencies, batch_size):
        """Create interactive benchmark visualization."""
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Latency Distribution', 'Latency Over Time',
                           'Cumulative Distribution', 'Box Plot'),
            vertical_spacing=0.1
        )
        
        # Histogram
        fig.add_trace(
            go.Histogram(x=latencies, nbinsx=30, name='Distribution',
                        marker_color='blue'),
            row=1, col=1
        )
        
        # Add mean and median lines
        fig.add_vline(x=np.mean(latencies), line_dash="dash", line_color="red",
                     annotation_text=f"Mean: {np.mean(latencies):.2f} ms", row=1, col=1)
        fig.add_vline(x=np.median(latencies), line_dash="dash", line_color="green",
                     annotation_text=f"Median: {np.median(latencies):.2f} ms", row=1, col=1)
        
        # Time series
        fig.add_trace(
            go.Scatter(y=latencies, mode='lines+markers', name='Latency',
                      line=dict(color='orange')),
            row=1, col=2
        )
        
        fig.add_hline(y=np.mean(latencies), line_dash="dash", line_color="red",
                     annotation_text=f"Mean", row=1, col=2)
        
        # CDF
        sorted_latencies = np.sort(latencies)
        cdf = np.arange(1, len(sorted_latencies) + 1) / len(sorted_latencies)
        
        fig.add_trace(
            go.Scatter(x=sorted_latencies, y=cdf, mode='lines', name='CDF',
                      line=dict(width=3, color='purple')),
            row=2, col=1
        )
        
        # Add percentile lines
        percentiles = [50, 90, 95, 99]
        percentile_values = [np.percentile(latencies, p) for p in percentiles]
        percentile_labels = [f'P{p}: {v:.1f} ms' for p, v in zip(percentiles, percentile_values)]
        
        for value, label in zip(percentile_values, percentile_labels):
            fig.add_vline(x=value, line_dash="dash", line_color="red", row=2, col=1,
                         annotation_text=label)
        
        # Box plot
        fig.add_trace(
            go.Box(y=latencies, name='Latency', boxmean='sd',
                  marker_color='lightblue'),
            row=2, col=2
        )
        
        fig.update_layout(
            height=800,
            width=1200,
            title_text=f"Inference Latency Benchmark (Batch Size: {batch_size})",
            showlegend=True
        )
        
        fig.show()
    
    def benchmark_memory(self, batch_sizes=[1, 2, 4, 8, 16]):
        """Benchmark memory usage for different batch sizes."""
        print(f"\n{'='*60}")
        print(f"MEMORY BENCHMARK")
        print(f"{'='*60}")
        print(f"Batch sizes: {batch_sizes}")
        print(f"Device: {self.device}")
        
        memory_results = []
        
        for batch_size in batch_sizes:
            print(f"\nTesting batch size: {batch_size}")
            
            # Clear cache
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
            # Measure memory before
            if torch.cuda.is_available():
                torch.cuda.reset_peak_memory_stats()
                start_memory = torch.cuda.memory_allocated()
            else:
                start_memory = psutil.Process().memory_info().rss / 1024**2  # MB
            
            # Create test input
            H, W = self.config['model']['image_size']
            test_input = torch.randn(batch_size, 3, H, W).to(self.device)
            
            # Run inference to allocate memory
            with torch.no_grad():
                if self.config['inference']['use_amp']:
                    with torch.cuda.amp.autocast():
                        _ = self.engine.model(test_input)
                else:
                    _ = self.engine.model(test_input)
            
            # Measure memory after
            if torch.cuda.is_available():
                peak_memory = torch.cuda.max_memory_allocated()
                memory_used = (peak_memory - start_memory) / 1024**3  # Convert to GB
            else:
                end_memory = psutil.Process().memory_info().rss / 1024**2
                memory_used = (end_memory - start_memory) / 1024  # Convert to GB
            
            memory_results.append((batch_size, memory_used))
            
            print(f"  Memory used: {memory_used:.3f} GB")
            
            # Clear cache
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # Analyze results
        print(f"\n{'='*60}")
        print("MEMORY USAGE ANALYSIS:")
        print(f"{'='*60}")
        
        batch_sizes_arr = np.array([r[0] for r in memory_results])
        memory_used_arr = np.array([r[1] for r in memory_results])
        
        for bs, mem in zip(batch_sizes_arr, memory_used_arr):
            print(f"Batch size {bs}: {mem:.3f} GB")
        
        # Calculate memory per sample
        memory_per_sample = memory_used_arr / batch_sizes_arr
        print(f"\nMemory per sample: {np.mean(memory_per_sample):.3f} GB")
        
        # Fit linear model
        coeffs = np.polyfit(batch_sizes_arr, memory_used_arr, 1)
        print(f"Memory model: {coeffs[0]:.3f} * batch_size + {coeffs[1]:.3f}")
        
        # Visualize
        self.visualize_memory_benchmark(batch_sizes_arr, memory_used_arr, coeffs)
        
        return memory_results
    
    def visualize_memory_benchmark(self, batch_sizes, memory_used, coeffs):
        """Visualize memory benchmark results."""
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Memory vs batch size
        axes[0].plot(batch_sizes, memory_used, 'bo-', linewidth=2, markersize=8)
        
        # Linear fit
        x_fit = np.linspace(min(batch_sizes), max(batch_sizes), 100)
        y_fit = coeffs[0] * x_fit + coeffs[1]
        axes[0].plot(x_fit, y_fit, 'r--', label=f'Fit: {coeffs[0]:.3f}x + {coeffs[1]:.3f}')
        
        axes[0].set_xlabel('Batch Size')
        axes[0].set_ylabel('Memory Used (GB)')
        axes[0].set_title('Memory Usage vs Batch Size')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3)
        
        # Add annotations
        for bs, mem in zip(batch_sizes, memory_used):
            axes[0].annotate(f'{mem:.3f} GB', 
                           xy=(bs, mem), 
                           xytext=(0, 10),
                           textcoords='offset points',
                           ha='center')
        
        # Memory per sample
        memory_per_sample = memory_used / batch_sizes
        
        axes[1].bar(range(len(batch_sizes)), memory_per_sample)
        axes[1].set_xlabel('Batch Size Index')
        axes[1].set_ylabel('Memory per Sample (GB)')
        axes[1].set_title('Memory Efficiency')
        axes[1].set_xticks(range(len(batch_sizes)))
        axes[1].set_xticklabels([f'BS={bs}' for bs in batch_sizes])
        axes[1].grid(True, alpha=0.3, axis='y')
        
        # Add value labels
        for i, mem_per in enumerate(memory_per_sample):
            axes[1].text(i, mem_per, f'{mem_per:.3f}', 
                        ha='center', va='bottom')
        
        plt.tight_layout()
        plt.show()
    
    def benchmark_batch_sizes(self, max_batch_size=32):
        """Find optimal batch size for throughput."""
        print(f"\n{'='*60}")
        print(f"BATCH SIZE OPTIMIZATION")
        print(f"{'='*60}")
        
        batch_sizes = [1, 2, 4, 8, 16, 32]
        batch_sizes = [bs for bs in batch_sizes if bs <= max_batch_size]
        
        results = []
        
        for batch_size in batch_sizes:
            print(f"\nBenchmarking batch size: {batch_size}")
            
            # Benchmark latency for this batch size
            latencies = self.benchmark_latency(
                num_runs=50,
                warmup_runs=5,
                batch_size=batch_size
            )
            
            mean_latency = np.mean(latencies)
            throughput = 1000 / mean_latency * batch_size  # Images per second
            
            results.append({
                'batch_size': batch_size,
                'mean_latency': mean_latency,
                'throughput': throughput,
                'efficiency': throughput / batch_size  # Throughput per GPU unit
            })
            
            print(f"  Throughput: {throughput:.1f} images/sec")
            print(f"  Efficiency: {throughput/batch_size:.1f} images/sec per batch unit")
        
        # Find optimal batch size
        df_results = pd.DataFrame(results)
        optimal_idx = df_results['throughput'].idxmax()
        optimal = df_results.iloc[optimal_idx]
        
        print(f"\n{'='*60}")
        print("OPTIMAL BATCH SIZE ANALYSIS:")
        print(f"{'='*60}")
        print(f"Optimal batch size: {optimal['batch_size']}")
        print(f"Max throughput: {optimal['throughput']:.1f} images/sec")
        print(f"Latency at optimal: {optimal['mean_latency']:.2f} ms")
        
        # Visualize
        self.visualize_batch_optimization(df_results)
        
        return df_results, optimal
    
    def visualize_batch_optimization(self, df_results):
        """Visualize batch size optimization results."""
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
        # Throughput vs batch size
        axes[0, 0].plot(df_results['batch_size'], df_results['throughput'], 'bo-', 
                       linewidth=2, markersize=8)
        axes[0, 0].set_xlabel('Batch Size')
        axes[0, 0].set_ylabel('Throughput (images/sec)')
        axes[0, 0].set_title('Throughput vs Batch Size')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Mark optimal point
        optimal_idx = df_results['throughput'].idxmax()
        optimal_bs = df_results.loc[optimal_idx, 'batch_size']
        optimal_throughput = df_results.loc[optimal_idx, 'throughput']
        axes[0, 0].plot(optimal_bs, optimal_throughput, 'r*', markersize=15, 
                       label=f'Optimal: BS={optimal_bs}')
        axes[0, 0].legend()
        
        # Latency vs batch size
        axes[0, 1].plot(df_results['batch_size'], df_results['mean_latency'], 'ro-',
                       linewidth=2, markersize=8)
        axes[0, 1].set_xlabel('Batch Size')
        axes[0, 1].set_ylabel('Latency (ms)')
        axes[0, 1].set_title('Latency vs Batch Size')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Efficiency vs batch size
        axes[1, 0].plot(df_results['batch_size'], df_results['efficiency'], 'go-',
                       linewidth=2, markersize=8)
        axes[1, 0].set_xlabel('Batch Size')
        axes[1, 0].set_ylabel('Efficiency (images/sec per batch unit)')
        axes[1, 0].set_title('Efficiency vs Batch Size')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Throughput-Latency tradeoff
        axes[1, 1].scatter(df_results['mean_latency'], df_results['throughput'],
                          c=df_results['batch_size'], s=100, cmap='viridis')
        axes[1, 1].set_xlabel('Latency (ms)')
        axes[1, 1].set_ylabel('Throughput (images/sec)')
        axes[1, 1].set_title('Throughput-Latency Tradeoff')
        axes[1, 1].grid(True, alpha=0.3)
        
        # Add batch size labels
        for i, row in df_results.iterrows():
            axes[1, 1].annotate(f"BS={row['batch_size']}", 
                              xy=(row['mean_latency'], row['throughput']),
                              xytext=(5, 5), textcoords='offset points')
        
        # Colorbar for batch sizes
        plt.colorbar(axes[1, 1].collections[0], ax=axes[1, 1], label='Batch Size')
        
        plt.suptitle('Batch Size Optimization Analysis', fontsize=16)
        plt.tight_layout()
        plt.show()

# %%
# Run performance benchmarks
benchmark = PerformanceBenchmark(inference_demo.engine, config)

# Benchmark latency
latencies = benchmark.benchmark_latency(
    num_runs=config['inference']['benchmark_runs'],
    warmup_runs=config['inference']['warmup_runs'],
    batch_size=config['inference']['batch_size']
)

# Benchmark memory
memory_results = benchmark.benchmark_memory(batch_sizes=[1, 2, 4, 8, 16])

# Find optimal batch size
batch_results, optimal = benchmark.benchmark_batch_sizes(max_batch_size=16)

# %% [markdown]
"""
## 5. Real-time Video Inference Demo
"""

# %%
class VideoInferenceDemo:
    """Demonstrate real-time video inference."""
    
    def __init__(self, inference_engine, config):
        self.engine = inference_engine
        self.config = config
        self.device = torch.device(config['inference']['device'])
        
        # Create visualizer
        self.visualizer = DetectionVisualizer(
            class_names=inference_engine.visualizer.class_names,
            color_scheme=config['visualization']['color_scheme']
        )
        
        # Statistics
        self.stats = {
            'frame_count': 0,
            'total_time': 0,
            'fps_history': [],
            'detection_history': []
        }
    
    def process_video_file(self, video_path, max_frames=100, output_path=None):
        """Process video file and display results."""
        print(f"\nProcessing video: {video_path}")
        print(f"Max frames: {max_frames}")
        
        # Open video
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Error opening video: {video_path}")
            print("Using synthetic video instead...")
            return self.create_synthetic_video(max_frames)
        
        # Get video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        print(f"Video properties: {width}x{height}, {fps:.1f} FPS, {total_frames} frames")
        
        # Prepare output writer
        if output_path:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
        else:
            out = None
        
        # Process frames
        frames_processed = 0
        processed_frames = []
        
        while frames_processed < max_frames:
            ret, frame = cap.read()
            if not ret:
                break
            
            # Convert BGR to RGB
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            # Run inference
            start_time = time.time()
            
            input_tensor = self.engine.preprocessor(frame_rgb)
            
            with torch.no_grad():
                if self.config['inference']['use_amp']:
                    with torch.cuda.amp.autocast():
                        outputs = self.engine.engine.infer(input_tensor)
                else:
                    outputs = self.engine.engine.infer(input_tensor)
            
            inference_time = (time.time() - start_time) * 1000
            
            # Postprocess
            detections = self.engine.postprocessor(outputs)
            
            # Visualize
            visualized = self.visualizer.visualize_detections(
                frame_rgb,
                detections,
                show_confidence=True,
                show_class_names=True
            )
            
            # Convert back to BGR for OpenCV
            visualized_bgr = cv2.cvtColor(visualized, cv2.COLOR_RGB2BGR)
            
            # Add FPS counter
            fps_text = f"FPS: {1000/inference_time:.1f}" if inference_time > 0 else "FPS: N/A"
            cv2.putText(visualized_bgr, fps_text, (10, 30),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            # Add detection count
            det_text = f"Detections: {len(detections['boxes'])}"
            cv2.putText(visualized_bgr, det_text, (10, 70),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            # Store processed frame
            processed_frames.append(visualized_bgr)
            
            # Write to output
            if out:
                out.write(visualized_bgr)
            
            # Update statistics
            self.stats['frame_count'] += 1
            self.stats['total_time'] += inference_time
            self.stats['fps_history'].append(1000/inference_time if inference_time > 0 else 0)
            self.stats['detection_history'].append(len(detections['boxes']))
            
            frames_processed += 1
            
            if frames_processed % 10 == 0:
                print(f"  Processed {frames_processed}/{min(max_frames, total_frames)} frames")
        
        # Release resources
        cap.release()
        if out:
            out.release()
        
        print(f"\nVideo processing completed!")
        print(f"Frames processed: {frames_processed}")
        
        # Display sample frames
        self.display_sample_frames(processed_frames)
        
        # Analyze performance
        self.analyze_video_performance()
        
        return processed_frames
    
    def create_synthetic_video(self, num_frames=50):
        """Create synthetic video for demonstration."""
        print("Creating synthetic video...")
        
        processed_frames = []
        
        for i in range(num_frames):
            # Create synthetic frame
            frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
            
            # Add moving rectangle
            x = int(300 + 200 * np.sin(i * 0.1))
            y = int(200 + 100 * np.cos(i * 0.1))
            cv2.rectangle(frame, (x, y), (x+100, y+100), (0, 255, 0), -1)
            
            # Add static rectangle
            cv2.rectangle(frame, (100, 100), (200, 200), (255, 0, 0), -1)
            
            # Run inference (simulated for demo)
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            # Create simulated detections
            detections = {
                'boxes': [
                    [x/640, y/480, (x+100)/640, (y+100)/480],  # Moving object
                    [100/640, 100/480, 200/640, 200/480]       # Static object
                ],
                'scores': [0.9, 0.8],
                'class_ids': [0, 1],
                'class_names': ['person', 'car']
            }
            
            # Visualize
            visualized = self.visualizer.visualize_detections(
                frame_rgb,
                detections,
                show_confidence=True,
                show_class_names=True
            )
            
            visualized_bgr = cv2.cvtColor(visualized, cv2.COLOR_RGB2BGR)
            
            # Add frame info
            cv2.putText(visualized_bgr, f"Frame {i+1}", (10, 30),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            cv2.putText(visualized_bgr, "SYNTHETIC DEMO", (10, 70),
                       cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
            processed_frames.append(visualized_bgr)
        
        print(f"Created {num_frames} synthetic frames")
        self.display_sample_frames(processed_frames)
        
        return processed_frames
    
    def display_sample_frames(self, frames, sample_count=4):
        """Display sample frames from video."""
        if not frames:
            return
        
        print(f"\nDisplaying {sample_count} sample frames...")
        
        fig, axes = plt.subplots(1, min(sample_count, len(frames)), figsize=(16, 4))
        
        if sample_count == 1:
            axes = [axes]
        
        indices = np.linspace(0, len(frames)-1, sample_count, dtype=int)
        
        for idx, ax_idx in enumerate(indices):
            frame = frames[ax_idx]
            
            # Convert BGR to RGB for display
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            axes[idx].imshow(frame_rgb)
            axes[idx].set_title(f'Frame {ax_idx + 1}')
            axes[idx].axis('off')
        
        plt.suptitle('Video Inference Results', fontsize=16)
        plt.tight_layout()
        plt.show()
    
    def analyze_video_performance(self):
        """Analyze video inference performance."""
        print(f"\n{'='*60}")
        print("VIDEO INFERENCE PERFORMANCE ANALYSIS")
        print(f"{'='*60}")
        
        if not self.stats['fps_history']:
            print("No performance data available.")
            return
        
        fps_history = np.array(self.stats['fps_history'])
        detection_history = np.array(self.stats['detection_history'])
        
        print(f"Frames processed: {self.stats['frame_count']}")
        print(f"Total time: {self.stats['total_time']/1000:.2f} seconds")
        print(f"Average FPS: {np.mean(fps_history):.1f}")
        print(f"Std FPS: {np.std(fps_history):.1f}")
        print(f"Min FPS: {np.min(fps_history):.1f}")
        print(f"Max FPS: {np.max(fps_history):.1f}")
        print(f"Average detections per frame: {np.mean(detection_history):.1f}")
        
        # Visualize performance
        self.visualize_video_performance(fps_history, detection_history)
    
    def visualize_video_performance(self, fps_history, detection_history):
        """Visualize video inference performance."""
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
        # FPS over time
        axes[0, 0].plot(fps_history, linewidth=2)
        axes[0, 0].axhline(np.mean(fps_history), color='red', linestyle='--',
                          label=f'Mean: {np.mean(fps_history):.1f} FPS')
        axes[0, 0].set_xlabel('Frame')
        axes[0, 0].set_ylabel('FPS')
        axes[0, 0].set_title('FPS Over Time')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # FPS distribution
        axes[0, 1].hist(fps_history, bins=30, alpha=0.7, edgecolor='black')
        axes[0, 1].axvline(np.mean(fps_history), color='red', linestyle='--',
                          label=f'Mean: {np.mean(fps_history):.1f}')
        axes[0, 1].axvline(np.median(fps_history), color='green', linestyle='--',
                          label=f'Median: {np.median(fps_history):.1f}')
        axes[0, 1].set_xlabel('FPS')
        axes[0, 1].set_ylabel('Frequency')
        axes[0, 1].set_title('FPS Distribution')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Detections over time
        axes[1, 0].plot(detection_history, linewidth=2, color='orange')
        axes[1, 0].axhline(np.mean(detection_history), color='red', linestyle='--',
                          label=f'Mean: {np.mean(detection_history):.1f}')
        axes[1, 0].set_xlabel('Frame')
        axes[1, 0].set_ylabel('Number of Detections')
        axes[1, 0].set_title('Detections Over Time')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # FPS vs Detections scatter
        axes[1, 1].scatter(detection_history, fps_history, alpha=0.6)
        axes[1, 1].set_xlabel('Number of Detections')
        axes[1, 1].set_ylabel('FPS')
        axes[1, 1].set_title('FPS vs Detections')
        axes[1, 1].grid(True, alpha=0.3)
        
        # Add correlation coefficient
        correlation = np.corrcoef(detection_history, fps_history)[0, 1]
        axes[1, 1].text(0.05, 0.95, f'Correlation: {correlation:.3f}',
                       transform=axes[1, 1].transAxes,
                       verticalalignment='top',
                       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.suptitle('Video Inference Performance Analysis', fontsize=16)
        plt.tight_layout()
        plt.show()
    
    def live_camera_demo(self, camera_index=0, duration_sec=10):
        """Live camera inference demo."""
        print(f"\n{'='*60}")
        print("LIVE CAMERA INFERENCE DEMO")
        print(f"{'='*60}")
        print(f"Duration: {duration_sec} seconds")
        print("Press 'q' to quit early")
        
        # Open camera
        cap = cv2.VideoCapture(camera_index)
        if not cap.isOpened():
            print(f"Error opening camera {camera_index}")
            print("Using synthetic camera feed instead...")
            return self.synthetic_camera_demo(duration_sec)
        
        # Get camera properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        print(f"Camera: {width}x{height}, {fps:.1f} FPS")
        
        # Create window
        cv2.namedWindow('Live Inference', cv2.WINDOW_NORMAL)
        
        # Statistics
        frame_count = 0
        start_time = time.time()
        fps_history = []
        
        print("\nStarting live inference...")
        
        while time.time() - start_time < duration_sec:
            ret, frame = cap.read()
            if not ret:
                print("Error reading frame from camera")
                break
            
            # Convert BGR to RGB
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            # Run inference
            frame_start = time.time()
            
            input_tensor = self.engine.preprocessor(frame_rgb)
            
            with torch.no_grad():
                if self.config['inference']['use_amp']:
                    with torch.cuda.amp.autocast():
                        outputs = self.engine.engine.infer(input_tensor)
                else:
                    outputs = self.engine.engine.infer(input_tensor)
            
            inference_time = (time.time() - frame_start) * 1000
            
            # Postprocess
            detections = self.engine.postprocessor(outputs)
            
            # Visualize
            visualized = self.visualizer.visualize_detections(
                frame_rgb,
                detections,
                show_confidence=True,
                show_class_names=True
            )
            
            # Convert back to BGR for display
            visualized_bgr = cv2.cvtColor(visualized, cv2.COLOR_RGB2BGR)
            
            # Add overlay info
            current_fps = 1000 / inference_time if inference_time > 0 else 0
            fps_history.append(current_fps)
            
            overlay_text = [
                f"FPS: {current_fps:.1f}",
                f"Detections: {len(detections['boxes'])}",
                f"Frame: {frame_count}",
                f"Time: {time.time() - start_time:.1f}s"
            ]
            
            y_offset = 30
            for text in overlay_text:
                cv2.putText(visualized_bgr, text, (10, y_offset),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
                y_offset += 30
            
            # Display
            cv2.imshow('Live Inference', visualized_bgr)
            
            frame_count += 1
            
            # Check for quit
            if cv2.waitKey(1) & 0xFF == ord('q'):
                print("\nUser requested quit")
                break
        
        # Cleanup
        cap.release()
        cv2.destroyAllWindows()
        
        # Performance summary
        total_time = time.time() - start_time
        avg_fps = np.mean(fps_history) if fps_history else 0
        
        print(f"\nLive demo completed!")
        print(f"Frames processed: {frame_count}")
        print(f"Total time: {total_time:.1f} seconds")
        print(f"Average FPS: {avg_fps:.1f}")
        
        return fps_history
    
    def synthetic_camera_demo(self, duration_sec=10):
        """Synthetic camera demo."""
        print("Running synthetic camera demo...")
        
        fps_history = []
        start_time = time.time()
        frame_count = 0
        
        cv2.namedWindow('Synthetic Camera', cv2.WINDOW_NORMAL)
        
        while time.time() - start_time < duration_sec:
            # Create synthetic frame
            frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
            
            # Add moving object
            t = time.time() - start_time
            x = int(320 + 200 * np.sin(t * 2))
            y = int(240 + 150 * np.cos(t * 1.5))
            cv2.circle(frame, (x, y), 50, (0, 255, 0), -1)
            
            # Add static objects
            cv2.rectangle(frame, (100, 100), (200, 200), (255, 0, 0), -1)
            cv2.circle(frame, (500, 300), 40, (0, 0, 255), -1)
            
            # Simulate inference time
            time.sleep(0.033)  # ~30 FPS
            
            # Add overlay
            current_time = time.time() - start_time
            current_fps = 30  # Simulated
            
            fps_history.append(current_fps)
            
            overlay_text = [
                "SYNTHETIC CAMERA DEMO",
                f"FPS: {current_fps:.1f}",
                f"Time: {current_time:.1f}s",
                f"Frame: {frame_count}",
                "Press 'q' to quit"
            ]
            
            y_offset = 30
            for text in overlay_text:
                cv2.putText(frame, text, (10, y_offset),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
                y_offset += 30
            
            cv2.imshow('Synthetic Camera', frame)
            
            frame_count += 1
            
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        
        cv2.destroyAllWindows()
        
        print(f"\nSynthetic demo completed!")
        print(f"Frames: {frame_count}")
        print(f"Average FPS: {np.mean(fps_history):.1f}")
        
        return fps_history

# %%
# Run video inference demo
video_demo = VideoInferenceDemo(inference_demo, config)

# Process video file (if available)
video_path = '../data/samples/sample_video.mp4'  # Example path
if os.path.exists(video_path):
    processed_frames = video_demo.process_video_file(
        video_path, 
        max_frames=50,
        output_path='../results/inference/output_video.mp4'
    )
else:
    print(f"\nVideo file not found: {video_path}")
    print("Running synthetic video demo instead...")
    processed_frames = video_demo.create_synthetic_video(num_frames=50)

# Live camera demo (optional)
run_live_demo = False  # Set to True to run live demo
if run_live_demo:
    print("\nStarting live camera demo...")
    fps_history = video_demo.live_camera_demo(duration_sec=5)
else:
    print("\nSkipping live camera demo (set run_live_demo=True to enable)")

# %% [markdown]
"""
## 6. Deployment Readiness Test
"""

# %%
class DeploymentTester:
    """Test deployment readiness."""
    
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.device = torch.device(config['inference']['device'])
        
    def run_comprehensive_tests(self):
        """Run comprehensive deployment tests."""
        print(f"\n{'='*60}")
        print("DEPLOYMENT READINESS TESTS")
        print(f"{'='*60}")
        
        test_results = []
        
        # Test 1: Model loading and validation
        test_results.append(self.test_model_loading())
        
        # Test 2: Inference correctness
        test_results.append(self.test_inference_correctness())
        
        # Test 3: Performance requirements
        test_results.append(self.test_performance_requirements())
        
        # Test 4: Memory constraints
        test_results.append(self.test_memory_constraints())
        
        # Test 5: Robustness
        test_results.append(self.test_robustness())
        
        # Test 6: Export capabilities
        test_results.append(self.test_export_capabilities())
        
        # Generate overall assessment
        self.generate_deployment_assessment(test_results)
        
        return test_results
    
    def test_model_loading(self):
        """Test model loading and validation."""
        print("\n[TEST 1] Model Loading and Validation:")
        
        results = {
            'name': 'Model Loading',
            'status': 'PASS',
            'details': [],
            'recommendations': []
        }
        
        try:
            # Check model is in eval mode
            if self.model.training:
                results['status'] = 'FAIL'
                results['details'].append('Model is in training mode')
                results['recommendations'].append('Set model to eval mode: model.eval()')
            else:
                results['details'].append('✅ Model is in evaluation mode')
            
            # Check device placement
            if next(self.model.parameters()).device != self.device:
                results['status'] = 'FAIL'
                results['details'].append(f'Model not on correct device. Expected: {self.device}, Got: {next(self.model.parameters()).device}')
                results['recommendations'].append(f'Move model to device: model.to({self.device})')
            else:
                results['details'].append(f'✅ Model correctly placed on {self.device}')
            
            # Check parameter count
            total_params = sum(p.numel() for p in self.model.parameters())
            results['details'].append(f'✅ Total parameters: {total_params:,}')
            
            # Test forward pass
            H, W = self.config['model']['image_size']
            test_input = torch.randn(1, 3, H, W).to(self.device)
            
            with torch.no_grad():
                if self.config['inference']['use_amp']:
                    with torch.cuda.amp.autocast():
                        output = self.model(test_input, task='detection')
                else:
                    output = self.model(test_input, task='detection')
            
            if isinstance(output, dict) and 'detections' in output:
                results['details'].append('✅ Forward pass successful with detection output')
            else:
                results['details'].append('✅ Forward pass successful')
            
            print(f"  Status: {results['status']}")
            for detail in results['details']:
                print(f"    {detail}")
            
        except Exception as e:
            results['status'] = 'FAIL'
            results['details'].append(f'❌ Error during test: {str(e)}')
            print(f"  Status: {results['status']}")
            print(f"  Error: {str(e)}")
        
        return results
    
    def test_inference_correctness(self):
        """Test inference correctness."""
        print("\n[TEST 2] Inference Correctness:")
        
        results = {
            'name': 'Inference Correctness',
            'status': 'PASS',
            'details': [],
            'recommendations': []
        }
        
        try:
            # Create test images with known objects
            test_cases = self.create_test_cases()
            
            correct_predictions = 0
            total_predictions = 0
            
            for i, (image, expected_classes) in enumerate(test_cases):
                # Run inference
                input_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float().to(self.device)
                input_tensor = input_tensor / 255.0  # Normalize
                
                with torch.no_grad():
                    if self.config['inference']['use_amp']:
                        with torch.cuda.amp.autocast():
                            output = self.model(input_tensor, task='detection')
                    else:
                        output = self.model(input_tensor, task='detection')
                
                # Simple correctness check (for demo purposes)
                # In real testing, would compare with ground truth
                if isinstance(output, dict) and 'detections' in output:
                    detections = output['detections']
                    # Count as correct if we get some detections
                    if detections.shape[1] > 0:  # Has some detections
                        correct_predictions += 1
                
                total_predictions += 1
            
            accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
            
            if accuracy >= 0.8:
                results['details'].append(f'✅ Inference accuracy: {accuracy:.1%}')
            elif accuracy >= 0.5:
                results['details'].append(f'⚠️ Moderate inference accuracy: {accuracy:.1%}')
                results['recommendations'].append('Consider fine-tuning on target domain')
            else:
                results['status'] = 'FAIL'
                results['details'].append(f'❌ Low inference accuracy: {accuracy:.1%}')
                results['recommendations'].append('Model needs significant improvement')
            
            results['details'].append(f'Tested on {total_predictions} synthetic images')
            
            print(f"  Status: {results['status']}")
            for detail in results['details']:
                print(f"    {detail}")
            
        except Exception as e:
            results['status'] = 'FAIL'
            results['details'].append(f'❌ Error during test: {str(e)}')
            print(f"  Status: {results['status']}")
            print(f"  Error: {str(e)}")
        
        return results
    
    def create_test_cases(self):
        """Create test cases for correctness testing."""
        test_cases = []
        
        # Create synthetic test images
        for i in range(5):
            # Create image with geometric shapes
            image = np.zeros((416, 416, 3), dtype=np.uint8)
            
            # Add different colored rectangles
            colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]
            
            for j, color in enumerate(colors[:3]):  # Add 3 objects
                x = 50 + j * 100
                y = 50 + i * 70
                cv2.rectangle(image, (x, y), (x+80, y+80), color, -1)
            
            test_cases.append((image, [0, 1, 2]))  # Expected class IDs
        
        return test_cases
    
    def test_performance_requirements(self):
        """Test performance requirements."""
        print("\n[TEST 3] Performance Requirements:")
        
        results = {
            'name': 'Performance',
            'status': 'PASS',
            'details': [],
            'recommendations': []
        }
        
        try:
            # Requirements
            req_latency = 50  # ms (for real-time robotics)
            req_fps = 20      # FPS minimum
            req_batch_latency = 100  # ms for batch size 4
            
            # Test single image latency
            H, W = self.config['model']['image_size']
            test_input = torch.randn(1, 3, H, W).to(self.device)
            
            # Warmup
            for _ in range(10):
                with torch.no_grad():
                    _ = self.model(test_input, task='detection')
            
            # Benchmark
            latencies = []
            for _ in range(50):
                torch.cuda.synchronize() if torch.cuda.is_available() else None
                start = time.perf_counter()
                
                with torch.no_grad():
                    if self.config['inference']['use_amp']:
                        with torch.cuda.amp.autocast():
                            _ = self.model(test_input, task='detection')
                    else:
                        _ = self.model(test_input, task='detection')
                
                torch.cuda.synchronize() if torch.cuda.is_available() else None
                end = time.perf_counter()
                latencies.append((end - start) * 1000)
            
            mean_latency = np.mean(latencies)
            fps = 1000 / mean_latency
            
            # Check requirements
            if mean_latency <= req_latency:
                results['details'].append(f'✅ Latency: {mean_latency:.1f} ms (≤ {req_latency} ms)')
            else:
                results['status'] = 'WARNING'
                results['details'].append(f'⚠️ Latency: {mean_latency:.1f} ms (> {req_latency} ms)')
                results['recommendations'].append('Optimize model for lower latency')
            
            if fps >= req_fps:
                results['details'].append(f'✅ Throughput: {fps:.1f} FPS (≥ {req_fps} FPS)')
            else:
                results['status'] = 'FAIL'
                results['details'].append(f'❌ Throughput: {fps:.1f} FPS (< {req_fps} FPS)')
                results['recommendations'].append('Improve inference speed')
            
            # Test batch latency
            batch_input = torch.randn(4, 3, H, W).to(self.device)
            
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            start = time.perf_counter()
            
            with torch.no_grad():
                if self.config['inference']['use_amp']:
                    with torch.cuda.amp.autocast():
                        _ = self.model(batch_input, task='detection')
                else:
                    _ = self.model(batch_input, task='detection')
            
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            batch_latency = (time.perf_counter() - start) * 1000
            
            if batch_latency <= req_batch_latency:
                results['details'].append(f'✅ Batch (4) latency: {batch_latency:.1f} ms')
            else:
                results['details'].append(f'⚠️ Batch (4) latency: {batch_latency:.1f} ms')
            
            results['details'].append(f'Latency std: {np.std(latencies):.1f} ms')
            
            print(f"  Status: {results['status']}")
            for detail in results['details']:
                print(f"    {detail}")
            
        except Exception as e:
            results['status'] = 'FAIL'
            results['details'].append(f'❌ Error during test: {str(e)}')
            print(f"  Status: {results['status']}")
            print(f"  Error: {str(e)}")
        
        return results
    
    def test_memory_constraints(self):
        """Test memory constraints."""
        print("\n[TEST 4] Memory Constraints:")
        
        results = {
            'name': 'Memory',
            'status': 'PASS',
            'details': [],
            'recommendations': []
        }
        
        try:
            # Requirements (for edge deployment)
            req_single_image_memory = 2.0  # GB maximum for single image
            req_model_memory = 1.0  # GB maximum for model
            
            if torch.cuda.is_available():
                # Get GPU memory info
                gpu = GPUtil.getGPUs()[0]
                total_memory = gpu.memoryTotal / 1024  # Convert to GB
                
                # Clear cache and measure
                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()
                
                initial_memory = torch.cuda.memory_allocated() / 1024**3
                
                # Load model and do inference
                H, W = self.config['model']['image_size']
                test_input = torch.randn(1, 3, H, W).to(self.device)
                
                with torch.no_grad():
                    if self.config['inference']['use_amp']:
                        with torch.cuda.amp.autocast():
                            _ = self.model(test_input, task='detection')
                    else:
                        _ = self.model(test_input, task='detection')
                
                peak_memory = torch.cuda.max_memory_allocated() / 1024**3
                memory_used = peak_memory - initial_memory
                
                # Check requirements
                if memory_used <= req_single_image_memory:
                    results['details'].append(f'✅ Memory usage: {memory_used:.2f} GB (≤ {req_single_image_memory} GB)')
                else:
                    results['status'] = 'WARNING'
                    results['details'].append(f'⚠️ Memory usage: {memory_used:.2f} GB (> {req_single_image_memory} GB)')
                    results['recommendations'].append('Reduce model size or use memory optimization')
                
                # Estimate model memory
                param_memory = sum(p.numel() * p.element_size() for p in self.model.parameters()) / 1024**3
                
                if param_memory <= req_model_memory:
                    results['details'].append(f'✅ Model parameters: {param_memory:.2f} GB')
                else:
                    results['details'].append(f'⚠️ Model parameters: {param_memory:.2f} GB')
                
                results['details'].append(f'GPU memory available: {total_memory:.1f} GB')
                results['details'].append(f'Peak memory allocated: {peak_memory:.2f} GB')
                
            else:
                # CPU memory test
                import psutil
                process = psutil.Process()
                
                initial_memory = process.memory_info().rss / 1024**3
                
                # Do inference
                H, W = self.config['model']['image_size']
                test_input = torch.randn(1, 3, H, W)
                
                with torch.no_grad():
                    _ = self.model(test_input, task='detection')
                
                final_memory = process.memory_info().rss / 1024**3
                memory_used = final_memory - initial_memory
                
                results['details'].append(f'CPU memory usage: {memory_used:.2f} GB')
                results['details'].append('⚠️ GPU memory test skipped (CPU only)')
            
            print(f"  Status: {results['status']}")
            for detail in results['details']:
                print(f"    {detail}")
            
        except Exception as e:
            results['status'] = 'WARNING'
            results['details'].append(f'⚠️ Memory test incomplete: {str(e)}')
            print(f"  Status: {results['status']}")
            print(f"  Warning: {str(e)}")
        
        return results
    
    def test_robustness(self):
        """Test model robustness."""
        print("\n[TEST 5] Robustness:")
        
        results = {
            'name': 'Robustness',
            'status': 'PASS',
            'details': [],
            'recommendations': []
        }
        
        try:
            # Test with different input sizes
            test_sizes = [(320, 320), (416, 416), (512, 512), (640, 640)]
            
            for size in test_sizes:
                H, W = size
                test_input = torch.randn(1, 3, H, W).to(self.device)
                
                try:
                    with torch.no_grad():
                        if self.config['inference']['use_amp']:
                            with torch.cuda.amp.autocast():
                                output = self.model(test_input, task='detection')
                        else:
                            output = self.model(test_input, task='detection')
                    
                    results['details'].append(f'✅ Input size {H}x{W}: Successful')
                except Exception as e:
                    results['status'] = 'WARNING'
                    results['details'].append(f'⚠️ Input size {H}x{W}: Failed - {str(e)}')
            
            # Test with extreme values
            extreme_inputs = [
                torch.zeros(1, 3, 416, 416).to(self.device),  # All zeros
                torch.ones(1, 3, 416, 416).to(self.device),   # All ones
                torch.randn(1, 3, 416, 416).to(self.device) * 10,  # High variance
            ]
            
            for i, inp in enumerate(extreme_inputs):
                try:
                    with torch.no_grad():
                        output = self.model(inp, task='detection')
                    
                    if not torch.any(torch.isnan(output)):
                        results['details'].append(f'✅ Extreme input {i+1}: Stable')
                    else:
                        results['status'] = 'WARNING'
                        results['details'].append(f'⚠️ Extreme input {i+1}: NaN detected')
                except Exception as e:
                    results['status'] = 'WARNING'
                    results['details'].append(f'⚠️ Extreme input {i+1}: Failed')
            
            print(f"  Status: {results['status']}")
            for detail in results['details'][:5]:  # Show first 5 details
                print(f"    {detail}")
            if len(results['details']) > 5:
                print(f"    ... and {len(results['details']) - 5} more")
            
        except Exception as e:
            results['status'] = 'WARNING'
            results['details'].append(f'⚠️ Robustness test incomplete: {str(e)}')
            print(f"  Status: {results['status']}")
            print(f"  Warning: {str(e)}")
        
        return results
    
    def test_export_capabilities(self):
        """Test model export capabilities."""
        print("\n[TEST 6] Export Capabilities:")
        
        results = {
            'name': 'Export',
            'status': 'PASS',
            'details': [],
            'recommendations': []
        }
        
        try:
            # Test TorchScript export
            H, W = self.config['model']['image_size']
            example_input = torch.randn(1, 3, H, W).to(self.device)
            
            try:
                traced_model = torch.jit.trace(self.model, example_input)
                results['details'].append('✅ TorchScript: Export successful')
                
                # Test inference with traced model
                with torch.no_grad():
                    traced_output = traced_model(example_input, task='detection')
                results['details'].append('✅ TorchScript: Inference successful')
            except Exception as e:
                results['status'] = 'WARNING'
                results['details'].append(f'⚠️ TorchScript: Export failed - {str(e)}')
                results['recommendations'].append('Fix model for TorchScript compatibility')
            
            # Test ONNX export (if available)
            try:
                import onnx
                import onnxruntime
                
                onnx_path = '../results/inference/model.onnx'
                
                torch.onnx.export(
                    self.model,
                    example_input,
                    onnx_path,
                    opset_version=11,
                    input_names=['input'],
                    output_names=['output'],
                    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
                )
                
                results['details'].append('✅ ONNX: Export successful')
                
                # Validate ONNX model
                onnx_model = onnx.load(onnx_path)
                onnx.checker.check_model(onnx_model)
                results['details'].append('✅ ONNX: Model validation passed')
                
            except ImportError:
                results['details'].append('ℹ️ ONNX: Package not available')
            except Exception as e:
                results['details'].append(f'⚠️ ONNX: Export failed - {str(e)}')
            
            print(f"  Status: {results['status']}")
            for detail in results['details']:
                print(f"    {detail}")
            
        except Exception as e:
            results['status'] = 'WARNING'
            results['details'].append(f'⚠️ Export test incomplete: {str(e)}')
            print(f"  Status: {results['status']}")
            print(f"  Warning: {str(e)}")
        
        return results
    
    def generate_deployment_assessment(self, test_results):
        """Generate overall deployment assessment."""
        print(f"\n{'='*60}")
        print("DEPLOYMENT ASSESSMENT")
        print(f"{'='*60}")
        
        # Count results
        status_counts = {'PASS': 0, 'WARNING': 0, 'FAIL': 0}
        all_recommendations = []
        
        for result in test_results:
            status_counts[result['status']] += 1
            all_recommendations.extend(result['recommendations'])
        
        # Overall status
        if status_counts['FAIL'] > 0:
            overall_status = '❌ NOT READY'
        elif status_counts['WARNING'] > 0:
            overall_status = '⚠️ READY WITH WARNINGS'
        else:
            overall_status = '✅ READY FOR DEPLOYMENT'
        
        print(f"Overall Status: {overall_status}")
        print(f"\nTest Results: PASS={status_counts['PASS']}, WARNING={status_counts['WARNING']}, FAIL={status_counts['FAIL']}")
        
        # Show failed tests
        failed_tests = [r['name'] for r in test_results if r['status'] == 'FAIL']
        if failed_tests:
            print(f"\nFailed Tests: {', '.join(failed_tests)}")
        
        # Show warnings
        warning_tests = [r['name'] for r in test_results if r['status'] == 'WARNING']
        if warning_tests:
            print(f"\nTests with Warnings: {', '.join(warning_tests)}")
        
        # Recommendations
        if all_recommendations:
            print(f"\nKey Recommendations:")
            for i, rec in enumerate(set(all_recommendations), 1):
                print(f"  {i}. {rec}")
        
        # Deployment readiness score
        total_tests = len(test_results)
        readiness_score = (status_counts['PASS'] + 0.5 * status_counts['WARNING']) / total_tests
        
        print(f"\nDeployment Readiness Score: {readiness_score:.1%}")
        
        if readiness_score >= 0.9:
            print("Assessment: Excellent - Ready for production deployment")
        elif readiness_score >= 0.7:
            print("Assessment: Good - Ready with minor optimizations")
        elif readiness_score >= 0.5:
            print("Assessment: Fair - Needs significant improvements")
        else:
            print("Assessment: Poor - Not ready for deployment")
        
        # Next steps
        print(f"\nNext Steps:")
        if status_counts['FAIL'] > 0:
            print("  1. Address all FAILED tests immediately")
        if status_counts['WARNING'] > 0:
            print("  2. Review and address WARNING tests")
        print("  3. Run integration tests with target hardware")
        print("  4. Perform stress testing")
        print("  5. Create deployment package")
        
        return overall_status, readiness_score

# %%
# Run deployment tests
deployment_tester = DeploymentTester(model, config)
test_results = deployment_tester.run_comprehensive_tests()

# %% [markdown]
"""
## 7. Interactive Demo with Widgets
"""

# %%
class InteractiveDemo:
    """Interactive demo with IPython widgets."""
    
    def __init__(self, inference_demo, config):
        self.inference_demo = inference_demo
        self.config = config
        
        # Create widgets
        self.create_widgets()
        
    def create_widgets(self):
        """Create interactive widgets."""
        print("\nCreating interactive demo...")
        
        # Image upload widget
        self.upload_widget = widgets.FileUpload(
            accept='.jpg,.jpeg,.png',
            multiple=False,
            description='Upload Image'
        )
        
        # Confidence threshold slider
        self.confidence_slider = widgets.FloatSlider(
            value=self.config['model']['confidence_threshold'],
            min=0.0,
            max=1.0,
            step=0.05,
            description='Confidence:',
            continuous_update=False
        )
        
        # IOU threshold slider
        self.iou_slider = widgets.FloatSlider(
            value=self.config['model']['iou_threshold'],
            min=0.0,
            max=1.0,
            step=0.05,
            description='IOU:',
            continuous_update=False
        )
        
        # Inference button
        self.inference_button = widgets.Button(
            description='Run Inference',
            button_style='success',
            icon='play'
        )
        
        # Results display
        self.results_output = widgets.Output()
        
        # Connect callbacks
        self.inference_button.on_click(self.run_interactive_inference)
        
        # Display widgets
        display(widgets.VBox([
            widgets.HTML("<h3>Interactive Inference Demo</h3>"),
            self.upload_widget,
            self.confidence_slider,
            self.iou_slider,
            self.inference_button,
            self.results_output
        ]))
    
    def run_interactive_inference(self, b):
        """Run inference with current widget settings."""
        with self.results_output:
            clear_output(wait=True)
            
            # Get uploaded image
            if not self.upload_widget.value:
                print("Please upload an image first!")
                return
            
            # Get the uploaded file
            uploaded_file = list(self.upload_widget.value.values())[0]
            
            # Read image
            import io
            image_data = uploaded_file['content']
            image = Image.open(io.BytesIO(image_data))
            image_np = np.array(image)
            
            print(f"Image loaded: {image.size}")
            
            # Update thresholds
            self.inference_demo.engine.confidence_threshold = self.confidence_slider.value
            self.inference_demo.engine.iou_threshold = self.iou_slider.value
            
            # Run inference
            start_time = time.time()
            
            input_tensor = self.inference_demo.preprocessor(image_np)
            
            with torch.no_grad():
                if self.config['inference']['use_amp']:
                    with torch.cuda.amp.autocast():
                        outputs = self.inference_demo.engine.infer(input_tensor)
                else:
                    outputs = self.inference_demo.engine.infer(input_tensor)
            
            inference_time = (time.time() - start_time) * 1000
            
            # Postprocess with updated thresholds
            self.inference_demo.postprocessor.confidence_threshold = self.confidence_slider.value
            self.inference_demo.postprocessor.iou_threshold = self.iou_slider.value
            
            detections = self.inference_demo.postprocessor(outputs)
            
            # Visualize
            visualized = self.inference_demo.visualizer.visualize_detections(
                image_np,
                detections,
                show_confidence=True,
                show_class_names=True
            )
            
            # Display results
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            
            axes[0].imshow(image_np)
            axes[0].set_title('Original Image')
            axes[0].axis('off')
            
            axes[1].imshow(visualized)
            axes[1].set_title(f'Detections ({len(detections["boxes"])} objects)')
            axes[1].axis('off')
            
            plt.tight_layout()
            plt.show()
            
            # Print statistics
            print(f"\nInference Results:")
            print(f"  Inference time: {inference_time:.2f} ms")
            print(f"  FPS: {1000/inference_time:.1f}")
            print(f"  Detected objects: {len(detections['boxes'])}")
            print(f"  Confidence threshold: {self.confidence_slider.value}")
            print(f"  IOU threshold: {self.iou_slider.value}")
            
            if len(detections['boxes']) > 0:
                print(f"\nTop detections:")
                for i in range(min(3, len(detections['boxes']))):
                    class_name = detections['class_names'][i]
                    confidence = detections['scores'][i]
                    print(f"  {i+1}. {class_name}: {confidence:.3f}")

# %%
# Create interactive demo
try:
    interactive_demo = InteractiveDemo(inference_demo, config)
except Exception as e:
    print(f"Interactive demo setup failed: {e}")
    print("Continuing with other tests...")

# %% [markdown]
"""
## 8. Export Inference Analysis Report
"""

# %%
class InferenceAnalysisExporter:
    """Export comprehensive inference analysis report."""
    
    def __init__(self, config, benchmark_results, deployment_results):
        self.config = config
        self.benchmark_results = benchmark_results
        self.deployment_results = deployment_results
        
    def export_report(self):
        """Export inference analysis report."""
        print("\nExporting Inference Analysis Report...")
        
        # Create report data
        report = {
            'timestamp': pd.Timestamp.now().isoformat(),
            'config': self.config,
            'performance_summary': self.generate_performance_summary(),
            'benchmark_results': self.benchmark_results,
            'deployment_assessment': self.analyze_deployment_results(),
            'recommendations': self.generate_inference_recommendations(),
            'next_steps': self.generate_next_steps()
        }
        
        # Export as JSON
        import json
        with open('../reports/inference_analysis_report.json', 'w') as f:
            json.dump(report, f, indent=2)
        
        print("Inference analysis report exported to ../reports/inference_analysis_report.json")
        
        # Also export as HTML
        self.export_html_report(report)
    
    def generate_performance_summary(self):
        """Generate performance summary."""
        if 'latencies' not in locals():
            return "Performance benchmarks not run."
        
        summary = f"""
        INFERENCE PERFORMANCE SUMMARY
        
        Hardware:
        - Device: {self.config['inference']['device']}
        - Mixed precision: {self.config['inference']['use_amp']}
        
        Single Image Inference:
        - Mean latency: {np.mean(latencies):.2f} ms
        - Throughput: {1000/np.mean(latencies):.1f} FPS
        - Latency std: {np.std(latencies):.2f} ms
        
        Memory Usage:
        - Single image: {memory_results[0][1] if memory_results else 'N/A':.3f} GB
        - Batch efficiency: {np.mean([mem/bs for bs, mem in memory_results]) if memory_results else 'N/A':.3f} GB per image
        
        Optimal Configuration:
        - Batch size: {optimal['batch_size'] if 'optimal' in locals() else 'N/A'}
        - Max throughput: {optimal['throughput'] if 'optimal' in locals() else 'N/A':.1f} images/sec
        
        Real-time Capability: {'✅ READY' if np.mean(latencies) < 50 else '⚠️ NEEDS OPTIMIZATION'}
        """
        return summary
    
    def analyze_deployment_results(self):
        """Analyze deployment test results."""
        if not hasattr(self, 'deployment_results'):
            return "Deployment tests not run."
        
        # Count results
        status_counts = {'PASS': 0, 'WARNING': 0, 'FAIL': 0}
        for result in self.deployment_results:
            status_counts[result['status']] += 1
        
        assessment = {
            'total_tests': len(self.deployment_results),
            'passed': status_counts['PASS'],
            'warnings': status_counts['WARNING'],
            'failed': status_counts['FAIL'],
            'readiness_score': (status_counts['PASS'] + 0.5 * status_counts['WARNING']) / len(self.deployment_results),
            'overall_status': 'READY' if status_counts['FAIL'] == 0 else 'NOT READY'
        }
        
        return assessment
    
    def generate_inference_recommendations(self):
        """Generate inference optimization recommendations."""
        recommendations = []
        
        # Latency recommendations
        if 'latencies' in locals() and np.mean(latencies) > 50:
            recommendations.append({
                'category': 'Performance',
                'issue': 'High inference latency',
                'recommendation': 'Optimize model architecture or use TensorRT',
                'priority': 'HIGH'
            })
        
        # Memory recommendations
        if memory_results and memory_results[0][1] > 2.0:  # More than 2GB for single image
            recommendations.append({
                'category': 'Memory',
                'issue': 'High memory usage',
                'recommendation': 'Implement model pruning or quantization',
                'priority': 'HIGH'
            })
        
        # Batch size recommendations
        if 'optimal' in locals() and optimal['batch_size'] > 8:
            recommendations.append({
                'category': 'Throughput',
                'issue': 'Large optimal batch size',
                'recommendation': f'Use batch size {optimal["batch_size"]} for maximum throughput',
                'priority': 'MEDIUM'
            })
        
        # Deployment recommendations
        if hasattr(self, 'deployment_results'):
            for result in self.deployment_results:
                if result['status'] == 'FAIL':
                    recommendations.append({
                        'category': 'Deployment',
                        'issue': f'Failed test: {result["name"]}',
                        'recommendation': result['recommendations'][0] if result['recommendations'] else 'Fix the issue',
                        'priority': 'HIGH'
                    })
        
        return recommendations
    
    def generate_next_steps(self):
        """Generate next steps for deployment."""
        next_steps = [
            '1. Address HIGH priority recommendations',
            '2. Run integration tests on target hardware',
            '3. Optimize for specific deployment scenario',
            '4. Create production deployment pipeline',
            '5. Set up monitoring and logging',
            '6. Perform A/B testing in staging',
            '7. Deploy to production with gradual rollout'
        ]
        return next_steps
    
    def export_html_report(self, report):
        """Export HTML report."""
        html_content = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <title>Humanoid Vision System - Inference Analysis Report</title>
            <style>
                body {{ font-family: Arial, sans-serif; margin: 40px; line-height: 1.6; }}
                h1 {{ color: #2c3e50; border-bottom: 3px solid #3498db; }}
                h2 {{ color: #34495e; margin-top: 30px; }}
                h3 {{ color: #2c3e50; margin-top: 20px; }}
                .card {{ background: #f8f9fa; border-left: 4px solid #3498db; 
                        padding: 20px; margin: 20px 0; border-radius: 5px; }}
                .metric-card {{ display: inline-block; background: white; padding: 15px; 
                         margin: 10px; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); 
                         width: 200px; vertical-align: top; }}
                .pass {{ color: #27ae60; font-weight: bold; }}
                .warning {{ color: #f39c12; font-weight: bold; }}
                .fail {{ color: #e74c3c; font-weight: bold; }}
                table {{ width: 100%; border-collapse: collapse; margin: 20px 0; }}
                th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }}
                th {{ background-color: #3498db; color: white; }}
                .high {{ background: #ffeaa7; padding: 10px; border-radius: 5px; }}
                .medium {{ background: #a29bfe; padding: 10px; border-radius: 5px; }}
                .low {{ background: #55efc4; padding: 10px; border-radius: 5px; }}
            </style>
        </head>
        <body>
            <h1>Humanoid Vision System - Inference Analysis Report</h1>
            <p>Generated on: {report['timestamp']}</p>
            
            <div class="card">
                <h2>Executive Summary</h2>
                <pre>{report['performance_summary']}</pre>
            </div>
            
            <h2>Deployment Assessment</h2>
            <div class="card">
        """
        
        if isinstance(report['deployment_assessment'], dict):
            assessment = report['deployment_assessment']
            html_content += f"""
                <p><strong>Overall Status:</strong> <span class="{assessment['overall_status'].lower().replace(' ', '-')}">{assessment['overall_status']}</span></p>
                <p><strong>Readiness Score:</strong> {assessment['readiness_score']:.1%}</p>
                <p><strong>Tests Passed:</strong> {assessment['passed']}/{assessment['total_tests']}</p>
                <p><strong>Tests with Warnings:</strong> {assessment['warnings']}</p>
                <p><strong>Tests Failed:</strong> {assessment['failed']}</p>
            """
        else:
            html_content += f"<p>{report['deployment_assessment']}</p>"
        
        html_content += """
            </div>
            
            <h2>Optimization Recommendations</h2>
            <table>
                <tr>
                    <th>Category</th>
                    <th>Priority</th>
                    <th>Issue</th>
                    <th>Recommendation</th>
                </tr>
        """
        
        for rec in report['recommendations']:
            priority_class = rec['priority'].lower()
            html_content += f"""
                <tr class="{priority_class}">
                    <td>{rec['category']}</td>
                    <td>{rec['priority']}</td>
                    <td>{rec['issue']}</td>
                    <td>{rec['recommendation']}</td>
                </tr>
            """
        
        html_content += """
            </table>
            
            <h2>Next Steps for Deployment</h2>
            <div class="card">
                <ol>
        """
        
        for step in report['next_steps']:
            html_content += f"<li>{step}</li>"
        
        html_content += """
                </ol>
            </div>
            
            <div class="card">
                <h2>Key Metrics</h2>
                <div>
        """
        
        # Add key metrics cards
        if 'latencies' in locals():
            html_content += f"""
                <div class="metric-card">
                    <h3>Latency</h3>
                    <p>{np.mean(latencies):.1f} ms</p>
                    <p>±{np.std(latencies):.1f} ms</p>
                </div>
                <div class="metric-card">
                    <h3>Throughput</h3>
                    <p>{1000/np.mean(latencies):.1f} FPS</p>
                </div>
            """
        
        if memory_results:
            html_content += f"""
                <div class="metric-card">
                    <h3>Memory</h3>
                    <p>{memory_results[0][1]:.2f} GB</p>
                    <p>per image</p>
                </div>
            """
        
        if 'optimal' in locals():
            html_content += f"""
                <div class="metric-card">
                    <h3>Optimal Batch</h3>
                    <p>{optimal['batch_size']}</p>
                    <p>{optimal['throughput']:.0f} img/sec</p>
                </div>
            """
        
        html_content += """
                </div>
            </div>
            
            <div class="card">
                <h2>Conclusion</h2>
                <p>The inference pipeline is {'ready for deployment' if report.get('deployment_assessment', {}).get('overall_status') == 'READY' else 'not yet ready for deployment'}. 
                {'All critical tests have passed and performance meets requirements.' if report.get('deployment_assessment', {}).get('overall_status') == 'READY' else 'Critical issues need to be addressed before deployment.'}</p>
                <p>Proceed to deployment testing (05_deployment_test.ipynb) for the final validation.</p>
            </div>
        </body>
        </html>
        """
        
        with open('../reports/inference_analysis_report.html', 'w') as f:
            f.write(html_content)
        
        print("HTML report exported to ../reports/inference_analysis_report.html")

# %%
# Export inference analysis report
inference_exporter = InferenceAnalysisExporter(config, 
                                             benchmark_results=batch_results if 'batch_results' in locals() else None,
                                             deployment_results=test_results)
inference_exporter.export_report()

# %% [markdown]
"""
## 9. Conclusion
"""

# %%
print("\n" + "="*70)
print("INFERENCE DEMO AND ANALYSIS - COMPLETED")
print("="*70)

print("\n✅ DEMONSTRATIONS COMPLETED:")
print("  1. Model loading and validation")
print("  2. Single image inference")
print("  3. Batch inference")
print("  4. Performance benchmarking")
print("  5. Video processing")
print("  6. Deployment readiness tests")
print("  7. Interactive demo")

print("\n📊 KEY PERFORMANCE METRICS:")
if 'latencies' in locals():
    print(f"  • Inference latency: {np.mean(latencies):.1f} ms ({1000/np.mean(latencies):.1f} FPS)")
if memory_results:
    print(f"  • Memory usage: {memory_results[0][1]:.2f} GB per image")
if 'optimal' in locals():
    print(f"  • Optimal batch size: {optimal['batch_size']} ({optimal['throughput']:.0f} images/sec)")

print("\n🚀 DEPLOYMENT READINESS:")
if 'test_results' in locals() and test_results:
    failed = sum(1 for r in test_results if r['status'] == 'FAIL')
    warnings = sum(1 for r in test_results if r['status'] == 'WARNING')
    
    if failed == 0 and warnings == 0:
        print("  ✅ EXCELLENT - Ready for production deployment")
    elif failed == 0:
        print(f"  ⚠️ GOOD - Ready with {warnings} warnings")
    else:
        print(f"  ❌ NEEDS WORK - {failed} critical issues, {warnings} warnings")

print("\n🎯 NEXT STEPS:")
print("  1. Address any critical issues identified")
print("  2. Run deployment tests on target hardware")
print("  3. Create production deployment pipeline")
print("  4. Set up monitoring and alerting")
print("  5. Proceed to deployment testing (05_deployment_test.ipynb)")

print("\n" + "="*70)