In [1]:
import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

import brax

import flax
from brax.envs import env
from brax import envs
from brax import base
from brax.io import model
from brax.io import json
from brax.io import html
from brax.io import mjcf

from a1 import A1
from inverse_kinematics.inverse_kinematics_controller import InverseKinematicsController

In [2]:
a1_env = A1()
jit_env_reset = jax.jit(a1_env.reset, backend="cpu")
jit_env_step = jax.jit(a1_env.step, backend="cpu")

2023-04-12 11:18:14.818747: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_NOT_INITIALIZED
2023-04-12 11:18:14.818805: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:438] Possibly insufficient driver version: 530.30.2


XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

Use the inverse kinematics controller as nominal trajectory, then improve it using ILQR (cost is robot's stability)

In [None]:
controller = InverseKinematicsController(Xdist=0.366, Ydist=0.28, height=0.25, coxa=0.038, femur=0.2, tibia=0.2, L=0.8, angle=0, T=1.0, dt=0.01)

In [None]:
# bring robot to standing position and take the first step of the controller
rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)

current_joint = jp.array([
    0., 1.4, -2.6, 
    0., 1.4, -2.6, 
    0., 1.4, -2.6, 
    0., 1.4, -2.6
])

# stand up
reference_joint = jp.array([
    0.0, 0.6, -1.5, 
    0.0, 0.6, -1.5, 
    0.0, 0.6, -1.5,
    0.0, 0.6, -1.5
])

trajectory = jp.linspace(current_joint, reference_joint, 100)

for action in trajectory:
    rollout.append(state.pipeline_state)
    state = jit_env_step(state, action)

# move from stable stance to first step of performance controller
reference_joint = jp.array(controller.get_action(
    joint_order = ["FR", "FL", "BR", "BL"], offset=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
))

trajectory = jp.linspace(state.pipeline_state.q[7:], reference_joint, 50)

for action in trajectory:
    rollout.append(state.pipeline_state)
    state = jit_env_step(state, action)

In [None]:
import copy
# start ILQR
init_state = copy.deepcopy(state)

In [7]:
from abc import ABC, abstractmethod
from functools import partial
from jaxlib.xla_extension import DeviceArray
import jax
from jax import numpy as jnp
import numpy as np

from ILQR.cost.base_cost import BaseCost
from ILQR.cost.base_cost import quadratic_cost, huber_cost


class StateCost(BaseCost):
    def __init__(self, config):
        super().__init__()
        self.weight = jnp.array([config.ctrl_cost_accel_weight,
                                config.ctrl_cost_steer_weight]) # shape of (dim_u)
        
        self.delta = jnp.array([config.ctrl_cost_accel_huber_delta,
                                config.ctrl_cost_steer_huber_delta]) # shape of (dim_u)
        
        if config.ctrl_cost_type == 'quadratic':
            self.cost_func = quadratic_cost
        elif config.ctrl_cost_type == 'huber':
            self.cost_func = huber_cost
        else:
            raise NotImplementedError(
                f'Cost type {config.ctrl_cost_type} not implemented for CTRL COST. '+
                'Please choose from [quadratic, huber]'
                )
    
    
    @partial(jax.jit, static_argnums=(0,))
    def get_running_cost(
			self, state: DeviceArray, ctrl: DeviceArray, ref: DeviceArray
	) -> float:
        '''
        Given a state, control, and time index, return the cost.
        Input:
            state: (dim_x) state
            ctrl: (dim_u) control
            ref: (dim_ref) reference 
        return:
            cost: float
        '''
        return jnp.sum(self.cost_func(ctrl, self.weight, self.delta))

State(pipeline_state=State(q=Array([-1.7871406e-02, -9.0128271e-04,  2.7165914e-01,  9.9769437e-01,
        6.7612510e-03, -3.4949272e-03, -6.7439020e-02,  1.4658351e-01,
        1.0323814e+00, -1.8317171e+00, -1.4854372e-01,  7.9488891e-01,
       -1.8152134e+00,  1.4490052e-01,  7.5020140e-01, -1.8183707e+00,
       -1.4611095e-01,  1.0382730e+00, -1.8297520e+00], dtype=float32), qd=Array([ 0.2899591 , -0.04669133, -0.05286638,  0.6081843 , -0.09735972,
       -0.5360012 , -0.27266806,  1.2771477 , -0.32450888, -0.66507244,
        1.0427214 , -0.30620658,  1.3593173 ,  1.1573977 , -2.6940699 ,
       -0.93962914,  1.5012367 , -0.13914491], dtype=float32), x=Transform(pos=Array([[-0.01787141, -0.00090128,  0.27165914],
       [ 0.15713711, -0.07210389,  0.2721122 ],
       [ 0.14591685, -0.15530705,  0.25851652],
       [-0.02140658, -0.11603074,  0.15624125],
       [ 0.16978197,  0.02103249,  0.27342469],
       [ 0.18119614,  0.10453698,  0.26201338],
       [ 0.03804474,  0.10501