In [1]:
import sys
add_paths = ["/home/r2ci/rex", "/home/r2ci/brax", "/home/r2ci/trajax"]
for p in add_paths:
    if p not in sys.path:
        sys.path.append(p)
print(sys.path)

['/home/r2ci/rex/notebooks', '', '/home/r2ci/catkin_ws/devel/lib/python3/dist-packages', '/home/r2ci/interbotix_ws/devel/lib/python3/dist-packages', '/opt/ros/noetic/lib/python3/dist-packages', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.8/lib/python3.8/site-packages', '/home/r2ci/rex', '/home/r2ci/brax', '/home/r2ci/trajax']


In [2]:
#@markdown ## ⚠️ PLEASE NOTE:
#@markdown This colab runs best using a GPU runtime.  From the Colab menu, choose Runtime > Change Runtime Type, then select **'GPU'** in the dropdown.

import functools
import jax
import os

from datetime import datetime
from jax import numpy as jp
import matplotlib.pyplot as plt

from IPython.display import HTML, clear_output

try:
  import brax
except ImportError:
  !pip install git+https://github.com/google/brax.git@main
  clear_output()
  import brax

import flax
from brax import envs
from brax.io import model
from brax.io import json
from brax.io import html
from brax.training.agents.ppo import train as ppo
from brax.training.agents.sac import train as sac

if 'COLAB_TPU_ADDR' in os.environ:
  from jax.tools import colab_tpu
  colab_tpu.setup_tpu()

In [14]:
import tqdm
from etils import epath
from brax.io import mjcf
from brax.envs.base import PipelineEnv, State
# from brax.generalized import pipeline
from brax.positional import pipeline
from brax import base, math

import jax.numpy as jnp
from rex.utils import timer

In [43]:
print("loading system")
m = mjcf.load('/home/r2ci/rex/envs/vx300s/assets/vx300s.xml')

# Determine collision pairs
from brax.geometry.contact import _geom_pairs
for (geom_i, geom_j) in _geom_pairs(m):
    # print(geom_i.link_idx, geom_j.link_idx)
    name_i = m.link_names[geom_i.link_idx[0]] if geom_i.link_idx is not None else "world"
    name_j = m.link_names[geom_j.link_idx[0]] if geom_j.link_idx is not None else "world"
    print(f"collision pair: {name_i} --> {name_j}")

# Actuators
print(f"actuator size: {m.act_size()}")
4
# Get indices
ee_arm_idx = m.link_names.index("ee_link")
box_idx = m.link_names.index("box")
goal_idx = m.link_names.index("goal")

loading system
collision pair: vx300s/gripper_link --> world
collision pair: vx300s/gripper_link --> world
collision pair: vx300s/gripper_link --> world
collision pair: box --> vx300s/gripper_link
collision pair: box --> vx300s/gripper_link
collision pair: box --> vx300s/gripper_link
actuator size: 6


In [44]:
# Calculate reward fn
def get_reward(state):
    x_i = state.x.vmap().do(
        base.Transform.create(pos=m.link.inertia.transform.pos)
    )
    vec_1 = x_i.pos[box_idx] - x_i.pos[ee_arm_idx]
    vec_2 = x_i.pos[box_idx] - x_i.pos[goal_idx]
    reward_near = -math.safe_norm(vec_1)
    reward_dist = -math.safe_norm(vec_2)
    return reward_near, reward_dist

# Jit
jit_env_reset = jax.jit(pipeline.init)
jit_env_step = jax.jit(pipeline.step)
jit_get_reward = jax.jit(get_reward)
rng = jax.random.PRNGKey(seed=1)
with timer("jit[reset]", log_level=100):
    state = jit_env_reset(m, m.init_q, jp.zeros(m.qd_size()))
with timer("jit[step]", log_level=100):
    _ = jit_env_step(m, state, 10 * jp.sin(1 / 100) * jp.ones(m.act_size()))
with timer("jit[get_reward]", log_level=100):
    _ = jit_get_reward(state)

# Rollouts
rollout = []
rewards = []
for i in tqdm.tqdm(range(1500)):
    rollout.append(state)
    act = 0. * jp.ones(m.act_size())
    state = jit_env_step(m, state, act)
    rewards.append(jit_get_reward(state))
print("rendering")
HTML(html.render(m, rollout))

[97m[30082][MainThread               ][tracer              ][jit[reset]          ] Elapsed: 0.0012159347534179688[0m
[97m[30082][MainThread               ][tracer              ][jit[step]           ] Elapsed: 0.001766204833984375[0m
[97m[30082][MainThread               ][tracer              ][jit[get_reward]     ] Elapsed: 0.15605401992797852[0m


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1500/1500 [00:02<00:00, 748.20it/s]


rendering
