In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import jax
import jax.numpy as jnp
import optax
import flax
from flax.training import train_state, checkpoints
from flax_rbf.flax_rbf import *
from irbfn_mpc.model import WCRBFNet
# from irbfn_mpc.irbfn_planner import IRBFNPlanner
import irbfn_mpc.irbfn_planner as ip



In [3]:
import gymnasium as gym
from irbfn_mpc.planner_utils import intersect_point, nearest_point
import numpy as np
import yaml
import argparse

In [4]:
from IPython.display import HTML, display
%matplotlib inline

In [5]:
env = gym.make(
    "f1tenth_gym:f1tenth-v0",
    config={
        "map": "L",
        "observation_config": {
            "type": "features",
            "features": [
                "pose_x",
                "pose_y",
                "delta",
                "linear_vel_x",
                "linear_vel_y",
                "pose_theta",
                "ang_vel_z",
                "beta",
            ],
        },
        "num_agents": 1,
        "control_input": ["accl", "steering_speed"],
    },
    render_mode="rgb_array",
)
env = gym.wrappers.RecordVideo(env, "video_irbfn_nmpc_frenet")

  logger.warn(


In [6]:
track = env.unwrapped.track
waypoints = np.stack([track.centerline.xs, track.centerline.ys, track.centerline.vxs, track.centerline.yaws]).T

In [7]:
config_f = "configs/dnmpc_warmstart_constraint_centers_mode.yaml"
ckpt = "ckpts/dnmpc_warmstart_constraint_centers_mode/checkpoint_8400"

In [8]:
from importlib import reload
reload(ip)

<module 'irbfn_mpc.irbfn_planner' from '/home/irbfn/src/irbfn_mpc/irbfn_planner.py'>

In [9]:
# planner = ip.IRBFNFrenetPlanner(config_f, ckpt, track, sv_ind=5, deeper=False, mlp=False, fixed_centers=True, centers_path="/data/tables/frenet/constraints_12ey_7delta_11vxcar_11vycar_5vxgoal_11wz_11epsi_3curv_mu1.0999999999999999_cs5.0_sorted_top500mode.npz")
planner = ip.IRBFNFrenetPlanner(config_f, ckpt, track, sv_ind=5, deeper=False, mlp=False)
# planner = ip.IRBFNFrenetPlanner(config_f, ckpt, track, sv_ind=5, deeper=False, mlp=True)

In [10]:
env.unwrapped.clear_render_callback()
env.unwrapped.add_render_callback(planner.render_waypoints)
env.unwrapped.add_render_callback(planner.render_local_plan)
env.unwrapped.add_render_callback(planner.render_mpc_sol)

In [11]:
import chex
chex.clear_trace_counter()

In [12]:
poses = np.array(
    [
        [
            env.unwrapped.track.raceline.xs[-280],
            env.unwrapped.track.raceline.ys[-280],
            env.unwrapped.track.raceline.yaws[-280],
        ]
    ]
)
obs, info = env.reset(options={"poses": poses})
done = False

step = 0

while not done and step < 3000:
    current_state = obs["agent_0"]
    if current_state["linear_vel_x"] < 1.0:
        accl = 9.0
        steerv = 0.0
    else:
        accl, steerv, pred_u = planner.plan(current_state)
        print(pred_u)
    action = env.action_space.sample()
    # speed = current_state["linear_vel_x"] + pred_u[0, 0] * 0.1
    # steer = current_state["delta"] + pred_u[0, 5] * 0.1
    action[0] = [steerv, accl]
    print(f"steerv: {steerv}, accl: {accl}")
    print("---------------------------------")

    obs, step_reward, done, truncated, info = env.step(action)
    step += 1

  gym.logger.warn("Casting input x to numpy array.")


steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
steerv: 0.0, accl: 9.0
---------------------------------
Mirror False, ey lookup: -0.0007423310889862478
Mirror False, delta lookup: 0.0
Mirror False, vx_car lookup: 1.0800000429153442
Mirror False, vy_car lookup: 0.0
Mirror False, vx_goal lookup: 4.0
Mirror False, wz lookup: 0.0
Mirror False, epsi lookup: -0.022678589448332787
Mirror False, curv lookup: 0.06969100236892

In [13]:
env.close()

Moviepy - Building video /home/irbfn/scripts/video_irbfn_nmpc_frenet/rl-video-episode-0.mp4.
Moviepy - Writing video /home/irbfn/scripts/video_irbfn_nmpc_frenet/rl-video-episode-0.mp4



                                                                 

Moviepy - Done !
Moviepy - video ready /home/irbfn/scripts/video_irbfn_nmpc_frenet/rl-video-episode-0.mp4




In [14]:
import glob
import io
import base64

for video_file in glob.glob("video_irbfn_nmpc_frenet/*.mp4"):
    video = io.open(video_file, "rb").read()
    encoded = base64.b64encode(video).decode("ascii")
    display(
        HTML(
            f"""<video width="800" height="auto" controls>
                <source src="data:video/mp4;base64,{encoded}" type="video/mp4" />
            </video>"""
        )
    )