# Main simulator loop

## Imports, config

In [1]:
import os
from pathlib import Path
from path_utils import use_path

base_dir = Path(__file__).resolve().parent if "__file__" in globals() else Path.cwd()

# Use relative paths from the base directory
config_path = str(base_dir / "configs/datasets/nuplan/8cams_undistorted.yaml")
checkpoint_path = str(base_dir / "output/master-project/run_omnire_undistorted_8cams_0")
n_steps = 100

## Initialize simulator, environment model and the random agent

In [2]:
"""
This script initializes a NuPlan simulator and
provides methods to get the current state of the simulation
and perform actions based on a given trajectory.
It uses the NuPlan database and maps to create a simulation environment.
It also includes a function to create a waypoint from a given point.
"""

import os
from state_types import State, Position
from omegaconf import OmegaConf
from nuplan.common.actor_state.oriented_box import OrientedBox
from nuplan.common.actor_state.vehicle_parameters import get_pacifica_parameters
from nuplan.common.actor_state.state_representation import StateSE2
from nuplan.common.actor_state.waypoint import Waypoint
from nuplan.planning.scenario_builder.nuplan_db.nuplan_scenario_builder import NuPlanScenarioBuilder
from nuplan.planning.scenario_builder.scenario_filter import ScenarioFilter
from nuplan.planning.simulation.observation.tracks_observation import TracksObservation
from nuplan.planning.simulation.simulation_time_controller.step_simulation_time_controller import (
    StepSimulationTimeController,
)
from nuplan.planning.simulation.simulation import Simulation
from nuplan.planning.simulation.simulation_setup import SimulationSetup
from nuplan.planning.script.builders.worker_pool_builder import build_worker
from nuplan.planning.simulation.controller.perfect_tracking import PerfectTrackingController
from nuplan.planning.simulation.trajectory.interpolated_trajectory import InterpolatedTrajectory
from nuplan.common.actor_state.dynamic_car_state import DynamicCarState
from nuplan.common.actor_state.state_representation import StateVector2D
from nuplan.common.actor_state.ego_state import EgoState
from nuplan.common.actor_state.ego_state import CarFootprint


NUPLAN_DATA_ROOT = os.getenv('NUPLAN_DATA_ROOT', '/data/sets/nuplan')
NUPLAN_MAPS_ROOT = os.getenv('NUPLAN_MAPS_ROOT', '/data/sets/nuplan/maps')
NUPLAN_DB_FILES = os.getenv('NUPLAN_DB_FILES', '/data/sets/nuplan/nuplan-v1.1/splits/mini')
NUPLAN_MAP_VERSION = os.getenv('NUPLAN_MAP_VERSION', 'nuplan-maps-v1.0')
NUPLAN_SENSOR_ROOT = f"{NUPLAN_DATA_ROOT}/nuplan-v1.1/sensor_blobs"
DB_FILE = f"{NUPLAN_DATA_ROOT}/nuplan-v1.1/splits/mini/2021.05.12.22.28.35_veh-35_00620_01164.db"
MAP_NAME = "us-nv-las-vegas"

class Simulator:
    """Base class for the simulator."""
    def __init__(self):
        pass

    def get_state(self):
        """Get the current state of the simulation."""
        raise NotImplementedError("This method should be overridden in subclasses.")

    def do_action(self, action):
        """Perform an action in the simulation."""
        raise NotImplementedError("This method should be overridden in subclasses.")

class NuPlan(Simulator):
    """
    NuPlan simulator class that initializes the NuPlan simulation environment.
    It uses the NuPlan database and maps to create a simulation environment.
    It provides methods to get the current state of the simulation and perform
    actions based on a given trajectory.
    """

    def __init__(self):
        super().__init__()
        print("Initializing NuPlan simulator...")
        scenario_builder = NuPlanScenarioBuilder(
            data_root=NUPLAN_DATA_ROOT,
            map_root=NUPLAN_MAPS_ROOT,
            sensor_root=NUPLAN_SENSOR_ROOT,
            db_files=[DB_FILE],
            map_version=NUPLAN_MAP_VERSION,
            vehicle_parameters=get_pacifica_parameters(),
            include_cameras=False,
            verbose=True
        )

        scenario_filter = ScenarioFilter(
            log_names = ["2021.05.12.22.28.35_veh-35_00620_01164"],
            scenario_types = None,
            scenario_tokens = None,
            map_names = None,
            num_scenarios_per_type = None,
            limit_total_scenarios = None,
            timestamp_threshold_s = None,
            ego_displacement_minimum_m = None,
            expand_scenarios = False,
            remove_invalid_goals = False,
            shuffle = False
        )

        worker_config = OmegaConf.create({
            'worker': {
                '_target_': 'nuplan.planning.utils.multithreading.worker_sequential.Sequential',
            }
        })

        worker = build_worker(worker_config)
        scenario = scenario_builder.get_scenarios(scenario_filter, worker)[0]

        time_controller = StepSimulationTimeController(scenario)
        observation = TracksObservation(scenario)
        controller = PerfectTrackingController(scenario)

        simulation_setup = SimulationSetup(
            time_controller=time_controller,
            observations=observation,
            ego_controller=controller,
            scenario=scenario
        )

        self.simulation = Simulation(
            simulation_setup=simulation_setup,
            callback=None,
            simulation_history_buffer_duration=2.0
        )

        self.simulation.initialize()


        planner_input = self.simulation.get_planner_input()
        history = planner_input.history
        self.original_ego_state, self.original_observation_state = history.current_state
        
        print(self.original_ego_state)

        self.ego_vehicle_oriented_box = self.original_ego_state.waypoint.oriented_box

        print("NuPlan initialized.")

    def get_state(self) -> State:
        planner_input = self.simulation.get_planner_input()
        history = planner_input.history
        ego_state, observation_state = history.current_state

        ego_pos: Position = Position(
            x=ego_state.waypoint.center.x,
            y=ego_state.waypoint.center.y,
            # x=0,
            # y=0,
            z=606.740,  # height of camera
            heading=ego_state.waypoint.heading
        )
        agent_pos_list: list[Position] = [
            Position(
                x=agent.center.x,
                y=agent.center.y,
                z=2,
                heading=agent.center.heading
            )
            for agent in observation_state.tracked_objects.get_agents()
        ]
        state = State(
            ego_pos=ego_pos,
            vehicle_pos_list=agent_pos_list,
            timestamp=ego_state.waypoint.time_point
        )
        return state

    def do_action(self, action):
        trajectory = action
        interpolated_trajectory = self.create_interpolated_trajectory(trajectory)
        self.simulation.propagate(interpolated_trajectory)

    def create_interpolated_trajectory(self, trajectory):
        """
        Create an interpolated trajectory from a given trajectory.
        :param trajectory: The trajectory to create the interpolated trajectory from.
        :return: The created interpolated trajectory.
        """
        # Convert Waypoints to EgoStates
        ego_states = []
        vehicle_parameters = get_pacifica_parameters()
        
        for waypoint in trajectory:
            # Extract data from the waypoint
            time_point = waypoint.time_point
            oriented_box = waypoint.oriented_box
            velocity = waypoint.velocity
            
            # Create a dynamic car state with speed from velocity
            speed = (velocity.x**2 + velocity.y**2)**0.5  # Calculate speed from velocity components
            
            car_footprint = CarFootprint(
                center=oriented_box.center,
                vehicle_parameters=vehicle_parameters,
            )
            
            dynamic_car_state = DynamicCarState(
                rear_axle_to_center_dist=vehicle_parameters.cog_position_from_rear_axle,
                rear_axle_velocity_2d=StateVector2D(velocity.x, velocity.y),
                rear_axle_acceleration_2d=StateVector2D(0, 0)  # Assuming no acceleration for simplicity
            )
            
            # Create an EgoState
            ego_state = EgoState(
                time_point=time_point,
                car_footprint=car_footprint,
                dynamic_car_state=dynamic_car_state,
                tire_steering_angle=0.0,  # Assuming no steering angle for simplicity
                is_in_auto_mode=True,
            )
            
            ego_states.append(ego_state)
        
        # Create an interpolated trajectory with EgoState objects
        return InterpolatedTrajectory(ego_states)

    def create_waypoint_from_point(self, point):
        """
        Create a waypoint from a point.
        :param point: The point to create the waypoint from.
        :return: The created waypoint.
        """
        pose = StateSE2(point.x, point.y, point.yaw)
        oriented_box = OrientedBox(
            pose,
            width=self.ego_vehicle_oriented_box.width,
            length=self.ego_vehicle_oriented_box.length,
            height=self.ego_vehicle_oriented_box.height
        )
        return Waypoint(
            time_point=point.time_point,
            oriented_box=oriented_box,
            velocity=point.velocity
        )

simulator = NuPlan()

Initializing NuPlan simulator...
<nuplan.common.actor_state.ego_state.EgoState object at 0x7f92f8c87d90>
NuPlan initialized.


In [3]:
import numpy as np
import torch

from state_types import State
from setup import OmniReSetup

OPENCV2DATASET = np.array([
    [0, 0, 1, 0],
    [-1, 0, 0, 0],
    [0, -1, 0, 0],
    [0, 0, 0, 1]
], dtype=np.float32)

class OmniReModel:
    def __init__(self, setup: OmniReSetup):
        self.data_cfg = setup.data_cfg
        self.train_cfg = setup.train_cfg
        self.trainer = setup.trainer
        self.dataset = setup.dataset
        self.device = setup.device
        self.camera_matrix_cache = {}
    
    def render_single_frame(self, frame_data: dict) -> np.ndarray:
        """
        Render a single frame based on provided frame data.
        
        Args:
            frame_data (dict): Dictionary containing camera and image info for the frame
            
        Returns:
            np.ndarray: The rendered RGB image as a numpy array
        """
        self.trainer.set_eval()
        
        with torch.no_grad():
            # Create copies of the dictionaries to avoid modifying originals
            cam_infos = {}
            image_infos = {}
            
            # Move camera info tensors to GPU
            for key, value in frame_data["cam_infos"].items():
                if isinstance(value, torch.Tensor):
                    cam_infos[key] = value.to(self.device)
                else:
                    cam_infos[key] = value
            
            # Move image info tensors to GPU
            for key, value in frame_data["image_infos"].items():
                if isinstance(value, torch.Tensor):
                    image_infos[key] = value.to(self.device)
                else:
                    image_infos[key] = value

            # Perform rendering - explicitly set novel_view=True as in DriveStudio
            outputs = self.trainer(
                image_infos=image_infos,
                camera_infos=cam_infos,
                novel_view=True
            )

            # Clip RGB output to valid range exactly as DriveStudio does
            if "rgb" in outputs:
                outputs["rgb"] = outputs["rgb"].clamp(min=1.e-6, max=1-1.e-6)

            # Extract RGB image and convert to numpy
            rgb = outputs["rgb"].cpu().numpy()
            
            # If depth is needed, you can extract it too
            if "depth" in outputs:
                depth = outputs["depth"].cpu().numpy()
                return rgb, depth

            return rgb

    def get_sensor_input(self, state):
        """
        Generate sensor output (RGB image) for the given simulation state.
        
        Args:
            state (dict): Current state of the simulation containing:
                - camera_position (np.ndarray): 3D position of the camera
                - camera_rotation (np.ndarray): Rotation of the camera (e.g., quaternion)
                - vehicle_positions (dict): Dictionary mapping vehicle IDs to positions
                - vehicle_rotations (dict): Dictionary mapping vehicle IDs to rotations
                - timestamp (float): Current simulation time
                
        Returns:
            dict: Sensor outputs including rendered image
        """
        # Prepare frame data for rendering based on current state
        frame_data = self.prepare_frame_data(state)

        # Render the image
        rgb_image = self.render_single_frame(frame_data)

        # Create sensor output dictionary
        sensor_output = {
            "rgb_image": rgb_image,
            # Add other sensor outputs as needed
        }

        return sensor_output

    # def prepare_frame_data(self, state: State):
    #     """
    #     Prepare the frame data needed for rendering based on simulation state.
        
    #     Args:
    #         state (dict): Current state of the simulation
            
    #     Returns:
    #         dict: Frame data dictionary with cam_infos and image_infos
    #     """
    #     # Extract camera information
    #     camera_position = torch.tensor([
    #         state.ego_pos.x,
    #         state.ego_pos.y,
    #         state.ego_pos.z
    #     ], dtype=torch.float32, device=self.device)
        
    #     # Convert heading to proper quaternion in DriveStudio's coordinate system
    #     # Create quaternion for yaw rotation (around z-axis)
    #     # Convert heading to tensor first to avoid TypeError with numpy.float64
    #     heading = torch.tensor(state.ego_pos.heading, dtype=torch.float32, device=self.device)
    #     # Format quaternion as [w, x, y, z] which is the format expected by compute_camera_matrix
    #     camera_rotation = torch.tensor([
    #         torch.cos(heading / 2),  # w
    #         0.0,                     # x (roll)
    #         0.0,                     # y (pitch)
    #         torch.sin(heading / 2)   # z (yaw)
    #     ], dtype=torch.float32, device=self.device)
        
    #     # Get camera matrix
    #     c2w = self.compute_camera_matrix(camera_position, camera_rotation)
        
    #     # Correct construction of intrinsics from 9-element array
    #     # The NuPlan format is [f_u, f_v, c_u, c_v, k1, k2, p1, p2, k3]
    #     intrinsics_array = torch.tensor([
    #         1.545000000000000000e+03,  # fx (f_u)
    #         1.545000000000000000e+03,  # fy (f_v)
    #         9.600000000000000000e+02,  # cx (c_u)
    #         5.600000000000000000e+02,  # cy (c_v)
    #         -3.561230000000000229e-01, # k1
    #         1.725450000000000039e-01,  # k2
    #         -2.129999999999999949e-03, # p1
    #         4.640000000000000027e-04,  # p2
    #         -5.231000000000000233e-02  # k3
    #     ], dtype=torch.float32, device=self.device)
        
    #     # Convert to proper 3x3 intrinsics matrix format
    #     fx, fy, cx, cy = intrinsics_array[0], intrinsics_array[1], intrinsics_array[2], intrinsics_array[3]
    #     intrinsics = torch.tensor([
    #         [fx, 0.0, cx],
    #         [0.0, fy, cy],
    #         [0.0, 0.0, 1.0]
    #     ], dtype=torch.float32, device=self.device)
        
    #     # Set image dimensions that match the intrinsics
    #     H, W = 720, 1280  # Standard dimensions for NuPlan
        
    #     # Create the camera information dictionary
    #     cam_infos = {
    #         "camera_to_world": c2w,
    #         "intrinsics": intrinsics,
    #         "height": torch.tensor(H, dtype=torch.long, device=self.device),
    #         "width": torch.tensor(W, dtype=torch.long, device=self.device),
    #         "cam_name": "front_camera",  # For logging purposes
    #         "cam_id": torch.tensor(0, dtype=torch.long, device=self.device),  # Front camera ID
    #     }
        
    #     # Generate ray directions and origins for all pixels
    #     j, i = torch.meshgrid(
    #         torch.arange(H, device=self.device),
    #         torch.arange(W, device=self.device),
    #         indexing='ij'
    #     )
        
    #     # Convert pixel coordinates to ray directions using the intrinsics matrix
    #     # Following the DriveStudio convention in get_rays() in pixel_source.py
    #     # See datasets/base/pixel_source.py
    #     directions = torch.stack([
    #         (i - cx + 0.5) / fx,
    #         (j - cy + 0.5) / fy,
    #         torch.ones_like(i)
    #     ], dim=-1)  # [H, W, 3]
        
    #     # Compute ray direction norms for depth scaling
    #     direction_norm = torch.linalg.norm(directions, dim=-1, keepdim=True)
        
    #     # Normalize ray directions
    #     viewdirs = directions / (direction_norm + 1e-8)
        
    #     # Convert from camera space to world space using c2w
    #     # Rotating the viewdirs by the camera rotation matrix
    #     viewdirs = (c2w[:3, :3] @ viewdirs.reshape(-1, 3).t()).t().reshape(H, W, 3)
        
    #     # Ray origins are the camera position for all pixels
    #     origins = torch.broadcast_to(c2w[:3, 3], (H, W, 3))
        
    #     # Normalized time for the current timestamp
    #     normalized_time = torch.full(
    #         (H, W), 
    #         (state.timestamp.time_us - self.dataset.start_timestep) / 
    #         (self.dataset.end_timestep - self.dataset.start_timestep)
    #     ).to(self.device)
        
    #     # Create frame indices and other required data for the image
    #     image_infos = {
    #         "origins": origins,
    #         "viewdirs": viewdirs,
    #         "direction_norm": direction_norm,
    #         "pixel_coords": torch.stack([j.float() / H, i.float() / W], dim=-1),
    #         "normed_time": normalized_time,
    #         "img_idx": torch.full((H, W), 0, dtype=torch.long, device=self.device),
    #         "frame_idx": torch.full((H, W), 0, dtype=torch.long, device=self.device),
    #     }
        
    #     return {
    #         "cam_infos": cam_infos,
    #         "image_infos": image_infos
    #     }
    
    def prepare_frame_data(self, state: State):
        """
        Prepare the frame data needed for rendering based on simulation state.
        Args:
            state (State): Current state of the simulation
        Returns:
            dict: Frame data dictionary with cam_infos and image_infos
        """
        import numpy as np
        import torch
        from scipy.spatial.transform import Rotation as R

        # Camera position
        camera_position = torch.tensor([
            state.ego_pos.x,
            state.ego_pos.y,
            state.ego_pos.z
        ], dtype=torch.float32, device=self.device)

        # Heading (yaw) to quaternion [w, x, y, z]
        heading = float(state.ego_pos.heading)
        quat = R.from_euler('z', heading).as_quat()  # [x, y, z, w]
        quat = np.roll(quat, 1)  # to [w, x, y, z]
        camera_rotation = torch.tensor(quat, dtype=torch.float32, device=self.device)

        # Camera-to-world matrix
        c2w = self.compute_camera_matrix(camera_position, camera_rotation)

        # Intrinsics
        intrinsics_array = torch.tensor([
            1.545000000000000000e+03,  # fx (f_u)
            1.545000000000000000e+03,  # fy (f_v)
            9.600000000000000000e+02,  # cx (c_u)
            5.600000000000000000e+02,  # cy (c_v)
            -3.561230000000000229e-01, # k1
            1.725450000000000039e-01,  # k2
            -2.129999999999999949e-03, # p1
            4.640000000000000027e-04,  # p2
            -5.231000000000000233e-02  # k3
        ], dtype=torch.float32, device=self.device)
        
        # Convert to proper 3x3 intrinsics matrix format
        fx, fy, cx, cy = intrinsics_array[0], intrinsics_array[1], intrinsics_array[2], intrinsics_array[3]
        intrinsics = torch.tensor([
            [fx, 0.0, cx],
            [0.0, fy, cy],
            [0.0, 0.0, 1.0]
        ], dtype=torch.float32, device=self.device)

        H, W = 720, 1280

        # Generate pixel grid (center of each pixel)
        j, i = torch.meshgrid(
            torch.arange(H, device=self.device),
            torch.arange(W, device=self.device),
            indexing='ij'
        )
        # Pixel centers
        pixel_i = i.float() + 0.5
        pixel_j = j.float() + 0.5

        # Directions in camera space
        directions = torch.stack([
            (pixel_i - cx) / fx,
            (pixel_j - cy) / fy,
            torch.ones_like(pixel_i)
        ], dim=-1)  # [H, W, 3]

        # Normalize directions
        direction_norm = torch.linalg.norm(directions, dim=-1, keepdim=True)
        viewdirs = directions / (direction_norm + 1e-8)

        # Rotate directions to world space
        viewdirs = (c2w[:3, :3] @ viewdirs.reshape(-1, 3).t()).t().reshape(H, W, 3)

        # Ray origins
        origins = torch.broadcast_to(c2w[:3, 3], (H, W, 3))

        # Normalized time
        normalized_time = torch.full(
            (H, W),
            (state.timestamp.time_us - self.dataset.start_timestep) /
            (self.dataset.end_timestep - self.dataset.start_timestep),
            device=self.device
        )

        # Pixel coordinates normalized to [0, 1]
        pixel_coords = torch.stack([
            pixel_j / H,
            pixel_i / W
        ], dim=-1)

        cam_infos = {
            "camera_to_world": c2w,
            "intrinsics": intrinsics,
            "height": torch.tensor(H, dtype=torch.long, device=self.device),
            "width": torch.tensor(W, dtype=torch.long, device=self.device),
            "cam_name": "front_camera",
            "cam_id": torch.tensor(0, dtype=torch.long, device=self.device),
        }

        image_infos = {
            "origins": origins,
            "viewdirs": viewdirs,
            "direction_norm": direction_norm,
            "pixel_coords": pixel_coords,
            "normed_time": normalized_time,
            "img_idx": torch.full((H, W), 0, dtype=torch.long, device=self.device),
            "frame_idx": torch.full((H, W), 0, dtype=torch.long, device=self.device),
        }

        return {
            "cam_infos": cam_infos,
            "image_infos": image_infos
        }

    def compute_camera_matrix(self, position, rotation):
        """
        Compute the camera-to-world transformation matrix from position and rotation.
        Following DriveStudio's camera transformation convention.
        
        Args:
            position (torch.Tensor): 3D position vector [x, y, z]
            rotation (torch.Tensor): Rotation as quaternion [w, x, y, z]
            
        Returns:
            torch.Tensor: 4x4 camera-to-world transformation matrix
        """
        # Create a cache key for this position and rotation
        cache_key = (tuple(position.cpu().numpy().tolist() if isinstance(position, torch.Tensor) else position), 
                     tuple(rotation.cpu().numpy().tolist() if isinstance(rotation, torch.Tensor) else rotation))

        # Check if we've already computed this matrix
        if cache_key in self.camera_matrix_cache:
            return self.camera_matrix_cache[cache_key]

        # Initialize transformation matrix
        c2w = torch.eye(4, device=self.device)

        # Set translation component (position)
        c2w[:3, 3] = position

        # Convert quaternion to rotation matrix
        # Note: rotation contains quaternion in order [w, x, y, z]
        w, x, y, z = rotation

        # Construct rotation matrix from quaternion
        rot_matrix = torch.tensor([
            [1 - 2*y*y - 2*z*z, 2*x*y - 2*w*z, 2*x*z + 2*w*y],
            [2*x*y + 2*w*z, 1 - 2*x*x - 2*z*z, 2*y*z - 2*w*x],
            [2*x*z - 2*w*y, 2*y*z + 2*w*x, 1 - 2*x*x - 2*y*y]
        ], dtype=torch.float32, device=self.device)

        # Set rotation component of the transformation matrix
        c2w[:3, :3] = rot_matrix
        
        # Apply OpenCV to dataset coordinate system transformation
        opencv2dataset = torch.tensor(OPENCV2DATASET, dtype=torch.float32, device=self.device)
        c2w = c2w @ opencv2dataset

        # Store in cache for future use
        self.camera_matrix_cache[cache_key] = c2w

        return c2w

Entered local import context
Adding ../drivestudio to sys.path
Current sys.path: ['/cluster/home/larstond/.conda/envs/master/lib/python3.9/site-packages/ray/thirdparty_files', '/cluster/home/larstond/.conda/envs/master/lib/python39.zip', '/cluster/home/larstond/.conda/envs/master/lib/python3.9', '/cluster/home/larstond/.conda/envs/master/lib/python3.9/lib-dynload', '', '/cluster/home/larstond/.conda/envs/master/lib/python3.9/site-packages', '/cluster/home/larstond/master-project/drivestudio/third_party/smplx', '/cluster/home/larstond/master-project/nuplan-devkit']
Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [None]:
base_dir = Path.cwd()

setup = None

# Now use the context manager for clean path handling
with use_path("drivestudio", True):
    # Define paths relative to the drivestudio directory
    relative_config_path = "configs/datasets/nuplan/8cams_undistorted.yaml"
    relative_checkpoint_path = "output/master-project/run_omnire_undistorted_8cams_0"
    
    print(f"Working directory: {os.getcwd()}")
    print(f"Config path (relative to drivestudio): {relative_config_path}")
    print(f"Absolute config path: {os.path.abspath(relative_config_path)}")
    
    # Check if these files exist in this context
    if not os.path.exists(relative_config_path):
        print(f"ERROR: Config file not found at {os.path.abspath(relative_config_path)}")
    if not os.path.exists(relative_checkpoint_path):
        print(f"ERROR: Checkpoint directory not found at {os.path.abspath(relative_checkpoint_path)}")
    
    # Only initialize if files exist
    if os.path.exists(relative_config_path) and os.path.exists(relative_checkpoint_path):
        setup = OmniReSetup(relative_config_path, relative_checkpoint_path)
        print("Successfully initialized OmniRe environment model")
    else:
        print("Failed to initialize environment model due to missing files")

Added /cluster/home/larstond/master-project/drivestudio to sys.path
Changed working directory to /cluster/home/larstond/master-project/drivestudio
Working directory: /cluster/home/larstond/master-project/drivestudio
Config path (relative to drivestudio): configs/datasets/nuplan/8cams_undistorted.yaml
Absolute config path: /cluster/home/larstond/master-project/drivestudio/configs/datasets/nuplan/8cams_undistorted.yaml
Loading config from: configs/datasets/nuplan/8cams_undistorted.yaml
Loading checkpoint from: output/master-project/run_omnire_undistorted_8cams_0


Loading images:   1%|▏         | 4/300 [00:00<00:10, 28.92it/s]

undistorting rgb


Loading images: 100%|██████████| 300/300 [00:08<00:00, 35.19it/s]
Loading dynamic masks:   3%|▎         | 9/300 [00:00<00:03, 85.88it/s]

undistorting dynamic mask


Loading dynamic masks: 100%|██████████| 300/300 [00:03<00:00, 83.48it/s]
Loading human masks:   6%|▌         | 18/300 [00:00<00:03, 85.54it/s]

undistorting human mask


Loading human masks: 100%|██████████| 300/300 [00:04<00:00, 64.22it/s]
Loading vehicle masks:   3%|▎         | 8/300 [00:00<00:03, 78.03it/s]

undistorting vehicle mask


Loading vehicle masks: 100%|██████████| 300/300 [00:03<00:00, 84.71it/s]
Loading sky masks:  10%|▉         | 29/300 [00:00<00:01, 140.80it/s]

undistorting sky mask


Loading sky masks: 100%|██████████| 300/300 [00:02<00:00, 135.19it/s]
Loading images:   1%|▏         | 4/300 [00:00<00:08, 35.52it/s]

undistorting rgb


Loading images: 100%|██████████| 300/300 [00:08<00:00, 36.97it/s]
Loading dynamic masks:   3%|▎         | 9/300 [00:00<00:03, 85.44it/s]

undistorting dynamic mask


Loading dynamic masks: 100%|██████████| 300/300 [00:03<00:00, 88.24it/s]
Loading human masks:   3%|▎         | 9/300 [00:00<00:03, 84.99it/s]

undistorting human mask


Loading human masks: 100%|██████████| 300/300 [00:03<00:00, 86.94it/s]
Loading vehicle masks:   4%|▍         | 13/300 [00:00<00:04, 64.77it/s]

undistorting vehicle mask


Loading vehicle masks: 100%|██████████| 300/300 [00:03<00:00, 85.04it/s]
Loading sky masks:  10%|█         | 30/300 [00:00<00:01, 145.84it/s]

undistorting sky mask


Loading sky masks: 100%|██████████| 300/300 [00:02<00:00, 144.23it/s]
Loading images:   2%|▏         | 6/300 [00:00<00:09, 29.66it/s]

undistorting rgb


Loading images: 100%|██████████| 300/300 [00:11<00:00, 26.24it/s]
Loading dynamic masks:   2%|▏         | 6/300 [00:00<00:05, 55.65it/s]

undistorting dynamic mask


Loading dynamic masks: 100%|██████████| 300/300 [00:07<00:00, 39.24it/s]
Loading human masks:   2%|▏         | 6/300 [00:00<00:05, 54.33it/s]

undistorting human mask


Loading human masks: 100%|██████████| 300/300 [00:06<00:00, 43.76it/s]
Loading vehicle masks:   2%|▏         | 5/300 [00:00<00:07, 40.60it/s]

undistorting vehicle mask


Loading vehicle masks: 100%|██████████| 300/300 [00:07<00:00, 40.12it/s]
Loading sky masks:   2%|▏         | 6/300 [00:00<00:05, 51.67it/s]

undistorting sky mask


Loading sky masks: 100%|██████████| 300/300 [00:04<00:00, 66.29it/s] 
Loading images:   1%|          | 3/300 [00:00<00:12, 23.37it/s]

undistorting rgb


Loading images: 100%|██████████| 300/300 [00:12<00:00, 24.28it/s]
Loading dynamic masks:   1%|▏         | 4/300 [00:00<00:09, 31.84it/s]

undistorting dynamic mask


Loading dynamic masks: 100%|██████████| 300/300 [00:07<00:00, 38.02it/s]
Loading human masks:   2%|▏         | 7/300 [00:00<00:09, 32.55it/s]

undistorting human mask


Loading human masks: 100%|██████████| 300/300 [00:08<00:00, 37.07it/s]
Loading vehicle masks:   1%|▏         | 4/300 [00:00<00:08, 35.71it/s]

undistorting vehicle mask


Loading vehicle masks: 100%|██████████| 300/300 [00:08<00:00, 37.35it/s]
Loading sky masks:   1%|          | 3/300 [00:00<00:11, 26.81it/s]

undistorting sky mask


Loading sky masks: 100%|██████████| 300/300 [00:05<00:00, 53.09it/s]
Loading images:   1%|          | 2/300 [00:00<00:19, 14.98it/s]

undistorting rgb


Loading images:  72%|███████▏  | 216/300 [00:10<00:04, 19.75it/s]

In [None]:
environment_model = OmniReModel(setup)

In [None]:
from interfaces import Agent
from nuplan.common.actor_state.state_representation import StateSE2, TimePoint, StateVector2D
from nuplan.common.actor_state.oriented_box import OrientedBox
from nuplan.common.actor_state.waypoint import Waypoint

class RandomAgent(Agent):
    """Random agent that selects random actions."""
    def __init__(self):
        super().__init__()

    def get_action(self, sensor_output, timestamp):
        """Select a random trajectory"""
        trajectory = []
        current_time = timestamp.time_us
        
        for i in range(10):
            # Create proper TimePoint object (microseconds)
            time_point = TimePoint(current_time + i * 100000) 
            
            # Create position and heading
            x = i * 10
            y = i * 10
            heading = -2.066
            
            # Create StateSE2 for position and heading
            center = StateSE2(x, y, heading)
            
            # Create oriented box
            # Parameters: center, length, width, height
            oriented_box = OrientedBox(center, length=1.0, width=1.0, height=1.0)
            
            # Create velocity vector
            velocity = StateVector2D(1.0, 0.0)  # x-velocity=1.0, y-velocity=0.0
            
            # Create waypoint with all required components
            waypoint = Waypoint(time_point, oriented_box, velocity)
            trajectory.append(waypoint)
            
        return trajectory

agent = RandomAgent()

## Do the simulation loop

In [None]:

error_history = []
sensor_outputs = []

# for i in range(3):
#     print("Step", i)
#     state = simulator.get_state()
#     print("State:", state)
#     sensor_output = environment_model.get_sensor_input(state)
#     sensor_outputs.append(sensor_output)
#     print("Sensor Output:", sensor_output)
#     action = agent.get_action(sensor_output, state.timestamp)
#     print("Action:", action)
#     simulator.do_action(action)
#     error_history.append(simulator.get_state())

start_heading = simulator.get_state().ego_pos.heading
# Get the initial state
print("Initial State:", start_pos)

for i in range(120):
    print("Step", i)
    # Add 100 to the x of the initial state for each step
    state = simulator.get_state()
    # state.ego_pos.heading = start_heading + i * 0.1
    state.ego_pos.x = start_pos.x + i * 1.0  # Move in x direction
    state.ego_pos.y = start_pos.y + i * 0.5  # Move in y direction
    
    sensor_output = environment_model.get_sensor_input(state)
    sensor_outputs.append(sensor_output)
    
# print("Error History:", error_history)


In [None]:
def print_rgb_image(rgb_image, gamma=2.2):
    """
    Display an RGB image using matplotlib with gamma correction for better visualization.
    
    Args:
        rgb_image (np.ndarray or torch.Tensor or tuple): The RGB image to display
        gamma (float): Gamma correction value (default: 2.2 for sRGB)
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import torch
    
    # Handle different input formats first
    if isinstance(rgb_image, tuple):
        # If it's a tuple, take the first element (array)
        print("Input is a tuple, extracting array...")
        rgb_image = rgb_image[0]
    
    # Convert torch tensor to numpy array if needed
    if isinstance(rgb_image, torch.Tensor):
        rgb_image = rgb_image.detach().cpu().numpy()
    
    # Now that we've handled the tensor case, we can safely check shape
    print(f"Image shape: {rgb_image.shape}, dtype: {rgb_image.dtype}")
    
    # Ensure 3D array with shape (H, W, 3)
    if rgb_image.ndim == 4:
        # If we have a batch dimension, take the first image
        print("Input has batch dimension, taking first image...")
        rgb_image = rgb_image[0]
    
    # Ensure the image is in [0, 1] range
    if rgb_image.max() > 1.0:
        rgb_image = rgb_image / 255.0
    
    # Create a figure with a specific size
    plt.figure(figsize=(10, 6))
    
    # Apply gamma correction for display
    corrected_image = np.power(np.clip(rgb_image, 0, 1), 1/gamma)
    
    # Display the image with gamma correction
    plt.imshow(corrected_image)
    plt.axis('off')  # Hide axes
    plt.title(f"Gamma-corrected image (γ={gamma})")
    plt.show()

# Try different gamma values for each image to find the optimal visualization
for i, sensor_output in enumerate(sensor_outputs):
    rgb_image = sensor_output["rgb_image"]
    print(f"Image {i} with gamma=2.2 (standard sRGB):")
    print_rgb_image(rgb_image, gamma=2.2)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cv2

def save_rgb_images_to_video(sensor_outputs, output_path, gamma=2.2):
    """
    Save a list of RGB images to a video file with proper gamma correction.
    
    Args:
        sensor_outputs (list): List of sensor outputs containing RGB images
        output_path (str): Path to save the video file
        gamma (float): Gamma correction value (default: 2.2 for sRGB)
    """
    import cv2
    import numpy as np
    
    # Get the first rgb_image and handle tuple case
    first_rgb = sensor_outputs[0]["rgb_image"]
    if isinstance(first_rgb, tuple):
        # If it's a tuple, take the first element (RGB array)
        first_rgb = first_rgb[0]
    
    # Get the dimensions of the first image
    height, width, _ = first_rgb.shape
    
    # Define the codec and create VideoWriter object
    # Use mp4v codec for MP4 output
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, 20.0, (width, height))
    
    for sensor_output in sensor_outputs:
        rgb_image = sensor_output["rgb_image"]
        
        # Handle tuple case
        if isinstance(rgb_image, tuple):
            rgb_image = rgb_image[0]  # Extract RGB array from tuple
        
        # Make sure values are in float [0, 1] range for consistent processing
        if rgb_image.dtype != np.float32 and rgb_image.dtype != np.float64:
            if rgb_image.max() > 1.0:
                rgb_image = rgb_image.astype(np.float32) / 255.0
        
        # Apply gamma correction for better visual appearance
        corrected_image = np.power(np.clip(rgb_image, 0, 1), 1/gamma)
        
        # Convert to 8-bit color format required for video
        uint8_image = (corrected_image * 255).astype(np.uint8)
        
        # Convert from RGB to BGR (OpenCV format)
        bgr_image = cv2.cvtColor(uint8_image, cv2.COLOR_RGB2BGR)
        
        # Write frame to video
        out.write(bgr_image)
    
    out.release()
    print(f"Video saved to {output_path} with gamma correction (γ={gamma})")

# Create and save the gamma-corrected video
output_video_path = "output_video.mp4"
save_rgb_images_to_video(sensor_outputs, output_video_path, gamma=2.2)

In [None]:
def compare_gamma_corrections(rgb_image, gamma_values=[1.0, 1.8, 2.2, 2.4]):
    """
    Display an RGB image with different gamma correction values side by side.
    
    Args:
        rgb_image (np.ndarray or torch.Tensor or tuple): The RGB image to display
        gamma_values (list): List of gamma values to compare
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import torch
    
    # Handle different input formats
    if isinstance(rgb_image, tuple):
        rgb_image = rgb_image[0]
    
    if isinstance(rgb_image, torch.Tensor):
        rgb_image = rgb_image.detach().cpu().numpy()
    
    if rgb_image.ndim == 4:
        rgb_image = rgb_image[0]
    
    # Ensure the image is in [0, 1] range
    if rgb_image.max() > 1.0:
        rgb_image = rgb_image / 255.0
    
    # Create subplots for each gamma value
    fig, axes = plt.subplots(1, len(gamma_values), figsize=(16, 5))
    fig.suptitle("Comparison of Different Gamma Correction Values", fontsize=16)
    
    for i, gamma in enumerate(gamma_values):
        # Apply gamma correction
        corrected_image = np.power(np.clip(rgb_image, 0, 1), 1/gamma)
        
        # Display in the appropriate subplot
        axes[i].imshow(corrected_image)
        axes[i].set_title(f"γ = {gamma}")
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Compare different gamma values for a sample rendered image
if len(sensor_outputs) > 0:
    compare_gamma_corrections(sensor_outputs[0]["rgb_image"])
    
    # You can try more advanced processing if needed
    # For example, to enhance local contrast while maintaining global appearance:
    import cv2
    import numpy as np
    
    # Function to apply advanced image enhancements
    def enhance_image(rgb_image, gamma=2.2, clahe_clip=2.0, clahe_grid=(8,8)):
        """Apply advanced image enhancement techniques"""
        # Make sure image is in proper format
        if isinstance(rgb_image, tuple):
            rgb_image = rgb_image[0]
            
        # Convert to 8-bit for OpenCV operations
        img_8bit = (np.clip(rgb_image, 0, 1) * 255).astype(np.uint8)
        
        # Convert to LAB color space (L=lightness, A=green-red, B=blue-yellow)
        lab = cv2.cvtColor(img_8bit, cv2.COLOR_RGB2LAB)
        l, a, b = cv2.split(lab)
        
        # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization) to lightness channel
        clahe = cv2.createCLAHE(clipLimit=clahe_clip, tileGridSize=clahe_grid)
        cl = clahe.apply(l)
        
        # Merge channels back
        enhanced_lab = cv2.merge((cl, a, b))
        
        # Convert back to RGB
        enhanced_rgb = cv2.cvtColor(enhanced_lab, cv2.COLOR_LAB2RGB)
        
        # Apply gamma correction for display
        gamma_corrected = np.power(enhanced_rgb / 255.0, 1/gamma)
        
        return gamma_corrected
    
    # Display enhanced version of the first frame
    if len(sensor_outputs) > 0:
        plt.figure(figsize=(12, 6))
        enhanced = enhance_image(sensor_outputs[0]["rgb_image"])
        plt.imshow(enhanced)
        plt.axis('off')
        plt.title("Enhanced image with CLAHE + gamma correction")
        plt.show()