In [1]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [2]:
import numpy as np
import os
import time

import meshcat
import meshcat.geometry as g
import meshcat.transformations as tf
from meshcat.animation import Animation

import matplotlib.pyplot as plt
import h5py

In [3]:
# Create a new visualizer
vis = meshcat.Visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7004/static/


In [4]:
f = h5py.File("test_traj.jld2", "r")

In [5]:
q = [np.array(list(qi[0])) for qi in f["q"]]

In [6]:
def visualize_hopper(vis, q, dt):
    def kinematics(q):
        return [q[2], q[3]]
    
    # body radius
    r_body = 0.2
    # foot radius
    r_foot = 0.1

    # leg width
    r_leg = 0.5 * r_foot
    
    fps = np.ceil(1/dt)
    anim = Animation(default_framerate=fps)
    
    # create body
    vis["body"].set_object(g.Sphere(r_body), 
                           g.MeshLambertMaterial(color=0x57dd73,reflectivity=0.8))

    # create foot
    vis["foot"].set_object(g.Sphere(r_foot), 
                           g.MeshLambertMaterial(color=0x9d37e6,reflectivity=0.8))

    # create leg
    n_leg = 100
    for i in range(n_leg):
        vis["leg{}".format(i)].set_object(g.Sphere(r_leg), 
                                g.MeshPhongMaterial(color=0x3f2a32,reflectivity=0.8))

    p_leg = [np.zeros(3) for i in range(n_leg)]
    
    for t in range(len(q)):
        p_body = np.array([q[t][0], 0.0, q[t][1]])
        p_foot = np.array([kinematics(q[t])[0], 0.0, kinematics(q[t])[1]])
        div = np.array([q[t][2] - q[t][0], q[t][3] - q[t][1]])
        div = div / np.linalg.norm(div)
        r_range = np.linspace(0, np.sqrt((q[t][0] - q[t][2])**2 + (q[t][1] - q[t][3])**2), n_leg)
        for i in range(n_leg):
            p_leg[i] = [q[t][0] + r_range[i] * div[0], 0.0, q[t][1] + r_range[i] * div[1]]
        z_shift = np.array([0.0, 0.0, r_foot])
        with anim.at_frame(vis, t) as frame:
            frame["body"].set_transform(tf.translation_matrix(p_body + z_shift))
            frame["foot"].set_transform(tf.translation_matrix(p_foot + z_shift))
            for i in range(n_leg):
                frame["leg{}".format(i)].set_transform(tf.translation_matrix(p_leg[i] + z_shift))
                
    vis.set_animation(anim)

In [7]:
visualize_hopper(vis, q, 0.1)

## Jax dynamics

In [8]:
import jax.numpy as jnp
import jax

In [50]:
x = np.random.random(8)
x[3] = 1
u = np.random.random(2)

In [51]:
x = jnp.array(x)
u = jnp.array(u)

In [52]:
GRAVITY = 9.8 # m/s
M1 = 1.0
M2 = 0.5

M = jnp.array([[M1, 0, 0, 0],
               [0, M1, 0, 0], 
               [0, 0, M2, 0], 
               [0, 0, 0, M2]])

M_inv = jnp.array([[1/M1, 0, 0, 0],
                   [0, 1/M1, 0, 0], 
                   [0, 0, 1/M2, 0], 
                   [0, 0, 0, 1/M2]])

def flight_dynamics(x,u):
    rb = x[0:2]
    rf = x[2:4]
    v = x[4:8]
    
    l1 = (rb[0]-rf[0])/jnp.linalg.norm(rb-rf)
    l2 = (rb[1]-rf[1])/jnp.linalg.norm(rb-rf)
      
    B = jnp.array([[l1, l2],
                   [l2, -l1],
                   [-l1, -l2],
                   [-l2, l1]])
    v_dot = jnp.array([0, -GRAVITY, 0, GRAVITY]) + jnp.dot(jnp.dot(M_inv,B), u)
    x_dot = jnp.concatenate([v, v_dot])
    return x_dot

def stance_dynamics(x,u):
    rb = x[0:2]
    rf = x[2:4]
    v = x[4:8]
    
    l1 = (rb[0]-rf[0])/jnp.linalg.norm(rb-rf)
    l2 = (rb[1]-rf[1])/jnp.linalg.norm(rb-rf)
    
    B = jnp.array([[l1, l2],
                   [l2, -l1],
                   [0, 0],
                   [0, 0]])
    v_dot = jnp.array([0, -GRAVITY, 0, 0]) + jnp.dot(jnp.dot(M_inv,B), u)
    x_dot = jnp.concatenate([v, v_dot])
    return x_dot

In [40]:
def rk4(dynamics, x, u, h):
    # RK4 integration with zero-order hold on u
    f1 = dynamics(x, u)
    f2 = dynamics(x + 0.5 * h * f1, u)
    f3 = dynamics(x + 0.5 * h * f2, u)
    f4 = dynamics(x + h * f3, u)
    return x + (h / 6.0) * (f1 + 2 * f2 + 2 * f3 + f4)

In [41]:
def flight_dynamics_rk4(x, u, h):
    return rk4(flight_dynamics, x, u, h)

def stance_dynamics_rk4(x, u, h):
    return rk4(stance_dynamics, x, u, h)  

def jump_map(x):
    # Assume the foot experiences inelastic collisions
    return jnp.array([*x[0:6], 0, 0])

In [14]:
def guard_function(x):
    rf_y = x[3]
    index_1 = jnp.where(rf_y < 0, 2, 0) # 2 -> collision
    index_2 = jnp.where(rf_y > 0, 1, 0) # 1 -> flight, 0 -> stance
    index = index_1 + index_2 
    return index

In [15]:
def collision_function(x, u, h):
    x = jnp.array([*x[0:3], 0, *x[4:]])
    x = jump_map(x)
    return stance_dynamics_rk4(x, u, h)

In [45]:
def hopper_dynamics(x, u, h):
    flag = guard_function(x)
    return jax.lax.switch(flag, 
                          [stance_dynamics_rk4, flight_dynamics_rk4, collision_function], 
                          x, u, h)

In [17]:
def hopper_cost(x, u):
    rb = x[0:2]
    rf = x[2:4]
    v = x[4:8]
    
    cost = rb[0] ** 2 + rf[0] ** 2 + 0.2 * jnp.dot(v, v) + 0.2 * jnp.dot(u, u) #+ (rb[1] - 1.0) ** 2 + rb[1] ** 2
    return cost  

In [18]:
from utils.env_utils import step

In [19]:
def step_wrapper_hopper(carry, action):
    obs = carry[0]
    h = carry[1]
    next_obs = step(obs, action, hopper_dynamics, h)
    cost = hopper_cost(obs, action)
    carry = (next_obs, h)
    output = (next_obs, cost)
    return carry, output

In [20]:
from jax_wrappers.rollout_functions import load_rollout_jax

In [21]:
n_samples = 1000
horizon = 20
act_dim = 2

In [22]:
rng = np.random.default_rng(0)

In [23]:
size = (n_samples, horizon, act_dim)
acts = rng.normal(size=size)

In [24]:
rollout_jax = load_rollout_jax(step_wrapper_hopper)

In [25]:
%timeit rollout_jax(x, acts, 0.1)

1.23 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Hopper MPPI

In [26]:
from controller.mppi import MPPI

In [32]:
params = {'seed':42,
          'h':0.05,
          'env_name':'Hopper-meshcat',
          #'sample_type':'cubic',
          'sample_type':'normal',
          'n_knots':15,
          'horizon':200,
          'temperature':1.0,
          'n_samples':250,
          'noise_sigma':5.0}

In [33]:
controller_jax = MPPI(None, params)

In [34]:
controller_jax.reset_planner()

In [35]:
x0 = np.array([0, 1.0, 0, 0, 0, 0, 0, 0])
tfinal = 10
tvec = np.linspace(0,tfinal,201)
h = params["h"]
x = x0*1
q_sim = []
for ti in tvec:
    u = controller_jax.get_action(x)
    x = hopper_dynamics(x, u, h)
    q_sim.append(x[0:4])

In [36]:
visualize_hopper(vis, q_sim, h)