In [3]:
import gymnasium as gym
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import jax
import jax.numpy as jnp
from irbfn_mpc.planner_utils import intersect_point, nearest_point
from irbfn_mpc.dynamics import integrate_st_mult
import numpy as np

In [4]:
from IPython.display import HTML, display
%matplotlib inline

In [5]:
npz_path = "/data/tables/8v_19x_19y_64t_8vgoal_6beta_12angvz_mu1.0_cs5.0.npz"
data = np.load(npz_path)
inputs, outputs = data["inputs"], data["outputs"]

In [6]:
@jax.jit
def get_closest_ind(input, all_inputs):
    diff = all_inputs - input
    ind = jnp.argmin(jnp.linalg.norm(diff, axis=1))
    return ind

In [7]:
def get_current_waypoint(waypoints, lookahead_distance, position, theta):
    wpts = waypoints[:, :2]
    lookahead_distance = np.float32(lookahead_distance)
    npt, nearest_dist, t, i = nearest_point(position, wpts)
    if nearest_dist < lookahead_distance:
        t1 = np.float32(i + t)
        lookahead_point, i2, t2 = intersect_point(
            position, lookahead_distance, wpts, t1, wrap=True
        )
        if i2 is None:
            return None, None
        current_waypoint = np.empty((3,), dtype=np.float32)
        # x, y
        current_waypoint[0:2] = wpts[i2, :]
        # speed
        current_waypoint[2] = waypoints[i, 2]
        return current_waypoint, i
    elif nearest_dist < 20:
        return wpts[i, :], i
    else:
        return None, None

In [8]:
env = gym.make(
    "f1tenth_gym:f1tenth-v0",
    config={
        "map": "Spielberg_blank",
        "observation_config": {"type": "dynamic_state"},
        "num_agents": 1,
        "control_input": "accl",
    },
    render_mode="rgb_array",
)
env = gym.wrappers.RecordVideo(env, "video_explicit_nmpc")

  logger.warn(


In [9]:
track = env.unwrapped.track

In [10]:
global goal
global output_sol
global waypoints
goal = None
output_sol = None
waypoints = np.stack([track.centerline.xs, track.centerline.ys, track.centerline.vxs, track.centerline.yaws]).T

In [11]:
def render_waypoints(e):
    points = waypoints[:, :2]
    e.render_closed_lines(points, color=(128, 0, 0), size=1)

In [12]:
def render_goal_state(e):
    if goal is not None:
        e.render_points(goal[:2][None], color=(0, 128, 0), size=3)

In [13]:
def render_planner_sol(e):
    if output_sol is not None:
        for traj in output_sol:
            e.render_lines(np.array(traj[:, 0:2]), color=(0, 0, 128), size=2)

In [14]:
env.unwrapped.add_render_callback(render_waypoints)
env.unwrapped.add_render_callback(render_goal_state)
env.unwrapped.add_render_callback(render_planner_sol)

In [15]:
poses = np.array(
    [
        [
            env.unwrapped.track.centerline.xs[0] + 0.1,
            env.unwrapped.track.centerline.ys[0] + 0.1,
            env.unwrapped.track.centerline.yaws[0],
        ]
    ]
)
obs, info = env.reset(options={"poses": poses})
done = False

step = 0

v = obs["agent_0"]["linear_vel_x"]
v_lookahead = max(v, 0.5)
ld = v_lookahead * 0.5


while not done and step <= 10000:
    current_state = obs["agent_0"]
    ctheta = current_state["pose_theta"]
    goal, goal_i = get_current_waypoint(
        waypoints,
        ld,
        np.array([current_state["pose_x"], current_state["pose_y"]]),
        current_state["pose_theta"],
    )
    # goal_state = np.array([goal[0], goal[1], 0.0, goal[2], waypoints[goal_i, 3]])
    # diff_x = goal[0] - current_state["pose_x"]
    # diff_y = goal[1] - current_state["pose_y"]

    rot_m = np.array(
        [[np.cos(-ctheta), -np.sin(-ctheta)], [np.sin(-ctheta), np.cos(-ctheta)]]
    )
    goal_local = np.dot(rot_m, goal[:2].T - np.array(
        [current_state["pose_x"], current_state["pose_y"]]
    ))
    goal_theta = waypoints[goal_i, 3] - current_state["pose_theta"]
    goal_v = waypoints[goal_i, 2]
    goal_x = goal_local[0]
    goal_y = goal_local[1]
    print(f"Current pose: {current_state["pose_x"], current_state["pose_y"]}")
    print(f"Tracking Local: {goal_local.T}, Global: {goal[:2]}")

    x = current_state["pose_x"]
    y = current_state["pose_y"]
    delta = current_state["delta"]
    v = current_state["linear_vel_x"]
    theta = current_state["pose_theta"]
    beta = current_state["beta"]
    angv = current_state["ang_vel_z"]

    v_lookahead = max(v, 0.5)
    ld = v_lookahead * 0.5

    current_input = np.array([v, goal_x, goal_y, goal_theta, goal_v, beta, angv])
    closest_ind = get_closest_ind(current_input, inputs)
    queried_outputs = outputs[closest_ind]
    
    action = env.action_space.sample()
    action[0] = [queried_outputs[0, 1], queried_outputs[0, 0]]
    # print(f"Taking action: {action}")
    # print(f"current state {current_state}")
    # print(f"goal state {goal_state}")
    print(f"taking action {action[0]}")
    cst = np.array([x, y, delta, v, theta, angv, beta])
    x_and_u = np.hstack((cst, queried_outputs.flatten()))
    output_sol = integrate_st_mult(x_and_u[None,:], np.array([1.0, 1.0489, 0.04712, 0.15875, 0.17145, 5.0, 5.0, 0.074, 0.1, 3.2, 9.51, 0.4189, 7.0]))

    obs, step_reward, done, truncated, info = env.step(action)
    step += 1

env.close()

  gym.logger.warn("Casting input x to numpy array.")


Current pose: (0.1, 0.1)
Tracking Local: [0.22253224 0.07061037], Global: [-0.09657146 -0.02596061]
taking action [-2.0990208e-18  1.5517778e+00]


  gym.logger.warn("Casting input x to numpy array.")


Current pose: (0.09992507, 0.099979855)
Tracking Local: [0.22245464 0.07061037], Global: [-0.09657146 -0.02596061]
taking action [-2.0990208e-18  1.5517778e+00]
Current pose: (0.09970029, 0.09991943)
Tracking Local: [0.22222188 0.07061036], Global: [-0.09657146 -0.02596061]
taking action [-2.0990208e-18  1.5517778e+00]
Current pose: (0.09932564, 0.099818714)
Tracking Local: [0.22183393 0.07061037], Global: [-0.09657146 -0.02596061]
taking action [-2.0990208e-18  1.5517778e+00]
Current pose: (0.09880114, 0.09967772)
Tracking Local: [0.22129083 0.07061037], Global: [-0.09657146 -0.02596061]
taking action [-2.0990208e-18  1.5517778e+00]
Current pose: (0.098126784, 0.09949643)
Tracking Local: [0.2205925  0.07061037], Global: [-0.09657146 -0.02596061]
taking action [-2.0990208e-18  1.5517778e+00]
Current pose: (0.09730257, 0.09927486)
Tracking Local: [0.21973903 0.07061036], Global: [-0.09657146 -0.02596061]
taking action [-2.0990208e-18  1.5517778e+00]
Current pose: (0.0963285, 0.09901301)

                                                                  

Moviepy - Done !
Moviepy - video ready /home/irbfn/scripts/video_explicit_nmpc/rl-video-episode-0.mp4


In [16]:
# import chex
# chex.clear_trace_counter()

In [17]:
env.close()

In [19]:
import glob
import io
import base64

for video_file in glob.glob("video_explicit_nmpc/*.mp4"):
    video = io.open(video_file, "rb").read()
    encoded = base64.b64encode(video).decode("ascii")
    display(
        HTML(
            f"""<video width="800" height="auto" controls>
                <source src="data:video/mp4;base64,{encoded}" type="video/mp4" />
            </video>"""
        )
    )