In [None]:
import os
import sys
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
from typing import Any, Tuple, List, Optional
from hydra import compose, initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate

# --- USER CONFIGURATION START ---
# UPDATE THESE PATHS BEFORE RUNNING
REPO_ROOT = Path("/path/to/your/drivor_repo")  # e.g., /home/user/workspace/drivor
DATASET_ROOT = Path("/path/to/datasets")       # e.g., /datasets_local/navsim_workspace

# Set specific dataset paths
os.environ["NUPLAN_MAPS_ROOT"] = str(DATASET_ROOT / "dataset/maps")
os.environ["NUPLAN_MAP_VERSION"] = "nuplan-maps-v1.0"
os.environ["OPENSCENE_DATA_ROOT"] = str(DATASET_ROOT / "dataset/openscene-v1.1")
os.environ["NAVSIM_EXP_ROOT"] = str(DATASET_ROOT / "exp")
os.environ["NAVSIM_DEVKIT_ROOT"] = str(REPO_ROOT)

# Model Checkpoint
CHECKPOINT_PATH = REPO_ROOT / "your ckpt"
EXPERIMENT_NAME = "attention_map_viz"
AGENT_NAME = "drivoR"
# --- USER CONFIGURATION END ---

# Ensure repo is in python path
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))

# Navsim Imports (must happen after sys.path update)
import navsim.common.dataclasses as dc
from navsim.agents.abstract_agent import AbstractAgent
from navsim.common.dataloader import SceneLoader, SceneFilter, SensorConfig, MetricCacheLoader
from navsim.visualization.config import TRAJECTORY_CONFIG, CAMERAS_PLOT_CONFIG
from navsim.visualization.bev import add_configured_bev_on_ax, add_trajectory_to_bev_ax
from navsim.visualization.plots import configure_bev_ax
from navsim.common.dataclasses import Trajectory

# Sanity Check
assert Path(os.environ["NUPLAN_MAPS_ROOT"]).exists(), f"Maps root not found: {os.environ['NUPLAN_MAPS_ROOT']}"
print("Environment configured successfully.")

In [None]:
def load_agent_and_scenes(
    repo_path: Path, 
    checkpoint_path: Path, 
    agent_name: str, 
    experiment_name: str
):
    """
    Loads the Agent, and two SceneLoaders (one for inference, one for visualization).
    """
    config_dir = repo_path / "navsim/planning/script/config/training"
    agent_config_dir = repo_path / "navsim/planning/script/config/common/agent"
    
    # Overrides for the general experiment
    overrides = [
        "train_test_split=navtrain",
        f"experiment_name={experiment_name}",
    ]

    # Overrides strictly for the Agent
    agent_overrides = [
        f"checkpoint_path={str(checkpoint_path)}",
        "config.shared_refiner=false",
        "lr_args=null",
        "scheduler_args=null",
        "progress_bar=false",
        "config.refiner_ls_values=0.0",
        "config.image_backbone.focus_front_cam=false",
        "config.one_token_per_traj=true",
        "config.refiner_num_heads=1",
        "config.tf_d_model=256",
        "config.tf_d_ffn=1024",
        "config.area_pred=false",
        "config.agent_pred=false",
        "config.ref_num=4",
        "config.noc=1", "config.dac=1", "config.ddc=0",
        "config.ttc=5", "config.ep=5", "config.comfort=2",
        "loss.prev_weight=0.0",
        "batch_size=null",
    ]

    # clear hydra instance if running in notebook
    if GlobalHydra.instance().is_initialized():
        GlobalHydra.instance().clear()

    # 1. Initialize Agent
    with initialize_config_dir(version_base=None, config_dir=str(agent_config_dir)):
        agent_cfg = compose(config_name=agent_name, overrides=agent_overrides)
    
    agent: AbstractAgent = instantiate(agent_cfg)
    agent.initialize()
    
    # Optional: Load metric cache if needed
    cache_path = Path(os.environ["NAVSIM_EXP_ROOT"]) / "train_metric_cache"
    if cache_path.exists():
        agent.test_metric_cache_paths = MetricCacheLoader(cache_path).metric_cache_paths
    
    agent.b2d = False
    agent.ray = False

    # 2. Initialize Scene Loaders
    with initialize_config_dir(version_base=None, config_dir=str(config_dir)):
        cfg = compose(config_name="default_training", overrides=overrides)

    scene_filter: SceneFilter = instantiate(cfg.train_test_split.scene_filter)
    scene_filter.log_names = cfg.val_logs

    # Loader for model inference
    inference_loader = SceneLoader(
        sensor_blobs_path=Path(cfg.sensor_blobs_path),
        data_path=Path(cfg.navsim_log_path),
        scene_filter=scene_filter,
        sensor_config=agent.get_sensor_config()
    )
    
    # Loader for visualization (includes images)
    viz_sensor_config = SensorConfig(
        cam_f0=[3], cam_l0=[3], cam_l1=[3], cam_l2=[3],
        cam_r0=[3], cam_r1=[3], cam_r2=[3], cam_b0=[3],
        lidar_pc=False,
    )
    viz_loader = SceneLoader(
        sensor_blobs_path=Path(cfg.sensor_blobs_path),
        data_path=Path(cfg.navsim_log_path),
        scene_filter=scene_filter,
        sensor_config=viz_sensor_config
    )
    
    return agent, inference_loader, viz_loader

print("Loading model...")
agent, inference_loader, viz_loader = load_agent_and_scenes(
    REPO_ROOT, CHECKPOINT_PATH, AGENT_NAME, EXPERIMENT_NAME
)
print("Model loaded.")

In [None]:
def avg_heads(A):
    """ (B,H,N,N) -> (B,N,N) """
    A = A.mean(1)
    return A / (A.sum(-1, keepdim=True) + 1e-9)

@torch.no_grad()
def reg_to_patch_last_layer(attn_last, reg_idx, patch_idx, H_p, W_p):
    """Simple extraction from the final layer."""
    A = avg_heads(attn_last)                  # (B,N,N)
    V = A[:, reg_idx][:, :, patch_idx]        # (B,R,P)
    return V.view(V.shape[0], V.shape[1], H_p, W_p)

@torch.no_grad()
def reg_to_patch_shallow_rollout(attn_list, reg_idx, patch_idx, H_p, W_p, last_k=3, alpha=0.9):
    """
    Rollout attention over the last k layers to capture deeper dependencies.
    """
    mats = []
    for A in attn_list[-last_k:]:
        M = avg_heads(A).clone()
        I = torch.eye(M.size(-1), device=M.device, dtype=M.dtype).expand_as(M)
        M = alpha * M + (1 - alpha) * I
        mats.append(M)
    
    R = mats[0]
    for M in mats[1:]:
        R = R @ M
        
    R = R / (R.sum(-1, keepdim=True) + 1e-9)
    V = R[:, reg_idx][:, :, patch_idx]
    return V.view(V.shape[0], V.shape[1], H_p, W_p)

@torch.no_grad()
def reg_to_patch_low_entropy_head(attn_last, reg_idx, patch_idx, H_p, W_p):
    """
    Selects the attention head with the lowest entropy for sharper visualizations.
    """
    B, H, N, _ = attn_last.shape
    # Slice reg->patch per head: (B,H,R,P)
    R2P = attn_last[:, :, reg_idx][:, :, :, patch_idx]
    R2P = R2P / (R2P.sum(-1, keepdim=True) + 1e-9)
    
    # Calculate entropy
    ent = -(R2P.clamp_min(1e-12) * R2P.clamp_min(1e-12).log()).sum(-1) # (B,H,R)

    # Pick best head per (B,R)
    best_h = ent.argmin(1) # (B,R)
    out = torch.zeros((B, len(reg_idx), H_p, W_p), device=attn_last.device, dtype=attn_last.dtype)
    
    for b in range(B):
        # Gather per-reg head: (R,P)
        sel = R2P[b, best_h[b], torch.arange(len(reg_idx), device=attn_last.device)]
        out[b] = sel.view(len(reg_idx), H_p, W_p)
    return out

In [None]:
def plot_trajectories_by_camera(proposals, A_tr2reg, ego_yaw=0.0, title="Trajectory Attention"):
    """Plots decoded trajectories colored by their dominant camera attention."""
    # 1. Compute per-camera attention strength
    if A_tr2reg.ndim == 2:
        A_tr2reg_cams = A_tr2reg.view(64, 4, 16) # (64 regs total -> 4 cams * 16 regs)
    else:
        A_tr2reg_cams = A_tr2reg

    cam_strength = A_tr2reg_cams.sum(-1)
    cam_strength = cam_strength / (cam_strength.sum(-1, keepdim=True) + 1e-9)
    dominant_cam = cam_strength.argmax(dim=-1).cpu().numpy()
    cam_confidence = cam_strength.max(-1).values.cpu().numpy()

    # 2. Rotate coordinates
    rot = np.array([
        [np.cos(ego_yaw), -np.sin(ego_yaw)],
        [np.sin(ego_yaw),  np.cos(ego_yaw)],
    ])
    xy = proposals[..., :2].cpu().numpy() @ rot.T

    # 3. Setup Plot
    camera_labels = ["Front", "Back", "Left", "Right"]
    camera_colors = ["#E41A1C", "#377EB8", "#4DAF4A", "#984EA3"] 
    colors = np.array(camera_colors)[dominant_cam]

    fig, ax = plt.subplots(figsize=(6, 6))
    
    for i in range(xy.shape[0]):
        x, y = xy[i, :, 0], xy[i, :, 1]
        alpha = 0.4 + 0.6 * cam_confidence[i]
        ax.plot(x, y, color=colors[i], alpha=alpha, lw=2, zorder=2)
        ax.scatter(x[-1], y[-1], color=colors[i], s=10, zorder=3)

    ego_shape = np.array([[-1.0, -0.5], [1.0, -0.5], [1.0, 0.5], [-1.0, 0.5]]) @ rot.T
    ax.add_patch(plt.Polygon(ego_shape, closed=True, color="gray", alpha=0.5, zorder=4))

    handles = [plt.Line2D([0], [0], color=c, lw=3, label=l) for c, l in zip(camera_colors, camera_labels)]
    ax.legend(handles=handles, title="Dominant Camera", loc="upper right")
    
    pad = 5.0
    ax.set_xlim(xy[..., 0].min() - pad, xy[..., 0].max() + pad)
    ax.set_ylim(xy[..., 1].min() - pad, xy[..., 1].max() + pad)
    ax.set_aspect('equal')
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

def visualize_attention_maps(reg_maps, cams, R=16, H_img=672, W_img=1148):
    """Visualizes attention maps overlaid on camera images."""
    PATCH_SIZE = 14
    H_p, W_p = reg_maps.shape[-2:]
    
    for b, (title, cam_img) in enumerate(cams.items()):
        maps_b = reg_maps[b] # (R, H_p, W_p)

        if torch.is_tensor(cam_img):
            cam_img = cam_img.detach().cpu().permute(1, 2, 0).numpy()
        if cam_img.max() > 1.1:
            cam_img = cam_img / 255.0
            
        img_resized = cv2.resize(cam_img, (W_img, H_img), interpolation=cv2.INTER_AREA)
        H_crop, W_crop = H_p * PATCH_SIZE, W_p * PATCH_SIZE
        img_cropped = img_resized[:H_crop, :W_crop]

        plt.figure(figsize=(12, 12))
        for i in range(min(R, 16)): 
            plt.subplot(4, 4, i + 1)
            hm = maps_b[i].detach().cpu().numpy()
            hm = (hm - hm.min()) / (np.ptp(hm) + 1e-9)
            hm_resized = cv2.resize(hm, (W_crop, H_crop), interpolation=cv2.INTER_NEAREST)
            
            plt.imshow(np.clip(img_cropped, 0, 1))
            plt.imshow(hm_resized, alpha=0.5, cmap='magma')
            plt.title(f"{title} Reg {i}")
            plt.axis('off')
        
        plt.suptitle(f"{title.upper()} — Register Attention Maps", fontsize=14)
        plt.tight_layout()
        plt.show()

def plot_patch_energy(reg_maps, cams):
    """Plots the mean attention energy over patches (Universal Attractors)."""
    for b, (title, _) in enumerate(cams.items()):
        patch_energy = reg_maps[b].mean(0)  # (H_p, W_p)
        plt.figure(figsize=(6, 5))
        plt.imshow(patch_energy.cpu(), cmap='hot')
        plt.title(f"{title.upper()} — Patch Attention Energy")
        plt.colorbar(label="Mean attention")
        plt.tight_layout()
        plt.show()

def plot_register_similarity(reg_maps, cams):
    """Plots the cosine similarity between registers within a specific view."""
    for b, (title, _) in enumerate(cams.items()):
        rf = reg_maps[b].flatten(1)     # (R, P)
        rf = rf / (rf.norm(dim=-1, keepdim=True) + 1e-9)
        sim = rf @ rf.T                 # (R, R)
        
        plt.figure(figsize=(4.5, 4))
        plt.imshow(sim.cpu(), cmap='viridis', vmin=0, vmax=1)
        plt.title(f"{title.upper()} — Reg↔Reg Cosine")
        plt.colorbar(label="cosine sim")
        plt.xlabel("Reg j")
        plt.ylabel("Reg i")
        plt.tight_layout()
        plt.show()
        print(f"{title} mean inter-register similarity: {sim.triu(1).mean().item():.4f}")

def plot_layer_divergence(attentions_list, cams, R=16):
    """Plots how register specialization (cosine similarity) changes across layers."""
    for b, (title, _) in enumerate(cams.items()):
        divergence = []
        for l, A in enumerate(attentions_list):  # A: (B,H,N,N)
            # Select camera b, average heads, select first R registers
            regs = A.mean(1)[b, :R, :] # (R,N)
            sim = F.cosine_similarity(regs[:, None, :], regs[None, :, :], dim=-1)
            mean_offdiag = sim.triu(1).mean().item()
            divergence.append(mean_offdiag)

        plt.figure(figsize=(5, 3))
        plt.plot(divergence, marker='o')
        plt.xlabel("Layer index")
        plt.ylabel("Mean inter-register cosine sim")
        plt.title(f"{title.upper()} — Register Specialization vs Depth")
        plt.grid(True)
        plt.tight_layout()
        plt.show()

In [None]:
# 1. Select a Scene
# You can pick a random one, or a specific token
token = np.random.choice(inference_loader.tokens)
# token = "02abc6b6508f5516" # specific example
print(f"Visualizing Token: {token}")

scene = viz_loader.get_scene_from_token(token)
agent_input = inference_loader.get_agent_input_from_token(token)

# 2. Run Inference
agent.eval()
features = {}
for builder in agent.get_feature_builders():
    features.update(builder.compute_features(agent_input))
features = {k: v.unsqueeze(0) for k, v in features.items()} # Add batch dim

with torch.no_grad():
    predictions = agent.forward(features)

# 3. Extract Outputs
img_attentions = predictions["image_backbone_attentions"]
traj_attentions = predictions["trajectory_attentions"][1]
proposals = predictions["proposals"][0] # (64, num_poses, 3)

# 4. Process Attention Maps (Config)
# Constants based on architecture
R = 16 
N_total = img_attentions[0].shape[-1]
prefix = 21 # 16 regs + 5 specials
H_p, W_p = 48, 82

reg_idx = torch.arange(0, R, device=img_attentions[0].device)
patch_idx = torch.arange(prefix, N_total, device=img_attentions[0].device)

# 5. Generate Attention Maps (Choose your aggregation method)
# Options: "last", "shallow", "besthead"
aggregation = "besthead" 

if aggregation == "last":
    reg_maps = reg_to_patch_last_layer(img_attentions[-1], reg_idx, patch_idx, H_p, W_p)
elif aggregation == "shallow":
    reg_maps = reg_to_patch_shallow_rollout(img_attentions, reg_idx, patch_idx, H_p, W_p)
else:
    reg_maps = reg_to_patch_low_entropy_head(img_attentions[-1], reg_idx, patch_idx, H_p, W_p)

# 6. Visualize
cams = {
    "cam_f0": scene.frames[3].cameras.cam_f0.image,
    "cam_b0": scene.frames[3].cameras.cam_b0.image,
    "cam_l0": scene.frames[3].cameras.cam_l0.image,
    "cam_r0": scene.frames[3].cameras.cam_r0.image,
}

print(f"Visualizing Method: {aggregation}")
visualize_attention_maps(reg_maps, cams)

print("Visualizing Trajectory-Camera Alignment")
A_tr2reg = traj_attentions[-1][0]
plot_trajectories_by_camera(proposals, A_tr2reg)