In [1]:
import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

import brax

import flax
from brax.envs import env
from brax import envs
from brax import base
from brax.io import model
from brax.io import json
from brax.io import html
from brax.io import mjcf

  jax.tree_util.register_keypaths(


In [2]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [3]:
#@title Load Env { run: "auto" }

env_name = 'ant'  # @param ['ant', 'halfcheetah', 'hopper', 'humanoid', 'humanoidstandup', 'inverted_pendulum', 'inverted_double_pendulum', 'pusher', 'reacher', 'walker2d']
backend = 'positional'  # @param ['generalized', 'positional', 'spring']

ant_env = envs.get_environment(env_name=env_name,
                              backend=backend)
state = jax.jit(ant_env.reset, backend="cpu")(rng=jax.random.PRNGKey(seed=0))

HTML(html.render(ant_env.sys, [state.pipeline_state]))

In [98]:
# an attempt to create the class A1 for the A1 robot XML

class A1(env.PipelineEnv):
    def __init__(
        self, 
        path="a1/xml/a1.xml", 
        backend='generalized',
        reset_noise_scale=0.1,
        **kwargs
    ):
        sys = mjcf.load(path)
        n_frames = 5
        kwargs['n_frames'] = kwargs.get('n_frames', n_frames)
        
        super().__init__(sys=sys, backend=backend, **kwargs)
        
        self.qd_size = self.sys.qd_size()
        self.q_size = self.sys.q_size()
        self._reset_noise_scale = reset_noise_scale
        
    def reset(self, rng: jp.ndarray) -> env.State:
        rng, rng1, rng2 = jax.random.split(rng, 3)
        # q = self.sys.init_q
        
        # q = jp.array([
        #     0., 0., 0.14, 
        #     1., 0., 0., 0., 
        #     0., 1.4, -2.6, 
        #     0., 1.4, -2.6, 
        #     0., 1.4, -2.6, 
        #     0., 1.4, -2.6
        # ]) 
        q = jp.array([
            0., 0., 0.3, 
            0.0, 1., 0., 0., 
            0.0, 0.7, -1.5, 
            0.0, 0.7, -1.5, 
            0.0, 0.7, -1.5, 
            0.0, 0.7, -1.5
        ]) 
        qd = jp.zeros(self.qd_size) # velocity initialized to 0
        
        # low, hi = -self._reset_noise_scale, self._reset_noise_scale
        # q = self.sys.init_q + jax.random.uniform(
        #     rng1, (self.q_size,), minval=low, maxval=hi
        # )
        # qd = hi * jax.random.normal(rng2, (self.qd_size,))
        
        pipeline_state = self.pipeline_init(q, qd)
        obs = self._get_obs(pipeline_state)
        
        reward, done, zero = jp.zeros(3)
        info = {}
        
        return env.State(pipeline_state, obs, reward, done, info)

    def _get_obs(self, pipeline_state: base.State) -> jp.ndarray:
        """Observe ant body position and velocities."""	
        qpos = pipeline_state.q
        qvel = pipeline_state.qd

        return jp.concatenate([qpos] + [qvel])
    
    def step(self, state: env.State, action: jp.ndarray) -> env.State:
        """Run one timestep of the environment's dynamics."""
        pipeline_state0 = state.pipeline_state
        pipeline_state = self.pipeline_step(pipeline_state0, action)
        obs = self._get_obs(pipeline_state)
        
        reward, done, zero = jp.zeros(3)
        info = {}
        
        return state.replace(
            pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
        )

In [99]:
a1_env = A1()

In [100]:
state = jax.jit(a1_env.reset, backend="cpu")(rng=jax.random.PRNGKey(seed=0))

In [None]:
HTML(html.render(a1_env.sys, [state.pipeline_state]))

In [103]:
jit_env_reset = jax.jit(a1_env.reset, backend="cpu")
jit_env_step = jax.jit(a1_env.step, backend="cpu")

In [104]:
state = jit_env_reset(rng=jax.random.PRNGKey(seed=0))

In [107]:
rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)

reference_joint = jp.array([
    0.0, 0.7, -1.5, 
    0.0, 0.7, -1.5, 
    0.0, 0.7, -1.5, 
    0.0, 0.7, -1.5
])

for _ in range(20):
    rollout.append(state.pipeline_state)
    joint_pos = state.pipeline_state.q[7:]
    joint_vel = state.pipeline_state.qd[6:]
    e_pos = reference_joint - joint_pos
    # e_v = jp.zeros(12) - joint_vel
    action = e_pos*10.0
    state = jit_env_step(state, action)

In [108]:
with open("render.html", "w") as file:
    file.write(html.render(a1_env.sys.replace(dt=0.1), rollout))
file.close()

In [None]:
state.pipeline_state.x.rot[0]

In [None]:
state.pipeline_state.q[3:7]

Define the following state variables to be available:

| state variable      | state reference                                         |
----------------------|---------------------------------------------------------|
| z                   | `q[2]` or `x[2]`                                        |
| roll, pitch, yaw    | `brax.math.quat_to_euler(.)` with `q[3:7]` or `x.rot[0]`|
| w_x, w_y, w_z       | `xd.ang[0]` or `qd[3:6]`                                |
| v_x, v_y, v_z       | `xd.vel[0]` or `qd[:3]`                                 |
| theta_j (joint pos) | `q[7:]`                                                 |
| theta_j_dot         | `qd[6:]`                                                |

In [None]:
def integrate_forward(x, u):
    pipeline_state = a1_env.pipeline_init(x[:19], x[19:])
    qpos = pipeline_state.q
    qvel = pipeline_state.qd
    obs = jp.concatenate([qpos] + [qvel])
    reward, done, zero = jp.zeros(3)
    info = {}
    state = jit_env_step(env.State(pipeline_state, obs, reward, done, info), u)
    
    q = state.pipeline_state.q
    qd = state.pipeline_state.qd
    return jp.concatenate((q, qd), axis=0)

jacobian_fn = jax.jit(jax.jacfwd(integrate_forward, argnums=[0, 1]))

In [None]:
q = state.pipeline_state.q
qd = state.pipeline_state.qd
action = jp.ones(12) # dummy control
x = jp.concatenate((q, qd), axis=0)

state = jit_env_step(state, action)
A, B = jacobian_fn(x, action)
print(A.shape, B.shape)

In [None]:
# let's see how bad our linearized dynamics is
q = state.pipeline_state.q
qd = state.pipeline_state.qd
action = jp.ones(12) # dummy control
x = jp.concatenate((q, qd), axis=0)

est_state = A @ x + B @ action
real_state = jit_env_step(state, action)

In [None]:
jp.concatenate((real_state.pipeline_state.q, real_state.pipeline_state.qd), axis=0) - est_state

In [None]:
# a bettter attempt to avoid re-initializing the pipeline
def integrate_forward(x, u, state):
    state = jit_env_step(state.replace(pipeline_state=state.pipeline_state.replace(q=x[:19], qd=x[19:])), u)
    q = state.pipeline_state.q
    qd = state.pipeline_state.qd
    return jp.concatenate((q, qd), axis=0)

jax_integrate_forward = jax.jit(jax.jacfwd(integrate_forward, argnums=[0, 1]), backend="cpu")

In [None]:
q = state.pipeline_state.q
qd = state.pipeline_state.qd
action = jp.ones(12) # dummy control
x = jp.concatenate((q, qd), axis=0)

A, B = jax_integrate_forward(x, action, state)
print(A.shape, B.shape)