In [1]:
import torch
import jax
from jax import numpy as jnp
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses
import torch.nn as nn

from waymax import config as _config
from waymax import dataloader
from waymax import visualization
from waymax import env as _env
from waymax import agents
from waymax import dynamics
from waymax import datatypes

In [52]:
max_num_objects = 32
config = dataclasses.replace(_config.WOD_1_1_0_VALIDATION, max_num_objects=max_num_objects)
data_iter = dataloader.simulator_state_generator(config=config)
scenario = next(data_iter)
dynamics_model = dynamics.StateDynamics()

In [73]:
env = _env.MultiAgentEnvironment(
    dynamics_model=dynamics_model,
    config=dataclasses.replace(
        _config.EnvironmentConfig(),
        max_num_objects=max_num_objects,
        controlled_object=_config.ObjectType.VALID,
    ),
)

In [54]:
#helper functions:
def extract_features_and_labels_for_timestep(log_trajectory, t):
    features = torch.cat([
        torch.tensor(jax.device_get(log_trajectory.xyz[:, t:t+1, :]), dtype=torch.float32),
        torch.tensor(jax.device_get(log_trajectory.yaw[:, t:t+1]), dtype=torch.float32).unsqueeze(-1),
        torch.tensor(jax.device_get(log_trajectory.vel_xy[:, t:t+1, :]), dtype=torch.float32),
    ], dim=-1)

    labels = torch.cat([
        torch.tensor(jax.device_get(log_trajectory.xyz[:, t+1:t+2, :]), dtype=torch.float32),
        torch.tensor(jax.device_get(log_trajectory.yaw[:, t+1:t+2]), dtype=torch.float32).unsqueeze(-1),
        torch.tensor(jax.device_get(log_trajectory.vel_xy[:, t+1:t+2, :]), dtype=torch.float32),
    ], dim=-1)

    return features.view(-1, features.shape[-1]), labels.view(-1, labels.shape[-1])


def extract_features_from_current_state(state):
    current_step = state.timestep  
    xyz = torch.tensor(jax.device_get(state.current_sim_trajectory.xyz[:, current_step, :]), dtype=torch.float32)
    yaw = torch.tensor(jax.device_get(state.current_sim_trajectory.yaw[:, current_step]), dtype=torch.float32).unsqueeze(-1)
    vel_xy = torch.tensor(jax.device_get(state.current_sim_trajectory.vel_xy[:, current_step, :]), dtype=torch.float32)

    features = torch.cat([xyz, yaw, vel_xy], dim=-1)

    return features

In [84]:
def test_model(env, bc_model, device, data_iter, num_scenarios=10):
    total_mae = 0
    total_timesteps = 0

    for _ in range(num_scenarios):
        try:
            scenario = next(data_iter)
            state = env.reset(scenario)  

            for t in range(state.remaining_timesteps):
                features, true_actions = extract_features_and_labels_for_timestep(scenario.log_trajectory, t)
                features, true_actions = features.to(device), true_actions.to(device)

                with torch.no_grad():
                    predicted_mean, _ = bc_model(features)
                    mae = torch.mean(torch.abs(predicted_mean - true_actions)).item()

                total_mae += mae
                total_timesteps += 1

        except StopIteration:
            break

    average_mae = total_mae / total_timesteps if total_timesteps > 0 else 0
    print(f'Test MAE across all timesteps: {average_mae:.4f}')

In [75]:
class BCNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BCNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.tanh1 = nn.Tanh()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.tanh2 = nn.Tanh()
        self.mean = nn.Linear(hidden_size, output_size)
        self.log_std = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.tanh1(x)
        x = self.fc2(x)
        x = self.tanh2(x)
        mean = self.mean(x)
        log_std = self.log_std(x)
        std = torch.exp(log_std)
        return mean, std
    
# Initialize
model = BCNetwork(input_size=6, hidden_size=128, output_size=6)

state_dict = torch.load('/Users/carolwang/waymax/algorithms/bc/bc_model.pth')
model.load_state_dict(state_dict)
model = model.to(torch.device("cpu"))

In [85]:
test_model(env, model,torch.device("cpu"),data_iter, 100)

Test MAE across all timesteps: 808.9607


In [77]:
class BCModelSimAgent(agents.SimAgentActor):
    def __init__(self, bc_model, device, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.bc_model = bc_model.to(device)
        self.bc_model.eval()
        self.device = device

    def update_trajectory(self, state: datatypes.SimulatorState) -> datatypes.TrajectoryUpdate:
        features = extract_features_from_current_state(state).to(self.device)
        
        with torch.no_grad():
            predicted_mean, _ = self.bc_model(features.unsqueeze(0))
            predicted_actions = predicted_mean.squeeze(0).cpu().numpy()
        
        return datatypes.TrajectoryUpdate(
            x=predicted_actions[:, 0:1],
            y=predicted_actions[:, 1:2],
            yaw=predicted_actions[:, 2:3],
            vel_x=predicted_actions[:, 3:4],
            vel_y=predicted_actions[:, 4:5],
            valid=np.ones((predicted_actions.shape[0], 1), dtype=bool))

In [80]:
bc_agent = BCModelSimAgent(model, torch.device("cpu"), is_controlled_func=lambda state: jax.numpy.ones((state.object_metadata.num_objects,), dtype=bool))

scenario = next(data_iter)
state = env.reset(scenario)

states = [state]
for _ in range(state.remaining_timesteps):
    action = bc_agent.select_action({}, state, None, None).action
    state = env.step(state, action)
    states.append(state)


imgs = [visualization.plot_simulator_state(state) for state in states]
mediapy.show_video(imgs, fps=10)


0
This browser does not support the video tag.


In [81]:
from waymax.metrics.comfort import KinematicsInfeasibilityMetric


kinematics_metric = KinematicsInfeasibilityMetric(dt=0.1, max_acc=10.4, max_steering=0.3)

scenario = next(data_iter)
state = env.reset(scenario)

metrics_results = []
for _ in range(state.remaining_timesteps):
    action = bc_agent.select_action({}, state, None, None).action
    state = env.step(state, action)
    
    metric_result = kinematics_metric.compute(state)
    metrics_results.append(metric_result)

    state.timestep += 1

infeasible_count = sum(result.value > 0 for result in metrics_results)
total_steps = len(metrics_results)
print(f"Infeasible actions: {infeasible_count}/{total_steps} steps")


Infeasible actions: [1 1 1 1 2 2 1 1 2 1 2 1 1 2 1 1 2 1 1 2 1 2 1 0 1 1 1 1 2 1 1 1]/80 steps


In [79]:
from waymax.metrics.route import ProgressionMetric, OffRouteMetric

progression_metric = ProgressionMetric()
off_route_metric = OffRouteMetric()

env_config = _config.EnvironmentConfig(max_num_objects=32, controlled_object=_config.ObjectType.VALID)
env = _env.MultiAgentEnvironment(dynamics_model=dynamics.StateDynamics(), config=env_config)
scenario = next(data_iter)
state = env.reset(scenario)

progress_results = []
off_route_results = []

for _ in range(state.remaining_timesteps):
    action = bc_agent.select_action({}, state, None, None).action
    state = env.step(state, action)
    
    progress_result = progression_metric.compute(state)
    off_route_result = off_route_metric.compute(state)
    

    progress_results.append(progress_result.value)
    off_route_results.append(off_route_result.value)

print("Progression Results:", progress_results)
print("Off-Route Results:", off_route_results)


ValueError: SimulatorState.sdc_paths required to compute the route progression metric.