In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoProcessor, AutoModel
import numpy as np
from pathlib import Path
import hydra
from hydra.utils import instantiate
from PIL import Image
from tqdm import tqdm # For progress bars

# Assuming these imports are correct based on your environment
from navsim.common.dataloader import SceneLoader
from navsim.common.dataclasses import SceneFilter, SensorConfig, Scene, Camera, EgoStatus, Trajectory
from navsim.planning.simulation.planner.pdm_planner.utils.pdm_geometry_utils import convert_absolute_to_relative_se2_array
from planning_agent.PlanningHead import PlanningHead
from planning_agent.NavsimTrajectoryDataset import NavsimTrajectoryDataset, collate_fn_skip_none



In [2]:
SPLIT = "mini"
FILTER = "all_scenes"
OPENSCENE_DATA_ROOT = Path("../dataset") # Adjust if necessary
NUPLAN_MAPS_ROOT = OPENSCENE_DATA_ROOT / "nuplan-maps-v1.0" # Assuming maps are here
NUM_HISTORY_FRAMES = 4 # As per SceneFilter default, adjust if needed
NUM_FUTURE_FRAMES = 8 # Predicting 4 seconds at 0.5s interval = 8 poses

# --- Set up Data Loading ---
print("Setting up SceneLoader...")
hydra.initialize(config_path="../navsim/navsim/planning/script/config/common/train_test_split/scene_filter", version_base=None)
cfg = hydra.compose(config_name=FILTER)
scene_filter: SceneFilter = instantiate(cfg)
# Ensure SceneFilter matches desired history/future frames
scene_filter.num_history_frames = NUM_HISTORY_FRAMES
scene_filter.num_future_frames = NUM_FUTURE_FRAMES

# Correct paths based on previous debugging
navsim_log_path = OPENSCENE_DATA_ROOT / f"{SPLIT}_navsim_logs" / SPLIT
sensor_blob_path = OPENSCENE_DATA_ROOT / f"{SPLIT}_sensor_blobs" / "sensor_blobs" / SPLIT



Setting up SceneLoader...


In [3]:
# Specify only the sensors needed (saves memory/time)
# Just need front camera for this setup
sensor_config = SensorConfig(
    cam_f0=True, # Only need front camera image
    cam_l0=False, cam_l1=False, cam_l2=False,
    cam_r0=False, cam_r1=False, cam_r2=False,
    cam_b0=False, lidar_pc=False
)


scene_loader = SceneLoader(
    data_path=navsim_log_path,
    original_sensor_path=sensor_blob_path,
    scene_filter=scene_filter,
    synthetic_sensor_path=None, # Not using synthetic data here
    synthetic_scenes_path=None,
    sensor_config=sensor_config, # Use specific config
)
print(f"Loaded {len(scene_loader)} scenes for split '{SPLIT}'.")

Loading logs: 100%|██████████| 64/64 [00:09<00:00,  6.64it/s]

Loaded 3623 scenes for split 'mini'.





In [4]:
# --- Configuration ---
MODEL_ID = "facebook/ijepa_vith14_1k" # Using ViT-H/14 as loaded before

# Training Hyperparameters
LEARNING_RATE = 1e-4
EPOCHS = 1 # Adjust as needed
BATCH_SIZE = 64 # Adjust based on GPU memory
NUM_HISTORY_FRAMES = 4 # As per SceneFilter default, adjust if needed
NUM_FUTURE_FRAMES = 8 # Predicting 4 seconds at 0.5s interval = 8 poses

# I-JEPA Output Dimension (for ViT-H/14)
IJEP_DIM = 1280 # ViT-H/14 hidden size is 1280
# If you were using ViT-B/16, this would be 768

# Ego Status Dimension (vel_x, vel_y, acc_x, acc_y, cmd_left, cmd_straight, cmd_right)
EGO_DIM = 2 + 2 + 4 # Now 8 dimensions

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# --- Load Pre-trained I-JEPA Model ---
print(f"Loading I-JEPA processor and model: {MODEL_ID}")
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)
ijepa_encoder = AutoModel.from_pretrained(MODEL_ID).to(DEVICE)

# IMPORTANT: Freeze the I-JEPA encoder parameters
for param in ijepa_encoder.parameters():
    param.requires_grad = False
ijepa_encoder.eval() # Set to evaluation mode
print("I-JEPA model loaded and frozen.")


Using device: cuda
Loading I-JEPA processor and model: facebook/ijepa_vith14_1k
I-JEPA model loaded and frozen.


In [5]:

print("Creating Dataset and DataLoader...")
dataset = NavsimTrajectoryDataset(scene_loader, processor, NUM_HISTORY_FRAMES, NUM_FUTURE_FRAMES, DEVICE)
# Use the custom collate_fn
dataloader = DataLoader(dataset, 
                        batch_size=BATCH_SIZE, 
                        shuffle=True, 
                        num_workers=8, 
                        pin_memory=True, 
                        drop_last=True, 
                        collate_fn=collate_fn_skip_none,
                        persistent_workers=True)
print("DataLoader created.")

Creating Dataset and DataLoader...
DataLoader created.


In [6]:
NUM_FUTURE_FRAMES = 8 # Predicting 4 seconds at 0.5s interval = 8 poses
IJEP_DIM = 1280 # ViT-H/14 hidden size is 1280
EGO_DIM = 2 + 2 + 4 # Now 8 dimensions
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
HIDDEN_DIM = 256
OUTPUT_DIM = NUM_FUTURE_FRAMES*3

mlp_head = PlanningHead(ijep_dim=IJEP_DIM, ego_dim=EGO_DIM, hidden_dim=HIDDEN_DIM, output_dim=OUTPUT_DIM).to(DEVICE)
print("MLP Planning Head defined.")

MLP Planning Head defined.


In [7]:
history = mlp_head.fit(
    dataloader=dataloader,
    ijepa_encoder=ijepa_encoder,
    device=DEVICE,
    epochs=EPOCHS,
    lr=LEARNING_RATE,
    resume_from=None,              # or path to resume checkpoint
    checkpoint_interval=1,         # save every epoch
    use_cls_token=True            # or True if you want CLS token
)

print("Training finished. Avg losses per epoch:", history)

Epoch 1/1:   0%|          | 0/56 [00:00<?, ?it/s]ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
                                                 

Interrupted at epoch 1, saved ./checkpoint_failure_epoch1.pth


RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/envs/navsim/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/opt/conda/envs/navsim/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
  File "/navsim_workspace/code/planning_agent/NavsimTrajectoryDataset.py", line 115, in collate_fn_skip_none
    return torch.utils.data.dataloader.default_collate(batch)
  File "/opt/conda/envs/navsim/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/opt/conda/envs/navsim/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 211, in collate
    return [
  File "/opt/conda/envs/navsim/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 212, in <listcomp>
    collate(samples, collate_fn_map=collate_fn_map)
  File "/opt/conda/envs/navsim/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 155, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/opt/conda/envs/navsim/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 270, in collate_tensor_fn
    storage = elem._typed_storage()._new_shared(numel, device=elem.device)
  File "/opt/conda/envs/navsim/lib/python3.9/site-packages/torch/storage.py", line 1198, in _new_shared
    untyped_storage = torch.UntypedStorage._new_shared(
  File "/opt/conda/envs/navsim/lib/python3.9/site-packages/torch/storage.py", line 415, in _new_shared
    return cls._new_using_fd_cpu(size)
RuntimeError: unable to write to file </torch_57109_1613807916_0>: No space left on device (28)
