In [1]:
import dataclasses
import jax
from jax import random
from jax import numpy as jnp
import mediapy

from waymax import agents
from waymax import config as _config
from waymax import dynamics
from waymax import dataloader
from waymax import datatypes
from waymax import env as _env

import sys
sys.path.append('../')
sys.path.append('./')

from utils import plot_observation_with_mask

CURRENT_TIME_INDEX = 10
N_SIMULATION_STEPS = 80
N_ROLLOUTS = 32

2023-10-26 16:19:02.782554: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-10-26 16:19:03.649527: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-10-26 16:19:03.649640: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64


In [2]:
WOD_1_1_0_TRAINING = _config.DatasetConfig(
    path='/data/saruman/cleain/WOD_1_1_0/tf_example/training/training_tfexample.tfrecord@1000',
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
)

WOD_1_1_0_VALIDATION = _config.DatasetConfig(
    path='/data/saruman/cleain/WOD_1_1_0/tf_example/validation/validation_tfexample.tfrecord@150',
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
)

In [3]:
config = dataclasses.replace(WOD_1_1_0_TRAINING, max_num_objects=128)
data_iter = dataloader.simulator_state_generator(config=config)
scenario = next(data_iter)

2023-10-26 16:19:05.595041: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-10-26 16:19:05.595156: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-10-26 16:19:05.595251: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-10-26 16:19:05.595325: W tensorflow/compiler/xla/stream_executor/platform/defa

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [4]:
env_config = _config.EnvironmentConfig(
    # Ensure that the sim agent can control all valid objects.
    controlled_object=_config.ObjectType.VALID
)

dynamics_model = dynamics.InvertibleBicycleModel()
env = _env.MultiAgentEnvironment(
    dynamics_model=dynamics_model,
    config=env_config,
)

agent = agents.create_constant_speed_actor(
    dynamics_model=dynamics_model,
    # Controlled objects are those valid at t=0.
    is_controlled_func=lambda state: state.log_trajectory.valid[..., CURRENT_TIME_INDEX]
)

jit_step = jax.jit(env.step)
jit_observe = jax.jit(env.observe)
jit_select_action = jax.jit(agent.select_action)

In [10]:
key = random.PRNGKey(0)
N = 10

In [7]:
initial_state = current_state = env.reset(scenario)
# Controlled objects are those valid at t=0.
is_controlled = scenario.log_trajectory.valid[..., CURRENT_TIME_INDEX]

# Run the sim agent for N steps.
for _ in range(N):
    key, actor_key = random.split(key, 2)
    current_obs = jit_observe(current_state)
    actor_output = jit_select_action({}, current_state, None, actor_key)
    next_state = jit_step(current_state, actor_output.action)
    current_state = next_state

In [208]:
sim_trajectory = current_state.sim_trajectory
log_trajectory = current_state.log_trajectory
log_traffic_light = current_state.log_traffic_light
object_metadata = current_state.object_metadata
timestep = current_state.timestep
sdc_paths = current_state.sdc_paths
roadgraph_points = current_state.roadgraph_points

In [212]:
current_timestep = 0
non_visible_obj = jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 23])

sim_trajectory_modified = dataclasses.replace(sim_trajectory,
                                              valid=sim_trajectory.valid.at[:, non_visible_obj].set(False))
current_state_modified = dataclasses.replace(current_state, 
                                             sim_trajectory=sim_trajectory_modified)

In [213]:
sdc_observation = datatypes.sdc_observation_from_state(current_state_modified, roadgraph_top_k=10000)

### Distance based field of view

In [5]:
def plot_circle(ax, center, radius):
    theta = jnp.linspace(0, 2 * jnp.pi, 100)

    x = center[0] + radius * jnp.cos(theta)
    y = center[1] + radius * jnp.sin(theta)

    ax.plot(x, y)

def in_circle(obj_x, obj_y, radius):
    
    squared_distance = obj_x**2 + obj_y**2

    return squared_distance <= radius**2

### Conic field of view

In [6]:
def plot_conic(ax, center, angle, radius, color='b'):
    theta = jnp.linspace(- angle/2, angle/2, 100)

    x = center[0] + radius * jnp.cos(theta)
    y = center[1] + radius * jnp.sin(theta)

    x1 = center[0] + radius * jnp.cos(- angle/2)
    y1 = center[1] + radius * jnp.sin(- angle/2)

    x2 = center[0] + radius * jnp.cos(angle/2)
    y2 = center[1] + radius * jnp.sin(angle/2)

    ax.plot([0, x1], [0, y1], c=color)
    ax.plot([0, x2], [0, y2], c=color)
    ax.plot(x, y, c=color)

def in_conic(obj_x, obj_y, angle, radius):
    squared_distance = obj_x**2 + obj_y**2

    obj_angle = jnp.arctan2(obj_y, obj_x)
    angle_condition = (- angle / 2 <= obj_angle) & (obj_angle <= angle / 2)
    radius_condition = squared_distance <= radius**2

    return angle_condition & radius_condition

In [7]:
roadgraph_top_k = 5000
radius = 50
angle = jnp.pi / 4

# plot_mask_fun = lambda ax : plot_conic(ax, center=(0,0), angle=angle, radius=radius)
# mask_fun = lambda x, y : in_conic(x, y, angle=angle, radius=radius)

plot_mask_fun = lambda ax : plot_circle(ax, center=(0,0), radius=radius)
mask_fun = lambda x, y : in_circle(x, y, radius=radius)


N = 80

imgs = []

initial_state = current_state = env.reset(scenario)
for _ in range(N):

  # Simulator state
  current_state = datatypes.update_state_by_log(current_state, num_steps=1)
  # Observation
  sdc_obs = datatypes.sdc_observation_from_state(current_state,
                                                 roadgraph_top_k=roadgraph_top_k)

  # Limiting the observation of the SDC to the objects closer than the radius

  visible_obj = mask_fun(sdc_obs.trajectory.x,
                          sdc_obs.trajectory.y) 

  trajectory_limited = dataclasses.replace(sdc_obs.trajectory,
                                            valid=visible_obj)

  sdc_obs_limited = dataclasses.replace(sdc_obs,
                                        trajectory=trajectory_limited)

  img = plot_observation_with_mask(sdc_obs_limited,
                                   obj_idx=0,
                                   mask_function=plot_mask_fun)
  imgs.append(img)

mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


In [407]:
# Save as gif 
frames = []

for img in imgs:
    pil_img = Image.fromarray(img)
    frames.append(pil_img)

# Save the frames as a GIF
frames[0].save("../animation/full_obs.gif", save_all=True, append_images=frames[1:], duration=200, loop=0)