In [None]:
import sys
import os
sys.path.append('../src')

# Core libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import yaml
from typing import Dict, List, Tuple, Any, Optional
import warnings
warnings.filterwarnings('ignore')

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast

# Training utilities
from src.training.mhc_trainer import ManifoldConstrainedTrainer
from src.training.loss_functions import YOLOLoss, DetectionLoss
from src.training.optimizer import ManifoldOptimizer
from src.training.scheduler import CosineAnnealingWarmRestartsWithDecay
from src.training.stability_monitor import TrainingStabilityMonitor

# Data
from src.data.dataset import COCOVisionDataset
from src.data.transforms import VisionTransforms

# Model
from src.models.hybrid_vision import HybridVisionSystem

# Visualization
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from IPython.display import display, HTML

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

# Configuration
config = {
    'training': {
        'batch_size': 16,
        'num_epochs': 100,
        'learning_rate': 1e-3,
        'weight_decay': 1e-4,
        'gradient_clip': 1.0,
        'use_amp': True,
        'accumulation_steps': 1,
        'warmup_epochs': 5,
        'min_lr': 1e-6
    },
    'model': {
        'input_channels': 3,
        'base_channels': 32,
        'num_classes': 80,
        'image_size': (416, 416),
        'use_vit': True,
        'use_rag': False
    },
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

# Create logs directory
os.makedirs('../logs/training', exist_ok=True)

# %% [markdown]
"""
## 2. Training Simulation Setup
"""

# %%
class TrainingSimulator:
    """Simulate training for analysis purposes."""
    
    def __init__(self, config):
        self.config = config
        self.device = torch.device(config['device'])
        
        # Create model
        self.model = self.create_model()
        
        # Create optimizer
        self.optimizer = self.create_optimizer()
        
        # Create scheduler
        self.scheduler = self.create_scheduler()
        
        # Create loss functions
        self.loss_functions = self.create_loss_functions()
        
        # Training monitor
        self.monitor = TrainingStabilityMonitor()
        
        # Training history
        self.history = {
            'losses': [],
            'gradients': [],
            'learning_rates': [],
            'stability_metrics': []
        }
    
    def create_model(self):
        """Create model for training simulation."""
        print("Creating model for training simulation...")
        
        model = HybridVisionSystem(
            config=self.config['model'],
            num_classes=self.config['model']['num_classes'],
            use_vit=self.config['model']['use_vit'],
            use_rag=self.config['model']['use_rag']
        ).to(self.device)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"Model created:")
        print(f"  Total parameters: {total_params:,}")
        print(f"  Trainable parameters: {trainable_params:,}")
        
        return model
    
    def create_optimizer(self):
        """Create optimizer with manifold constraints."""
        print("\nCreating ManifoldOptimizer...")
        
        optimizer = ManifoldOptimizer(
            params=self.model.parameters(),
            lr=self.config['training']['learning_rate'],
            weight_decay=self.config['training']['weight_decay'],
            betas=(0.9, 0.999),
            eps=1e-8
        )
        
        print(f"Optimizer configured:")
        print(f"  Learning rate: {self.config['training']['learning_rate']}")
        print(f"  Weight decay: {self.config['training']['weight_decay']}")
        print(f"  Betas: (0.9, 0.999)")
        
        return optimizer
    
    def create_scheduler(self):
        """Create learning rate scheduler."""
        print("\nCreating learning rate scheduler...")
        
        scheduler = CosineAnnealingWarmRestartsWithDecay(
            optimizer=self.optimizer,
            T_0=self.config['training']['warmup_epochs'],
            T_mult=2,
            eta_min=self.config['training']['min_lr'],
            decay_factor=0.1,
            decay_patience=10
        )
        
        print(f"Scheduler configured:")
        print(f"  Warmup epochs: {self.config['training']['warmup_epochs']}")
        print(f"  Minimum LR: {self.config['training']['min_lr']}")
        print(f"  T_mult: 2")
        
        return scheduler
    
    def create_loss_functions(self):
        """Create loss functions for different tasks."""
        print("\nCreating loss functions...")
        
        loss_functions = {
            'detection': YOLOLoss(
                num_classes=self.config['model']['num_classes'],
                anchors=[[10, 13], [16, 30], [33, 23]],
                image_size=self.config['model']['image_size']
            ).to(self.device),
            
            'classification': nn.CrossEntropyLoss(),
            
            'regularization': nn.MSELoss()
        }
        
        print("Loss functions created:")
        print("  - Detection: YOLO loss with 3 anchors")
        print("  - Classification: Cross-entropy")
        print("  - Regularization: MSE")
        
        return loss_functions
    
    def create_simulated_batch(self, batch_size=4):
        """Create simulated training batch."""
        H, W = self.config['model']['image_size']
        
        # Simulate images
        images = torch.randn(batch_size, 3, H, W).to(self.device)
        
        # Simulate detection targets
        det_targets = []
        for i in range(batch_size):
            # Random number of objects
            num_objects = torch.randint(1, 5, (1,)).item()
            
            # Random bboxes (normalized)
            bboxes = torch.rand(num_objects, 4)
            bboxes[:, 2:] = bboxes[:, 2:] * 0.3 + 0.1  # Reasonable sizes
            
            # Random classes
            classes = torch.randint(0, self.config['model']['num_classes'], (num_objects,))
            
            det_targets.append({
                'bboxes': bboxes,
                'classes': classes
            })
        
        # Simulate classification targets
        cls_targets = torch.randint(0, self.config['model']['num_classes'], (batch_size,))
        
        return images, det_targets, cls_targets
    
    def training_step(self, batch_size=4):
        """Simulate a single training step."""
        self.model.train()
        
        # Create simulated batch
        images, det_targets, cls_targets = self.create_simulated_batch(batch_size)
        
        # Mixed precision context
        with autocast(enabled=self.config['training']['use_amp']):
            # Forward pass
            outputs = self.model(images, task='detection')
            
            # Calculate losses
            det_loss = self.loss_functions['detection'](
                outputs['detections'], det_targets
            )
            
            # Optional classification loss
            if 'classifications' in outputs:
                cls_loss = self.loss_functions['classification'](
                    outputs['classifications'], cls_targets.to(self.device)
                )
                total_loss = det_loss['total'] + 0.1 * cls_loss
            else:
                total_loss = det_loss['total']
        
        # Backward pass
        self.optimizer.zero_grad()
        
        if self.config['training']['use_amp']:
            scaler = GradScaler()
            scaler.scale(total_loss).backward()
            
            # Gradient clipping
            scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                max_norm=self.config['training']['gradient_clip']
            )
            
            # Optimizer step
            scaler.step(self.optimizer)
            scaler.update()
        else:
            total_loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                max_norm=self.config['training']['gradient_clip']
            )
            
            # Optimizer step
            self.optimizer.step()
        
        # Collect metrics
        step_metrics = {
            'total_loss': total_loss.item(),
            'detection_loss': det_loss['total'].item() if isinstance(det_loss, dict) else det_loss.item(),
            'learning_rate': self.optimizer.param_groups[0]['lr']
        }
        
        # Add component losses if available
        if isinstance(det_loss, dict):
            for key, value in det_loss.items():
                if key != 'total':
                    step_metrics[f'det_{key}'] = value.item()
        
        # Collect gradient statistics
        grad_stats = self.collect_gradient_statistics()
        step_metrics.update(grad_stats)
        
        # Collect stability metrics
        stability_metrics = self.model.get_stability_metrics()
        step_metrics.update(stability_metrics)
        
        return step_metrics
    
    def collect_gradient_statistics(self):
        """Collect gradient statistics from model."""
        grad_stats = {
            'grad_norm_total': 0,
            'grad_norm_mean': 0,
            'grad_norm_max': 0,
            'grad_norm_min': float('inf'),
            'grad_mean': 0,
            'grad_std': 0
        }
        
        all_grad_norms = []
        all_grad_values = []
        
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                grad_values = param.grad.cpu().flatten().numpy()
                
                grad_stats['grad_norm_total'] += grad_norm
                all_grad_norms.append(grad_norm)
                all_grad_values.extend(grad_values.tolist())
        
        if all_grad_norms:
            grad_stats['grad_norm_mean'] = np.mean(all_grad_norms)
            grad_stats['grad_norm_max'] = np.max(all_grad_norms)
            grad_stats['grad_norm_min'] = np.min(all_grad_norms)
        
        if all_grad_values:
            grad_stats['grad_mean'] = np.mean(all_grad_values)
            grad_stats['grad_std'] = np.std(all_grad_values)
        
        return grad_stats
    
    def simulate_training(self, num_steps=1000, log_interval=100):
        """Simulate training for analysis."""
        print(f"\nSimulating training for {num_steps} steps...")
        print(f"Log interval: {log_interval} steps")
        print("-" * 60)
        
        for step in range(num_steps):
            # Training step
            step_metrics = self.training_step(batch_size=self.config['training']['batch_size'])
            
            # Update scheduler
            if self.scheduler is not None:
                self.scheduler.step()
            
            # Store in history
            self.history['losses'].append(step_metrics['total_loss'])
            self.history['learning_rates'].append(step_metrics['learning_rate'])
            
            # Store gradient stats
            grad_stats = {k: v for k, v in step_metrics.items() if 'grad' in k}
            self.history['gradients'].append(grad_stats)
            
            # Store stability metrics
            stability_keys = ['max_eigenvalue', 'min_eigenvalue', 'signal_ratio_mean']
            stability_metrics = {k: step_metrics.get(k, 0) for k in stability_keys}
            self.history['stability_metrics'].append(stability_metrics)
            
            # Log progress
            if (step + 1) % log_interval == 0:
                print(f"Step {step + 1}/{num_steps}:")
                print(f"  Loss: {step_metrics['total_loss']:.4f}")
                print(f"  LR: {step_metrics['learning_rate']:.6f}")
                print(f"  Grad norm: {step_metrics.get('grad_norm_total', 0):.4f}")
                
                if 'max_eigenvalue' in step_metrics:
                    print(f"  Max eigenvalue: {step_metrics['max_eigenvalue']:.4f}")
        
        print("\nTraining simulation completed!")
        
        # Convert history to numpy arrays for easier analysis
        self.process_history()
        
        return self.history
    
    def process_history(self):
        """Process training history for analysis."""
        # Convert lists to numpy arrays
        self.history['losses_array'] = np.array(self.history['losses'])
        self.history['learning_rates_array'] = np.array(self.history['learning_rates'])
        
        # Process gradient statistics
        grad_keys = ['grad_norm_total', 'grad_norm_mean', 'grad_norm_max', 
                    'grad_norm_min', 'grad_mean', 'grad_std']
        
        for key in grad_keys:
            values = [grad[key] for grad in self.history['gradients']]
            self.history[f'{key}_array'] = np.array(values)
        
        # Process stability metrics
        stability_keys = ['max_eigenvalue', 'min_eigenvalue', 'signal_ratio_mean']
        for key in stability_keys:
            values = [metrics.get(key, 0) for metrics in self.history['stability_metrics']]
            self.history[f'{key}_array'] = np.array(values)

# %%
# Create training simulator
simulator = TrainingSimulator(config)

# Simulate training
history = simulator.simulate_training(num_steps=500, log_interval=50)

# %% [markdown]
"""
## 3. Loss Function Analysis
"""

# %%
class LossAnalyzer:
    """Analyze loss functions and convergence."""
    
    def __init__(self, history):
        self.history = history
        
    def analyze_loss_convergence(self):
        """Analyze loss convergence patterns."""
        print("\nAnalyzing Loss Convergence:")
        
        losses = self.history['losses_array']
        
        if len(losses) == 0:
            print("No loss data available.")
            return
        
        # Calculate convergence metrics
        initial_loss = losses[0] if len(losses) > 0 else 0
        final_loss = losses[-1] if len(losses) > 0 else 0
        min_loss = np.min(losses) if len(losses) > 0 else 0
        
        loss_reduction = ((initial_loss - final_loss) / initial_loss * 100) if initial_loss > 0 else 0
        
        print(f"  Initial loss: {initial_loss:.4f}")
        print(f"  Final loss: {final_loss:.4f}")
        print(f"  Minimum loss: {min_loss:.4f}")
        print(f"  Loss reduction: {loss_reduction:.1f}%")
        
        # Analyze convergence rate
        window_size = min(50, len(losses) // 10)
        if window_size > 1:
            smoothed = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')
            
            if len(smoothed) > 1:
                convergence_rate = abs(smoothed[-1] - smoothed[0]) / len(smoothed)
                print(f"  Convergence rate: {convergence_rate:.6f} per step")
        
        # Check for convergence issues
        self.detect_convergence_issues(losses)
        
        # Visualize loss convergence
        self.visualize_loss_convergence(losses)
    
    def detect_convergence_issues(self, losses):
        """Detect potential convergence issues."""
        print("\nConvergence Health Check:")
        
        # Check for NaN or Inf
        if np.any(np.isnan(losses)) or np.any(np.isinf(losses)):
            print("  ‚ùå NaN or Inf detected in losses")
        else:
            print("  ‚úÖ No NaN/Inf values")
        
        # Check for explosion
        if np.max(losses) > 1000:
            print("  ‚ö†Ô∏è Loss explosion detected")
        else:
            print("  ‚úÖ Loss values stable")
        
        # Check for oscillation
        if len(losses) > 10:
            diff = np.diff(losses)
            oscillation_score = np.std(diff) / (np.mean(np.abs(diff)) + 1e-8)
            
            if oscillation_score > 2.0:
                print(f"  ‚ö†Ô∏è High oscillation detected (score: {oscillation_score:.2f})")
            else:
                print(f"  ‚úÖ Stable convergence (oscillation score: {oscillation_score:.2f})")
        
        # Check for plateau
        if len(losses) > 100:
            last_100 = losses[-100:]
            plateau_score = np.std(last_100) / (np.mean(last_100) + 1e-8)
            
            if plateau_score < 0.01:
                print(f"  ‚ö†Ô∏è Possible plateau (variance: {plateau_score:.4f})")
            else:
                print(f"  ‚úÖ Active learning (variance: {plateau_score:.4f})")
    
    def visualize_loss_convergence(self, losses):
        """Visualize loss convergence."""
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
        # Raw loss curve
        axes[0, 0].plot(losses, linewidth=2, alpha=0.7)
        axes[0, 0].set_xlabel('Training Step')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Training Loss Curve')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Add moving average
        window_size = min(50, len(losses) // 10)
        if window_size > 1:
            moving_avg = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')
            x_avg = np.arange(window_size-1, len(losses))
            axes[0, 0].plot(x_avg, moving_avg, 'r--', linewidth=2, label=f'MA({window_size})')
            axes[0, 0].legend()
        
        # Log scale loss
        axes[0, 1].semilogy(losses, linewidth=2, alpha=0.7)
        axes[0, 1].set_xlabel('Training Step')
        axes[0, 1].set_ylabel('Loss (log scale)')
        axes[0, 1].set_title('Log-Scale Loss Curve')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Loss distribution
        axes[1, 0].hist(losses, bins=50, alpha=0.7, edgecolor='black')
        axes[1, 0].axvline(np.mean(losses), color='red', linestyle='--', 
                          label=f'Mean: {np.mean(losses):.4f}')
        axes[1, 0].axvline(np.median(losses), color='green', linestyle='--',
                          label=f'Median: {np.median(losses):.4f}')
        axes[1, 0].set_xlabel('Loss Value')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('Loss Value Distribution')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Loss difference (gradient of loss)
        if len(losses) > 1:
            loss_diff = np.diff(losses)
            axes[1, 1].plot(loss_diff, alpha=0.7)
            axes[1, 1].axhline(y=0, color='r', linestyle='-', alpha=0.3)
            axes[1, 1].set_xlabel('Training Step')
            axes[1, 1].set_ylabel('Œî Loss')
            axes[1, 1].set_title('Loss Changes (Gradient)')
            axes[1, 1].grid(True, alpha=0.3)
            
            # Add statistics
            pos_changes = np.sum(loss_diff > 0)
            neg_changes = np.sum(loss_diff < 0)
            axes[1, 1].text(0.05, 0.95, 
                           f'‚Üë Increases: {pos_changes}\n‚Üì Decreases: {neg_changes}',
                           transform=axes[1, 1].transAxes,
                           verticalalignment='top',
                           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.tight_layout()
        plt.show()
        
        # Create interactive visualization
        self.create_interactive_loss_viz(losses)
    
    def create_interactive_loss_viz(self, losses):
        """Create interactive loss visualization."""
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Training Loss Curve', 'Log-Scale Loss',
                           'Loss Distribution', 'Loss Changes'),
            vertical_spacing=0.1
        )
        
        # Loss curve
        fig.add_trace(
            go.Scatter(y=losses, mode='lines', name='Loss',
                      line=dict(width=2, color='blue')),
            row=1, col=1
        )
        
        # Add moving average
        window_size = min(50, len(losses) // 10)
        if window_size > 1:
            moving_avg = np.convolve(losses, np.ones(window_size)/window_size, mode='valid')
            x_avg = np.arange(window_size-1, len(losses))
            
            fig.add_trace(
                go.Scatter(x=x_avg, y=moving_avg, mode='lines',
                          name=f'MA({window_size})', line=dict(dash='dash', color='red')),
                row=1, col=1
            )
        
        # Log scale loss
        fig.add_trace(
            go.Scatter(y=losses, mode='lines', name='Loss (log)',
                      line=dict(width=2, color='green')),
            row=1, col=2
        )
        
        # Update y-axis to log scale
        fig.update_yaxes(type="log", row=1, col=2)
        
        # Loss distribution
        fig.add_trace(
            go.Histogram(x=losses, nbinsx=50, name='Loss Distribution',
                        marker_color='purple'),
            row=2, col=1
        )
        
        # Add mean and median lines
        fig.add_vline(x=np.mean(losses), line_dash="dash", line_color="red",
                     annotation_text=f"Mean: {np.mean(losses):.4f}", row=2, col=1)
        fig.add_vline(x=np.median(losses), line_dash="dash", line_color="green",
                     annotation_text=f"Median: {np.median(losses):.4f}", row=2, col=1)
        
        # Loss changes
        if len(losses) > 1:
            loss_diff = np.diff(losses)
            
            fig.add_trace(
                go.Scatter(y=loss_diff, mode='lines', name='Œî Loss',
                          line=dict(width=1, color='orange')),
                row=2, col=2
            )
            
            # Add zero line
            fig.add_hline(y=0, line_dash="dash", line_color="red", row=2, col=2)
        
        fig.update_layout(
            height=800,
            width=1200,
            title_text="Loss Convergence Analysis",
            showlegend=True
        )
        
        fig.show()
    
    def analyze_loss_components(self, history):
        """Analyze individual loss components."""
        print("\nAnalyzing Loss Components:")
        
        # Extract component losses from history
        component_keys = [key for key in history.keys() if 'det_' in key]
        
        if not component_keys:
            print("  No component loss data available.")
            return
        
        fig, axes = plt.subplots(1, len(component_keys), figsize=(4*len(component_keys), 5))
        
        if len(component_keys) == 1:
            axes = [axes]
        
        for idx, key in enumerate(component_keys):
            # Extract component values
            values = []
            for metrics in self.history['stability_metrics']:
                if key in metrics:
                    values.append(metrics[key])
                elif f'{key}_array' in self.history:
                    values = self.history[f'{key}_array']
                    break
            
            if values:
                values = np.array(values[:len(self.history['losses'])])
                
                axes[idx].plot(values, alpha=0.7)
                axes[idx].set_xlabel('Training Step')
                axes[idx].set_ylabel('Loss Value')
                axes[idx].set_title(f'{key.replace("det_", "").title()} Loss')
                axes[idx].grid(True, alpha=0.3)
                
                # Add statistics
                mean_val = np.mean(values)
                axes[idx].axhline(y=mean_val, color='r', linestyle='--',
                                 label=f'Mean: {mean_val:.4f}')
                axes[idx].legend()
        
        plt.tight_layout()
        plt.show()
        
        # Analyze component contributions
        self.analyze_component_contributions(component_keys)

# %%
# Analyze loss convergence
loss_analyzer = LossAnalyzer(history)
loss_analyzer.analyze_loss_convergence()

# Analyze loss components
loss_analyzer.analyze_loss_components(history)

# %% [markdown]
"""
## 4. Gradient Flow Analysis
"""

# %%
class GradientAnalyzer:
    """Analyze gradient flow during training."""
    
    def __init__(self, history):
        self.history = history
        
    def analyze_gradient_flow(self):
        """Analyze gradient flow statistics."""
        print("\nAnalyzing Gradient Flow:")
        
        # Check if gradient data exists
        if 'grad_norm_total_array' not in self.history:
            print("  No gradient data available.")
            return
        
        grad_norms = self.history['grad_norm_total_array']
        
        if len(grad_norms) == 0:
            print("  Empty gradient data.")
            return
        
        # Calculate statistics
        mean_grad = np.mean(grad_norms)
        std_grad = np.std(grad_norms)
        max_grad = np.max(grad_norms)
        min_grad = np.min(grad_norms)
        
        print(f"  Mean gradient norm: {mean_grad:.4f}")
        print(f"  Std gradient norm: {std_grad:.4f}")
        print(f"  Max gradient norm: {max_grad:.4f}")
        print(f"  Min gradient norm: {min_grad:.4f}")
        
        # Analyze gradient health
        self.analyze_gradient_health(grad_norms)
        
        # Visualize gradient flow
        self.visualize_gradient_flow(grad_norms)
        
        # Analyze gradient distribution
        if 'grad_mean_array' in self.history and 'grad_std_array' in self.history:
            self.analyze_gradient_distribution()
    
    def analyze_gradient_health(self, grad_norms):
        """Analyze gradient health indicators."""
        print("\nGradient Health Check:")
        
        # Check for gradient explosion
        explosion_threshold = 1000
        if np.any(grad_norms > explosion_threshold):
            explosion_count = np.sum(grad_norms > explosion_threshold)
            print(f"  ‚ùå Gradient explosion detected: {explosion_count} steps > {explosion_threshold}")
        else:
            print(f"  ‚úÖ No gradient explosion (all < {explosion_threshold})")
        
        # Check for gradient vanishing
        vanishing_threshold = 1e-6
        if np.any(grad_norms < vanishing_threshold):
            vanishing_count = np.sum(grad_norms < vanishing_threshold)
            print(f"  ‚ùå Gradient vanishing detected: {vanishing_count} steps < {vanishing_threshold}")
        else:
            print(f"  ‚úÖ No gradient vanishing (all > {vanishing_threshold})")
        
        # Check gradient stability
        grad_diff = np.diff(grad_norms)
        grad_instability = np.std(grad_diff) / (np.mean(np.abs(grad_diff)) + 1e-8)
        
        if grad_instability > 5.0:
            print(f"  ‚ö†Ô∏è Unstable gradients (instability score: {grad_instability:.2f})")
        elif grad_instability > 2.0:
            print(f"  ‚ö†Ô∏è Moderately unstable gradients (score: {grad_instability:.2f})")
        else:
            print(f"  ‚úÖ Stable gradients (instability score: {grad_instability:.2f})")
        
        # Check gradient mean (should be close to 0)
        if 'grad_mean_array' in self.history:
            grad_means = self.history['grad_mean_array']
            mean_of_means = np.mean(np.abs(grad_means))
            
            if mean_of_means > 0.1:
                print(f"  ‚ö†Ô∏è Gradient meanÂÅèÈ´ò (absolute mean: {mean_of_means:.4f})")
            else:
                print(f"  ‚úÖ Gradient meanÊé•ËøëÈõ∂ (absolute mean: {mean_of_means:.4f})")
    
    def visualize_gradient_flow(self, grad_norms):
        """Visualize gradient flow over training."""
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
        # Gradient norms over time
        axes[0, 0].plot(grad_norms, alpha=0.7)
        axes[0, 0].set_xlabel('Training Step')
        axes[0, 0].set_ylabel('Gradient Norm')
        axes[0, 0].set_title('Gradient Norms Over Time')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Add mean line
        mean_grad = np.mean(grad_norms)
        axes[0, 0].axhline(y=mean_grad, color='r', linestyle='--',
                          label=f'Mean: {mean_grad:.4f}')
        axes[0, 0].legend()
        
        # Log scale gradient norms
        axes[0, 1].semilogy(grad_norms, alpha=0.7)
        axes[0, 1].set_xlabel('Training Step')
        axes[0, 1].set_ylabel('Gradient Norm (log scale)')
        axes[0, 1].set_title('Log-Scale Gradient Norms')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Gradient norm distribution
        axes[1, 0].hist(grad_norms, bins=50, alpha=0.7, edgecolor='black')
        axes[1, 0].axvline(mean_grad, color='r', linestyle='--',
                          label=f'Mean: {mean_grad:.4f}')
        axes[1, 0].axvline(np.median(grad_norms), color='g', linestyle='--',
                          label=f'Median: {np.median(grad_norms):.4f}')
        axes[1, 0].set_xlabel('Gradient Norm')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('Gradient Norm Distribution')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # Gradient changes
        if len(grad_norms) > 1:
            grad_changes = np.diff(grad_norms)
            axes[1, 1].plot(grad_changes, alpha=0.7)
            axes[1, 1].axhline(y=0, color='r', linestyle='-', alpha=0.3)
            axes[1, 1].set_xlabel('Training Step')
            axes[1, 1].set_ylabel('Œî Gradient Norm')
            axes[1, 1].set_title('Gradient Norm Changes')
            axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Create interactive visualization
        self.create_interactive_gradient_viz(grad_norms)
    
    def create_interactive_gradient_viz(self, grad_norms):
        """Create interactive gradient visualization."""
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Gradient Norms', 'Log-Scale Gradient Norms',
                           'Gradient Distribution', 'Gradient Changes'),
            vertical_spacing=0.1
        )
        
        # Gradient norms
        fig.add_trace(
            go.Scatter(y=grad_norms, mode='lines', name='Gradient Norm',
                      line=dict(width=2, color='blue')),
            row=1, col=1
        )
        
        # Add mean line
        mean_grad = np.mean(grad_norms)
        fig.add_hline(y=mean_grad, line_dash="dash", line_color="red",
                     annotation_text=f"Mean: {mean_grad:.4f}", row=1, col=1)
        
        # Log scale
        fig.add_trace(
            go.Scatter(y=grad_norms, mode='lines', name='Gradient Norm (log)',
                      line=dict(width=2, color='green')),
            row=1, col=2
        )
        fig.update_yaxes(type="log", row=1, col=2)
        
        # Distribution
        fig.add_trace(
            go.Histogram(x=grad_norms, nbinsx=50, name='Distribution',
                        marker_color='purple'),
            row=2, col=1
        )
        
        # Add mean and median
        fig.add_vline(x=mean_grad, line_dash="dash", line_color="red",
                     annotation_text=f"Mean", row=2, col=1)
        fig.add_vline(x=np.median(grad_norms), line_dash="dash", line_color="green",
                     annotation_text=f"Median", row=2, col=1)
        
        # Changes
        if len(grad_norms) > 1:
            grad_changes = np.diff(grad_norms)
            
            fig.add_trace(
                go.Scatter(y=grad_changes, mode='lines', name='Œî Gradient',
                          line=dict(width=1, color='orange')),
                row=2, col=2
            )
            
            fig.add_hline(y=0, line_dash="dash", line_color="red", row=2, col=2)
        
        fig.update_layout(
            height=800,
            width=1200,
            title_text="Gradient Flow Analysis",
            showlegend=True
        )
        
        fig.show()
    
    def analyze_gradient_distribution(self):
        """Analyze gradient value distribution."""
        print("\nAnalyzing Gradient Distribution:")
        
        if 'grad_mean_array' not in self.history or 'grad_std_array' not in self.history:
            print("  No gradient distribution data available.")
            return
        
        grad_means = self.history['grad_mean_array']
        grad_stds = self.history['grad_std_array']
        
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Gradient means over time
        axes[0].plot(grad_means, alpha=0.7)
        axes[0].axhline(y=0, color='r', linestyle='--', alpha=0.5)
        axes[0].set_xlabel('Training Step')
        axes[0].set_ylabel('Mean Gradient Value')
        axes[0].set_title('Gradient Means Over Time')
        axes[0].grid(True, alpha=0.3)
        
        # Gradient standard deviations
        axes[1].plot(grad_stds, alpha=0.7)
        axes[1].set_xlabel('Training Step')
        axes[1].set_ylabel('Gradient Std Dev')
        axes[1].set_title('Gradient Standard Deviations')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print statistics
        print(f"  Mean of gradient means: {np.mean(grad_means):.6f}")
        print(f"  Std of gradient means: {np.std(grad_means):.6f}")
        print(f"  Mean gradient std: {np.mean(grad_stds):.6f}")
        
        # Check for bias
        mean_bias = np.mean(np.abs(grad_means))
        if mean_bias > 0.01:
            print(f"  ‚ö†Ô∏è Significant gradient bias detected: {mean_bias:.6f}")
        else:
            print(f"  ‚úÖ Minimal gradient bias: {mean_bias:.6f}")

# %%
# Analyze gradient flow
gradient_analyzer = GradientAnalyzer(history)
gradient_analyzer.analyze_gradient_flow()

# %% [markdown]
"""
## 5. Learning Rate Analysis
"""

# %%
class LearningRateAnalyzer:
    """Analyze learning rate scheduling and optimization."""
    
    def __init__(self, history):
        self.history = history
        
    def analyze_learning_rate(self):
        """Analyze learning rate schedule and effects."""
        print("\nAnalyzing Learning Rate Schedule:")
        
        if 'learning_rates_array' not in self.history:
            print("  No learning rate data available.")
            return
        
        lrs = self.history['learning_rates_array']
        
        if len(lrs) == 0:
            print("  Empty learning rate data.")
            return
        
        # Calculate statistics
        initial_lr = lrs[0]
        final_lr = lrs[-1]
        min_lr = np.min(lrs)
        max_lr = np.max(lrs)
        
        print(f"  Initial LR: {initial_lr:.6f}")
        print(f"  Final LR: {final_lr:.6f}")
        print(f"  Minimum LR: {min_lr:.6f}")
        print(f"  Maximum LR: {max_lr:.6f}")
        print(f"  LR reduction: {(1 - final_lr/initial_lr)*100:.1f}%")
        
        # Analyze schedule effectiveness
        self.analyze_schedule_effectiveness(lrs)
        
        # Visualize learning rate schedule
        self.visualize_learning_rate(lrs)
        
        # Analyze LR vs Loss correlation
        if 'losses_array' in self.history:
            self.analyze_lr_loss_correlation(lrs, self.history['losses_array'])
    
    def analyze_schedule_effectiveness(self, lrs):
        """Analyze learning rate schedule effectiveness."""
        print("\nSchedule Effectiveness Analysis:")
        
        # Calculate LR changes
        lr_changes = np.diff(lrs)
        
        # Count increases and decreases
        increases = np.sum(lr_changes > 0)
        decreases = np.sum(lr_changes < 0)
        constant = np.sum(lr_changes == 0)
        
        print(f"  LR increases: {increases}")
        print(f"  LR decreases: {decreases}")
        print(f"  Constant LR steps: {constant}")
        
        # Check schedule pattern
        if decreases > increases * 5:  # Mostly decreasing
            print("  ‚úÖ Schedule: Mostly decreasing (good for convergence)")
        elif increases > decreases * 5:  # Mostly increasing
            print("  ‚ö†Ô∏è Schedule: Mostly increasing (unusual)")
        else:  # Mixed
            print("  ‚ö†Ô∏è Schedule: Mixed increases/decreases (cyclic schedule)")
        
        # Check for proper decay
        if lrs[-1] < lrs[0] * 0.1:  # Decayed by at least 90%
            print("  ‚úÖ Strong decay achieved")
        elif lrs[-1] < lrs[0] * 0.5:  # Decayed by at least 50%
            print("  ‚úÖ Moderate decay achieved")
        else:
            print("  ‚ö†Ô∏è Limited decay achieved")
    
    def visualize_learning_rate(self, lrs):
        """Visualize learning rate schedule."""
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
        
        # LR over time
        axes[0, 0].plot(lrs, linewidth=2, color='darkblue')
        axes[0, 0].set_xlabel('Training Step')
        axes[0, 0].set_ylabel('Learning Rate')
        axes[0, 0].set_title('Learning Rate Schedule')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Log scale LR
        axes[0, 1].semilogy(lrs, linewidth=2, color='darkgreen')
        axes[0, 1].set_xlabel('Training Step')
        axes[0, 1].set_ylabel('Learning Rate (log scale)')
        axes[0, 1].set_title('Log-Scale Learning Rate')
        axes[0, 1].grid(True, alpha=0.3)
        
        # LR distribution
        axes[1, 0].hist(lrs, bins=30, alpha=0.7, edgecolor='black', color='purple')
        axes[1, 0].axvline(np.mean(lrs), color='red', linestyle='--',
                          label=f'Mean: {np.mean(lrs):.6f}')
        axes[1, 0].axvline(np.median(lrs), color='green', linestyle='--',
                          label=f'Median: {np.median(lrs):.6f}')
        axes[1, 0].set_xlabel('Learning Rate')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].set_title('Learning Rate Distribution')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
        
        # LR changes
        if len(lrs) > 1:
            lr_changes = np.diff(lrs)
            axes[1, 1].plot(lr_changes, alpha=0.7, color='orange')
            axes[1, 1].axhline(y=0, color='r', linestyle='-', alpha=0.3)
            axes[1, 1].set_xlabel('Training Step')
            axes[1, 1].set_ylabel('Œî Learning Rate')
            axes[1, 1].set_title('Learning Rate Changes')
            axes[1, 1].grid(True, alpha=0.3)
            
            # Add statistics
            pos_changes = np.sum(lr_changes > 0)
            neg_changes = np.sum(lr_changes < 0)
            axes[1, 1].text(0.05, 0.95, 
                           f'‚Üë Increases: {pos_changes}\n‚Üì Decreases: {neg_changes}',
                           transform=axes[1, 1].transAxes,
                           verticalalignment='top',
                           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.tight_layout()
        plt.show()
        
        # Create interactive visualization
        self.create_interactive_lr_viz(lrs)
    
    def create_interactive_lr_viz(self, lrs):
        """Create interactive learning rate visualization."""
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Learning Rate Schedule', 'Log-Scale Learning Rate',
                           'Learning Rate Distribution', 'Learning Rate Changes'),
            vertical_spacing=0.1
        )
        
        # LR schedule
        fig.add_trace(
            go.Scatter(y=lrs, mode='lines', name='Learning Rate',
                      line=dict(width=3, color='darkblue')),
            row=1, col=1
        )
        
        # Log scale
        fig.add_trace(
            go.Scatter(y=lrs, mode='lines', name='LR (log)',
                      line=dict(width=3, color='darkgreen')),
            row=1, col=2
        )
        fig.update_yaxes(type="log", row=1, col=2)
        
        # Distribution
        fig.add_trace(
            go.Histogram(x=lrs, nbinsx=30, name='Distribution',
                        marker_color='purple'),
            row=2, col=1
        )
        
        # Add mean and median
        fig.add_vline(x=np.mean(lrs), line_dash="dash", line_color="red",
                     annotation_text=f"Mean: {np.mean(lrs):.6f}", row=2, col=1)
        fig.add_vline(x=np.median(lrs), line_dash="dash", line_color="green",
                     annotation_text=f"Median: {np.median(lrs):.6f}", row=2, col=1)
        
        # Changes
        if len(lrs) > 1:
            lr_changes = np.diff(lrs)
            
            fig.add_trace(
                go.Scatter(y=lr_changes, mode='lines', name='Œî LR',
                          line=dict(width=2, color='orange')),
                row=2, col=2
            )
            
            fig.add_hline(y=0, line_dash="dash", line_color="red", row=2, col=2)
            
            # Add statistics annotation
            pos_changes = np.sum(lr_changes > 0)
            neg_changes = np.sum(lr_changes < 0)
            
            fig.add_annotation(
                x=0.05, y=0.95,
                xref="paper", yref="paper",
                text=f"Increases: {pos_changes}<br>Decreases: {neg_changes}",
                showarrow=False,
                font=dict(size=12),
                align="left",
                bgcolor="white",
                bordercolor="black",
                borderwidth=1,
                row=2, col=2
            )
        
        fig.update_layout(
            height=800,
            width=1200,
            title_text="Learning Rate Analysis",
            showlegend=True
        )
        
        fig.show()
    
    def analyze_lr_loss_correlation(self, lrs, losses):
        """Analyze correlation between learning rate and loss."""
        print("\nAnalyzing LR-Loss Correlation:")
        
        if len(lrs) != len(losses):
            print(f"  Mismatched data lengths: LR={len(lrs)}, Loss={len(losses)}")
            # Align lengths
            min_len = min(len(lrs), len(losses))
            lrs = lrs[:min_len]
            losses = losses[:min_len]
        
        # Calculate correlation
        correlation = np.corrcoef(lrs, losses)[0, 1]
        
        print(f"  Correlation coefficient: {correlation:.4f}")
        
        if correlation > 0.3:
            print("  ‚ö†Ô∏è Positive correlation: Higher LR ‚Üí Higher loss")
        elif correlation < -0.3:
            print("  ‚úÖ Negative correlation: Higher LR ‚Üí Lower loss (good)")
        else:
            print("  ‚ö†Ô∏è Weak correlation: LR changes not strongly affecting loss")
        
        # Create scatter plot
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Scatter plot
        scatter = axes[0].scatter(lrs, losses, alpha=0.6, c=range(len(lrs)), cmap='viridis')
        axes[0].set_xlabel('Learning Rate')
        axes[0].set_ylabel('Loss')
        axes[0].set_title(f'LR vs Loss (Correlation: {correlation:.3f})')
        axes[0].grid(True, alpha=0.3)
        plt.colorbar(scatter, ax=axes[0], label='Training Step')
        
        # Rolling correlation
        window_size = min(50, len(lrs) // 4)
        if window_size > 1:
            rolling_corr = []
            for i in range(len(lrs) - window_size + 1):
                lr_window = lrs[i:i+window_size]
                loss_window = losses[i:i+window_size]
                corr = np.corrcoef(lr_window, loss_window)[0, 1]
                rolling_corr.append(corr)
            
            axes[1].plot(range(window_size-1, len(lrs)), rolling_corr)
            axes[1].axhline(y=0, color='r', linestyle='--', alpha=0.5)
            axes[1].set_xlabel('Training Step')
            axes[1].set_ylabel('Correlation')
            axes[1].set_title(f'Rolling Correlation (window={window_size})')
            axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

# %%
# Analyze learning rate
lr_analyzer = LearningRateAnalyzer(history)
lr_analyzer.analyze_learning_rate()

# %% [markdown]
"""
## 6. Stability Analysis
"""

# %%
class StabilityAnalyzer:
    """Analyze training stability metrics."""
    
    def __init__(self, history):
        self.history = history
        
    def analyze_stability(self):
        """Analyze training stability."""
        print("\nAnalyzing Training Stability:")
        
        # Check for key stability metrics
        stability_metrics = ['max_eigenvalue_array', 'min_eigenvalue_array', 'signal_ratio_mean_array']
        available_metrics = [m for m in stability_metrics if m in self.history]
        
        if not available_metrics:
            print("  No stability metrics available.")
            return
        
        print(f"  Available stability metrics: {', '.join(available_metrics)}")
        
        # Analyze each metric
        for metric_name in available_metrics:
            self.analyze_stability_metric(metric_name)
        
        # Analyze overall stability
        self.analyze_overall_stability(available_metrics)
    
    def analyze_stability_metric(self, metric_name):
        """Analyze a specific stability metric."""
        print(f"\nAnalyzing {metric_name}:")
        
        values = self.history[metric_name]
        
        if len(values) == 0:
            print("  Empty data.")
            return
        
        # Calculate statistics
        mean_val = np.mean(values)
        std_val = np.std(values)
        min_val = np.min(values)
        max_val = np.max(values)
        
        print(f"  Mean: {mean_val:.6f}")
        print(f"  Std: {std_val:.6f}")
        print(f"  Range: [{min_val:.6f}, {max_val:.6f}]")
        
        # Metric-specific analysis
        if 'eigenvalue' in metric_name:
            if 'max' in metric_name:
                if max_val > 1.0:
                    print(f"  ‚ùå Max eigenvalue > 1.0: {max_val:.6f}")
                elif max_val > 0.95:
                    print(f"  ‚ö†Ô∏è Max eigenvalueÊé•Ëøë1.0: {max_val:.6f}")
                else:
                    print(f"  ‚úÖ Max eigenvalue < 1.0: {max_val:.6f}")
            
            if 'min' in metric_name:
                if min_val < 0:
                    print(f"  ‚ö†Ô∏è Min eigenvalue < 0: {min_val:.6f}")
                else:
                    print(f"  ‚úÖ Min eigenvalue ‚â• 0: {min_val:.6f}")
        
        elif 'signal_ratio' in metric_name:
            if abs(mean_val - 1.0) > 0.1:
                print(f"  ‚ö†Ô∏è Signal ratioÂÅèÁ¶ª1.0: {mean_val:.4f}")
            else:
                print(f"  ‚úÖ Signal ratioÊé•Ëøë1.0: {mean_val:.4f}")
            
            if std_val > 0.1:
                print(f"  ‚ö†Ô∏è High variance in signal ratio: {std_val:.4f}")
            else:
                print(f"  ‚úÖ Stable signal ratio: {std_val:.4f}")
    
    def analyze_overall_stability(self, available_metrics):
        """Analyze overall training stability."""
        print("\nOverall Stability Assessment:")
        
        stability_score = 0
        max_score = len(available_metrics)
        
        # Check each metric
        checks = []
        
        if 'max_eigenvalue_array' in available_metrics:
            max_eigenvalues = self.history['max_eigenvalue_array']
            if np.all(max_eigenvalues <= 1.0):
                stability_score += 1
                checks.append("‚úÖ Max eigenvalue ‚â§ 1.0")
            else:
                checks.append("‚ùå Max eigenvalue > 1.0")
        
        if 'min_eigenvalue_array' in available_metrics:
            min_eigenvalues = self.history['min_eigenvalue_array']
            if np.all(min_eigenvalues >= 0):
                stability_score += 1
                checks.append("‚úÖ Min eigenvalue ‚â• 0")
            else:
                checks.append("‚ùå Min eigenvalue < 0")
        
        if 'signal_ratio_mean_array' in available_metrics:
            signal_ratios = self.history['signal_ratio_mean_array']
            if np.all(np.abs(signal_ratios - 1.0) < 0.2):
                stability_score += 1
                checks.append("‚úÖ Signal ratio stable (~1.0)")
            else:
                checks.append("‚ùå Signal ratio unstable")
        
        # Calculate overall score
        if max_score > 0:
            overall_score = stability_score / max_score
            print(f"  Stability score: {stability_score}/{max_score} ({overall_score*100:.1f}%)")
            
            if overall_score >= 0.8:
                print("  üéâ EXCELLENT stability")
            elif overall_score >= 0.6:
                print("  üëç GOOD stability")
            elif overall_score >= 0.4:
                print("  ‚ö†Ô∏è MODERATE stability issues")
            else:
                print("  ‚ùå POOR stability - needs attention")
        else:
            print("  No stability metrics available for scoring.")
        
        # Print individual checks
        print("\nStability Checks:")
        for check in checks:
            print(f"  {check}")
    
    def visualize_stability_metrics(self):
        """Visualize stability metrics over training."""
        print("\nVisualizing Stability Metrics...")
        
        stability_metrics = ['max_eigenvalue_array', 'min_eigenvalue_array', 'signal_ratio_mean_array']
        available_metrics = [m for m in stability_metrics if m in self.history]
        
        if not available_metrics:
            print("  No stability metrics to visualize.")
            return
        
        fig, axes = plt.subplots(len(available_metrics), 2, figsize=(14, 4*len(available_metrics)))
        
        if len(available_metrics) == 1:
            axes = axes.reshape(1, -1)
        
        for idx, metric_name in enumerate(available_metrics):
            values = self.history[metric_name]
            
            # Time series
            axes[idx, 0].plot(values, linewidth=2, alpha=0.7)
            
            # Add reference lines based on metric type
            if 'eigenvalue' in metric_name:
                if 'max' in metric_name:
                    axes[idx, 0].axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='Stability bound')
                elif 'min' in metric_name:
                    axes[idx, 0].axhline(y=0.0, color='r', linestyle='--', alpha=0.5, label='Non-negativity')
            
            elif 'signal_ratio' in metric_name:
                axes[idx, 0].axhline(y=1.0, color='r', linestyle='--', alpha=0.5, label='Ideal = 1.0')
            
            axes[idx, 0].set_xlabel('Training Step')
            axes[idx, 0].set_ylabel(metric_name.replace('_array', ''))
            axes[idx, 0].set_title(f'{metric_name.replace("_array", "").replace("_", " ").title()}')
            axes[idx, 0].legend()
            axes[idx, 0].grid(True, alpha=0.3)
            
            # Distribution
            axes[idx, 1].hist(values, bins=30, alpha=0.7, edgecolor='black')
            axes[idx, 1].axvline(np.mean(values), color='red', linestyle='--',
                                label=f'Mean: {np.mean(values):.4f}')
            axes[idx, 1].axvline(np.median(values), color='green', linestyle='--',
                                label=f'Median: {np.median(values):.4f}')
            axes[idx, 1].set_xlabel('Value')
            axes[idx, 1].set_ylabel('Frequency')
            axes[idx, 1].set_title(f'Distribution of {metric_name.replace("_array", "")}')
            axes[idx, 1].legend()
            axes[idx, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Create interactive visualization
        self.create_interactive_stability_viz(available_metrics)

# %%
# Analyze stability
stability_analyzer = StabilityAnalyzer(history)
stability_analyzer.analyze_stability()
stability_analyzer.visualize_stability_metrics()

# %% [markdown]
"""
## 7. Training Optimization Recommendations
"""

# %%
class OptimizationRecommender:
    """Provide training optimization recommendations."""
    
    def __init__(self, config, history, loss_analyzer, gradient_analyzer, 
                 lr_analyzer, stability_analyzer):
        self.config = config
        self.history = history
        self.loss_analyzer = loss_analyzer
        self.gradient_analyzer = gradient_analyzer
        self.lr_analyzer = lr_analyzer
        self.stability_analyzer = stability_analyzer
        
    def generate_recommendations(self):
        """Generate comprehensive optimization recommendations."""
        print("\n" + "="*70)
        print("TRAINING OPTIMIZATION RECOMMENDATIONS")
        print("="*70)
        
        recommendations = []
        
        # 1. Learning Rate Recommendations
        lr_recs = self.analyze_learning_rate_optimization()
        recommendations.extend(lr_recs)
        
        # 2. Gradient Flow Recommendations
        grad_recs = self.analyze_gradient_optimization()
        recommendations.extend(grad_recs)
        
        # 3. Stability Recommendations
        stab_recs = self.analyze_stability_optimization()
        recommendations.extend(stab_recs)
        
        # 4. General Training Recommendations
        gen_recs = self.analyze_general_optimization()
        recommendations.extend(gen_recs)
        
        # Display recommendations
        self.display_recommendations(recommendations)
        
        # Generate summary
        self.generate_summary(recommendations)
        
        return recommendations
    
    def analyze_learning_rate_optimization(self):
        """Analyze and recommend LR optimizations."""
        recommendations = []
        
        if 'learning_rates_array' in self.history:
            lrs = self.history['learning_rates_array']
            initial_lr = lrs[0]
            final_lr = lrs[-1]
            
            # Check initial LR
            if initial_lr > 1e-2:
                recommendations.append({
                    'category': 'Learning Rate',
                    'issue': 'High initial learning rate',
                    'recommendation': 'Reduce initial LR to 1e-3 or 5e-4',
                    'priority': 'HIGH',
                    'rationale': f'Current LR {initial_lr:.1e} may cause instability'
                })
            elif initial_lr < 1e-4:
                recommendations.append({
                    'category': 'Learning Rate',
                    'issue': 'Low initial learning rate',
                    'recommendation': 'Increase initial LR to 1e-3',
                    'priority': 'MEDIUM',
                    'rationale': f'Current LR {initial_lr:.1e} may slow convergence'
                })
            
            # Check LR decay
            decay_ratio = final_lr / initial_lr
            if decay_ratio > 0.5:
                recommendations.append({
                    'category': 'Learning Rate',
                    'issue': 'Insufficient LR decay',
                    'recommendation': 'Increase decay factor or add more decay steps',
                    'priority': 'MEDIUM',
                    'rationale': f'Only {100*(1-decay_ratio):.1f}% decay achieved'
                })
        
        return recommendations
    
    def analyze_gradient_optimization(self):
        """Analyze and recommend gradient optimizations."""
        recommendations = []
        
        if 'grad_norm_total_array' in self.history:
            grad_norms = self.history['grad_norm_total_array']
            mean_grad = np.mean(grad_norms)
            max_grad = np.max(grad_norms)
            
            # Check for gradient explosion
            if max_grad > 1000:
                recommendations.append({
                    'category': 'Gradient Flow',
                    'issue': 'Gradient explosion detected',
                    'recommendation': 'Reduce LR or increase gradient clipping',
                    'priority': 'HIGH',
                    'rationale': f'Max gradient norm: {max_grad:.1f}'
                })
            
            # Check for gradient vanishing
            if mean_grad < 1e-6:
                recommendations.append({
                    'category': 'Gradient Flow',
                    'issue': 'Gradient vanishing detected',
                    'recommendation': 'Use gradient clipping lower bound or skip connections',
                    'priority': 'HIGH',
                    'rationale': f'Mean gradient norm: {mean_grad:.1e}'
                })
            
            # Check gradient clipping
            current_clip = self.config['training']['gradient_clip']
            if mean_grad > current_clip * 0.5:
                recommendations.append({
                    'category': 'Gradient Flow',
                    'issue': 'Aggressive gradient clipping',
                    'recommendation': f'Increase clip norm from {current_clip} to {current_clip*2}',
                    'priority': 'MEDIUM',
                    'rationale': f'Mean gradient {mean_grad:.2f}Êé•Ëøëclip value {current_clip}'
                })
        
        return recommendations
    
    def analyze_stability_optimization(self):
        """Analyze and recommend stability optimizations."""
        recommendations = []
        
        # Check eigenvalue stability
        if 'max_eigenvalue_array' in self.history:
            max_eigenvalues = self.history['max_eigenvalue_array']
            if np.any(max_eigenvalues > 1.0):
                recommendations.append({
                    'category': 'Stability',
                    'issue': 'Eigenvalue > 1.0 detected',
                    'recommendation': 'Increase Sinkhorn iterations or add eigenvalue penalty',
                    'priority': 'HIGH',
                    'rationale': 'Violates non-expansive mapping guarantee'
                })
        
        # Check signal preservation
        if 'signal_ratio_mean_array' in self.history:
            signal_ratios = self.history['signal_ratio_mean_array']
            mean_ratio = np.mean(signal_ratios)
            
            if abs(mean_ratio - 1.0) > 0.2:
                recommendations.append({
                    'category': 'Stability',
                    'issue': 'Poor signal preservation',
                    'recommendation': 'Adjust MHC initialization or add signal preservation loss',
                    'priority': 'MEDIUM',
                    'rationale': f'Signal ratio: {mean_ratio:.3f} (should be ~1.0)'
                })
        
        return recommendations
    
    def analyze_general_optimization(self):
        """Analyze and recommend general optimizations."""
        recommendations = []
        
        # Check batch size
        batch_size = self.config['training']['batch_size']
        if batch_size < 8:
            recommendations.append({
                'category': 'General',
                'issue': 'Small batch size',
                'recommendation': 'Increase batch size for better gradient estimates',
                'priority': 'MEDIUM',
                'rationale': f'Current batch size: {batch_size}'
            })
        elif batch_size > 32 and not self.config['training']['use_amp']:
            recommendations.append({
                'category': 'General',
                'issue': 'Large batch size without mixed precision',
                'recommendation': 'Enable mixed precision training',
                'priority': 'HIGH',
                'rationale': 'Will reduce memory and speed up training'
            })
        
        # Check weight decay
        weight_decay = self.config['training']['weight_decay']
        if weight_decay < 1e-5:
            recommendations.append({
                'category': 'Regularization',
                'issue': 'Weak weight decay',
                'recommendation': 'Increase weight decay to 1e-4',
                'priority': 'LOW',
                'rationale': f'Current weight decay: {weight_decay:.1e}'
            })
        elif weight_decay > 1e-3:
            recommendations.append({
                'category': 'Regularization',
                'issue': 'Strong weight decay',
                'recommendation': 'Decrease weight decay to 1e-4',
                'priority': 'MEDIUM',
                'rationale': f'Current weight decay: {weight_decay:.1e}'
            })
        
        # Check for mixed precision
        if not self.config['training']['use_amp'] and self.config['device'] == 'cuda':
            recommendations.append({
                'category': 'Performance',
                'issue': 'Mixed precision disabled',
                'recommendation': 'Enable mixed precision training',
                'priority': 'HIGH',
                'rationale': 'Will significantly speed up training on GPU'
            })
        
        return recommendations
    
    def display_recommendations(self, recommendations):
        """Display recommendations in formatted table."""
        if not recommendations:
            print("\nNo optimization recommendations.")
            return
        
        print("\nDetailed Recommendations:")
        print("-" * 120)
        print(f"{'Category':<15} {'Priority':<10} {'Issue':<40} {'Recommendation':<50}")
        print("-" * 120)
        
        # Sort by priority
        priority_order = {'HIGH': 0, 'MEDIUM': 1, 'LOW': 2}
        recommendations.sort(key=lambda x: priority_order[x['priority']])
        
        for rec in recommendations:
            # Color code priority
            if rec['priority'] == 'HIGH':
                priority_str = f"\033[91m{rec['priority']}\033[0m"
            elif rec['priority'] == 'MEDIUM':
                priority_str = f"\033[93m{rec['priority']}\033[0m"
            else:
                priority_str = f"\033[92m{rec['priority']}\033[0m"
            
            # Truncate long strings
            issue = rec['issue'][:38] + '..' if len(rec['issue']) > 40 else rec['issue']
            recommendation = rec['recommendation'][:48] + '..' if len(rec['recommendation']) > 50 else rec['recommendation']
            
            print(f"{rec['category']:<15} {priority_str:<10} {issue:<40} {recommendation:<50}")
        
        print("-" * 120)
        
        # Count by priority
        high_count = sum(1 for r in recommendations if r['priority'] == 'HIGH')
        medium_count = sum(1 for r in recommendations if r['priority'] == 'MEDIUM')
        low_count = sum(1 for r in recommendations if r['priority'] == 'LOW')
        
        print(f"\nPriority Summary: HIGH={high_count}, MEDIUM={medium_count}, LOW={low_count}")
    
    def generate_summary(self, recommendations):
        """Generate optimization summary."""
        print("\n" + "="*70)
        print("OPTIMIZATION SUMMARY")
        print("="*70)
        
        # Count recommendations by category
        categories = {}
        for rec in recommendations:
            cat = rec['category']
            if cat not in categories:
                categories[cat] = 0
            categories[cat] += 1
        
        print("\nRecommendations by Category:")
        for cat, count in categories.items():
            print(f"  {cat}: {count}")
        
        # Top recommendations
        high_recs = [r for r in recommendations if r['priority'] == 'HIGH']
        if high_recs:
            print("\nTOP PRIORITY (HIGH) Recommendations:")
            for rec in high_recs[:3]:  # Top 3
                print(f"  ‚Ä¢ {rec['issue']}")
                print(f"    ‚Üí {rec['recommendation']}")
        
        # Estimated impact
        print("\nExpected Impact of Implementations:")
        print("  ‚Ä¢ HIGH priority: 20-50% training improvement")
        print("  ‚Ä¢ MEDIUM priority: 10-20% training improvement") 
        print("  ‚Ä¢ LOW priority: 5-10% training improvement")
        
        # Implementation timeline
        print("\nSuggested Implementation Order:")
        print("  1. Implement all HIGH priority recommendations")
        print("  2. Address gradient flow issues")
        print("  3. Optimize learning rate schedule")
        print("  4. Implement stability improvements")
        print("  5. Apply general optimizations")

# %%
# Generate optimization recommendations
recommender = OptimizationRecommender(
    config, 
    history, 
    loss_analyzer, 
    gradient_analyzer, 
    lr_analyzer, 
    stability_analyzer
)

recommendations = recommender.generate_recommendations()

# %% [markdown]
"""
## 8. Export Training Analysis Report
"""

# %%
class TrainingAnalysisExporter:
    """Export comprehensive training analysis report."""
    
    def __init__(self, config, history, recommendations):
        self.config = config
        self.history = history
        self.recommendations = recommendations
        
    def export_report(self):
        """Export training analysis report."""
        print("\nExporting Training Analysis Report...")
        
        # Create report data
        report = {
            'timestamp': pd.Timestamp.now().isoformat(),
            'config': self.config,
            'training_summary': self.generate_training_summary(),
            'key_metrics': self.calculate_key_metrics(),
            'recommendations': self.recommendations,
            'implementation_plan': self.generate_implementation_plan()
        }
        
        # Export as JSON
        import json
        with open('../reports/training_analysis_report.json', 'w') as f:
            json.dump(report, f, indent=2)
        
        print("Training analysis report exported to ../reports/training_analysis_report.json")
        
        # Also export as HTML
        self.export_html_report(report)
    
    def generate_training_summary(self):
        """Generate training summary."""
        summary = f"""
        TRAINING ANALYSIS SUMMARY
        
        Configuration:
        - Batch size: {self.config['training']['batch_size']}
        - Initial LR: {self.config['training']['learning_rate']}
        - Weight decay: {self.config['training']['weight_decay']}
        - Gradient clip: {self.config['training']['gradient_clip']}
        - Mixed precision: {self.config['training']['use_amp']}
        
        Key Findings:
        1. {'Stable convergence achieved' if 'losses_array' in self.history and len(self.history['losses_array']) > 0 else 'Convergence analysis pending'}
        2. {'Good gradient flow' if 'grad_norm_total_array' in self.history and np.mean(self.history['grad_norm_total_array']) < 10 else 'Gradient issues detected'}
        3. {'Proper LR schedule' if 'learning_rates_array' in self.history and self.history['learning_rates_array'][-1] < self.history['learning_rates_array'][0] * 0.1 else 'LR schedule needs optimization'}
        4. {'Excellent stability' if 'max_eigenvalue_array' in self.history and np.all(self.history['max_eigenvalue_array'] <= 1.0) else 'Stability issues detected'}
        
        Overall Assessment: Training pipeline is {'READY for production' if len([r for r in self.recommendations if r['priority'] == 'HIGH']) == 0 else 'NEEDS OPTIMIZATION'}
        """
        return summary
    
    def calculate_key_metrics(self):
        """Calculate key training metrics."""
        metrics = {}
        
        # Loss metrics
        if 'losses_array' in self.history and len(self.history['losses_array']) > 0:
            losses = self.history['losses_array']
            metrics['loss'] = {
                'initial': float(losses[0]),
                'final': float(losses[-1]),
                'min': float(np.min(losses)),
                'reduction_percent': float((1 - losses[-1]/losses[0]) * 100) if losses[0] > 0 else 0
            }
        
        # Gradient metrics
        if 'grad_norm_total_array' in self.history and len(self.history['grad_norm_total_array']) > 0:
            grad_norms = self.history['grad_norm_total_array']
            metrics['gradient'] = {
                'mean': float(np.mean(grad_norms)),
                'std': float(np.std(grad_norms)),
                'max': float(np.max(grad_norms)),
                'min': float(np.min(grad_norms))
            }
        
        # Learning rate metrics
        if 'learning_rates_array' in self.history and len(self.history['learning_rates_array']) > 0:
            lrs = self.history['learning_rates_array']
            metrics['learning_rate'] = {
                'initial': float(lrs[0]),
                'final': float(lrs[-1]),
                'decay_ratio': float(lrs[-1] / lrs[0]) if lrs[0] > 0 else 0
            }
        
        # Stability metrics
        stability_metrics = {}
        for key in ['max_eigenvalue_array', 'min_eigenvalue_array', 'signal_ratio_mean_array']:
            if key in self.history and len(self.history[key]) > 0:
                values = self.history[key]
                stability_metrics[key.replace('_array', '')] = {
                    'mean': float(np.mean(values)),
                    'std': float(np.std(values)),
                    'in_range': bool(self.check_stability_range(key, values))
                }
        
        if stability_metrics:
            metrics['stability'] = stability_metrics
        
        return metrics
    
    def check_stability_range(self, metric_name, values):
        """Check if stability metric is in acceptable range."""
        if 'max_eigenvalue' in metric_name:
            return np.all(values <= 1.0)
        elif 'min_eigenvalue' in metric_name:
            return np.all(values >= 0)
        elif 'signal_ratio' in metric_name:
            return np.all(np.abs(values - 1.0) < 0.2)
        return True
    
    def generate_implementation_plan(self):
        """Generate implementation plan for recommendations."""
        high_recs = [r for r in self.recommendations if r['priority'] == 'HIGH']
        medium_recs = [r for r in self.recommendations if r['priority'] == 'MEDIUM']
        low_recs = [r for r in self.recommendations if r['priority'] == 'LOW']
        
        plan = {
            'phase_1_immediate': {
                'timeline': '1-2 days',
                'recommendations': high_recs,
                'expected_impact': 'Resolve critical issues, ensure training stability'
            },
            'phase_2_short_term': {
                'timeline': '3-7 days',
                'recommendations': medium_recs,
                'expected_impact': 'Improve convergence speed and final performance'
            },
            'phase_3_long_term': {
                'timeline': '1-2 weeks',
                'recommendations': low_recs,
                'expected_impact': 'Optimize for efficiency and deployment'
            }
        }
        
        return plan
    
    def export_html_report(self, report):
        """Export HTML report."""
        html_content = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <title>Humanoid Vision System - Training Analysis Report</title>
            <style>
                body {{ font-family: Arial, sans-serif; margin: 40px; line-height: 1.6; }}
                h1 {{ color: #2c3e50; border-bottom: 3px solid #3498db; }}
                h2 {{ color: #34495e; margin-top: 30px; }}
                h3 {{ color: #2c3e50; margin-top: 20px; }}
                .card {{ background: #f8f9fa; border-left: 4px solid #3498db; 
                        padding: 20px; margin: 20px 0; border-radius: 5px; }}
                .metric-card {{ display: inline-block; background: white; padding: 15px; 
                         margin: 10px; border-radius: 5px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); 
                         width: 200px; vertical-align: top; }}
                .high {{ color: #e74c3c; font-weight: bold; border-left: 4px solid #e74c3c; }}
                .medium {{ color: #f39c12; font-weight: bold; border-left: 4px solid #f39c12; }}
                .low {{ color: #27ae60; font-weight: bold; border-left: 4px solid #27ae60; }}
                table {{ width: 100%; border-collapse: collapse; margin: 20px 0; }}
                th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }}
                th {{ background-color: #3498db; color: white; }}
                .phase {{ margin: 20px 0; padding: 15px; border-radius: 5px; }}
                .phase-1 {{ background: #ffeaa7; }}
                .phase-2 {{ background: #a29bfe; }}
                .phase-3 {{ background: #55efc4; }}
            </style>
        </head>
        <body>
            <h1>Humanoid Vision System - Training Analysis Report</h1>
            <p>Generated on: {report['timestamp']}</p>
            
            <div class="card">
                <h2>Executive Summary</h2>
                <pre>{report['training_summary']}</pre>
            </div>
            
            <h2>Key Metrics</h2>
            <div>
        """
        
        # Add metric cards
        if 'loss' in report['key_metrics']:
            loss = report['key_metrics']['loss']
            html_content += f"""
                <div class="metric-card">
                    <h3>Loss</h3>
                    <p>Initial: {loss['initial']:.4f}</p>
                    <p>Final: {loss['final']:.4f}</p>
                    <p>Reduction: {loss['reduction_percent']:.1f}%</p>
                </div>
            """
        
        if 'gradient' in report['key_metrics']:
            grad = report['key_metrics']['gradient']
            html_content += f"""
                <div class="metric-card">
                    <h3>Gradient</h3>
                    <p>Mean norm: {grad['mean']:.4f}</p>
                    <p>Std: {grad['std']:.4f}</p>
                    <p>Max: {grad['max']:.4f}</p>
                </div>
            """
        
        if 'learning_rate' in report['key_metrics']:
            lr = report['key_metrics']['learning_rate']
            html_content += f"""
                <div class="metric-card">
                    <h3>Learning Rate</h3>
                    <p>Initial: {lr['initial']:.1e}</p>
                    <p>Final: {lr['final']:.1e}</p>
                    <p>Decay: {(1-lr['decay_ratio'])*100:.1f}%</p>
                </div>
            """
        
        html_content += """
            </div>
            
            <h2>Optimization Recommendations</h2>
            <table>
                <tr>
                    <th>Category</th>
                    <th>Priority</th>
                    <th>Issue</th>
                    <th>Recommendation</th>
                </tr>
        """
        
        for rec in report['recommendations']:
            priority_class = rec['priority'].lower()
            html_content += f"""
                <tr class="{priority_class}">
                    <td>{rec['category']}</td>
                    <td>{rec['priority']}</td>
                    <td>{rec['issue']}</td>
                    <td>{rec['recommendation']}</td>
                </tr>
            """
        
        html_content += """
            </table>
            
            <h2>Implementation Plan</h2>
        """
        
        plan = report['implementation_plan']
        phases = ['phase_1_immediate', 'phase_2_short_term', 'phase_3_long_term']
        phase_titles = ['Phase 1: Immediate (1-2 days)', 'Phase 2: Short-term (3-7 days)', 'Phase 3: Long-term (1-2 weeks)']
        phase_classes = ['phase-1', 'phase-2', 'phase-3']
        
        for phase_key, title, phase_class in zip(phases, phase_titles, phase_classes):
            if phase_key in plan:
                phase_data = plan[phase_key]
                
                html_content += f"""
                <div class="phase {phase_class}">
                    <h3>{title}</h3>
                    <p><strong>Expected Impact:</strong> {phase_data['expected_impact']}</p>
                    <ul>
                """
                
                for rec in phase_data['recommendations']:
                    html_content += f"<li><strong>{rec['category']}:</strong> {rec['recommendation']}</li>"
                
                html_content += """
                    </ul>
                </div>
                """
        
        html_content += """
            <div class="card">
                <h2>Next Steps</h2>
                <ol>
                    <li>Review and prioritize recommendations</li>
                    <li>Implement Phase 1 recommendations immediately</li>
                    <li>Monitor training after each optimization</li>
                    <li>Proceed to inference testing and deployment</li>
                    <li>Schedule follow-up analysis in 2 weeks</li>
                </ol>
            </div>
        </body>
        </html>
        """
        
        with open('../reports/training_analysis_report.html', 'w') as f:
            f.write(html_content)
        
        print("HTML report exported to ../reports/training_analysis_report.html")

# %%
# Export reports
training_exporter = TrainingAnalysisExporter(config, history, recommendations)
training_exporter.export_report()

# %% [markdown]
"""
## 9. Conclusion and Next Steps
"""

# %%
print("\n" + "="*70)
print("TRAINING ANALYSIS - COMPLETED")
print("="*70)

print("\n‚úÖ ANALYSIS COMPLETED:")
print("  1. Loss convergence analyzed")
print("  2. Gradient flow assessed")
print("  3. Learning rate schedule evaluated")
print("  4. Training stability verified")
print("  5. Optimization recommendations generated")
print("  6. Comprehensive reports exported")

print("\nüìä KEY FINDINGS:")
if 'losses_array' in history and len(history['losses_array']) > 0:
    loss_reduction = (1 - history['losses_array'][-1]/history['losses_array'][0]) * 100
    print(f"  ‚Ä¢ Loss reduction: {loss_reduction:.1f}%")
if 'grad_norm_total_array' in history:
    print(f"  ‚Ä¢ Mean gradient norm: {np.mean(history['grad_norm_total_array']):.2f}")
if 'max_eigenvalue_array' in history:
    stable = np.all(history['max_eigenvalue_array'] <= 1.0)
    print(f"  ‚Ä¢ Stability: {'‚úÖ EXCELLENT' if stable else '‚ö†Ô∏è NEEDS ATTENTION'}")

high_count = sum(1 for r in recommendations if r['priority'] == 'HIGH')
print(f"\nüö® CRITICAL ISSUES: {high_count} HIGH priority recommendations")

print("\nüöÄ NEXT STEPS:")
print("  1. Implement HIGH priority recommendations immediately")
print("  2. Run validation with real data")
print("  3. Proceed to inference analysis (04_inference_demo.ipynb)")
print("  4. Prepare for deployment testing")
print("  5. Schedule model retraining with optimizations")

print("\n" + "="*70)