In [1]:
import functools
import time

# turn on parallelism
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"

import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

import brax
import flax
from brax.envs import env
from brax import envs
from brax import base
from brax.io import model
from brax.io import json
from brax.io import html
from brax.io import mjcf

from a1_raw import A1Raw
from a1 import A1
from inverse_kinematics.inverse_kinematics_controller import InverseKinematicsController

In [2]:
a1_env = A1()
jit_env_reset = jax.jit(a1_env.reset, backend="cpu")
jit_env_step = jax.jit(a1_env.step, backend="cpu")



Use the inverse kinematics controller as nominal trajectory, then improve it using ILQR (cost is robot's stability)

In [3]:
controller = InverseKinematicsController(Xdist=0.366, Ydist=0.28, height=0.25, coxa=0.038, femur=0.2, tibia=0.2, L=0.8, angle=0, T=1.0, dt=0.01)

In [30]:
# bring robot to standing position and take the first step of the controller
states = []
actions = []

rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)

current_joint = jp.array([
    0., 1.4, -2.6, 
    0., 1.4, -2.6, 
    0., 1.4, -2.6, 
    0., 1.4, -2.6
])

# stand up
reference_joint = jp.array([
    0.0, 0.6, -1.5, 
    0.0, 0.6, -1.5, 
    0.0, 0.6, -1.5,
    0.0, 0.6, -1.5
])

trajectory = jp.linspace(current_joint, reference_joint, 100)

cur_time = time.time()
for action in trajectory:
    states.append(state)
    actions.append(action)
    state = jit_env_step(state, action)
sum_time = time.time() - cur_time

# move from stable stance to first step of performance controller
reference_joint = jp.array(controller.get_action(
    joint_order = ["FR", "FL", "BR", "BL"], offset=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
))

trajectory = jp.linspace(state.pipeline_state.q[7:], reference_joint, 50)

cur_time = time.time()
for action in trajectory:
    states.append(state)
    actions.append(action)
    state = jit_env_step(state, action)

sum_time += time.time() - cur_time

In [31]:
sum_time / 150.0

0.011307716369628906

In [6]:
a1_env_raw = A1Raw()
jit_rawenv_reset = jax.jit(a1_env_raw.reset, backend="cpu")
jit_rawenv_step = jax.jit(a1_env_raw.step, backend="cpu")



In [7]:
def integrate_forward(s, u, x_pos, x_rot, xd_ang, xd_vel, com, 
        cinr_transform_pos, cinr_transform_rot, cinr_i, cinr_mass, 
        cd_ang, cd_vel, cdof_ang, cdof_vel, 
        cdofd_ang, cdofd_vel, mass_mx_inv, 
        con_jac, con_aref, con_diag, 
        qf_smooth, qf_constraint, qdd):
    
    new_state = jit_rawenv_step(s, u, x_pos, x_rot, xd_ang, xd_vel, com, 
        cinr_transform_pos, cinr_transform_rot, cinr_i, cinr_mass, 
        cd_ang, cd_vel, cdof_ang, cdof_vel, 
        cdofd_ang, cdofd_vel, mass_mx_inv, 
        con_jac, con_aref, con_diag, 
        qf_smooth, qf_constraint, qdd
    )
    
    return jp.append(new_state.pipeline_state.q, new_state.pipeline_state.qd)

jax_integrate_forward = jax.jit(jax.jacfwd(integrate_forward, argnums=[0, 1]), backend="cpu")

In [8]:
def get_x(state: env.State):
    return jp.append(state.pipeline_state.q, state.pipeline_state.qd)

jax_get_x = jax.jit(get_x, backend="cpu")

In [9]:
def get_pipeline_parse(ps):
    return ps.x.pos, ps.x.rot, ps.xd.ang, ps.xd.vel, ps.com, ps.cinr.transform.pos, ps.cinr.transform.rot, ps.cinr.i, ps.cinr.mass, ps.cd.ang, ps.cd.vel, ps.cdof.ang, ps.cdof.vel, ps.cdofd.ang, ps.cdofd.vel, ps.mass_mx_inv, ps.con_jac, ps.con_aref, ps.con_diag, ps.qf_smooth, ps.qf_constraint, ps.qdd

In [10]:
# Get A, B from trajectory
import time

cur_time = time.time()
for i, s in enumerate(states):
    x = jax_get_x(s)
    action = actions[i]

    A, B = jax_integrate_forward(x, action, *get_pipeline_parse(s.pipeline_state))
    
    if i > 10:
        break
print(time.time() - cur_time)

41.95318031311035


In [11]:
# try vmap
jit_vmap_jacfwd = jax.jit(jax.vmap(jax_integrate_forward))

In [12]:
ss = jp.array([jax_get_x(s) for s in states[:10]])
x_poss = jp.array([s.pipeline_state.x.pos for s in states[:10]])
x_rots = jp.array([s.pipeline_state.x.rot for s in states[:10]])
xd_angs = jp.array([s.pipeline_state.xd.ang for s in states[:10]])
xd_vels = jp.array([s.pipeline_state.xd.vel for s in states[:10]])
coms = jp.array([s.pipeline_state.com for s in states[:10]])
cinr_transform_poss = jp.array([s.pipeline_state.cinr.transform.pos for s in states[:10]])
cinr_transform_rots = jp.array([s.pipeline_state.cinr.transform.rot for s in states[:10]])
cinr_is = jp.array([s.pipeline_state.cinr.i for s in states[:10]])
cinr_masses = jp.array([s.pipeline_state.cinr.mass for s in states[:10]])
cd_angs = jp.array([s.pipeline_state.cd.ang for s in states[:10]])
cd_vels  = jp.array([s.pipeline_state.cd.vel for s in states[:10]])
cdof_angs  = jp.array([s.pipeline_state.cdof.ang for s in states[:10]])
cdof_vels  = jp.array([s.pipeline_state.cdof.vel for s in states[:10]])
cdofd_angs  = jp.array([s.pipeline_state.cdofd.ang for s in states[:10]])
cdofd_vels = jp.array([s.pipeline_state.cdofd.vel for s in states[:10]])
mass_mx_invs = jp.array([s.pipeline_state.mass_mx_inv for s in states[:10]])
con_jacs = jp.array([s.pipeline_state.con_jac for s in states[:10]])
con_arefs = jp.array([s.pipeline_state.con_aref for s in states[:10]])
con_diags = jp.array([s.pipeline_state.con_diag for s in states[:10]])
qf_smooths = jp.array([s.pipeline_state.qf_smooth for s in states[:10]])
qf_constraints = jp.array([s.pipeline_state.qf_constraint for s in states[:10]])
qdds = jp.array([s.pipeline_state.qdd for s in states[:10]])

In [13]:
import time
cur_time = time.time()
jit_vmap_jacfwd(
    ss, jp.array(actions[:10]), 
    x_poss, x_rots, xd_angs, xd_vels, coms, 
    cinr_transform_poss, cinr_transform_rots, cinr_is, cinr_masses, 
    cd_angs, cd_vels, cdof_angs, cdof_vels, cdofd_angs, cdofd_vels,
    mass_mx_invs, con_jacs, con_arefs, con_diags, qf_smooths, qf_constraints, qdds
)
print(time.time() - cur_time)

25.6126389503479


In [14]:
# try pmap
jit_pmap_jacfwd = jax.pmap(jax_integrate_forward)

In [25]:
import time
cur_time = time.time()
jit_pmap_jacfwd(
    ss, jp.array(actions[:10]), 
    x_poss, x_rots, xd_angs, xd_vels, coms, 
    cinr_transform_poss, cinr_transform_rots, cinr_is, cinr_masses, 
    cd_angs, cd_vels, cdof_angs, cdof_vels, cdofd_angs, cdofd_vels,
    mass_mx_invs, con_jacs, con_arefs, con_diags, qf_smooths, qf_constraints, qdds
)
print(time.time() - cur_time)

0.016916990280151367


In [16]:
# rewrite integrate forward using has_aux
def integrate_forward_aux(s, u, x_pos, x_rot, xd_ang, xd_vel, com, 
        cinr_transform_pos, cinr_transform_rot, cinr_i, cinr_mass, 
        cd_ang, cd_vel, cdof_ang, cdof_vel, 
        cdofd_ang, cdofd_vel, mass_mx_inv, 
        con_jac, con_aref, con_diag, 
        qf_smooth, qf_constraint, qdd):
    
    new_state = jit_rawenv_step(s, u, x_pos, x_rot, xd_ang, xd_vel, com, 
        cinr_transform_pos, cinr_transform_rot, cinr_i, cinr_mass, 
        cd_ang, cd_vel, cdof_ang, cdof_vel, 
        cdofd_ang, cdofd_vel, mass_mx_inv, 
        con_jac, con_aref, con_diag, 
        qf_smooth, qf_constraint, qdd
    )
    
    return jp.append(new_state.pipeline_state.q, new_state.pipeline_state.qd), new_state

jax_integrate_forward = jax.jit(jax.jacfwd(integrate_forward_aux, argnums=[0, 1], has_aux=True), backend="cpu")

In [17]:
s0 = states[0]
(A, B), state = jax_integrate_forward(get_x(s0), jp.array(actions[0]), *get_pipeline_parse(s0.pipeline_state))