In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple, Optional
from torch_geometric.data import Batch
from scipy.spatial.distance import directed_hausdorff
from dtaidistance import dtw
import seaborn as sns

In [None]:
def load_trained_model(model_path: str, model_class, **model_kwargs):
    """Load trained GraphODE model"""
    model = model_class(**model_kwargs)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    model.eval()
    return model

def predict_trajectories(model, test_loader, device='cpu', max_batches=None):
    """Generate predictions for test dataset"""
    model.eval()
    all_predictions = []
    all_targets = []
    all_batch_info = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            if max_batches and batch_idx >= max_batches:
                break
                
            batch.graphs = batch.graphs.to(device)
            batch.next_positions = batch.next_positions.to(device)
            
            # Forward pass
            time_span = torch.tensor([0., 1.], device=device)
            result = model(batch.graphs, time_span)
            pred_next_positions = result['trajectories'][1]  # t=1 predictions
            
            # Store results
            all_predictions.append(pred_next_positions.cpu().numpy())
            all_targets.append(batch.next_positions.view(-1, 2).cpu().numpy())
            all_batch_info.append({
                'batch_size': batch.graphs.batch.max().item() + 1,
                'nodes_per_graph': torch.bincount(batch.graphs.batch).cpu().numpy()
            })
    
    return all_predictions, all_targets, all_batch_info

def calculate_position_error_metrics(predictions: List[np.ndarray], targets: List[np.ndarray]) -> Dict:
    """Calculate intuitive position error metrics"""
    all_pred = np.vstack(predictions)
    all_target = np.vstack(targets)
    
    # Euclidean distance errors
    position_errors = np.linalg.norm(all_pred - all_target, axis=1)
    
    # Direction errors (angle between predicted and actual movement)
    direction_errors = []
    for pred, target in zip(predictions, targets):
        if len(pred) > 1:
            pred_directions = pred[1:] - pred[:-1]
            target_directions = target[1:] - target[:-1]
            
            # Calculate angle between vectors
            for pd, td in zip(pred_directions, target_directions):
                if np.linalg.norm(pd) > 0 and np.linalg.norm(td) > 0:
                    cos_angle = np.dot(pd, td) / (np.linalg.norm(pd) * np.linalg.norm(td))
                    cos_angle = np.clip(cos_angle, -1, 1)  # Handle numerical errors
                    angle_error = np.arccos(cos_angle) * 180 / np.pi
                    direction_errors.append(angle_error)
    
    return {
        'mean_position_error': np.mean(position_errors),
        'std_position_error': np.std(position_errors),
        'median_position_error': np.median(position_errors),
        'max_position_error': np.max(position_errors),
        'position_errors': position_errors,
        'mean_direction_error': np.mean(direction_errors) if direction_errors else 0,
        'direction_errors': direction_errors
    }

def calculate_success_rates(predictions: List[np.ndarray], targets: List[np.ndarray], 
                          tolerance_levels: List[float] = [0.5, 1.0, 1.5, 2.0]) -> Dict:
    """Calculate success rates at different tolerance levels"""
    all_pred = np.vstack(predictions)
    all_target = np.vstack(targets)
    
    position_errors = np.linalg.norm(all_pred - all_target, axis=1)
    
    success_rates = {}
    for tolerance in tolerance_levels:
        success_rate = np.mean(position_errors <= tolerance)
        success_rates[f'success_rate_{tolerance}'] = success_rate
    
    return success_rates

def multi_step_prediction_accuracy(model, test_loader, num_steps: int = 10, device='cpu'):
    """Evaluate multi-step prediction accuracy"""
    model.eval()
    step_errors = {i: [] for i in range(1, num_steps + 1)}
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            if batch_idx >= 10:  # Limit for computation
                break
                
            current_graphs = batch.graphs.to(device)
            
            # Auto-regressive prediction
            predictions = []
            for step in range(num_steps):
                time_span = torch.tensor([0., 1.], device=device)
                result = model(current_graphs, time_span)
                pred_positions = result['trajectories'][1]
                predictions.append(pred_positions.cpu().numpy())
                
                # Update graph for next prediction (simplified)
                # In practice, you'd need to update the graph structure properly
                
            # Compare with ground truth (if available)
            # This is a simplified version - you'd need actual multi-step ground truth
            
    return step_errors

def analyze_collision_prediction(predictions: List[np.ndarray], targets: List[np.ndarray], 
                               batch_info: List[Dict], collision_threshold: float = 1.5) -> Dict:
    """Analyze collision prediction accuracy"""
    pred_collisions = []
    actual_collisions = []
    
    for pred, target, info in zip(predictions, targets, batch_info):
        nodes_per_graph = info['nodes_per_graph']
        start_idx = 0
        
        for graph_nodes in nodes_per_graph:
            if graph_nodes <= 1:
                start_idx += graph_nodes
                continue
                
            # Extract positions for this graph
            graph_pred = pred[start_idx:start_idx + graph_nodes]
            graph_target = target[start_idx:start_idx + graph_nodes]
            
            # Count collisions in predictions
            pred_distances = np.linalg.norm(graph_pred[:, None] - graph_pred[None, :], axis=2)
            pred_collisions.append(np.sum((pred_distances < collision_threshold) & (pred_distances > 0)) // 2)
            
            # Count collisions in targets
            target_distances = np.linalg.norm(graph_target[:, None] - graph_target[None, :], axis=2)
            actual_collisions.append(np.sum((target_distances < collision_threshold) & (target_distances > 0)) // 2)
            
            start_idx += graph_nodes
    
    return {
        'predicted_collisions': pred_collisions,
        'actual_collisions': actual_collisions,
        'collision_prediction_mse': np.mean((np.array(pred_collisions) - np.array(actual_collisions)) ** 2),
        'collision_prediction_mae': np.mean(np.abs(np.array(pred_collisions) - np.array(actual_collisions)))
    }

def plot_prediction_accuracy_analysis(error_metrics: Dict, success_rates: Dict, 
                                    collision_analysis: Dict, save_path: Optional[str] = None):
    """Plot comprehensive prediction accuracy analysis"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Position error distribution
    axes[0,0].hist(error_metrics['position_errors'], bins=50, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0,0].axvline(error_metrics['mean_position_error'], color='red', linestyle='--', 
                     label=f'Mean: {error_metrics["mean_position_error"]:.2f}')
    axes[0,0].axvline(error_metrics['median_position_error'], color='orange', linestyle='--',
                     label=f'Median: {error_metrics["median_position_error"]:.2f}')
    axes[0,0].set_title('Position Error Distribution')
    axes[0,0].set_xlabel('Position Error (grid units)')
    axes[0,0].set_ylabel('Frequency')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)
    
    # 2. Success rates at different tolerances
    tolerances = [float(k.split('_')[-1]) for k in success_rates.keys()]
    rates = list(success_rates.values())
    
    axes[0,1].bar(tolerances, rates, alpha=0.7, color='lightgreen', edgecolor='black')
    axes[0,1].set_title('Prediction Success Rate by Tolerance')
    axes[0,1].set_xlabel('Tolerance (grid units)')
    axes[0,1].set_ylabel('Success Rate')
    axes[0,1].set_ylim(0, 1)
    axes[0,1].grid(True, alpha=0.3)
    
    # Add percentage labels on bars
    for i, rate in enumerate(rates):
        axes[0,1].text(tolerances[i], rate + 0.01, f'{rate:.1%}', ha='center', va='bottom')
    
    # 3. Direction error distribution (if available)
    if error_metrics['direction_errors']:
        axes[0,2].hist(error_metrics['direction_errors'], bins=30, alpha=0.7, 
                      color='lightcoral', edgecolor='black')
        axes[0,2].axvline(error_metrics['mean_direction_error'], color='red', linestyle='--',
                         label=f'Mean: {error_metrics["mean_direction_error"]:.1f}°')
        axes[0,2].set_title('Direction Error Distribution')
        axes[0,2].set_xlabel('Direction Error (degrees)')
        axes[0,2].set_ylabel('Frequency')
        axes[0,2].legend()
        axes[0,2].grid(True, alpha=0.3)
    else:
        axes[0,2].text(0.5, 0.5, 'Direction Error\nNot Available', ha='center', va='center',
                      transform=axes[0,2].transAxes, fontsize=12)
        axes[0,2].set_title('Direction Error Distribution')
    
    # 4. Error statistics summary
    error_stats = [
        error_metrics['mean_position_error'],
        error_metrics['median_position_error'],
        error_metrics['std_position_error'],
        error_metrics['max_position_error']
    ]
    stat_labels = ['Mean', 'Median', 'Std Dev', 'Max']
    
    bars = axes[1,0].bar(stat_labels, error_stats, alpha=0.7, color='gold', edgecolor='black')
    axes[1,0].set_title('Position Error Statistics')
    axes[1,0].set_ylabel('Error (grid units)')
    axes[1,0].grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar, stat in zip(bars, error_stats):
        axes[1,0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                      f'{stat:.2f}', ha='center', va='bottom')
    
    # 5. Collision prediction accuracy
    pred_collisions = collision_analysis['predicted_collisions']
    actual_collisions = collision_analysis['actual_collisions']
    
    axes[1,1].scatter(actual_collisions, pred_collisions, alpha=0.6, color='purple')
    max_collisions = max(max(pred_collisions) if pred_collisions else 0, 
                        max(actual_collisions) if actual_collisions else 0)
    axes[1,1].plot([0, max_collisions], [0, max_collisions], 'r--', label='Perfect Prediction')
    axes[1,1].set_title('Collision Prediction Accuracy')
    axes[1,1].set_xlabel('Actual Collisions')
    axes[1,1].set_ylabel('Predicted Collisions')
    axes[1,1].legend()
    axes[1,1].grid(True, alpha=0.3)
    
    # 6. Performance summary
    summary_text = f"""Model Performance Summary
    
Position Accuracy:
• Mean Error: {error_metrics['mean_position_error']:.2f} units
• Success Rate (≤1.0): {success_rates.get('success_rate_1.0', 0):.1%}
• Success Rate (≤1.5): {success_rates.get('success_rate_1.5', 0):.1%}

Direction Accuracy:
• Mean Direction Error: {error_metrics['mean_direction_error']:.1f}°

Collision Prediction:
• MAE: {collision_analysis['collision_prediction_mae']:.2f}
• MSE: {collision_analysis['collision_prediction_mse']:.2f}
"""
    
    axes[1,2].text(0.05, 0.95, summary_text, transform=axes[1,2].transAxes, 
                  fontsize=10, verticalalignment='top', fontfamily='monospace',
                  bbox=dict(boxstyle='round', facecolor='lightgray', alpha=0.8))
    axes[1,2].set_xlim(0, 1)
    axes[1,2].set_ylim(0, 1)
    axes[1,2].axis('off')
    axes[1,2].set_title('Performance Summary')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def evaluate_model_comprehensive(model, test_loader, device='cpu', max_batches=20, save_dir='./evaluation_results/'):
    """Comprehensive model evaluation with intuitive metrics"""
    import os
    os.makedirs(save_dir, exist_ok=True)
    
    print("Generating predictions...")
    predictions, targets, batch_info = predict_trajectories(model, test_loader, device, max_batches)
    
    print("Calculating error metrics...")
    error_metrics = calculate_position_error_metrics(predictions, targets)
    
    print("Calculating success rates...")
    success_rates = calculate_success_rates(predictions, targets)
    
    print("Analyzing collision prediction...")
    collision_analysis = analyze_collision_prediction(predictions, targets, batch_info)
    
    print("Creating visualizations...")
    plot_prediction_accuracy_analysis(error_metrics, success_rates, collision_analysis, 
                                    f"{save_dir}/prediction_accuracy_analysis.png")
    
    # Print summary
    print("\n" + "="*50)
    print("MODEL EVALUATION SUMMARY")
    print("="*50)
    print(f"Mean Position Error: {error_metrics['mean_position_error']:.3f} grid units")
    print(f"Median Position Error: {error_metrics['median_position_error']:.3f} grid units")
    print(f"Success Rate (≤1.0 units): {success_rates.get('success_rate_1.0', 0):.1%}")
    print(f"Success Rate (≤1.5 units): {success_rates.get('success_rate_1.5', 0):.1%}")
    print(f"Mean Direction Error: {error_metrics['mean_direction_error']:.1f} degrees")
    print(f"Collision Prediction MAE: {collision_analysis['collision_prediction_mae']:.3f}")
    print("="*50)
    
    return {
        'error_metrics': error_metrics,
        'success_rates': success_rates,
        'collision_analysis': collision_analysis
    }