In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np

def visualize_npz_data(npz_path, scene_idx=0, max_agents=1):
    """
    Visualize trajectory data from NPZ file.
    
    Args:
        npz_path: Path to the NPZ file
        scene_idx: Which scene to visualize if multiple scenes exist
        max_agents: Maximum number of agents to visualize (to avoid cluttering)
    """
    # Load the data
    data = np.load(npz_path)
    
    # Create a figure
    plt.figure(figsize=(15, 10))
    
    # Plot for each agent (up to max_agents)
    valid_agents = np.where(data['agent_valid'])[0]
    agents_to_plot = valid_agents[:max_agents]
    
    colors = plt.cm.rainbow(np.linspace(0, 1, len(agents_to_plot)))
    
    for idx, agent_idx in enumerate(agents_to_plot):
        # Get agent dimensions
        width = data['width'][agent_idx]
        length = data['length'][agent_idx]
        
        # Plot history trajectory
        history = data['history/xy'][agent_idx]
        history_valid = data['history/valid'][agent_idx]
        valid_history = history[history_valid == 1]
        
        plt.plot(valid_history[:, 0], valid_history[:, 1], 
                'o-', color=colors[idx], alpha=0.5, 
                label=f'History Agent {agent_idx}')
        
        # Plot future trajectory
        future = data['future/xy'][agent_idx]
        future_valid = data['future/valid'][agent_idx]
        valid_future = future[future_valid == 1]
        
        plt.plot(valid_future[:, 0], valid_future[:, 1], 
                '--', color=colors[idx], alpha=0.5,
                label=f'Future Agent {agent_idx}')
        
        # Plot current position (last point of history) with vehicle rectangle
        current_pos = history[-1]
        current_yaw = data['history/yaw'][agent_idx][-1]
        
        # Create rectangle for vehicle
        rect = Rectangle(
            (current_pos[0] - length/2, current_pos[1] - width/2),
            length, width,
            angle=np.degrees(current_yaw),
            color=colors[idx],
            alpha=0.7
        )
        plt.gca().add_patch(rect)
        
        # Add agent type information
        agent_type = data['agent_type'][agent_idx]
        plt.annotate(f'Type: {agent_type}', 
                    (current_pos[0], current_pos[1]),
                    xytext=(10, 10), 
                    textcoords='offset points')
    
    plt.title('Trajectory Visualization\nDots: History, Dashed: Future, Rectangles: Current Position')
    plt.xlabel('X Position (meters)')
    plt.ylabel('Y Position (meters)')
    plt.axis('equal')
    plt.grid(True)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Print additional information
    print(f"\nScene Information:")
    print(f"Scenario ID: {data['scenario_id']}")
    print(f"Total number of valid agents: {len(valid_agents)}")
    print(f"Time spans:")
    print(f"- History: {data['history/xy'].shape[1]} timesteps")
    print(f"- Future: {data['future/xy'].shape[1]} timesteps")
    
    plt.tight_layout()
    plt.show()

# Add this to your main() function:
def main():
    # ... [previous code] ...
    
    # After processing, visualize the first processed file
    output_path = os.path.join("processed_data", 
                              f"processed_{os.path.basename(tfrecord_paths[0])}.npz")
    print("\nVisualizing processed data...")
    visualize_npz_data(output_path)

In [None]:
visualize_npz_data("processed_data\processed_uncompressed_tf_example_training_training_tfexample.tfrecord-00000-of-01000.npz")

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy as np

def visualize_npz_data(npz_path, scene_idx=0, agent_ids=None, only_sdc=False, show_all=False):
    """
    Visualize trajectory data from NPZ file with enhanced agent selection.
    
    Args:
        npz_path: Path to the NPZ file
        scene_idx: Which scene to visualize if multiple scenes exist
        agent_ids: List of specific agent IDs to visualize (optional)
        only_sdc: If True, only show the self-driving car trajectory
        show_all: If True, show all valid agents (might be cluttered)
    """
    # Load the data
    data = np.load(npz_path)
    
    # Create a figure
    plt.figure(figsize=(15, 10))
    
    # Find the SDC (main agent)
    # Since we don't have direct SDC flag, we'll use tracks_to_predict or first valid agent
    valid_agents = np.where(data['agent_valid'])[0]
    sdc_idx = valid_agents[0] if len(valid_agents) > 0 else None
    
    # Determine which agents to plot
    if only_sdc and sdc_idx is not None:
        agents_to_plot = [sdc_idx]
    elif agent_ids is not None:
        agents_to_plot = [idx for idx in agent_ids if idx in valid_agents]
    elif show_all:
        agents_to_plot = valid_agents
    else:
        # Default: show first valid agent and 4 closest agents
        if sdc_idx is not None:
            current_sdc_pos = data['history/xy'][sdc_idx, -1]
            distances = []
            for idx in valid_agents:
                if idx != sdc_idx:
                    pos = data['history/xy'][idx, -1]
                    dist = np.linalg.norm(pos - current_sdc_pos)
                    distances.append((idx, dist))
            distances.sort(key=lambda x: x[1])
            closest_agents = [d[0] for d in distances[:4]]
            agents_to_plot = [sdc_idx] + closest_agents
        else:
            agents_to_plot = valid_agents[:5]
    
    colors = plt.cm.rainbow(np.linspace(0, 1, len(agents_to_plot)))
    
    # Plot each selected agent
    for idx, agent_idx in enumerate(agents_to_plot):
        # Get agent dimensions
        width = data['width'][agent_idx]
        length = data['length'][agent_idx]
        
        # Determine if this is the main agent (first valid agent)
        is_main = agent_idx == sdc_idx
        
        # Set line properties based on agent type
        line_width = 2 if is_main else 1
        alpha = 1.0 if is_main else 0.5
        
        # Plot history trajectory
        history = data['history/xy'][agent_idx]
        history_valid = data['history/valid'][agent_idx]
        valid_history = history[history_valid == 1]
        
        label_prefix = 'Main Agent' if is_main else f'Agent {agent_idx}'
        plt.plot(valid_history[:, 0], valid_history[:, 1], 
                'o-', color=colors[idx], alpha=alpha, linewidth=line_width,
                label=f'{label_prefix} History')
        
        # Plot future trajectory
        future = data['future/xy'][agent_idx]
        future_valid = data['future/valid'][agent_idx]
        valid_future = future[future_valid == 1]
        
        plt.plot(valid_future[:, 0], valid_future[:, 1], 
                '--', color=colors[idx], alpha=alpha, linewidth=line_width,
                label=f'{label_prefix} Future')
        
        # Plot current position with vehicle rectangle
        current_pos = history[-1]
        current_yaw = data['history/yaw'][agent_idx][-1]
        
        # Create rectangle for vehicle
        rect = Rectangle(
            (current_pos[0] - length/2, current_pos[1] - width/2),
            length, width,
            angle=np.degrees(current_yaw),
            color=colors[idx],
            alpha=0.7,
            linewidth=2 if is_main else 1
        )
        plt.gca().add_patch(rect)
        
        # Add agent information
        agent_type = data['agent_type'][agent_idx]
        label_text = f'{"Main Agent" if is_main else f"Agent {agent_idx}"}\nType: {agent_type}'
        plt.annotate(label_text, 
                    (current_pos[0], current_pos[1]),
                    xytext=(10, 10), 
                    textcoords='offset points',
                    bbox=dict(facecolor='white', alpha=0.7))
    
    plt.title('Trajectory Visualization\nDots: History, Dashed: Future, Rectangles: Current Position')
    plt.xlabel('X Position (meters)')
    plt.ylabel('Y Position (meters)')
    plt.axis('equal')
    plt.grid(True)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Print additional information
    print(f"\nScene Information:")
    print(f"Scenario ID: {data['scenario_id']}")
    print(f"Total number of valid agents: {len(valid_agents)}")
    print(f"Time spans:")
    print(f"- History: {data['history/xy'].shape[1]} timesteps (~{data['history/xy'].shape[1]/10:.1f} seconds)")
    print(f"- Future: {data['future/xy'].shape[1]} timesteps (~{data['future/xy'].shape[1]/10:.1f} seconds)")
    
    plt.tight_layout()
    plt.show()

In [None]:
visualize_npz_data("processed_data/processed_uncompressed_tf_example_training_training_tfexample.tfrecord-00001-of-01000.npz" , scene_idx=True, only_sdc=True, show_all = True)

In [None]:
visualize_npz_data("processed_data/processed_uncompressed_tf_example_training_training_tfexample.tfrecord-00003-of-01000.npz" , only_sdc=True)