In [1]:
import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

import brax

import flax
from brax.envs import env
from brax import envs
from brax import base
from brax.io import model
from brax.io import json
from brax.io import html
from brax.io import mjcf

from a1 import A1
from inverse_kinematics.inverse_kinematics_controller import InverseKinematicsController

  jax.tree_util.register_keypaths(


In [2]:
a1_env = A1()
jit_env_reset = jax.jit(a1_env.reset, backend="cpu")
jit_env_step = jax.jit(a1_env.step, backend="cpu")



Use the inverse kinematics controller as nominal trajectory, then improve it using ILQR (cost is robot's stability)

In [3]:
controller = InverseKinematicsController(Xdist=0.366, Ydist=0.28, height=0.25, coxa=0.038, femur=0.2, tibia=0.2, L=0.8, angle=0, T=1.0, dt=0.01)

In [28]:
# bring robot to standing position and take the first step of the controller
states = []
actions = []

rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)

current_joint = jp.array([
    0., 1.4, -2.6, 
    0., 1.4, -2.6, 
    0., 1.4, -2.6, 
    0., 1.4, -2.6
])

# stand up
reference_joint = jp.array([
    0.0, 0.6, -1.5, 
    0.0, 0.6, -1.5, 
    0.0, 0.6, -1.5,
    0.0, 0.6, -1.5
])

trajectory = jp.linspace(current_joint, reference_joint, 100)

for action in trajectory:
    states.append(state)
    actions.append(action)
    state = jit_env_step(state, action)

# move from stable stance to first step of performance controller
reference_joint = jp.array(controller.get_action(
    joint_order = ["FR", "FL", "BR", "BL"], offset=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
))

trajectory = jp.linspace(state.pipeline_state.q[7:], reference_joint, 50)

for action in trajectory:
    states.append(state)
    actions.append(action)
    state = jit_env_step(state, action)

In [30]:
# Get A, B from trajectory
# jacobian, but the action space now is angular position instead of torque
def integrate_forward(x, u, state):
    new_state = jit_env_step(
        state.replace(pipeline_state=state.pipeline_state.replace(q=x[:19], qd=x[19:])), u
    )
    return jp.append(new_state.pipeline_state.q, new_state.pipeline_state.qd)

jax_integrate_forward = jax.jit(jax.jacfwd(integrate_forward, argnums=[0, 1]), backend="cpu")

def get_x(state: env.State):
    return jp.append(state.pipeline_state.q, state.pipeline_state.qd)

jax_get_x = jax.jit(get_x, backend="cpu")

As = []
Bs = []

for i, s in enumerate(states):
    x = jax_get_x(s)
    action = actions[i]

    A, B = jax_integrate_forward(x, action, s)
    As.append(A)
    Bs.append(B)

In [7]:
from abc import ABC, abstractmethod
from functools import partial
from jaxlib.xla_extension import DeviceArray
import jax
from jax import numpy as jnp

from ILQR.cost.base_cost import BaseCost
from ILQR.cost.base_cost import quadratic_cost, huber_cost


class StateCost(BaseCost):
    def __init__(self, config):
        super().__init__()
        self.weight = jnp.array([config.ctrl_cost_accel_weight,
                                config.ctrl_cost_steer_weight]) # shape of (dim_u)
        
        self.delta = jnp.array([config.ctrl_cost_accel_huber_delta,
                                config.ctrl_cost_steer_huber_delta]) # shape of (dim_u)
        
        if config.ctrl_cost_type == 'quadratic':
            self.cost_func = quadratic_cost
        elif config.ctrl_cost_type == 'huber':
            self.cost_func = huber_cost
        else:
            raise NotImplementedError(
                f'Cost type {config.ctrl_cost_type} not implemented for CTRL COST. '+
                'Please choose from [quadratic, huber]'
                )
    
    
    @partial(jax.jit, static_argnums=(0,))
    def get_running_cost(
			self, state: DeviceArray, ctrl: DeviceArray, ref: DeviceArray
	) -> float:
        '''
        Given a state, control, and time index, return the cost.
        Input:
            state: (dim_x) state
            ctrl: (dim_u) control
            ref: (dim_ref) reference 
        return:
            cost: float
        '''
        return jnp.sum(self.cost_func(ctrl, self.weight, self.delta))

In [61]:
s = states[10]
a = actions[10]
x = jax_get_x(s)

In [None]:
# write a function to get the cost
