In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [5]:
import jax
import jax.numpy as jnp
import optax
import flax
from flax.training import train_state, checkpoints
from flax_rbf.flax_rbf import *
from irbfn_mpc.model import WCRBFNet

In [6]:
import gymnasium as gym
from irbfn_mpc.planner_utils import intersect_point, nearest_point_on_trajectory
import numpy as np
import yaml
import argparse

In [7]:
from IPython.display import HTML, display
%matplotlib inline

In [11]:
@jax.jit
def pred_step(state, x):
    y = state.apply_fn(state.params, x)
    return y

In [12]:
config_f = "configs/nmpc_1_region.yaml"
ckpt = "ckpts/nmpc_1_region/checkpoint_0"
with open(config_f, "r") as f:
    config_dict = yaml.safe_load(f)
conf = argparse.Namespace(**config_dict)

In [13]:
wcrbf = WCRBFNet(
    in_features=conf.in_features,
    out_features=conf.out_features,
    num_kernels=conf.num_kernels,
    basis_func=eval(conf.basis_func),
    num_regions=conf.num_regions,
    lower_bounds=conf.lower_bounds,
    upper_bounds=conf.upper_bounds,
    dimension_ranges=conf.dimension_ranges,
    activation_idx=conf.activation_idx,
    delta=conf.delta,
)

rng = jax.random.PRNGKey(conf.seed)
rng, init_rng = jax.random.split(rng)
params = wcrbf.init(init_rng, jnp.ones((1, conf.in_features)))
optim = optax.adam(conf.lr)
state = train_state.TrainState.create(apply_fn=wcrbf.apply, params=params, tx=optim)
restored_state = checkpoints.restore_checkpoint(ckpt_dir=ckpt, target=state)

In [30]:
env = gym.make(
    "f1tenth_gym:f1tenth-v0",
    config={
        "observation_config": {"type": "kinematic_state"},
        "num_agents": 1,
    },
    render_mode="rgb_array",
)
env = gym.wrappers.RecordVideo(env, "video_irbfn_nmpc")

  logger.warn(


In [9]:
track = env.unwrapped.track
waypoints = np.stack([track.centerline.xs, track.centerline.ys, track.centerline.vxs, track.centerline.yaws]).T

In [10]:
def get_current_waypoint(waypoints, lookahead_distance, position, theta):
    wpts = waypoints[:, :2]
    lookahead_distance = np.float32(lookahead_distance)
    nearest_point, nearest_dist, t, i = nearest_point_on_trajectory(position, wpts)
    if nearest_dist < lookahead_distance:
        t1 = np.float32(i + t)
        lookahead_point, i2, t2 = intersect_point(
            position, lookahead_distance, wpts, t1, wrap=True
        )
        if i2 is None:
            return None, None
        current_waypoint = np.empty((3,), dtype=np.float32)
        # x, y
        current_waypoint[0:2] = wpts[i2, :]
        # speed
        current_waypoint[2] = waypoints[i, 2]
        return current_waypoint, i
    elif nearest_dist < 20:
        return wpts[i, :], i
    else:
        return None, None

In [33]:
# TODO: fix
obs, info = env.reset()
done = False

ld = 3.5
step = 0

while not done:
    current_state = obs["agent_0"]
    current_pos = np.array([current_state["pose_x"], current_state["pose_y"]])
    goal, goal_i = get_current_waypoint(
        waypoints,
        ld,
        current_pos,
        current_state["pose_theta"],
    )

    rot = jnp.array(
        [
            [
                jnp.cos(-current_state["pose_theta"]),
                -jnp.sin(-current_state["pose_theta"]),
            ],
            [
                jnp.sin(-current_state["pose_theta"]),
                jnp.cos(-current_state["pose_theta"]),
            ],
        ]
    )
    goal_local = jnp.dot(rot, (goal[:2] - current_pos))

    rbf_in = jnp.array(
        [
            [
                current_state["linear_vel_x"],
                goal_local[0],
                goal_local[1],
                goal[2],
                waypoints[goal_i, 3],
            ]
        ]
    )
    pred_u = pred_step(restored_state, rbf_in)

    action = env.action_space.sample()
    speed = current_state["linear_vel_x"] + pred_u[0, 0] * 0.1
    steer = current_state["delta"] + pred_u[0, 5] * 0.1
    action[0] = [steer, speed]

    obs, step_reward, done, truncated, info = env.step(action)
    step += 1

env.close()

Moviepy - Building video /home/irbfn/scripts/video_irbfn_nmpc/rl-video-episode-1.mp4.
Moviepy - Writing video /home/irbfn/scripts/video_irbfn_nmpc/rl-video-episode-1.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/irbfn/scripts/video_irbfn_nmpc/rl-video-episode-1.mp4




In [34]:
import glob
import io
import base64

for video_file in glob.glob("video_irbfn_nmpc/*.mp4"):
    video = io.open(video_file, "rb").read()
    encoded = base64.b64encode(video).decode("ascii")
    display(
        HTML(
            f"""<video width="360" height="auto" controls>
                <source src="data:video/mp4;base64,{encoded}" type="video/mp4" />
            </video>"""
        )
    )