In [5]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [6]:
import numpy as np
import os
import time

import meshcat
import meshcat.geometry as g
import meshcat.transformations as tf
from meshcat.animation import Animation

import matplotlib.pyplot as plt
import h5py

In [7]:
# Create a new visualizer
vis = meshcat.Visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7005/static/


In [255]:
vis.delete()

In [256]:
hinge = g.Cylinder(thick/10, thick/10)

In [257]:
m1,m2 = [1, 1]
l1,l2 = [1, 1]
J1,J2 = [1, 1]
thick = 0.5

In [271]:
model = {"m1":1,
         "m2":1,
         "l1":1,
         "l2":1,
         "J1":1,
         "J2":1}

In [305]:
def visualize_acrobat(vis, model, q, dt):
    vis.delete()
    l1 = model["l1"]
    l2 = model["l2"]
    m1 = model["m1"]
    m2 = model["m2"]

    joint1 = vis["joint1"]
    link1 = joint1["link1"]
    joint2 = link1["joint2"]
    link2 = joint2["link2"]

    link1_box = g.Box([thick/8, thick/12, l1])
    link2_box = g.Box([thick/8, thick/12, l2])

    mat1 = g.MeshPhongMaterial(color=0x3f2a32, reflectivity=0.8)
    mat2 = g.MeshPhongMaterial(color=0x9d37e6, reflectivity=0.8)
    mat3 = g.MeshPhongMaterial(color=0x3d34e5, reflectivity=0.8)

    joint1.set_object(hinge, mat1)
    joint1.set_transform(tf.compose_matrix(angles=[0,np.pi,0]))
    link1.set_object(link1_box, mat2)
    link1.set_transform(tf.translation_matrix([0,0,l1/2]))
    joint2.set_object(hinge, mat1) 
    joint2.set_transform(tf.compose_matrix(translate=[0,0,l1/2]))
    trans_link2 = tf.concatenate_matrices(*[tf.rotation_matrix(0, [0, 1, 0]), tf.translation_matrix([0,0,l1/2])])
    link2.set_object(link2_box, mat3)
    link2.set_transform(trans_link2)

    fps = np.ceil(1/dt)
    anim = Animation(default_framerate=fps)
    
    for t in range(len(q)):
        with anim.at_frame(vis, t) as frame:
            frame["joint1"].set_transform(tf.rotation_matrix(q[t][0], [0, 1, 0]))
            trans_link2 = tf.concatenate_matrices(*[tf.translation_matrix([0,0,l1/2]),
                                                    tf.rotation_matrix(q[t][1], [0, 1, 0])])
            frame["joint1/link1/joint2"].set_transform(trans_link2)
    vis.set_animation(anim)

In [306]:
q = [[np.pi*np.sin(2*np.pi*t), 0] for t in np.linspace(0,2, 41)]

In [307]:
visualize_acrobat(vis, model, q, 0.05)

In [4]:
def dynamics(model::Acrobot, x, u)
    g = 9.81
    l1 = model["l1"]
    l2 = model["l2"]
    m1 = model["m1"]
    m2 = model["m2"]
    J1 = model["J1"]
    J2 = model["J2"]
    
    theta1, theta2 = x[0], x[1]
    d_theta1, d_thet2 = x[2], x[3]
    
    
    s1,c1 = sincos(θ1)
    s2,c2 = sincos(θ2)
    c12 = cos(θ1 + θ2)

    # mass matrix
    m11 = m1*l1^2 + J1 + m2*(l1^2 + l2^2 + 2*l1*l2*c2) + J2
    m12 = m2*(l2^2 + l1*l2*c2 + J2)
    m22 = l2^2*m2 + J2
    M = @SMatrix [m11 m12; m12 m22]

    # bias term
    tmp = l1*l2*m2*s2
    b1 = -(2 * θ1dot * θ2dot + θ2dot^2)*tmp
    b2 = tmp * θ1dot^2
    B = @SVector [b1, b2]

    # friction
    c = 1.0
    C = @SVector [c*θ1dot, c*θ2dot]

    # gravity term
    g1 = ((m1 + m2)*l2*c1 + m2*l2*c12) * g
    g2 = m2*l2*c12*g
    G = @SVector [g1, g2]

    # equations of motion
    τ = @SVector [0, u[1]]
    θddot = M\(τ - B - G - C)
    return @SVector [θ1dot, θ2dot, θddot[1], θddot[2]]

SyntaxError: invalid syntax (1950176406.py, line 1)

In [6]:
def visualize_acrobot(vis, q, dt):
    
    
    for t in range(len(q)):
        p_body = np.array([q[t][0], 0.0, q[t][1]])
        p_foot = np.array([kinematics(q[t])[0], 0.0, kinematics(q[t])[1]])
        div = np.array([q[t][2] - q[t][0], q[t][3] - q[t][1]])
        div = div / np.linalg.norm(div)
        r_range = np.linspace(0, np.sqrt((q[t][0] - q[t][2])**2 + (q[t][1] - q[t][3])**2), n_leg)
        for i in range(n_leg):
            p_leg[i] = [q[t][0] + r_range[i] * div[0], 0.0, q[t][1] + r_range[i] * div[1]]
        z_shift = np.array([0.0, 0.0, r_foot])
        with anim.at_frame(vis, t) as frame:
            frame["body"].set_transform(tf.translation_matrix(p_body + z_shift))
            frame["foot"].set_transform(tf.translation_matrix(p_foot + z_shift))
            for i in range(n_leg):
                frame["leg{}".format(i)].set_transform(tf.translation_matrix(p_leg[i] + z_shift))
                
    vis.set_animation(anim)

In [7]:
visualize_hopper(vis, q, 0.1)

## Jax dynamics

In [8]:
import jax.numpy as jnp
import jax

In [50]:
x = np.random.random(8)
x[3] = 1
u = np.random.random(2)

In [51]:
x = jnp.array(x)
u = jnp.array(u)

In [52]:
GRAVITY = 9.8 # m/s
M1 = 1.0
M2 = 0.5

M = jnp.array([[M1, 0, 0, 0],
               [0, M1, 0, 0], 
               [0, 0, M2, 0], 
               [0, 0, 0, M2]])

M_inv = jnp.array([[1/M1, 0, 0, 0],
                   [0, 1/M1, 0, 0], 
                   [0, 0, 1/M2, 0], 
                   [0, 0, 0, 1/M2]])

def flight_dynamics(x,u):
    rb = x[0:2]
    rf = x[2:4]
    v = x[4:8]
    
    l1 = (rb[0]-rf[0])/jnp.linalg.norm(rb-rf)
    l2 = (rb[1]-rf[1])/jnp.linalg.norm(rb-rf)
      
    B = jnp.array([[l1, l2],
                   [l2, -l1],
                   [-l1, -l2],
                   [-l2, l1]])
    v_dot = jnp.array([0, -GRAVITY, 0, GRAVITY]) + jnp.dot(jnp.dot(M_inv,B), u)
    x_dot = jnp.concatenate([v, v_dot])
    return x_dot

def stance_dynamics(x,u):
    rb = x[0:2]
    rf = x[2:4]
    v = x[4:8]
    
    l1 = (rb[0]-rf[0])/jnp.linalg.norm(rb-rf)
    l2 = (rb[1]-rf[1])/jnp.linalg.norm(rb-rf)
    
    B = jnp.array([[l1, l2],
                   [l2, -l1],
                   [0, 0],
                   [0, 0]])
    v_dot = jnp.array([0, -GRAVITY, 0, 0]) + jnp.dot(jnp.dot(M_inv,B), u)
    x_dot = jnp.concatenate([v, v_dot])
    return x_dot

In [40]:
def rk4(dynamics, x, u, h):
    # RK4 integration with zero-order hold on u
    f1 = dynamics(x, u)
    f2 = dynamics(x + 0.5 * h * f1, u)
    f3 = dynamics(x + 0.5 * h * f2, u)
    f4 = dynamics(x + h * f3, u)
    return x + (h / 6.0) * (f1 + 2 * f2 + 2 * f3 + f4)

In [41]:
def flight_dynamics_rk4(x, u, h):
    return rk4(flight_dynamics, x, u, h)

def stance_dynamics_rk4(x, u, h):
    return rk4(stance_dynamics, x, u, h)  

def jump_map(x):
    # Assume the foot experiences inelastic collisions
    return jnp.array([*x[0:6], 0, 0])

In [14]:
def guard_function(x):
    rf_y = x[3]
    index_1 = jnp.where(rf_y < 0, 2, 0) # 2 -> collision
    index_2 = jnp.where(rf_y > 0, 1, 0) # 1 -> flight, 0 -> stance
    index = index_1 + index_2 
    return index

In [15]:
def collision_function(x, u, h):
    x = jnp.array([*x[0:3], 0, *x[4:]])
    x = jump_map(x)
    return stance_dynamics_rk4(x, u, h)

In [45]:
def hopper_dynamics(x, u, h):
    flag = guard_function(x)
    return jax.lax.switch(flag, 
                          [stance_dynamics_rk4, flight_dynamics_rk4, collision_function], 
                          x, u, h)

In [17]:
def hopper_cost(x, u):
    rb = x[0:2]
    rf = x[2:4]
    v = x[4:8]
    
    cost = rb[0] ** 2 + rf[0] ** 2 + 0.2 * jnp.dot(v, v) + 0.2 * jnp.dot(u, u) #+ (rb[1] - 1.0) ** 2 + rb[1] ** 2
    return cost  

In [18]:
from utils.env_utils import step

In [19]:
def step_wrapper_hopper(carry, action):
    obs = carry[0]
    h = carry[1]
    next_obs = step(obs, action, hopper_dynamics, h)
    cost = hopper_cost(obs, action)
    carry = (next_obs, h)
    output = (next_obs, cost)
    return carry, output

In [20]:
from jax_wrappers.rollout_functions import load_rollout_jax

In [21]:
n_samples = 1000
horizon = 20
act_dim = 2

In [22]:
rng = np.random.default_rng(0)

In [23]:
size = (n_samples, horizon, act_dim)
acts = rng.normal(size=size)

In [24]:
rollout_jax = load_rollout_jax(step_wrapper_hopper)

In [25]:
%timeit rollout_jax(x, acts, 0.1)

1.23 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Hopper MPPI

In [26]:
from controller.mppi import MPPI

In [32]:
params = {'seed':42,
          'h':0.05,
          'env_name':'Hopper-meshcat',
          #'sample_type':'cubic',
          'sample_type':'normal',
          'n_knots':15,
          'horizon':200,
          'temperature':1.0,
          'n_samples':250,
          'noise_sigma':5.0}

In [33]:
controller_jax = MPPI(None, params)

In [34]:
controller_jax.reset_planner()

In [35]:
x0 = np.array([0, 1.0, 0, 0, 0, 0, 0, 0])
tfinal = 10
tvec = np.linspace(0,tfinal,201)
h = params["h"]
x = x0*1
q_sim = []
for ti in tvec:
    u = controller_jax.get_action(x)
    x = hopper_dynamics(x, u, h)
    q_sim.append(x[0:4])

In [36]:
visualize_hopper(vis, q_sim, h)