In [16]:
import autoroot
import autorootcwd
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import torch.nn as nn

from pathlib import Path
from mppi_plus_proto.dynamics_models.models import UnicycleSteeringModel, UnicycleModel, BicycleSteeringModel, BicycleModel
from mppi_plus_proto.trainers.mutual_inf import MIJointTrainer, RolloutCollector
from mppi_plus_proto.nn_models.generative import generic_gaussian, generic_realnvp, generic_autoregressive, generic_recurrent_gaussian
from mppi_plus_proto.trainers.data_samplers import sample_uniform

In [17]:
def plot_trajectories_unicycle(trajectories: np.ndarray,
                               ax: plt.Axes = None,
                               figsize: tuple = (6, 6),
                               title: str = None) -> None:
    """
    Plot multiple unicycle trajectories on a 2D plane.
    
    Args:
        trajectories: Array of shape (n_trajectories, horizon, 3) containing 
                     state trajectories where each state is (x, y, theta)
        ax: Matplotlib axes to plot on. If None, creates new figure
        figsize: Figure size for the plot (only used if ax is None)
        title: Title for the plot
    """
    # Create figure if no axes provided
    if ax is None:
        plt.figure(figsize=figsize)
        ax = plt.gca()
    
    # Plot each trajectory
    for i in range(trajectories.shape[0]):
        traj = trajectories[i]
        ax.plot(traj[:, 0], traj[:, 1], alpha=0.3)
    ax.plot(trajectories[-2, :, 0], trajectories[-2, :, 1], alpha=1., color="red", linestyle="--")
    ax.plot(trajectories[-1, :, 0], trajectories[-1, :, 1], alpha=1., color="blue", linestyle="--")
    
    # Set labels and grid
    label_fs = 16
    title_fs = 20
    tick_fs = 14
    ax.set_xlabel('x', fontsize=label_fs)
    ax.set_ylabel('y', fontsize=label_fs)
    ax.tick_params(axis='both', labelsize=tick_fs)
    ax.grid(True)
    ax.set_aspect('equal')
    
    # Set consistent axis limits based on data
    margin = 0.5
    x_min, x_max = trajectories[:, :, 0].min(), trajectories[:, :, 0].max()
    y_min, y_max = trajectories[:, :, 1].min(), trajectories[:, :, 1].max()
    ax.set_xlim(x_min - margin, x_max + margin)
    ax.set_ylim(y_min - margin, y_max + margin)
    
    if title is not None:
        ax.set_title(title, fontsize=title_fs)


def plot_model(ax, 
               model, 
               n_samples: int,
               title: str,
               rollout_collector: RolloutCollector,
               reference_trajectories: np.ndarray):
    with torch.no_grad():
        u_seq, _ = model.sample(n_samples)
        u_seq = u_seq.reshape(n_samples, rollout_collector.horizon, -1)
        x_seq = rollout_collector.rollout(u_seq, clip=True)
        x_seq = x_seq.clone().detach().cpu().numpy()
    
    plot_trajectories_unicycle(np.concat([x_seq, reference_trajectories], axis=0),
                               title=title,
                               ax=ax)


def plot_config(ax, config, n_samples: int):
    rollout_collector = RolloutCollector(dynamics=config["dynamics"],
                                         horizon=config["horizon"],
                                         action_lb=config["action_lb"],
                                         action_ub=config["action_ub"],
                                         action_mid=config["action_mid"],
                                         state_dim_cut=config["state_dim_cut"])
    generator = config["model_fn"](rollout_collector.action_dim, config["horizon"])
    generator.to("cuda")
    generator.load_state_dict(torch.load(config["weights"], map_location="cuda"))
    generator.eval()

    with torch.no_grad():
        max_action = rollout_collector.action_ub
        mid_action = rollout_collector.action_mid
        u_seq_max = torch.tile(max_action, (1, rollout_collector.horizon, 1))
        x_seq_max = rollout_collector.rollout(u_seq_max)
        x_seq_max = x_seq_max.clone().detach().cpu().numpy()[0]

        u_seq_mid = torch.tile(mid_action, (1, rollout_collector.horizon, 1))
        x_seq_mid = rollout_collector.rollout(u_seq_mid)
        x_seq_mid = x_seq_mid.clone().detach().cpu().numpy()[0]
        reference = np.stack([x_seq_max, x_seq_mid], axis=0)

    plot_model(ax, generator, n_samples, config["title"], rollout_collector, reference)


def make_plots(configs: list[dict],
               n_samples: int,
               output_file: str | None = None):
    fig, axs = plt.subplots(1, len(configs), figsize=(12 * len(configs), 10))
    if not isinstance(axs, np.ndarray):
        axs = [axs]
    for config, ax in zip(configs, axs):
        plot_config(ax, config, n_samples)

    plt.tight_layout()

    if output_file is None:
        plt.plot()
    else:
        plt.savefig(output_file, dpi=100, bbox_inches='tight')
        plt.close()


In [18]:
DT = 0.2
WHEEL_BASE = 0.324
SPEED = 1.
BICYCLE_DELTA = float(np.deg2rad(30.))
UNICYCLE_OMEGA = float(np.deg2rad(45.))

configs_16 = [
    {
        "title": "Bicycle, $H$=16, $v$ = 1 m/s, $\delta$ $\in$ (-30, 30) deg",
        "model_fn": lambda a, h: generic_realnvp(a, h),
        "weights": "deploy_checkpoints/bicycle_steer__angle_30__horizon_16__dt_02.pth",
        "horizon": 16,
        "action_lb": (-BICYCLE_DELTA,),
        "action_ub": (BICYCLE_DELTA,),
        "action_mid": (0,),
        "state_dim_cut": 2,
        "dynamics": BicycleSteeringModel(dt=DT,
                                        l=WHEEL_BASE,
                                        speed=SPEED,
                                        backend="torch",
                                        device="cuda")
    },
    {
        "title": "Unicycle, $H$=16, $v$ = 1 m/s, $\omega$ $\in$ (-45, 45) deg/s",
        "model_fn": lambda a, h: generic_realnvp(a, h),
        "weights": "deploy_checkpoints/unicycle_steer__angle_45__horizon_16__dt_02.pth",
        "horizon": 16,
        "action_lb": (-UNICYCLE_OMEGA,),
        "action_ub": (UNICYCLE_OMEGA,),
        "action_mid": (0,),
        "state_dim_cut": 2,
        "dynamics": UnicycleSteeringModel(dt=DT,
                                          speed=SPEED,
                                          backend="torch",
                                          device="cuda")
    },
    {
        "title": "Bicycle, $H$=16, $v$ $\in$ (0.1, 1) m/s, $\delta$ $\in$ (-30, 30) deg",
        "model_fn": lambda a, h: generic_realnvp(a, h, num_layers=64),
        "weights": "deploy_checkpoints/bicycle_full__angle_30__horizon_16__dt_02.pth",
        "horizon": 16,
        "action_lb": (0.1, -BICYCLE_DELTA,),
        "action_ub": (1., BICYCLE_DELTA,),
        "action_mid": (1., 0,),
        "state_dim_cut": None,
        "dynamics": BicycleModel(dt=DT,
                                 l=WHEEL_BASE,
                                 backend="torch",
                                 device="cuda")
    },
    {
        "title": "Unicycle, $H$=16, $v$ $\in$ (0.1, 1) m/s, $\omega$ $\in$ (-45, 45) deg/s",
        "model_fn": lambda a, h: generic_realnvp(a, h),
        "weights": "deploy_checkpoints/unicycle_full__angle_45__horizon_16__dt_02.pth",
        "horizon": 16,
        "action_lb": (0.1, -UNICYCLE_OMEGA,),
        "action_ub": (1., UNICYCLE_OMEGA,),
        "action_mid": (1., 0,),
        "state_dim_cut": None,
        "dynamics": UnicycleModel(dt=DT,
                                          backend="torch",
                                          device="cuda")
    }
]


configs_26 = [
    {
        "title": "Bicycle, $H$=26, $v$ = 1 m/s, $\delta$ $\in$ (-30, 30) deg",
        "model_fn": lambda a, h: generic_realnvp(a, h),
        "weights": "deploy_checkpoints/bicycle_steer__angle_30__horizon_26__dt_02.pth",
        "horizon": 26,
        "action_lb": (-BICYCLE_DELTA,),
        "action_ub": (BICYCLE_DELTA,),
        "action_mid": (0,),
        "state_dim_cut": 2,
        "dynamics": BicycleSteeringModel(dt=DT,
                                        l=WHEEL_BASE,
                                        speed=SPEED,
                                        backend="torch",
                                        device="cuda")
    },
    {
        "title": "Unicycle, $H$=26, $v$ = 1 m/s, $\omega$ $\in$ (-45, 45) deg/s",
        "model_fn": lambda a, h: generic_realnvp(a, h),
        "weights": "deploy_checkpoints/unicycle_steer__angle_45__horizon_26__dt_02.pth",
        "horizon": 26,
        "action_lb": (-UNICYCLE_OMEGA,),
        "action_ub": (UNICYCLE_OMEGA,),
        "action_mid": (0,),
        "state_dim_cut": 2,
        "dynamics": UnicycleSteeringModel(dt=DT,
                                          speed=SPEED,
                                          backend="torch",
                                          device="cuda")
    },
    {
        "title": "Bicycle, $H$=26, $v$ $\in$ (0.1, 1) m/s, $\delta$ $\in$ (-30, 30) deg",
        "model_fn": lambda a, h: generic_realnvp(a, h, num_layers=64),
        "weights": "deploy_checkpoints/bicycle_full__angle_30__horizon_26__dt_02.pth",
        "horizon": 26,
        "action_lb": (0.1, -BICYCLE_DELTA,),
        "action_ub": (1., BICYCLE_DELTA,),
        "action_mid": (1., 0,),
        "state_dim_cut": None,
        "dynamics": BicycleModel(dt=DT,
                                 l=WHEEL_BASE,
                                 backend="torch",
                                 device="cuda")
    },
    {
        "title": "Unicycle, $H$=26, $v$ $\in$ (0.1, 1) m/s, $\omega$ $\in$ (-45, 45) deg/s",
        "model_fn": lambda a, h: generic_realnvp(a, h),
        "weights": "deploy_checkpoints/unicycle_full__angle_45__horizon_26__dt_02.pth",
        "horizon": 26,
        "action_lb": (0.1, -UNICYCLE_OMEGA,),
        "action_ub": (1., UNICYCLE_OMEGA,),
        "action_mid": (1., 0,),
        "state_dim_cut": None,
        "dynamics": UnicycleModel(dt=DT,
                                          backend="torch",
                                          device="cuda")
    }
]

  "title": "Bicycle, $H$=16, $v$ = 1 m/s, $\delta$ $\in$ (-30, 30) deg",
  "title": "Unicycle, $H$=16, $v$ = 1 m/s, $\omega$ $\in$ (-45, 45) deg/s",
  "title": "Bicycle, $H$=16, $v$ $\in$ (0.1, 1) m/s, $\delta$ $\in$ (-30, 30) deg",
  "title": "Unicycle, $H$=16, $v$ $\in$ (0.1, 1) m/s, $\omega$ $\in$ (-45, 45) deg/s",
  "title": "Bicycle, $H$=26, $v$ = 1 m/s, $\delta$ $\in$ (-30, 30) deg",
  "title": "Unicycle, $H$=26, $v$ = 1 m/s, $\omega$ $\in$ (-45, 45) deg/s",
  "title": "Bicycle, $H$=26, $v$ $\in$ (0.1, 1) m/s, $\delta$ $\in$ (-30, 30) deg",
  "title": "Unicycle, $H$=26, $v$ $\in$ (0.1, 1) m/s, $\omega$ $\in$ (-45, 45) deg/s",


In [19]:
all_configs = []
for i in range(len(configs_16)):
    all_configs.append(configs_16[i])
    all_configs.append(configs_26[i])
make_plots(all_configs, n_samples=10000, output_file="samples_all.jpg")

In [20]:
make_plots(configs_16, n_samples=10000, output_file="samples_16.png")

In [21]:
make_plots(configs_26, n_samples=10000, output_file="samples_26.png")