In [2]:
import dataclasses
import jax
from jax import random
from jax import numpy as jnp
import mediapy
from PIL import Image

import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
from waymax import agents
from waymax import config as _config
from waymax import dynamics
from waymax import dataloader
from waymax import datatypes
from waymax import env as _env
from waymax import visualization

import sys

sys.path.append("../")
sys.path.append("./")

from waymax.datatypes import SimulatorState
from obs_mask.mask import ObsMask

from model.state_preprocessing import ExtractObs
from model.config import XY_SCALING_FACTOR
from utils.viz import (
    plot_observation_with_mask,
    plot_observation_with_goal,
    plot_observation_with_heading,
)

CURRENT_TIME_INDEX = 10
N_SIMULATION_STEPS = 80
N_ROLLOUTS = 32



In [4]:
config = {
    "anneal_lr": False,
    "bins": 128,
    "discrete": False,
    "extractor": "ExtractObs",
    "feature_extractor": "KeyExtractor",
    "feature_extractor_kwargs": {
        "final_hidden_layers": None,
        "hidden_layers": {},
        "keys": ["proxy_goal", "heading"],
        "kwargs": {"heading": {"radius": 20}},
    },
    "key": 42,
    "lr": 3e-4,
    "max_grad_norm": 0.5,
    "max_num_obj": 128,
    "max_num_rg_points": 20000,
    "num_envs": 4,
    "num_envs_eval": 4,
    "num_epochs": 1,
    "num_steps": 80,
    "n_train_per_epoch": 1000,
    "roadgraph_top_k": 500,
    "shuffle_seed": 123,
    "shuffle_buffer_size": 1_000,
    "total_timesteps": 100,
    "training_path": "gs://waymo_open_dataset_motion_v_1_1_0/uncompressed/tf_example/training/training_tfexample.tfrecord@1000",
    "validation_path": "gs://waymo_open_dataset_motion_v_1_1_0/uncompressed/tf_example/validation/validation_tfexample.tfrecord@150",
    "testing_path": "gs://waymo_open_dataset_motion_v_1_1_0/uncompressed/tf_example/testing/testing_tfexample.tfrecord@150",
}

In [5]:
WOD_1_1_0_TRAINING = _config.DatasetConfig(
    path=config["training_path"],
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
)

WOD_1_1_0_VALIDATION = _config.DatasetConfig(
    path=config["validation_path"],
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=config["max_num_obj"],
    # batch_dims=(config['num_envs_eval'],)
    batch_dims=(1,),
)

In [6]:
data_iter = dataloader.simulator_state_generator(config=WOD_1_1_0_VALIDATION)
id = 0

In [7]:
scenario = next(data_iter)
id += 1

print("Scenario idx:", id)
print("Has SDC:", jnp.any(scenario.object_metadata.is_sdc))

_, sdc_idx = jax.lax.top_k(scenario.object_metadata.is_sdc, k=1)
sdc_vel_xy = jnp.take_along_axis(
    scenario.log_trajectory.vel_xy, sdc_idx[..., None, None], axis=1
)
sdc_mean_speed = jax.vmap(lambda x: jnp.sqrt(x[0] ** 2 + x[1] ** 2))(
    sdc_vel_xy[0]
).mean()

print("SDC mean speed:", sdc_mean_speed)

CUDA backend failed to initialize: Found cuSOLVER version 11405, but JAX was built against version 11504, which is newer. The copy of cuSOLVER that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


Scenario idx: 1
Has SDC: True
SDC mean speed: 0.0011510265


In [8]:
env_config = _config.EnvironmentConfig(
    # Ensure that the sim agent can control all valid objects.
    controlled_object=_config.ObjectType.VALID,
    max_num_objects=config["max_num_obj"],
)

dynamics_model = dynamics.InvertibleBicycleModel()
env = _env.MultiAgentEnvironment(dynamics_model=dynamics_model, config=env_config)

# Add limiting observability mask

In [9]:
def vis_mask(
    obs_mask: ObsMask,
    scenario: SimulatorState,
    need_speed: bool = False,
    N: int = 80,
    roadgraph_top_k: int = 2000,
    seed_extract: int = 123,
    seed: int = 42,
):
    imgs = []
    rng = jax.random.PRNGKey(seed)
    rng_extract = jax.random.PRNGKey(seed_extract)

    current_state = env.reset(scenario)
    for _ in range(N):
        # Simulator state
        current_state = datatypes.update_state_by_log(current_state, num_steps=1)
        # Observation
        sdc_obs = datatypes.sdc_observation_from_state(
            current_state, roadgraph_top_k=roadgraph_top_k
        )
        if obs_mask.mask_per_step:
            rng, rng_obs = jax.random.split(rng)
        else:
            rng_obs = rng_extract
        sdc_obs_limited = obs_mask.mask_obs(current_state, sdc_obs, rng_obs)

        if need_speed:
            _, sdc_idx = jax.lax.top_k(current_state.object_metadata.is_sdc, k=1)
            sdc_v = jnp.take_along_axis(
                sdc_obs.trajectory.speed, sdc_idx[..., None, None], axis=-2
            )

            def plot_mask_fun(ax):
                return obs_mask.plot_mask_fun(ax, sdc_v)

        else:
            plot_mask_fun = obs_mask.plot_mask_fun

        img = plot_observation_with_mask(
            jax.tree_map(lambda x: x[0], sdc_obs_limited),
            obj_idx=0,
            mask_function=plot_mask_fun,
        )

        imgs.append(img)
    return imgs

## Constant mask

### Circular FoV

In [None]:
from IGWaymax.obs_mask import DistanceObsMask

In [34]:
obs_mask = DistanceObsMask(radius=50)
imgs = vis_mask(obs_mask, scenario)

In [26]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


In [407]:
# Save as gif
frames = []

for img in imgs:
    pil_img = Image.fromarray(img)
    frames.append(pil_img)

# Save the frames as a GIF
frames[0].save(
    "../animation/full_obs.gif",
    save_all=True,
    append_images=frames[1:],
    duration=200,
    loop=0,
)

### Conic FoV

In [12]:
from IGWaymax.obs_mask import ConicObsMask

In [32]:
obs_mask = ConicObsMask(radius=20, angle=2 / 3 * jnp.pi)
imgs = vis_mask(obs_mask, scenario)

In [34]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


### Random masking

In [9]:
from IGWaymax.obs_mask import RandomMasking

In [83]:
obs_mask = RandomMasking(prob=0.2)
imgs = vis_mask(obs_mask, scenario)  # , N=3)

In [84]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


### Gaussian noise (constant accross the trajectory)

In [10]:
from IGWaymax.obs_mask import GaussianNoise

In [108]:
obs_mask = GaussianNoise(sigma=3)
imgs = vis_mask(obs_mask, scenario)

In [109]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


## Mask as a function of the speed

### Conic FoV

In [8]:
from IGWaymax.obs_mask import SpeedConicObsMask

In [45]:
obs_mask = SpeedConicObsMask(
    radius=20, angle_min=1 / 12 * jnp.pi, angle_max=1 * jnp.pi, v_max=15
)

imgs = vis_mask(obs_mask, scenario, need_speed=True)

In [46]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


In [12]:
# Save as gif
frames = []

for img in imgs:
    pil_img = Image.fromarray(img)
    frames.append(pil_img)

# Save the frames as a GIF
frames[0].save(
    f"../animation/limited_FoV/{obs_mask.__class__.__name__}/ex_{id}.gif",
    save_all=True,
    append_images=frames[1:],
    duration=200,
    loop=0,
)

### Gaussian noise

In [38]:
from IGWaymax.obs_mask import SpeedGaussianNoise

In [39]:
obs_mask = SpeedGaussianNoise(v_max=15, sigma_max=3.0, sigma_min=1.0)

imgs = vis_mask(obs_mask, scenario, need_speed=True)

In [40]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


In [85]:
# Save as gif
frames = []

for img in imgs:
    pil_img = Image.fromarray(img)
    frames.append(pil_img)

# Save the frames as a GIF
frames[0].save(
    f"../limited_FoV/gaussian_noise/ex_{id}.gif",
    save_all=True,
    append_images=frames[1:],
    duration=200,
    loop=0,
)

### Uniform noise

In [48]:
from IGWaymax.obs_mask import SpeedUniformNoise

In [49]:
obs_mask = SpeedUniformNoise(v_max=15, bound_max=5, bound_min=1)

imgs = vis_mask(obs_mask, scenario, need_speed=True)

In [50]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


## With a proxy goal 
### Last position of the SDC in the referential of the current position of the SDC

In [68]:
roadgraph_top_k = 5000
N = 80

imgs = []

initial_state = current_state = env.reset(scenario)
for _ in range(N):
    # Simulator state
    current_state = datatypes.update_state_by_log(current_state, num_steps=1)
    # Observation
    sdc_obs = datatypes.sdc_observation_from_state(
        current_state, roadgraph_top_k=roadgraph_top_k
    )

    obs = ExtractObs(config)(current_state, sdc_obs, None)

    img = plot_observation_with_goal(
        jax.tree_map(lambda x: x[0], sdc_obs),
        obj_idx=0,
        goal=obs["proxy_goal"][0] * XY_SCALING_FACTOR,
    )
    imgs.append(img)

In [69]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


In [57]:
# Save as gif
frames = []

for img in imgs:
    pil_img = Image.fromarray(img)
    frames.append(pil_img)

# Save the frames as a GIF
frames[0].save(
    "../animation/goal/ex_1.gif",
    save_all=True,
    append_images=frames[1:],
    duration=200,
    loop=0,
)

## With a heading
### The heading is directed in the position of the SDC car at $n$ meters away from the current position

In [79]:
roadgraph_top_k = 5000
N = 80

imgs = []

initial_state = current_state = scenario

for _ in range(N):
    # Simulator state
    current_state = datatypes.update_state_by_log(current_state, num_steps=1)
    # Observation
    sdc_obs = datatypes.sdc_observation_from_state(
        current_state, roadgraph_top_k=roadgraph_top_k
    )

    obs = ExtractObs(config)(current_state, sdc_obs, None)

    img = plot_observation_with_heading(
        jax.tree_map(lambda x: x[0], sdc_obs),
        obj_idx=0,
        heading=obs["heading"].squeeze(),
    )
    imgs.append(img)

In [80]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


In [73]:
# Save as gif
frames = []

for img in imgs:
    pil_img = Image.fromarray(img)
    frames.append(pil_img)

# Save the frames as a GIF
frames[0].save(
    f"../animation/heading/ex_{id}.gif",
    save_all=True,
    append_images=frames[1:],
    duration=100,
    loop=0,
)