In [1]:
import dataclasses
import jax
from jax import random
from jax import numpy as jnp
import mediapy
from PIL import Image

from waymax import agents
from waymax import config as _config
from waymax import dynamics
from waymax import dataloader
from waymax import datatypes
from waymax import env as _env

import sys
sys.path.append('../')
sys.path.append('./')

from utils.viz import plot_observation_with_mask, plot_observation_with_goal
from obs_mak.mask import DistanceObsMask, ConicObsMask

CURRENT_TIME_INDEX = 10
N_SIMULATION_STEPS = 80
N_ROLLOUTS = 32

2023-11-22 17:54:09.507938: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-22 17:54:09.507991: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-22 17:54:09.508040: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
config = {
    'anneal_lr': False,
    'bins': 128,
    'discrete': False,
    'key': 42,
    'lr': 3e-4,
    "max_grad_norm": 0.5,
    'max_num_obj': 8,
    'max_num_rg_points': 20000,
    'num_envs': 4,
    'num_envs_eval': 4,
    "num_epochs": 1,
    'num_steps': 80,
    "n_train_per_epoch": 1000,
    'roadgraph_top_k': 100,
    'shuffle_seed': 123,
    'shuffle_buffer_size': 1_000,
    'total_timesteps': 100,
    'training_path': '/data/saruman/cleain/WOD_1_1_0/tf_example/training/training_tfexample.tfrecord@1000',
    'validation_path': '/data/saruman/cleain/WOD_1_1_0/tf_example/validation/validation_tfexample.tfrecord@150'
    }

In [3]:
WOD_1_1_0_TRAINING = _config.DatasetConfig(
    path=config['training_path'],
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
)

WOD_1_1_0_VALIDATION = _config.DatasetConfig(
    path=config['validation_path'],
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=config['max_num_obj'],
    # batch_dims=(config['num_envs_eval'],)
)

In [4]:
data_iter = dataloader.simulator_state_generator(config=WOD_1_1_0_VALIDATION)
scenario = next(data_iter)

2023-11-22 17:54:12.703876: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [6]:
scenario = next(data_iter)

In [5]:
env_config = _config.EnvironmentConfig(
    # Ensure that the sim agent can control all valid objects.
    controlled_object=_config.ObjectType.VALID,
    max_num_objects=config['max_num_obj']
)

dynamics_model = dynamics.InvertibleBicycleModel()
env = _env.MultiAgentEnvironment(
    dynamics_model=dynamics_model,
    config=env_config
)

agent = agents.create_constant_speed_actor(
    dynamics_model=dynamics_model,
    # Controlled objects are those valid at t=0.
    is_controlled_func=lambda state: state.log_trajectory.valid[..., CURRENT_TIME_INDEX]
)

jit_step = jax.jit(env.step)
jit_observe = jax.jit(env.observe)
jit_select_action = jax.jit(agent.select_action)

In [6]:
key = random.PRNGKey(0)
N = 10

In [7]:
initial_state = current_state = env.reset(scenario)
# Controlled objects are those valid at t=0.
is_controlled = scenario.log_trajectory.valid[..., CURRENT_TIME_INDEX]

# Run the sim agent for N steps.
for _ in range(N):
    key, actor_key = random.split(key, 2)
    current_obs = jit_observe(current_state)
    actor_output = jit_select_action({}, current_state, None, actor_key)
    next_state = jit_step(current_state, actor_output.action)
    current_state = next_state

In [7]:
roadgraph_top_k = 5000
radius = 50
angle = jnp.pi / 4

obs_mask = DistanceObsMask(radius=radius)

plot_mask_fun = lambda ax : obs_mask.plot_mask_fun(ax, center=(0,0))
mask_fun = lambda x, y : obs_mask.mask_fun(x, y)

N = 80

imgs = []

initial_state = current_state = env.reset(scenario)
for _ in range(N):

  # Simulator state
  current_state = datatypes.update_state_by_log(current_state, num_steps=1)
  # Observation
  sdc_obs = datatypes.sdc_observation_from_state(current_state,
                                                 roadgraph_top_k=roadgraph_top_k)

  # Limiting the observation of the SDC to the objects closer than the radius

  visible_obj = mask_fun(sdc_obs.trajectory.x,
                          sdc_obs.trajectory.y) 

  trajectory_limited = dataclasses.replace(sdc_obs.trajectory,
                                            valid=visible_obj)

  sdc_obs_limited = dataclasses.replace(sdc_obs,
                                        trajectory=trajectory_limited)

  img = plot_observation_with_mask(sdc_obs_limited,
                                   obj_idx=0,
                                   mask_function=plot_mask_fun)
  imgs.append(img)

mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


In [407]:
# Save as gif 
frames = []

for img in imgs:
    pil_img = Image.fromarray(img)
    frames.append(pil_img)

# Save the frames as a GIF
frames[0].save("../animation/full_obs.gif", save_all=True, append_images=frames[1:], duration=200, loop=0)

## With a proxy goal 
### Last position of the SDC in the referential of the current position of the SDC

In [13]:
import sys
sys.path.append('../')
sys.path.append('../model')
from model.state_preprocessing import ExtractXYGoal

In [35]:
from abc import ABC, abstractmethod
from dataclasses import dataclass
import jax
import jax.numpy as jnp
from typing import Any, Dict

from model.config import UNVALID_MASK_VALUE
from waymax import datatypes

from utils.observation import last_sdc_observation_for_current_sdc_from_state

class Extractor(ABC):

    @abstractmethod
    def __call__(self, *args: Any, **kwds: Any) -> Any:
        pass
        
    @abstractmethod
    def init_x(self):
        pass

@dataclass
class ExtractXY(Extractor):
    config: Dict

    def __call__(self, state):
        obs = datatypes.sdc_observation_from_state(state,
                                                   roadgraph_top_k=self.config['roadgraph_top_k'])
        traj = obs.trajectory.xy
        valid = obs.trajectory.valid[..., None]
        masked_traj = jnp.where(valid, traj, UNVALID_MASK_VALUE * jnp.ones_like(traj))

        return masked_traj
    
    def init_x(self,):
        return  (jnp.zeros((1, self.config["num_envs"], self.config['max_num_obj'], 2)),
                 jnp.zeros((1, self.config["num_envs"]), dtype=bool),
                 )
    
@dataclass
class ExtractXYGoal(Extractor):
    config: Dict

    def __call__(self, state): 
        
        # Last obs of the log in the current SDC pos referential
        last_sdc_pos = last_sdc_observation_for_current_sdc_from_state(state)

        # Get the last log pos of the SDC
        _, sdc_idx = jax.lax.top_k(state.object_metadata.is_sdc, k=1)
        sdc_xy = jnp.take_along_axis(last_sdc_pos.trajectory.xy[..., 0, :], sdc_idx[..., None, None], axis=-2)
        
        # Mask if no SDC
        mask = jnp.any(state.object_metadata.is_sdc)[..., None]
        # Extract batched proxy goal
        proxy_goal = sdc_xy * mask

        return {"xy": ExtractXY(self.config)(state),
                "proxy_goal": proxy_goal}

    def init_x(self):

        return(
            {'xy': jnp.zeros((1, self.config["num_envs"], self.config['max_num_obj'], 2)),
            'proxy_goal': jnp.zeros((1, self.config["num_envs"], 2))},
            jnp.zeros((1, self.config["num_envs"]), dtype=bool),
        )

In [53]:
roadgraph_top_k = 5000
radius = 50
angle = jnp.pi / 4

N = 80

imgs = []

initial_state = current_state = env.reset(scenario)
for _ in range(N):

  # Simulator state
  current_state = datatypes.update_state_by_log(current_state, num_steps=1)
  # Observation
  sdc_obs = datatypes.sdc_observation_from_state(current_state,
                                                 roadgraph_top_k=roadgraph_top_k)
  
  obs = ExtractXYGoal(config)(current_state)

  img = plot_observation_with_goal(sdc_obs,
                                   obj_idx=0,
                                   goal=obs['proxy_goal'][0, 0])
  imgs.append(img)

In [54]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


In [57]:
# Save as gif 
frames = []

for img in imgs:
    pil_img = Image.fromarray(img)
    frames.append(pil_img)

# Save the frames as a GIF
frames[0].save("../animation/goal/ex_1.gif", save_all=True, append_images=frames[1:], duration=200, loop=0)