In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import h5py
from collections import defaultdict
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics import pairwise_distances
import pandas as pd

ModuleNotFoundError: No module named 'seaborn'

In [None]:
def calculate_spatial_density_heatmap(trajectories: List[np.ndarray], grid_size: Tuple[int, int], 
                                    cell_size: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Calculate smooth spatial density heatmap"""
    # Create density grid
    x_bins = np.arange(0, grid_size[1] + cell_size, cell_size)
    y_bins = np.arange(0, grid_size[0] + cell_size, cell_size)
    
    density_grid = np.zeros((len(y_bins)-1, len(x_bins)-1))
    
    # Accumulate agent positions across all trajectories and timesteps
    for trajectory in trajectories:
        for timestep_positions in trajectory:
            for pos in timestep_positions:
                x, y = pos[0], pos[1]
                # Find which bin this position belongs to
                x_idx = np.digitize(x, x_bins) - 1
                y_idx = np.digitize(y, y_bins) - 1
                
                # Make sure indices are within bounds
                if 0 <= x_idx < len(x_bins)-1 and 0 <= y_idx < len(y_bins)-1:
                    density_grid[y_idx, x_idx] += 1
    
    # Normalize by total timesteps to get density
    total_timesteps = sum(len(traj) for traj in trajectories)
    density_grid = density_grid / total_timesteps if total_timesteps > 0 else density_grid
    
    # Create coordinate grids for plotting
    X, Y = np.meshgrid(x_bins[:-1] + cell_size/2, y_bins[:-1] + cell_size/2)
    
    return density_grid, X, Y

def calculate_temporal_density_evolution(trajectories: List[np.ndarray], grid_size: Tuple[int, int], 
                                       time_windows: int = 10, cell_size: float = 2.0) -> List[np.ndarray]:
    """Calculate how density evolves over time"""
    if not trajectories:
        return []
    
    max_episode_length = max(len(traj) for traj in trajectories)
    window_size = max_episode_length // time_windows
    
    density_evolution = []
    
    for window_idx in range(time_windows):
        start_time = window_idx * window_size
        end_time = min((window_idx + 1) * window_size, max_episode_length)
        
        # Extract trajectories for this time window
        windowed_trajectories = []
        for trajectory in trajectories:
            if len(trajectory) > start_time:
                window_traj = trajectory[start_time:min(end_time, len(trajectory))]
                if len(window_traj) > 0:
                    windowed_trajectories.append(window_traj)
        
        # Calculate density for this window
        if windowed_trajectories:
            density_grid, _, _ = calculate_spatial_density_heatmap(
                windowed_trajectories, grid_size, cell_size
            )
            density_evolution.append(density_grid)
    
    return density_evolution

def plot_spatial_density_heatmaps(datasets_analysis: List[Dict], save_path: Optional[str] = None):
    """Plot spatial density heatmaps for multiple datasets"""
    num_datasets = len(datasets_analysis)
    fig, axes = plt.subplots(2, num_datasets, figsize=(5*num_datasets, 10))
    
    if num_datasets == 1:
        axes = axes.reshape(-1, 1)
    
    for i, analysis in enumerate(datasets_analysis):
        density_heatmap = analysis['density_heatmap']
        X, Y = analysis['heatmap_coords']
        dataset_name = analysis['dataset_name']
        
        # Regular heatmap
        im1 = axes[0, i].contourf(X, Y, density_heatmap, levels=20, cmap='YlOrRd')
        axes[0, i].set_title(f'{dataset_name}\nDensity Heatmap')
        axes[0, i].set_xlabel('X Position')
        axes[0, i].set_ylabel('Y Position')
        fig.colorbar(im1, ax=axes[0, i], label='Density')
        
        # Smoothed contour plot
        im2 = axes[1, i].contour(X, Y, density_heatmap, levels=10, colors='black', alpha=0.6)
        axes[1, i].contourf(X, Y, density_heatmap, levels=20, cmap='viridis', alpha=0.8)
        axes[1, i].set_title(f'{dataset_name}\nSmooth Density Contours')
        axes[1, i].set_xlabel('X Position')
        axes[1, i].set_ylabel('Y Position')
        
        # Add density statistics as text
        stats_text = f'Max: {analysis["max_local_density"]:.3f}\n'
        stats_text += f'Concentration: {analysis["density_concentration"]:.3f}\n'
        stats_text += f'Effective: {analysis["effective_density"]:.3f}'
        axes[1, i].text(0.02, 0.98, stats_text, transform=axes[1, i].transAxes, 
                       verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_temporal_density_evolution(h5_file_path: str, dataset_name: str, 
                                  time_windows: int = 6, save_path: Optional[str] = None):
    """Plot how density evolves over time within episodes"""
    trajectories, metadata = load_episode_trajectories(h5_file_path, max_episodes=50)
    density_evolution = calculate_temporal_density_evolution(trajectories, metadata['grid_size'], time_windows)
    
    if not density_evolution:
        print("No density evolution data available")
        return
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    axes = axes.flatten()
    
    for i, density_grid in enumerate(density_evolution[:6]):  # Show first 6 time windows
        if i < len(axes):
            im = axes[i].imshow(density_grid, cmap='YlOrRd', origin='lower')
            axes[i].set_title(f'Time Window {i+1}')
            axes[i].set_xlabel('X Position')
            axes[i].set_ylabel('Y Position')
            plt.colorbar(im, ax=axes[i], shrink=0.8)
    
    # Hide unused subplots
    for i in range(len(density_evolution), len(axes)):
        axes[i].set_visible(False)
    
    plt.suptitle(f'{dataset_name}: Density Evolution Over Time', fontsize=16)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def load_dataset_metadata(h5_file_path: str) -> Dict:
    """Load basic metadata from HDF5 file"""
    with h5py.File(h5_file_path, 'r') as f:
        first_episode = next(iter([key for key in f.keys() if key.startswith('episode_')]))
        metadata_group = f[first_episode]['metadata']
        
        return {
            'num_agvs': metadata_group.attrs['num_agvs'],
            'num_pickers': metadata_group.attrs['num_pickers'],
            'grid_size': metadata_group.attrs['grid_size'],
            'num_episodes': len([key for key in f.keys() if key.startswith('episode_')])
        }

def extract_positions_from_observations(observations: np.ndarray, num_agvs: int, num_pickers: int) -> np.ndarray:
    """Extract agent positions from observations"""
    positions = []
    
    for i, obs in enumerate(observations):
        if i < num_agvs:  # AGV agent
            y, x = obs[3], obs[4]
        else:  # Picker agent
            y, x = obs[0], obs[1]
        positions.append([x, y])  # Note: switching to (x, y) format
    
    return np.array(positions)

def load_episode_trajectories(h5_file_path: str, max_episodes: Optional[int] = None) -> Tuple[List[np.ndarray], Dict]:
    """Load all trajectories from episodes"""
    trajectories = []
    metadata = load_dataset_metadata(h5_file_path)
    
    with h5py.File(h5_file_path, 'r') as f:
        episode_keys = [key for key in f.keys() if key.startswith('episode_')]
        episode_keys = sorted(episode_keys)[:max_episodes] if max_episodes else episode_keys
        
        for episode_key in episode_keys:
            episode_group = f[episode_key]
            steps_group = episode_group['steps']
            
            episode_trajectory = []
            for step_name in sorted(steps_group.keys()):
                step_group = steps_group[step_name]
                observations = step_group['observations'][:]
                positions = extract_positions_from_observations(
                    observations, metadata['num_agvs'], metadata['num_pickers']
                )
                episode_trajectory.append(positions)
            
            if episode_trajectory:
                trajectories.append(np.array(episode_trajectory))  # [steps, agents, 2]
    
    return trajectories, metadata

# 1. Density and Complexity Analysis
def analyze_density_and_complexity(h5_file_path: str, dataset_name: str) -> Dict:
    """Analyze agent density and interaction complexity"""
    trajectories, metadata = load_episode_trajectories(h5_file_path)
    
    grid_size = metadata['grid_size']
    num_agents = metadata['num_agvs'] + metadata['num_pickers']
    
    # Agent density
    total_area = grid_size[0] * grid_size[1]
    agent_density = num_agents / total_area
    
    # Interaction complexity metrics
    collision_counts = []
    avg_distances = []
    path_lengths = []
    
    for trajectory in trajectories:
        # Calculate average inter-agent distances
        distances_per_step = []
        for step_positions in trajectory:
            if len(step_positions) > 1:
                distances = pdist(step_positions)
                distances_per_step.append(np.mean(distances))
        
        if distances_per_step:
            avg_distances.append(np.mean(distances_per_step))
        
        # Calculate path lengths for each agent
        for agent_idx in range(num_agents):
            agent_path = trajectory[:, agent_idx, :]
            path_length = np.sum(np.linalg.norm(np.diff(agent_path, axis=0), axis=1))
            path_lengths.append(path_length)
        
        # Count potential collisions (agents within distance threshold)
        collision_count = 0
        for step_positions in trajectory:
            if len(step_positions) > 1:
                distances = squareform(pdist(step_positions))
                # Count pairs with distance < 1.5 (collision threshold)
                collision_count += np.sum((distances < 1.5) & (distances > 0)) / 2
        collision_counts.append(collision_count)
    
    return {
        'dataset_name': dataset_name,
        'agent_density': agent_density,
        'avg_inter_agent_distance': np.mean(avg_distances),
        'avg_collision_count': np.mean(collision_counts),
        'avg_path_length': np.mean(path_lengths),
        'grid_size': grid_size,
        'num_agents': num_agents,
        'num_episodes': len(trajectories)
    }

# 2. Trajectory Characteristics Analysis
def analyze_trajectory_characteristics(h5_file_path: str) -> Dict:
    """Analyze speed distribution and path diversity"""
    trajectories, metadata = load_episode_trajectories(h5_file_path)
    
    all_speeds = []
    all_accelerations = []
    path_diversities = []
    
    for trajectory in trajectories:
        # Path diversity (using variance of positions)
        all_positions = trajectory.reshape(-1, 2)
        position_variance = np.var(all_positions, axis=0)
        path_diversities.append(np.mean(position_variance))
    
    return {
        'path_diversity_mean': np.mean(path_diversities),
        'path_diversity_std': np.std(path_diversities)
    }

# 3. Visualization Functions
def plot_density_comparison(datasets_analysis: List[Dict], save_path: Optional[str] = None):
    """Plot density comparison across datasets"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    dataset_names = [d['dataset_name'] for d in datasets_analysis]
    
    # Agent density
    densities = [d['agent_density'] for d in datasets_analysis]
    axes[0,0].bar(dataset_names, densities)
    axes[0,0].set_title('Agent Density')
    axes[0,0].set_ylabel('Agents per Grid Cell')
    
    # Average inter-agent distance
    distances = [d['avg_inter_agent_distance'] for d in datasets_analysis]
    axes[0,1].bar(dataset_names, distances)
    axes[0,1].set_title('Average Inter-Agent Distance')
    axes[0,1].set_ylabel('Distance')
    
    # Collision counts
    collisions = [d['avg_collision_count'] for d in datasets_analysis]
    axes[1,0].bar(dataset_names, collisions)
    axes[1,0].set_title('Average Collision Count per Episode')
    axes[1,0].set_ylabel('Collision Count')
    
    # Path lengths
    path_lengths = [d['avg_path_length'] for d in datasets_analysis]
    axes[1,1].bar(dataset_names, path_lengths)
    axes[1,1].set_title('Average Path Length')
    axes[1,1].set_ylabel('Path Length')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_initial_position_distribution(h5_file_path: str, max_episodes: int = 50, save_path: Optional[str] = None):
    """Plot initial position distribution of agents"""
    trajectories, metadata = load_episode_trajectories(h5_file_path, max_episodes)
    
    # Collect all initial positions
    initial_positions = []
    for trajectory in trajectories:
        initial_positions.append(trajectory[0])  # First step positions
    
    initial_positions = np.vstack(initial_positions)  # [total_agents, 2]
    
    plt.figure(figsize=(10, 8))
    plt.scatter(initial_positions[:, 0], initial_positions[:, 1], alpha=0.6, s=20)
    plt.title(f'Initial Position Distribution (n={len(trajectories)} episodes)')
    plt.xlabel('X Position')
    plt.ylabel('Y Position')
    plt.grid(True, alpha=0.3)
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def plot_trajectory_examples(h5_file_path: str, num_episodes: int = 3, save_path: Optional[str] = None):
    """Plot example trajectories from episodes"""
    trajectories, metadata = load_episode_trajectories(h5_file_path, num_episodes)
    
    fig, axes = plt.subplots(1, num_episodes, figsize=(5*num_episodes, 5))
    if num_episodes == 1:
        axes = [axes]
    
    colors = plt.cm.tab10(np.linspace(0, 1, metadata['num_agvs'] + metadata['num_pickers']))
    
    for ep_idx, trajectory in enumerate(trajectories):
        ax = axes[ep_idx]
        
        # Plot each agent's trajectory
        for agent_idx in range(metadata['num_agvs'] + metadata['num_pickers']):
            agent_path = trajectory[:, agent_idx, :]
            
            # Plot trajectory
            ax.plot(agent_path[:, 0], agent_path[:, 1], 
                   color=colors[agent_idx], alpha=0.7, linewidth=2)
            
            # Mark start and end
            ax.scatter(agent_path[0, 0], agent_path[0, 1], 
                      color=colors[agent_idx], marker='o', s=100, edgecolor='black')
            ax.scatter(agent_path[-1, 0], agent_path[-1, 1], 
                      color=colors[agent_idx], marker='s', s=100, edgecolor='black')
        
        ax.set_title(f'Episode {ep_idx + 1}')
        ax.set_xlabel('X Position')
        ax.set_ylabel('Y Position')
        ax.grid(True, alpha=0.3)
        ax.legend([f'Agent {i}' for i in range(metadata['num_agvs'] + metadata['num_pickers'])], 
                 bbox_to_anchor=(1.05, 1), loc='upper left')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# Main analysis function
def analyze_datasets(dataset_paths: Dict[str, str], output_dir: str = "./analysis_results/"):
    """Complete analysis of multiple datasets with enhanced spatial density"""
    import os
    os.makedirs(output_dir, exist_ok=True)
    
    # 1. Density and complexity analysis
    density_results = []
    for name, path in dataset_paths.items():
        result = analyze_density_and_complexity(path, name)
        density_results.append(result)
        print(f"Dataset: {name}")
        print(f"  Traditional Density: {result['agent_density']:.4f}")
        print(f"  Effective Density: {result['effective_density']:.4f}")
        print(f"  Density Concentration: {result['density_concentration']:.4f}")
        print(f"  Max Local Density: {result['max_local_density']:.4f}")
        print(f"  Avg Inter-Agent Distance: {result['avg_inter_agent_distance']:.2f}")
        print(f"  Avg Collision Count: {result['avg_collision_count']:.2f}")
        print(f"  Avg Path Length: {result['avg_path_length']:.2f}")
        print()
    
    # 2. Trajectory characteristics
    traj_chars = []
    for name, path in dataset_paths.items():
        chars = analyze_trajectory_characteristics(path)
        traj_chars.append(chars)
    
    # 3. Generate enhanced visualizations
    plot_density_comparison(density_results, f"{output_dir}/density_comparison.png")
    plot_spatial_density_heatmaps(density_results, f"{output_dir}/spatial_density_heatmaps.png")
    
    # Individual dataset visualizations
    for name, path in dataset_paths.items():
        plot_initial_position_distribution(path, save_path=f"{output_dir}/{name}_initial_positions.png")
        plot_trajectory_examples(path, save_path=f"{output_dir}/{name}_trajectory_examples.png")
        plot_temporal_density_evolution(path, name, save_path=f"{output_dir}/{name}_temporal_density.png")
    
    return density_results, traj_chars