In [1]:
import dataclasses
import jax
from jax import random
from jax import numpy as jnp
import mediapy

from waymax import agents
from waymax import config as _config
from waymax import dynamics
from waymax import dataloader
from waymax import datatypes
from waymax import env as _env
from waymax.datatypes import observation
from waymax.datatypes import object_state
from waymax.dynamics import discretizer

import sys
sys.path.append('../')
sys.path.append('./')

from utils import plot_observation_with_mask

CURRENT_TIME_INDEX = 10
N_SIMULATION_STEPS = 80
N_ROLLOUTS = 32

2023-11-03 16:02:40.584207: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-11-03 16:02:41.544510: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-11-03 16:02:41.544623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64


In [2]:
WOD_1_1_0_TRAINING = _config.DatasetConfig(
    path='/data/saruman/cleain/WOD_1_1_0/tf_example/training/training_tfexample.tfrecord@1000',
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
    batch_dims = (4,),
    max_num_objects=8
)

WOD_1_1_0_VALIDATION = _config.DatasetConfig(
    path='/data/saruman/cleain/WOD_1_1_0/tf_example/validation/validation_tfexample.tfrecord@150',
    max_num_rg_points=20000,
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=8
)

In [3]:
data_iter = dataloader.simulator_state_generator(config=WOD_1_1_0_TRAINING)

scenario = next(data_iter)

2023-11-03 16:02:43.920318: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-11-03 16:02:43.920427: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-11-03 16:02:43.920497: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-11-03 16:02:43.920566: W tensorflow/compiler/xla/stream_executor/platform/defa

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [5]:
i = 0
for scenario in data_iter:
    if i > 10:
        break
    i += 1

In [31]:
obs = []
traj_len = 80

# Env

env_config = _config.EnvironmentConfig(
    # Ensure that the sim agent can control all valid objects.
    controlled_object=_config.ObjectType.VALID,
    max_num_objects=8
)

dynamics_model = dynamics.InvertibleBicycleModel()
discrete_dynamic_model = discretizer.DiscreteActionSpaceWrapper(dynamics_model=dynamics_model,
                                                                bins=128 * jnp.ones_like(dynamics_model.action_spec(), dtype='uint8'))

env = _env.MultiAgentEnvironment(
    dynamics_model=discrete_dynamic_model,
    config=env_config,
)

expert_agent = agents.create_expert_actor(discrete_dynamic_model)

In [36]:
dynamics_model.action_spec()

BoundedArray(shape=(2,), dtype=dtype('float32'), name=None, minimum=[-6.  -0.3], maximum=[6.  0.3])

In [41]:
discrete_dynamic_model.action_spec().maximum+1

array([16641], dtype=int32)

In [42]:
expert_agent.select_action(None, scenario, None, None).action.data.shape

(4, 8, 1)

In [None]:
expert_agent.select_action(None, scenario, None, None).action.data

#### Using the observayion datatype and updating the simulator

In [None]:
traj_len = 80
roadgraph_top_k = 1000

jit_sdc_obs_from_state = jax.jit(datatypes.sdc_observation_from_state)

obs = []

current_state = env.reset(scenario)
for _ in range(traj_len):
    current_state = datatypes.update_state_by_log(current_state, num_steps=1)

    # sdc_obs = jit_sdc_obs_from_state(current_state, roadgraph_top_k=roadgraph_top_k)
    sdc_obs = datatypes.sdc_observation_from_state(current_state, roadgraph_top_k=roadgraph_top_k)
    
    obs.append(sdc_obs)

#### Using the initial simulator state and log trajectories

In [None]:
traj_len = 80

current_state = env.reset(scenario)

# Identify SDC agent
sdc_idx = jax.lax.top_k(current_state.object_metadata.is_sdc, k=1)[1]
sdc_x = jnp.take_along_axis(current_state.log_trajectory.x, sdc_idx[..., jnp.newaxis], axis=-2)
sdc_y = jnp.take_along_axis(current_state.log_trajectory.y, sdc_idx[..., jnp.newaxis], axis=-2)

sdc_xy = jnp.take_along_axis(current_state.log_trajectory.xy, sdc_idx[..., jnp.newaxis, jnp.newaxis], axis=-3)

sdc_yaw = jnp.take_along_axis(current_state.log_trajectory.yaw, sdc_idx[..., jnp.newaxis], axis=-2)
sdc_valid = jnp.take_along_axis(current_state.log_trajectory.valid, sdc_idx[..., jnp.newaxis], axis=-2)


# # Translate the object positions
# obj_x = current_state.log_trajectory.x
# obj_y = current_state.log_trajectory.y

# sdc_x = jnp.tile(sdc_x, (1, obj_x.shape[1], 1))
# sdc_y = jnp.tile(sdc_y, (1, obj_y.shape[1], 1))

# obj_x = obj_x - sdc_x
# obj_y = obj_y - sdc_y

# obs = {'obj_x': obj_x,
#        'obj_y': obj_y,
#        }

# # Translate roadmap ## To be continued


# Translate the light signals

global_obs = observation.global_observation_from_state(current_state, obs_num_steps=1, num_obj=1)
sdc_pose2D = jnp.take_along_axis(global_obs.pose2d, sdc_idx[..., jnp.newaxis], axis=-1)

pose2d = observation.ObjectPose2D.from_center_and_yaw(xy=sdc_xy, yaw=sdc_yaw, valid=sdc_valid) #Problem

pose = observation.combine_two_object_pose_2d(src_pose=sdc_pose2D, dst_pose=pose2d)

transf_traj = observation.transform_trajectory(current_state.log_trajectory, pose)
transf_rg = observation.transform_roadgraph_points(current_state.roadgraph_static_points, pose)
transf_tls = observation.transform_traffic_lights(current_state.traffic_lights, pose)


In [37]:
current_state.roadgraph_points.shape

(4, 20000)

## Test rnnBC

In [4]:
config = {
    'roadgraph_top_k': 100,
    'NUM_ENVS': 4,
    'NUM_STEPS': 80,
    'KEY': random.PRNGKey(42)
    }

In [5]:
from typing import NamedTuple

class Transition(NamedTuple):
    done: jnp.ndarray
    expert_action: jnp.array
    obs: jnp.ndarray

In [6]:
current_state = env.reset(scenario)
obsv = datatypes.sdc_observation_from_state(current_state,
                                    roadgraph_top_k=config['roadgraph_top_k'])

expert_action = expert_agent.select_action(state=current_state, params=None, rng=None, actor_state=None)

runner_state = (current_state,
                expert_action,
                obsv,
                jnp.zeros((config["NUM_ENVS"]), dtype=bool),
                )

# COLLECT TRAJECTORIES FROM scenario
def _env_step(runner_state, unused):
    current_state, expert_action, _, _ = runner_state
    
    current_state = datatypes.update_state_by_log(current_state, num_steps=1)
    done = jnp.tile(current_state.is_done, (4,))

    obsv = datatypes.sdc_observation_from_state(current_state,
                                                roadgraph_top_k=config['roadgraph_top_k'])

    expert_action = expert_agent.select_action(state=current_state, params=None, rng=None, actor_state=None)
    
    # Add a mask here

    runner_state = (current_state, expert_action, obsv, done)

    transition = Transition(done,
                            expert_action,
                            obsv
                            )
    return runner_state, transition

# Use jax.lax.scan with the modified _env_step function
_, traj_batch = jax.lax.scan(f=_env_step, init=runner_state, xs=None, length=config["NUM_STEPS"])


NameError: name 'env' is not defined

In [100]:
traj_batch.obs.trajectory.xy.shape

(80, 4, 1, 8, 1, 2)

In [12]:
init_x = jnp.zeros((1, 4, 8, 2))

In [17]:
init_x_reshaped = init_x.reshape((1, 4, -1))

In [18]:
init_x_reshaped.shape

(1, 4, 16)

In [7]:

# Env config
env_config = _config.EnvironmentConfig(
    # Ensure that the sim agent can control all valid objects.
    controlled_object=_config.ObjectType.SDC,
    max_num_objects=8
)

dynamics_model = dynamics.InvertibleBicycleModel()

action_space_dim = dynamics_model.action_spec().shape
dynamics_model = dynamics.discretizer.DiscreteActionSpaceWrapper(dynamics_model=dynamics_model,
                                                                 bins=128 * jnp.ones((action_space_dim), dtype='uint8'))

dynamics_model = _env.PlanningAgentDynamics(dynamics_model)

env = _env.PlanningAgentEnvironment(dynamics_model=dynamics_model,
                                    config=env_config,
                                    )

# DEFINE EXPERT AGENT
expert_agent = agents.create_expert_actor(dynamics_model)

# INIT ENV
current_state = env.reset(scenario)
obsv = datatypes.sdc_observation_from_state(current_state,
                                            roadgraph_top_k=100)

expert_action = expert_agent.select_action(state=current_state,
                                            actor_state=None,
                                            params=None,
                                            rng=None)

In [56]:
expert_action.action.data.shape

(4, 1)