In [1]:
!git clone https://github.com/kchen50/point-robot-imitation-learning

fatal: destination path 'point-robot-imitation-learning' already exists and is not an empty directory.


In [2]:
!pip install -r point-robot-imitation-learning/requirements.txt



In [4]:
!git -C point-robot-imitation-learning/ fetch
!git -C point-robot-imitation-learning/ switch master

import sys
sys.path.append("/content/point-robot-imitation-learning")

Already on 'master'
Your branch is up to date with 'origin/master'.


In [7]:
import os
os.environ["MUJOCO_GL"] = "egl"
import mujoco.egl

In [8]:
import os

import torch
import numpy as np
from PIL import Image
import mediapy as media

In [9]:
from policy.policy import Policy
from planner.rrt_core import in_goal, get_qpos_indices, get_qvel_indices, get_ctrl_indices

In [10]:
def run_policy_colab_mjr(
    xml_path,
    policy_path,
    randomize_start=False,
    start_pose=None,
    seed=123,
    max_steps=1000,
    fps=30,
    W=640,
    H=480,
    show_first_frame=False,
):
    if not os.path.exists(xml_path):
        raise FileNotFoundError(f"XML not found: {xml_path}")
    if not os.path.exists(policy_path):
        raise FileNotFoundError(f"Policy not found: {policy_path}")

    # Load policy
    policy = Policy.load(policy_path)
    policy.eval()

    # Build sim
    model = mujoco.MjModel.from_xml_path(xml_path)
    data = mujoco.MjData(model)

    gl_context = mujoco.egl.GLContext(1024, 1024)
    gl_context.make_current()
    mujoco.MjrContext(model, mujoco.mjtFontScale.mjFONTSCALE_100.value)

    qpos_idx = get_qpos_indices(model)
    qvel_idx = get_qvel_indices(model)
    ctrl_idx = get_ctrl_indices(model)

    # Optional starts
    if randomize_start:
        rng = np.random.default_rng(seed)
        start_pose = np.array([rng.uniform(-0.13, 1.03), rng.uniform(-0.3, 0.3)], dtype=float)
        data.qpos[qpos_idx] = start_pose
        data.qvel[qvel_idx] = 0.0
        mujoco.mj_forward(model, data)

    if start_pose is not None:
        if isinstance(start_pose, np.ndarray) and start_pose.shape == (2,):
            data.qpos[qpos_idx] = start_pose
            data.qvel[qvel_idx] = 0.0
            mujoco.mj_forward(model, data)

    # --- Low-level rendering objects (your style) ---
    con = mujoco.MjrContext(model, mujoco.mjtFontScale.mjFONTSCALE_100.value)

    scene = mujoco.MjvScene(model, maxgeom=10000)
    cam = mujoco.MjvCamera()
    opt = mujoco.MjvOption()
    pert = mujoco.MjvPerturb()

    # mujoco.mjv_defaultFreeCamera(model, cam)
    xmin, xmax = -0.13, 1.03
    ymin, ymax = -0.30, 0.30

    cx = 0.5 * (xmin + xmax)
    cy = 0.5 * (ymin + ymax)
    span = max(xmax - xmin, ymax - ymin)

    cam.lookat[:] = np.array([cx, cy, 0.0])
    cam.azimuth = 90
    cam.elevation = -89
    cam.distance = 1.2 * span  # scale factor to fit

    viewport = mujoco.MjrRect(0, 0, W, H)

    frames = []
    sim_dt = model.opt.timestep
    frame_interval = 1.0 / fps
    next_frame_t = 0.0

    rgb = np.zeros((H, W, 3), dtype=np.uint8)
    depth = np.zeros((H, W), dtype=np.float32)

    reached = False

    with torch.inference_mode():
        for step in range(max_steps):
            # Build state [x, y, vx, vy]
            pose = data.qpos[qpos_idx].copy()
            vel = data.qvel[qvel_idx].copy()
            state = np.concatenate([pose, vel], axis=0).astype(np.float32)

            action = policy(torch.from_numpy(state).unsqueeze(0)).squeeze(0).cpu().numpy()

            data.ctrl[ctrl_idx] = action
            mujoco.mj_step(model, data)

            # Render at target FPS
            if data.time >= next_frame_t:
                mujoco.mjv_updateScene(
                    model, data, opt, pert, cam,
                    mujoco.mjtCatBit.mjCAT_ALL.value,
                    scene
                )
                mujoco.mjr_render(viewport, scene, con)
                mujoco.mjr_readPixels(rgb, None, viewport, con)

                # Flip vertically (OpenGL origin is bottom-left)
                img = np.flipud(rgb).copy()
                frames.append(img)

                if show_first_frame and len(frames) == 1:
                    display(Image.fromarray(img))

                next_frame_t += frame_interval

            # Goal check (use updated pose after step)
            pose_after = data.qpos[qpos_idx].copy()
            if in_goal(pose_after):
                print(f"Reached goal at sim time {data.time:.2f}s (step {step}).")
                reached = True
                break

    if not reached:
        print(f"Did not reach goal in {max_steps} steps (sim time {data.time:.2f}s).")

    media.show_video(frames, fps=fps)
    return frames

In [11]:
xml_path = "./point-robot-imitation-learning/scenes/point_robot_nav.xml"
policy_path = "./point-robot-imitation-learning/policy.pth"
max_steps = 1500
fps = 30
W = 320
H = 240
show_first_frame = False

In [12]:
frames = run_policy_colab_mjr(
    xml_path,
    policy_path,
    randomize_start=False, # Change this!
    start_pose=None, # Change this!
    seed=123,
    max_steps=max_steps,
    fps=fps,
    W=W,
    H=H,
    show_first_frame=show_first_frame,
)

Did not reach goal in 1500 steps (sim time 15.00s).


0
This browser does not support the video tag.


In [13]:
frames = run_policy_colab_mjr(
    xml_path,
    policy_path,
    randomize_start=False,
    start_pose=np.array([-0.1, -0.2]),
    seed=123,
    max_steps=max_steps,
    fps=fps,
    W=W,
    H=H,
    show_first_frame=show_first_frame,
)

Reached goal at sim time 8.84s (step 883).


0
This browser does not support the video tag.


In [14]:
frames = run_policy_colab_mjr(
    xml_path,
    policy_path,
    randomize_start=False,
    start_pose=np.array([0.4, 0.2]),
    seed=123,
    max_steps=max_steps,
    fps=fps,
    W=W,
    H=H,
    show_first_frame=show_first_frame,
)

Reached goal at sim time 9.21s (step 920).


0
This browser does not support the video tag.
