## Batchable Inverse Kinematics with MJINX

Welcome to this tutorial where we'll explore how to solve inverse kinematics (IK) problems efficiently in batch using `mjinx`. 

We'll work with a 7-degree-of-freedom (7-DoF) robotic arm and demonstrate how to:
- Track batch of target poses (positions and orientations) simultaneously
- Enforce joint limits to keep the robot's motion within safe bounds
- Implement additional safety constraints

Before moving forward we need to set up the environment and GPU, make sure to choose one in runtime settings in Google Colab.

If you are running this on a local machine, you can skip the GPU setup.

In [10]:
# Set up GPU rendering. 
# # from google.colab import files
# import distutils.util
import os
import subprocess


# Configure MuJoCo to use the EGL rendering backend (requires GPU)
print('Setting environment variable to use GPU rendering:')
%env MUJOCO_GL=egl

# Graphics and plotting.
print('Installing mediapy:')
!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)


# SETUP XLA FLAGS
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

if subprocess.run('nvidia-smi').returncode:
  raise RuntimeError(
      'Cannot communicate with GPU. '
      'Make sure you are using a GPU Colab runtime. '
      'Go to the Runtime menu and select Choose runtime type.')

# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.
# This is usually installed as part of an Nvidia driver package, but the Colab
# kernel doesn't install its driver via APT, and as a result the ICD is missing.
# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)
NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'
if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):
  with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:
    f.write("""{
    "file_format_version" : "1.0.0",
    "ICD" : {
        "library_path" : "libEGL_nvidia.so.0"
    }
}
""")

from IPython.display import clear_output
clear_output()

To see if ef NVIDIA GPU is properly available let us envoke the `nvidia-smi`.

In [None]:
!nvidia-smi

Once GPU is set up, let us install mjinx with `examples` tag, more on installation in the [installation guide](https://github.com/based-robotics/mjinx?tab=readme-ov-file#installation). 


In [None]:
!pip install mjinx[examples]
clear_output()

import jax
import jax.numpy as jnp
import mujoco as mj
import mujoco.mjx as mjx
import numpy as np
import mediapy as media
from time import perf_counter

 As robot we will use IIWA 14 from `robot_descriptions` package. Let us upload the model and initialize the MuJoCo model and data as well as MJX model.
 

In [16]:
from robot_descriptions.iiwa14_mj_description import MJCF_PATH
mj_model = mj.MjModel.from_xml_path(MJCF_PATH)
mj_data = mj.MjData(mj_model)

mjx_model = mjx.put_model(mj_model)

q_min = mj_model.jnt_range[:, 0].copy()
q_max = mj_model.jnt_range[:, 1].copy()

Creating problem formulation is as simple as:

In [None]:
from mjinx.problem import Problem
# Creating problem formulation
problem = Problem(mjx_model, v_min=-5, v_max=5)

The problem formulation in MJINX is highly modular, and allows to add different components to the problem, all of them are defined in `mjinx.components` module.

To make robot follow a pose trajectory we will use `FrameTask` component.

In [None]:
from mjinx.components.tasks import FrameTask
# Creating task component
frame_task = FrameTask("ee_task", 
                       cost=1, # cost of the task in to the objective function
                       gain=20, # gain of the task: dy = gain * (y_d - y)
                       obj_name="link7") # name of the object to track
# Adding task to the problem
problem.add_component(frame_task)

In `mjinx` we can add different types of constraints/barriers to the problem, them are defined in `mjinx.components.barriers` module.


For instance, we may add `JointLimitBarrier` to ensure the robot doesn't hit the joint limits and `PositionBarrier`  to ensure the robot doesn't hit the environment.

In [None]:
# Importing all MJINX components
from mjinx.components.barriers import JointBarrier, PositionBarrier

joints_barrier = JointBarrier("jnt_range", gain=10)
problem.add_component(joints_barrier)

position_barrier = PositionBarrier(
    "ee_barrier",
    gain=100,
    obj_name="link7",
    limit_type="max",
    p_max=0.5,
    safe_displacement_gain=1e-2,
    mask=[1, 0, 0],
)
problem.add_component(position_barrier)

To update the components in to the problem we need to recompile it:

In [None]:
# Compiling the problem upon any parameters update
problem_data = problem.compile()

Now let us initialize the solver, here we will use `LocalIKSolver` which is QP based solver that uses local optimization over velocities


In [None]:
from mjinx.solvers import LocalIKSolver

# Initializing solver and its initial state
solver = LocalIKSolver(mjx_model, maxiter=10)

The velocities produced by the local IK solver need to be integrated to get the solution for the configuration. To do so we will use `integrate` function from `mjinx.configuration` module.

In [None]:
from mjinx.configuration import integrate

Now let us create the batch of `N_batch` initial configurations and target poses that we will use to solve the inverse kinematics problem.

In [None]:
N_batch = 2000
np.random.seed(42)
q0 = jnp.array(
    [
        -1.4238753,
        -1.7268502,
        -0.84355015,
        2.0962472,
        2.1339328,
        2.0837479,
        -2.5521986,
    ]
)
q = jnp.array(
    [
        np.clip(
            q0
            + np.random.uniform(
                -0.1,
                0.1,
                size=(mj_model.nq),
            ),
            q_min + 1e-1,
            q_max - 1e-1,
        )
        for _ in range(N_batch)
    ]
)

# First of all, data should be created via vmapped init function
solver_data = jax.vmap(solver.init, in_axes=0)(v_init=jnp.zeros((N_batch, mjx_model.nv)))

# To create a batch w.r.t. desired component's attributes, mjinx defines convinient wrapper
# That sets all elements to None and allows user to mutate dataclasses of interest.
# After exiting the Context Manager, you'll get immutable jax dataclass object.
with problem.set_vmap_dimension() as empty_problem_data:
    empty_problem_data.components["ee_task"].target_frame = 0

To accelerate the computations we will compile the solve and integrate functions.

In [None]:
# Vmapping solve and integrate functions.
solve_jit = jax.jit(jax.vmap(solver.solve, in_axes=(0, 0, empty_problem_data)))
integrate_jit = jax.jit(jax.vmap(integrate, in_axes=(None, 0, 0, None)), static_argnames=["dt"])

t_warmup = perf_counter()
print("Performing warmup calls...")
# Warmup iterations for JIT compilation
frame_task.target_frame = np.array([[0.4, 0.2, 0.7, 1, 0, 0, 0] for _ in range(N_batch)])
problem_data = problem.compile()
opt_solution, _ = solve_jit(q, solver_data, problem_data)
q_warmup = integrate_jit(mjx_model, q, opt_solution.v_opt, 0)

t_warmup_duration = perf_counter() - t_warmup
print(f"Warmup completed in {t_warmup_duration:.3f} seconds")

To help visualize the robot's movements, we'll use the `BatchVisualizer` tool, which allows us to see how our solutions play out in real-time.

In [None]:
from mjinx.visualize import BatchVisualizer

vis = BatchVisualizer(MJCF_PATH, n_models=5, alpha=0.5, record=True, passive_viewer=False)
vis.camera.distance = 2
vis.camera.azimuth = 100
vis.camera.elevation = -25
vis.camera.lookat = np.array([0, 0, 0.2])

# Initialize a sphere marker for end-effector task
vis.add_markers(
    name=[f"ee_marker_{i}" for i in range(vis.n_models)],
    size=0.05,
    marker_alpha=0.4,
    color_begin=np.array([0, 1.0, 0.53]),
    color_end=np.array([0.38, 0.94, 1.0]),
    n_markers=vis.n_models,
)
vis.add_markers(
    name="blocking_plane",
    marker_type=mj.mjtGeom.mjGEOM_PLANE,
    size=np.array([0.5, 0.5, 0.02]),
    marker_alpha=0.3,
    color_begin=np.array([1, 0, 0]),
)

vis.marker_data["blocking_plane"].pos = np.array([position_barrier.p_max[0], 0, 0.3])
vis.marker_data["blocking_plane"].rot = np.array(
    [
        [0, 0, -1],
        [0, 1, 0],
        [1, 0, 0],
    ]
)

vis.update(q[:: N_batch // vis.n_models])
media.show_image(vis.frames[-1])

Now we are ready to solve the IK, we will simulate the different target poses for the robot's end-effector and log the time it takes to solve the problem.

In [20]:
dt = 2e-2
ts = np.arange(0, 10, dt)

# Performance tracking
solve_times = []
integrate_times = []
n_steps = 0

for t in ts:
    # Changing desired values
    frame_task.target_frame = np.array(
        [
            [
                0.4 + 0.3 * np.sin(t + 2 * np.pi * i / N_batch),
                0.2,
                0.4 + 0.3 * np.cos(t + 2 * np.pi * i / N_batch),
                1,
                0,
                0,
                0,
            ]
            for i in range(N_batch)
        ]
    )
    problem_data = problem.compile()

    # Solving the instance of the problem
    t1 = perf_counter()
    opt_solution, solver_data = solve_jit(q, solver_data, problem_data)
    t2 = perf_counter()
    solve_times.append(t2 - t1)

    # Integrating
    t1 = perf_counter()
    q = integrate_jit(
        mjx_model,
        q,
        opt_solution.v_opt,
        dt,
    )
    t2 = perf_counter()
    integrate_times.append(t2 - t1)

    # --- MuJoCo visualization ---
    for i, q_i in enumerate(frame_task.target_frame.wxyz_xyz[:: N_batch // vis.n_models, -3:]):
        vis.marker_data[f"ee_marker_{i}"].pos = q_i
    vis.update(q[:: N_batch // vis.n_models])
    n_steps += 1

Let us visualize the solutions, note that we visualize just tiny fraction of all the solutions in `N_batch`:

In [None]:
media.show_video(vis.frames, fps=round(1 / dt))

Let us check short performance report, do not hesitate to try it on your own machine or different colab GPUs!

In [None]:
# Performance report
solve_times = np.array(solve_times)
print(f"\n=== Performance Report for {N_batch} targets ===\n"
      f"Steps: {n_steps}\n"
      f"Solve time: {np.mean(solve_times)*1000:.1f} ± {np.std(solve_times)*1000:.1f} ms\n" 
      f"Rate: {1/np.mean(solve_times):.1f} Hz")

This is just a glimpse of what you can do with `mjinx`. For more examples, including global/local IK on different robots and more advanced components, please refer to the [examples](https://github.com/based-robotics/mjinx/tree/main/examples). Stay tuned!