# Generate Overlap Test Cases

This notebook generates test cases for overlap computation with visualizations.

Each scene contains:
- 6 joint predictions
- 3 modeled agents
- 5 other agents

We create 8 different test cases covering various collision scenarios.


In [16]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from pathlib import Path
import math
from IPython.display import HTML

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)


In [17]:
# Configuration
NUM_JOINT_PREDICTIONS = 6  # M
NUM_MODELED_AGENTS = 3     # N (agents in joint prediction)
NUM_OTHER_AGENTS = 5       # Other agents in scene
NUM_TIMESTEPS = 80         # T (8 seconds at 10Hz)
DT = 0.1                   # Time step (10Hz)

# Output directory
output_dir = Path("../test_samples/overlap")
output_dir.mkdir(parents=True, exist_ok=True)

print(f"Output directory: {output_dir.absolute()}")


Output directory: /home/leo/Projects/womd-torch-metrics/notebooks/../test_samples/overlap


## Helper Functions


In [18]:
def create_box_polygon(center, length, width, heading):
    """Create a box polygon from center, dimensions, and heading."""
    cos_h = math.cos(heading)
    sin_h = math.sin(heading)
    
    # Box corners in local frame
    corners_local = np.array([
        [-length/2, -width/2],
        [length/2, -width/2],
        [length/2, width/2],
        [-length/2, width/2]
    ])
    
    # Rotate and translate
    rotation = np.array([[cos_h, -sin_h], [sin_h, cos_h]])
    corners_global = corners_local @ rotation.T + center
    
    return corners_global


def generate_straight_trajectory(start_pos, velocity, num_steps, dt=0.1):
    """Generate a straight trajectory."""
    trajectory = torch.zeros(num_steps, 2)
    for t in range(num_steps):
        trajectory[t] = start_pos + velocity * t * dt
    return trajectory


def generate_turning_trajectory(start_pos, start_heading, angular_velocity, 
                                linear_velocity, num_steps, dt=0.1):
    """Generate a turning trajectory."""
    trajectory = torch.zeros(num_steps, 2)
    pos = start_pos.clone()
    heading = start_heading
    
    for t in range(num_steps):
        trajectory[t] = pos.clone()
        # Update position based on heading
        pos += torch.tensor([math.cos(heading), math.sin(heading)]) * linear_velocity * dt
        heading += angular_velocity * dt
    
    return trajectory


In [19]:
def visualize_scene(pred_trajectories, gt_trajectory, gt_boxes, pred_scores, 
                    scene_name, output_path=None, show_animation=True):
    """
    Visualize predictions and ground truth as an animated GIF.
    
    Args:
        pred_trajectories: [M, K, N, T, 2] predictions
        gt_trajectory: [T, 2] ground truth
        gt_boxes: [T, 4] boxes [length, width, heading, velocity_heading]
        pred_scores: [M, K] prediction scores
        scene_name: Name of the scene
        output_path: Path to save GIF (optional)
        show_animation: Whether to display in notebook
    """
    M, K, N, T, _ = pred_trajectories.shape
    
    # Set up the figure
    fig, ax = plt.subplots(figsize=(12, 12))
    
    # Determine plot bounds
    all_positions = torch.cat([
        pred_trajectories.reshape(-1, T, 2),
        gt_trajectory.unsqueeze(0)
    ], dim=0)
    
    x_min = all_positions[:, :, 0].min().item() - 5
    x_max = all_positions[:, :, 0].max().item() + 5
    y_min = all_positions[:, :, 1].min().item() - 5
    y_max = all_positions[:, :, 1].max().item() + 5
    
    def animate(frame):
        ax.clear()
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        ax.set_aspect('equal')
        ax.grid(True, alpha=0.3)
        ax.set_title(f"{scene_name} - Frame {frame}/{T-1}", fontsize=14)
        
        # Draw ground truth
        gt_pos = gt_trajectory[frame].numpy()
        gt_box = gt_boxes[frame].numpy()
        
        # Draw GT trajectory up to current frame
        if frame > 0:
            gt_path = gt_trajectory[:frame+1].numpy()
            ax.plot(gt_path[:, 0], gt_path[:, 1], 'g-', linewidth=3, 
                   label='Ground Truth', alpha=0.7)
        
        # Draw GT box
        gt_poly = create_box_polygon(gt_pos, gt_box[0], gt_box[1], gt_box[2])
        gt_poly_closed = np.vstack([gt_poly, gt_poly[0]])
        ax.fill(gt_poly_closed[:, 0], gt_poly_closed[:, 1], 
               color='green', alpha=0.3, edgecolor='green', linewidth=2)
        ax.plot(gt_pos[0], gt_pos[1], 'go', markersize=10, label='GT Center')
        
        # Draw predictions for each joint prediction group
        colors = plt.cm.tab10(np.linspace(0, 1, M))
        
        for m in range(M):
            for k in range(K):
                # Get trajectory for first modeled agent (N=0)
                pred_traj = pred_trajectories[m, k, 0, :frame+1, :].numpy()
                
                if len(pred_traj) > 0:
                    # Draw trajectory
                    alpha = pred_scores[m, k].item()
                    linewidth = 1 + alpha * 2
                    
                    ax.plot(pred_traj[:, 0], pred_traj[:, 1], 
                           color=colors[m], alpha=alpha*0.6, linewidth=linewidth)
                    
                    # Draw current position
                    if frame < T:
                        curr_pos = pred_trajectories[m, k, 0, frame, :].numpy()
                        ax.plot(curr_pos[0], curr_pos[1], 'o', 
                               color=colors[m], alpha=alpha, markersize=4)
        
        # Draw other modeled agents (N=1, 2)
        for n in range(1, N):
            for m in range(M):
                for k in range(K):
                    pred_traj = pred_trajectories[m, k, n, :frame+1, :].numpy()
                    if len(pred_traj) > 0:
                        alpha = pred_scores[m, k].item() * 0.5
                        ax.plot(pred_traj[:, 0], pred_traj[:, 1], 
                               '--', color=colors[m], alpha=alpha*0.4, linewidth=1)
        
        ax.legend(loc='upper right', fontsize=10)
        
        return ax
    
    # Create animation
    anim = animation.FuncAnimation(fig, animate, frames=T, interval=100, blit=False, repeat=True)
    
    # Save if path provided
    if output_path:
        print(f"Saving animation to {output_path}...")
        anim.save(output_path, writer='pillow', fps=10)
        print(f"✓ Saved: {output_path}")
    
    if show_animation:
        plt.close()
        return HTML(anim.to_jshtml())
    else:
        plt.close()
        return anim


## Test Case Generation Functions

Realistic traffic scenarios:
1. **Intersection collision** - Vehicle fails to yield at intersection
2. **Lane change scenarios** - Safe vs aggressive lane changes
3. **Highway cruising** - Stable lane keeping
4. **Merge conflicts** - Highway merging with varying confidence
5. **Following scenarios** - Car following with different gap predictions
6. **Turn predictions** - Left/right turns at intersections


In [20]:
def create_test_case_1_all_collisions():
    """Case 1: Intersection - All predictions fail to yield (T-bone collision)."""
    T = NUM_TIMESTEPS
    
    # Ground truth: Vehicle going straight through intersection at 10 m/s (36 km/h)
    gt_trajectory = generate_straight_trajectory(
        torch.tensor([-20.0, 0.0]), torch.tensor([10.0, 0.0]), T, dt=DT
    )
    gt_boxes = torch.ones(T, 4)
    gt_boxes[:, 0] = 4.5  # length (sedan)
    gt_boxes[:, 1] = 2.0  # width
    gt_boxes[:, 2] = 0.0  # heading (east)
    
    # All predictions: Cross traffic from perpendicular road (north to south)
    # All fail to stop at intersection - realistic but dangerous predictions
    pred_trajectories = torch.zeros(NUM_JOINT_PREDICTIONS, 1, NUM_MODELED_AGENTS, T, 2)
    pred_scores = torch.ones(NUM_JOINT_PREDICTIONS, 1) / NUM_JOINT_PREDICTIONS
    
    for m in range(NUM_JOINT_PREDICTIONS):
        # Start north of intersection, different distances
        start_y = 15.0 + m * 2.0
        start_x = 0.0
        
        # All predictions cross intersection without stopping (8-12 m/s)
        speed = 8.0 + m * 0.5
        
        for n in range(NUM_MODELED_AGENTS):
            # Different lanes (parallel vehicles)
            offset_x = n * 3.5  # lane width
            pred_trajectories[m, 0, n, :, :] = generate_straight_trajectory(
                torch.tensor([start_x + offset_x, start_y]), 
                torch.tensor([0.0, -speed]), 
                T, dt=DT
            )
    
    return pred_trajectories, gt_trajectory, gt_boxes, pred_scores

In [21]:
def create_test_case_2_partial_collisions():
    """Case 2: Lane change - Some predictions cut in aggressively, others are safe."""
    T = NUM_TIMESTEPS
    
    # Ground truth: Vehicle in center lane, constant speed 15 m/s (54 km/h)
    gt_trajectory = generate_straight_trajectory(
        torch.tensor([0.0, 0.0]), torch.tensor([15.0, 0.0]), T, dt=DT
    )
    gt_boxes = torch.ones(T, 4)
    gt_boxes[:, 0] = 4.5
    gt_boxes[:, 1] = 2.0
    gt_boxes[:, 2] = 0.0
    
    pred_trajectories = torch.zeros(NUM_JOINT_PREDICTIONS, 1, NUM_MODELED_AGENTS, T, 2)
    pred_scores = torch.ones(NUM_JOINT_PREDICTIONS, 1) / NUM_JOINT_PREDICTIONS
    
    # First 3 predictions: Aggressive lane change from left lane (collision)
    # Last 3 predictions: Safe lane keeping or gentle lane change (no collision)
    for m in range(NUM_JOINT_PREDICTIONS):
        if m < 3:
            # Aggressive lane change: start from left lane, merge too early
            start_pos = torch.tensor([-5.0, 3.5])  # Left lane, behind GT
            
            for n in range(NUM_MODELED_AGENTS):
                traj = torch.zeros(T, 2)
                for t in range(T):
                    # Fast merge into GT lane
                    progress = min(t / 30.0, 1.0)  # Complete merge in 3 seconds
                    x = -5.0 + 16.0 * t * DT  # Speed: 16 m/s (faster than GT)
                    y = 3.5 * (1 - progress)  # Lateral movement
                    traj[t] = torch.tensor([x, y + n * 3.5])
                pred_trajectories[m, 0, n, :, :] = traj
        else:
            # Safe prediction: maintain lane or safe merge
            if m == 3:
                # Stay in right lane
                pred_trajectories[m, 0, 0, :, :] = generate_straight_trajectory(
                    torch.tensor([0.0, -3.5]), torch.tensor([15.0, 0.0]), T, dt=DT
                )
            else:
                # Safe lane change from right lane with proper spacing
                start_pos = torch.tensor([30.0, -3.5])  # Far ahead in right lane
                for n in range(NUM_MODELED_AGENTS):
                    traj = torch.zeros(T, 2)
                    for t in range(T):
                        progress = min(t / 50.0, 1.0)  # Slow merge over 5 seconds
                        x = 30.0 + 15.0 * t * DT
                        y = -3.5 * (1 - progress)
                        traj[t] = torch.tensor([x, y + n * 3.5])
                    pred_trajectories[m, 0, n, :, :] = traj
    
    return pred_trajectories, gt_trajectory, gt_boxes, pred_scores

In [22]:
def create_test_case_3_no_collisions():
    """Case 3: Highway cruising - All predictions maintain safe distances."""
    T = NUM_TIMESTEPS
    
    # Ground truth: Vehicle cruising in center lane at 25 m/s (90 km/h)
    gt_trajectory = generate_straight_trajectory(
        torch.tensor([0.0, 0.0]), torch.tensor([25.0, 0.0]), T, dt=DT
    )
    gt_boxes = torch.ones(T, 4)
    gt_boxes[:, 0] = 4.5
    gt_boxes[:, 1] = 2.0
    gt_boxes[:, 2] = 0.0
    
    pred_trajectories = torch.zeros(NUM_JOINT_PREDICTIONS, 1, NUM_MODELED_AGENTS, T, 2)
    pred_scores = torch.ones(NUM_JOINT_PREDICTIONS, 1) / NUM_JOINT_PREDICTIONS
    
    # All predictions show safe scenarios:
    # - Different lanes with safe spacing
    # - Varying speeds but no collision
    for m in range(NUM_JOINT_PREDICTIONS):
        if m < 2:
            # Predictions in left lane, ahead
            speed = 27.0 + m * 1.0  # Slightly faster
            start_x = 20.0 + m * 15.0  # Well ahead
            pred_trajectories[m, 0, 0, :, :] = generate_straight_trajectory(
                torch.tensor([start_x, 3.5]), torch.tensor([speed, 0.0]), T, dt=DT
            )
        elif m < 4:
            # Predictions in right lane
            speed = 23.0 + (m-2) * 1.0  # Slightly slower
            start_x = -10.0 - (m-2) * 15.0  # Behind
            pred_trajectories[m, 0, 0, :, :] = generate_straight_trajectory(
                torch.tensor([start_x, -3.5]), torch.tensor([speed, 0.0]), T, dt=DT
            )
        else:
            # Predictions in same lane but safe distance
            if m == 4:
                # Far ahead
                pred_trajectories[m, 0, 0, :, :] = generate_straight_trajectory(
                    torch.tensor([50.0, 0.0]), torch.tensor([25.0, 0.0]), T, dt=DT
                )
            else:
                # Far behind
                pred_trajectories[m, 0, 0, :, :] = generate_straight_trajectory(
                    torch.tensor([-50.0, 0.0]), torch.tensor([25.0, 0.0]), T, dt=DT
                )
        
        # Add parallel vehicles in other lanes for multi-agent
        for n in range(1, NUM_MODELED_AGENTS):
            lane_offset = (n - 1) * 3.5 - 3.5  # Distribute across lanes
            base_traj = pred_trajectories[m, 0, 0, :, :].clone()
            pred_trajectories[m, 0, n, :, 0] = base_traj[:, 0] - n * 8.0  # Staggered
            pred_trajectories[m, 0, n, :, 1] = lane_offset
    
    return pred_trajectories, gt_trajectory, gt_boxes, pred_scores

In [23]:
def create_test_case_4_highest_confident_collision():
    """Case 4: Merge conflict - Highest confidence prediction fails to yield during merge."""
    T = NUM_TIMESTEPS
    
    # Ground truth: Vehicle on highway mainline at 28 m/s (100 km/h)
    gt_trajectory = generate_straight_trajectory(
        torch.tensor([0.0, 0.0]), torch.tensor([28.0, 0.0]), T, dt=DT
    )
    gt_boxes = torch.ones(T, 4)
    gt_boxes[:, 0] = 4.5
    gt_boxes[:, 1] = 2.0
    gt_boxes[:, 2] = 0.0
    
    pred_trajectories = torch.zeros(NUM_JOINT_PREDICTIONS, 1, NUM_MODELED_AGENTS, T, 2)
    
    # Highest confidence: aggressive merge from on-ramp (collision)
    # Lower confidence: various safe alternatives
    pred_scores = torch.tensor([[0.6], [0.2], [0.1], [0.05], [0.03], [0.02]])
    
    for m in range(NUM_JOINT_PREDICTIONS):
        if m == 0:
            # Highest confidence: Aggressive merge from on-ramp
            for n in range(NUM_MODELED_AGENTS):
                traj = torch.zeros(T, 2)
                for t in range(T):
                    # Start from on-ramp (right side), accelerate and merge
                    progress = min(t / 40.0, 1.0)  # 4 second merge
                    x = -10.0 + 26.0 * t * DT  # Accelerating (26 m/s)
                    y = -5.0 + 5.0 * progress  # Merge from right
                    traj[t] = torch.tensor([x, y - n * 3.5])
                pred_trajectories[m, 0, n, :, :] = traj
        else:
            # Lower confidence: Safe alternatives
            if m == 1:
                # Stay in right lane
                pred_trajectories[m, 0, 0, :, :] = generate_straight_trajectory(
                    torch.tensor([0.0, -3.5]), torch.tensor([28.0, 0.0]), T, dt=DT
                )
            elif m == 2:
                # Merge but with safe gap (further ahead)
                for n in range(NUM_MODELED_AGENTS):
                    traj = torch.zeros(T, 2)
                    for t in range(T):
                        progress = min(t / 50.0, 1.0)
                        x = 40.0 + 26.0 * t * DT  # Well ahead
                        y = -5.0 + 5.0 * progress
                        traj[t] = torch.tensor([x, y - n * 3.5])
                    pred_trajectories[m, 0, n, :, :] = traj
            else:
                # Stay on on-ramp or in other lanes
                lane_y = -3.5 * (m - 2)
                pred_trajectories[m, 0, 0, :, :] = generate_straight_trajectory(
                    torch.tensor([0.0, lane_y]), torch.tensor([28.0, 0.0]), T, dt=DT
                )
    
    return pred_trajectories, gt_trajectory, gt_boxes, pred_scores

In [24]:
def create_test_case_5_lowest_confident_collision():
    """Case 5: Uncertain lane change - Lowest confidence prediction is risky cut-in."""
    T = NUM_TIMESTEPS
    
    # Ground truth: Vehicle in center lane at 20 m/s (72 km/h)
    gt_trajectory = generate_straight_trajectory(
        torch.tensor([0.0, 0.0]), torch.tensor([20.0, 0.0]), T, dt=DT
    )
    gt_boxes = torch.ones(T, 4)
    gt_boxes[:, 0] = 4.5
    gt_boxes[:, 1] = 2.0
    gt_boxes[:, 2] = 0.0
    
    pred_trajectories = torch.zeros(NUM_JOINT_PREDICTIONS, 1, NUM_MODELED_AGENTS, T, 2)
    
    # Last prediction (lowest confidence): Risky cut-in from right
    # Higher confidence: Safe lane keeping or normal behavior
    pred_scores = torch.tensor([[0.4], [0.3], [0.15], [0.1], [0.04], [0.01]])
    
    for m in range(NUM_JOINT_PREDICTIONS):
        if m == NUM_JOINT_PREDICTIONS - 1:
            # Lowest confidence: Sudden cut-in from slower vehicle in right lane
            for n in range(NUM_MODELED_AGENTS):
                traj = torch.zeros(T, 2)
                for t in range(T):
                    # Start behind in right lane, suddenly cut in
                    if t < 20:
                        # Initially in right lane, slower
                        x = -15.0 + 17.0 * t * DT
                        y = -3.5
                    else:
                        # Sudden lane change
                        progress = min((t - 20) / 15.0, 1.0)  # 1.5s cut-in
                        x = -15.0 + 17.0 * t * DT
                        y = -3.5 + 3.5 * progress
                    traj[t] = torch.tensor([x, y - n * 3.5])
                pred_trajectories[m, 0, n, :, :] = traj
        else:
            # Higher confidence: Normal lane keeping
            if m < 3:
                # Continue in current lane
                pred_trajectories[m, 0, 0, :, :] = generate_straight_trajectory(
                    torch.tensor([0.0, 0.0]), torch.tensor([20.0, 0.0]), T, dt=DT
                )
            else:
                # Gentle lane change to left with safe spacing
                for n in range(NUM_MODELED_AGENTS):
                    traj = torch.zeros(T, 2)
                    for t in range(T):
                        progress = min(t / 60.0, 1.0)  # 6s gentle change
                        x = 0.0 + 20.0 * t * DT
                        y = 3.5 * progress
                        traj[t] = torch.tensor([x, y + n * 3.5])
                    pred_trajectories[m, 0, n, :, :] = traj
    
    return pred_trajectories, gt_trajectory, gt_boxes, pred_scores

In [25]:
def create_test_case_6_middle_confident_collision():
    """Case 6: Following scenario - Middle confidence prediction follows too closely."""
    T = NUM_TIMESTEPS
    
    # Ground truth: Lead vehicle with gradual deceleration
    gt_trajectory = torch.zeros(T, 2)
    velocity = 22.0  # Start at 22 m/s (79 km/h)
    pos = torch.tensor([0.0, 0.0])
    
    for t in range(T):
        gt_trajectory[t] = pos.clone()
        # Gradual deceleration after t=30
        if t > 30:
            decel = 0.3  # m/s^2
            velocity = max(velocity - decel * DT, 15.0)
        pos += torch.tensor([velocity * DT, 0.0])
    
    gt_boxes = torch.ones(T, 4)
    gt_boxes[:, 0] = 4.5
    gt_boxes[:, 1] = 2.0
    gt_boxes[:, 2] = 0.0
    
    pred_trajectories = torch.zeros(NUM_JOINT_PREDICTIONS, 1, NUM_MODELED_AGENTS, T, 2)
    
    # Middle prediction: Following too closely, doesn't brake in time
    pred_scores = torch.tensor([[0.3], [0.25], [0.2], [0.15], [0.07], [0.03]])
    
    mid_idx = NUM_JOINT_PREDICTIONS // 2
    
    for m in range(NUM_JOINT_PREDICTIONS):
        if m == mid_idx:
            # Middle confidence: Too close following, collision
            for n in range(NUM_MODELED_AGENTS):
                traj = torch.zeros(T, 2)
                gap = 8.0 - n * 2.0  # Start 8m behind (too close)
                velocity_ego = 22.0
                pos_ego = torch.tensor([-gap, n * 3.5])
                
                for t in range(T):
                    traj[t] = pos_ego.clone()
                    # React slowly to lead vehicle braking
                    if t > 40:  # 1 second reaction delay
                        decel = 0.25  # Brake less aggressively
                        velocity_ego = max(velocity_ego - decel * DT, 15.0)
                    pos_ego += torch.tensor([velocity_ego * DT, 0.0])
                
                pred_trajectories[m, 0, n, :, :] = traj
        else:
            # Other predictions: Safe following distances
            if m < mid_idx:
                # Conservative: Large gap, early braking
                gap = 20.0 + m * 5.0
                for n in range(NUM_MODELED_AGENTS):
                    traj = torch.zeros(T, 2)
                    velocity_ego = 22.0
                    pos_ego = torch.tensor([-gap, n * 3.5])
                    
                    for t in range(T):
                        traj[t] = pos_ego.clone()
                        if t > 35:  # Earlier braking
                            decel = 0.35
                            velocity_ego = max(velocity_ego - decel * DT, 15.0)
                        pos_ego += torch.tensor([velocity_ego * DT, 0.0])
                    
                    pred_trajectories[m, 0, n, :, :] = traj
            else:
                # Lane change to pass
                for n in range(NUM_MODELED_AGENTS):
                    traj = torch.zeros(T, 2)
                    gap = 15.0
                    velocity_ego = 22.0
                    pos_ego = torch.tensor([-gap, n * 3.5])
                    
                    for t in range(T):
                        traj[t] = pos_ego.clone()
                        # Change to left lane and maintain speed
                        if 20 < t < 40:
                            progress = (t - 20) / 20.0
                            pos_ego[1] = (n * 3.5) + 3.5 * progress
                        elif t >= 40:
                            pos_ego[1] = n * 3.5 + 3.5
                        pos_ego += torch.tensor([velocity_ego * DT, 0.0])
                    
                    pred_trajectories[m, 0, n, :, :] = traj
    
    return pred_trajectories, gt_trajectory, gt_boxes, pred_scores

In [26]:
def create_test_case_7_early_collision():
    """Case 7: Intersection - Early collision due to running red light."""
    T = NUM_TIMESTEPS
    
    # Ground truth: Vehicle approaching intersection, stops for red light
    gt_trajectory = torch.zeros(T, 2)
    velocity = 15.0  # 54 km/h
    pos = torch.tensor([0.0, 0.0])
    stop_position = 40.0  # Stop line at intersection
    
    for t in range(T):
        gt_trajectory[t] = pos.clone()
        
        # Decelerate when approaching intersection
        if pos[0] > 25.0 and pos[0] < stop_position:
            # Braking phase
            decel = 2.5  # m/s^2
            velocity = max(velocity - decel * DT, 0.0)
        elif pos[0] >= stop_position:
            # Stopped at intersection
            velocity = 0.0
        
        pos += torch.tensor([velocity * DT, 0.0])
    
    gt_boxes = torch.ones(T, 4)
    gt_boxes[:, 0] = 4.5
    gt_boxes[:, 1] = 2.0
    gt_boxes[:, 2] = 0.0
    
    pred_trajectories = torch.zeros(NUM_JOINT_PREDICTIONS, 1, NUM_MODELED_AGENTS, T, 2)
    pred_scores = torch.ones(NUM_JOINT_PREDICTIONS, 1) / NUM_JOINT_PREDICTIONS
    
    # All predictions: Cross traffic runs red light (early collision at t=20-30)
    for m in range(NUM_JOINT_PREDICTIONS):
        # Different starting positions, all run through intersection
        start_y = 30.0 + m * 3.0
        speed = 12.0 + m * 1.0  # 43-61 km/h
        
        for n in range(NUM_MODELED_AGENTS):
            # Perpendicular approach from north
            pred_trajectories[m, 0, n, :, :] = generate_straight_trajectory(
                torch.tensor([40.0 + n * 3.5, start_y]),  # Intersection center at x=40
                torch.tensor([0.0, -speed]),
                T, dt=DT
            )
    
    return pred_trajectories, gt_trajectory, gt_boxes, pred_scores

In [27]:
def create_test_case_8_late_collision():
    """Case 8: Slow convergence - Late collision from gradual lane drift."""
    T = NUM_TIMESTEPS
    
    # Ground truth: Vehicle maintaining lane at steady speed
    gt_trajectory = generate_straight_trajectory(
        torch.tensor([0.0, 0.0]), torch.tensor([18.0, 0.0]), T, dt=DT
    )
    gt_boxes = torch.ones(T, 4)
    gt_boxes[:, 0] = 4.5
    gt_boxes[:, 1] = 2.0
    gt_boxes[:, 2] = 0.0
    
    pred_trajectories = torch.zeros(NUM_JOINT_PREDICTIONS, 1, NUM_MODELED_AGENTS, T, 2)
    pred_scores = torch.ones(NUM_JOINT_PREDICTIONS, 1) / NUM_JOINT_PREDICTIONS
    
    # All predictions: Vehicles in adjacent lane gradually drift (late collision t=60-70)
    for m in range(NUM_JOINT_PREDICTIONS):
        # Start in left or right lane
        start_y = 3.5 if m < 3 else -3.5
        start_x = -5.0 - m * 3.0  # Slightly behind, staggered
        
        for n in range(NUM_MODELED_AGENTS):
            traj = torch.zeros(T, 2)
            velocity = 18.5 + m * 0.3  # Slightly faster than GT
            pos = torch.tensor([start_x - n * 5.0, start_y])
            
            for t in range(T):
                traj[t] = pos.clone()
                
                # Gradual lane drift starting at t=40, completing at t=70
                if t > 40 and t < 70:
                    drift_rate = start_y / (70 - 40)  # Constant drift rate
                    pos[1] -= drift_rate * DT
                elif t >= 70:
                    pos[1] = 0.0  # Fully drifted into GT lane
                
                pos += torch.tensor([velocity * DT, 0.0])
            
            pred_trajectories[m, 0, n, :, :] = traj
    
    return pred_trajectories, gt_trajectory, gt_boxes, pred_scores

## Generate and Save Test Cases


In [28]:
# Define test cases with realistic traffic scenarios
test_cases = [
    ("case_01_all_collisions", create_test_case_1_all_collisions, 
     "Intersection: All predictions fail to yield (T-bone)"),
    ("case_02_partial_collisions", create_test_case_2_partial_collisions, 
     "Lane change: Aggressive vs safe merge behavior"),
    ("case_03_no_collisions", create_test_case_3_no_collisions, 
     "Highway cruising: Safe lane keeping"),
    ("case_04_highest_confident_collision", create_test_case_4_highest_confident_collision, 
     "Merge conflict: High-confidence aggressive merge"),
    ("case_05_lowest_confident_collision", create_test_case_5_lowest_confident_collision, 
     "Uncertain lane change: Low-confidence risky cut-in"),
    ("case_06_middle_confident_collision", create_test_case_6_middle_confident_collision, 
     "Following: Medium-confidence tailgating collision"),
    ("case_07_early_collision", create_test_case_7_early_collision, 
     "Intersection: Early collision from red light violation"),
    ("case_08_late_collision", create_test_case_8_late_collision, 
     "Lane drift: Late collision from gradual drift"),
]

print(f"Generating {len(test_cases)} realistic traffic test cases...")

Generating 8 realistic traffic test cases...


In [29]:
# Import overlap computation to compute expected values
import sys
sys.path.insert(0, str(Path("../").absolute()))
from womd_torch_metrics import compute_overlap_rate

# Generate and save each test case
for case_name, create_func, description in test_cases:
    print(f"\n{'='*60}")
    print(f"Generating: {case_name}")
    print(f"Description: {description}")
    
    # Generate test data
    pred_trajectories, gt_trajectory, gt_boxes, pred_scores = create_func()
    
    # For overlap computation, we need to extract predictions for the first modeled agent
    # Format: [M*K, T, 2] where M is joint predictions, K is top-K (here K=1)
    M, K, N, T, _ = pred_trajectories.shape
    
    # Extract first agent (n=0) from each joint prediction
    # Flatten to [M*K, T, 2] for overlap computation
    pred_traj_flat = pred_trajectories[:, :, 0, :, :].reshape(M * K, T, 2)
    
    # Compute expected overlap rate
    expected_overlap_rate = compute_overlap_rate(
        pred_traj_flat,
        gt_trajectory,
        gt_boxes,
        threshold=0.5
    ).item()
    
    print(f"  Computed overlap rate: {expected_overlap_rate:.6f}")
    
    # Prepare test data dictionary
    test_data = {
        "prediction_trajectories": pred_traj_flat,  # [M*K, T, 2]
        "ground_truth_trajectory": gt_trajectory,  # [T, 2]
        "ground_truth_boxes": gt_boxes,  # [T, 4]
        "threshold": 0.5,
        "expected_overlap_rate": expected_overlap_rate,
        "tolerance": 1e-5,
        "description": description,
    }
    
    # Save test data
    pth_path = output_dir / f"{case_name}.pth"
    torch.save(test_data, pth_path)
    print(f"✓ Saved: {pth_path}")
    
    # Create visualization
    gif_path = output_dir / f"{case_name}.gif"
    visualize_scene(
        pred_trajectories, gt_trajectory, gt_boxes, pred_scores,
        scene_name=description,
        output_path=gif_path,
        show_animation=False
    )
    
    print(f"✓ Visualization: {gif_path}")

print(f"\n{'='*60}")
print(f"✓ Generated {len(test_cases)} test cases")
print(f"Output directory: {output_dir.absolute()}")



Generating: case_01_all_collisions
Description: Intersection: All predictions fail to yield (T-bone)
  Computed overlap rate: 0.166667
✓ Saved: ../test_samples/overlap/case_01_all_collisions.pth
Saving animation to ../test_samples/overlap/case_01_all_collisions.gif...
✓ Saved: ../test_samples/overlap/case_01_all_collisions.gif
✓ Visualization: ../test_samples/overlap/case_01_all_collisions.gif

Generating: case_02_partial_collisions
Description: Lane change: Aggressive vs safe merge behavior
  Computed overlap rate: 0.500000
✓ Saved: ../test_samples/overlap/case_02_partial_collisions.pth
Saving animation to ../test_samples/overlap/case_02_partial_collisions.gif...
✓ Saved: ../test_samples/overlap/case_02_partial_collisions.gif
✓ Visualization: ../test_samples/overlap/case_02_partial_collisions.gif

Generating: case_03_no_collisions
Description: Highway cruising: Safe lane keeping
  Computed overlap rate: 0.000000
✓ Saved: ../test_samples/overlap/case_03_no_collisions.pth
Saving animat

## Visualize Test Cases

Display visualizations for a few test cases as examples:


In [15]:
# # Display visualizations for first 2 test cases as examples
# for case_name, create_func, description in test_cases[:2]:
#     print(f"\n{description}")
#     pred_trajectories, gt_trajectory, gt_boxes, pred_scores = create_func()
    
#     display(visualize_scene(
#         pred_trajectories, gt_trajectory, gt_boxes, pred_scores,
#         scene_name=description,
#         output_path=None,
#         show_animation=True
#     ))
