In [1]:
import jax
from jax import random
from jax import numpy as jnp
import mediapy
import optax

from waymax import agents
from waymax import config as _config
from waymax import dynamics
from waymax import dataloader
from waymax import datatypes
from waymax import env as _env
from waymax import visualization

from flax.training.train_state import TrainState

import sys
sys.path.append('../')
sys.path.append('./')

from model.feature_extractor import XYExtractor
from model.rnn_policy import ActorCriticRNN, ScannedRNN

2023-11-09 18:20:33.092645: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-09 18:20:33.092703: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-09 18:20:33.092738: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [9]:
# Training config
config = {
    'anneal_lr': False,
    'bins': 128,
    'discrete': False,
    'key': 42,
    'lr': 3e-4,
    "max_grad_norm": 0.5,
    'max_num_obj': 8,
    'max_num_rg_points': 20000,
    'num_envs': 4,
    'num_envs_eval': 1,
    "num_epochs": 1,
    'num_steps': 80,
    "n_train_per_epoch": 1000,
    'roadgraph_top_k': 100,
    'shuffle_seed': 123,
    'shuffle_buffer_size': 1_000,
    'total_timesteps': 100,
    'training_path': '/data/saruman/cleain/WOD_1_1_0/tf_example/training/training_tfexample.tfrecord@1000',
    'validation_path': '/data/saruman/cleain/WOD_1_1_0/tf_example/validation/validation_tfexample.tfrecord@150'
    }

WOD_1_1_0_VALIDATION = _config.DatasetConfig(
    path=config['validation_path'],
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=8,
    # batch_dims=(config['num_envs_eval'],),
)

data_iter = dataloader.simulator_state_generator(config=WOD_1_1_0_VALIDATION)

In [12]:
scenario = next(data_iter)

In [6]:
bicycle_dynamics_model = dynamics.InvertibleBicycleModel()
planning_dynamics_model = _env.PlanningAgentDynamics(bicycle_dynamics_model)

# Env config
env_config = _config.EnvironmentConfig(
    controlled_object=_config.ObjectType.SDC,
    max_num_objects=config['max_num_obj']
)
env = _env.PlanningAgentEnvironment(
            dynamics_model=bicycle_dynamics_model, # /!\ NOT the planning_dynamic_model otherwise repetition
            config=env_config,
            )

# DEFINE EXPERT AGENT
expert_agent = agents.create_expert_actor(planning_dynamics_model)

network = ActorCriticRNN(planning_dynamics_model.action_spec().shape[0],
                            feature_extractor_class=XYExtractor,
                            feature_extractor_kwargs={'max_num_obj': config['max_num_obj']},
                            config=config)

feature_extractor_shape = config['max_num_obj'] * 2

init_rnn_state = ScannedRNN.initialize_carry((config["num_envs"], feature_extractor_shape))
rng, _rng = jax.random.split(random.PRNGKey(config['key']))

init_x = (
    jnp.zeros(
        (1, config["num_envs"], config['max_num_obj'], 2)
    ),
    jnp.zeros((1, config["num_envs"]), dtype=bool),
)
network_params = network.init(_rng, init_rnn_state, init_x)

tx = optax.chain(
    optax.clip_by_global_norm(config["max_grad_norm"]),
    optax.adam(config["lr"], eps=1e-5),
)

train_state = TrainState.create(apply_fn=network.apply,
                                params=network_params,
                                tx=tx,
                                )

In [7]:
from typing import NamedTuple

class Transition(NamedTuple):
    done: jnp.ndarray
    expert_action: jnp.array
    obs: jnp.ndarray

init_rnn_state_eval = ScannedRNN.initialize_carry((config["num_envs_eval"], feature_extractor_shape))

def _log_step(current_state, unused):

    done = current_state.is_done
    full_obsv = datatypes.sdc_observation_from_state(current_state,
                                                roadgraph_top_k=config['roadgraph_top_k'])
    obsv = full_obsv.trajectory.xy # /!\ TODO: need to remove the unvalid object


    transition = Transition(done,
                            None,
                            obsv
    )

    # Update the simulator with the log trajectory
    current_state = datatypes.update_state_by_log(current_state, num_steps=1)

    return current_state, transition


In [13]:
# INIT ENV
current_state = env.reset(scenario)

# Compute the rnn_state on first env.config.init_steps from the log trajectory 
_, log_traj_batch = jax.lax.scan(_log_step, scenario, None, env.config.init_steps - 1) 
rnn_state, _, _ = network.apply(train_state.params, init_rnn_state_eval, (log_traj_batch.obs[:, jnp.newaxis, ...], log_traj_batch.done[:, jnp.newaxis, ...]))


In [14]:
imgs = []
scenario_metrics = {'log_divergence': [],
                    'overlap': [],
                    'offroad': []}

step = 0
while not current_state.is_done:
    step += 1
    # imgs.append(visualization.plot_simulator_state(current_state, use_log_traj=False))
    
    done = jnp.tile(current_state.is_done, (config['num_envs_eval'],))
    full_obsv = datatypes.sdc_observation_from_state(current_state,
                                                roadgraph_top_k=config['roadgraph_top_k'])
    validv = full_obsv.is_ego
    obsv = full_obsv.trajectory.xy # /!\ TODO: need to remove the unvalid object

    expert_action = expert_agent.select_action(state=current_state,
                                            actor_state=None,
                                            params=None,
                                            rng=None)
    
    # Add a mask here
    
    rnn_state, data_action, _ = network.apply(train_state.params, rnn_state, (obsv[jnp.newaxis, ...], done[jnp.newaxis, ...]))
    action = datatypes.Action(data=data_action[0, 1], valid=jnp.ones((1), dtype='bool'))

    current_state = env.step(current_state, action)
    
    metric = env.metrics(current_state)

    for key in metric.keys():
        if metric[key].valid:
            scenario_metrics[key].append(metric[key].value)


In [7]:
mediapy.show_video(imgs, fps=10)

0
This browser does not support the video tag.


## Ground truth

In [None]:
state = scenario

imgs = []
for _ in range(scenario.remaining_timesteps):
  state = datatypes.update_state_by_log(state, num_steps=1)
  imgs.append(visualization.plot_simulator_state(state, use_log_traj=True))

mediapy.show_video(imgs, fps=10)