In [1]:
import sys
add_paths = ["/home/r2ci/rex", "/home/r2ci/brax", "/home/r2ci/trajax"]
for p in add_paths:
    if p not in sys.path:
        sys.path.append(p)
print(sys.path)

['/home/r2ci/rex/notebooks', '', '/home/r2ci/catkin_ws/devel/lib/python3/dist-packages', '/home/r2ci/interbotix_ws/devel/lib/python3/dist-packages', '/opt/ros/noetic/lib/python3/dist-packages', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.8/lib/python3.8/site-packages', '/home/r2ci/rex', '/home/r2ci/brax', '/home/r2ci/trajax']


In [2]:
from IPython.display import HTML, clear_output

import jax
import jax.numpy as jnp
from math import ceil
import os
import tqdm

import brax
from brax.io import html
from brax.io import mjcf
from brax.generalized import pipeline
# from brax.positional import pipeline
# from brax.spring import pipeline
from brax import base, math
from flax import struct

from rex.utils import timer

In [24]:
print("loading system")
m = mjcf.load('/home/r2ci/rex/envs/vx300s/assets/vx300s_pos.xml')

# Determine collision pairs
from brax.geometry.contact import _geom_pairs
for (geom_i, geom_j) in _geom_pairs(m):
    # print(geom_i.link_idx, geom_j.link_idx)
    name_i = m.link_names[geom_i.link_idx[0]] if geom_i.link_idx is not None else "world"
    name_j = m.link_names[geom_j.link_idx[0]] if geom_j.link_idx is not None else "world"
    print(f"collision pair: {name_i} --> {name_j}")

# Actuators
print(f"actuator size: {m.act_size()}")

# Overwrite xml values
m = m.replace(dt=0.015)

# Get sampling time (0.8s horizon needed)
total_time = 5
cem_dt = 0.3
substeps = ceil(cem_dt / m.dt)
timesteps = ceil(total_time / cem_dt)
assert cem_dt > m.dt
dt = cem_dt / substeps
m = m.replace(dt=dt)
horizon = 4
print(f"cem_dt: {cem_dt}, brax_dt: {dt}, substeps: {substeps}, horizon_steps: {horizon}, horizon_t: {horizon * cem_dt}, t_final: {timesteps*cem_dt}")

control_high = 0.2 * 3.1416 * jnp.ones(m.act_size())
control_low = -control_high

loading system
collision pair: box --> world
collision pair: vx300s/gripper_link --> world
collision pair: vx300s/gripper_link --> world
collision pair: box --> vx300s/gripper_link
collision pair: box --> vx300s/gripper_link
actuator size: 6
cem_dt: 0.3, brax_dt: 0.015, substeps: 20, horizon_steps: 4, horizon_t: 1.2, t_final: 5.1


In [11]:
cem_hyperparams = {
      'sampling_smoothing': 0.,
      'evolution_smoothing': 0.1,
      'elite_portion': 0.1,
      'max_iter': 3,
      'num_samples': 100
}

# Get indices
ee_arm_idx = m.link_names.index("ee_link")
box_idx = m.link_names.index("box")
goal_idx = m.link_names.index("goal")


def format_info(info):
    formatted_items = []
    for key, value in info.items():
        try:
            value_list = value.tolist()
            if isinstance(value_list, list):
                value_list= [round(v, 2) for v in value_list]
            else:
                value_list = round(value_list, 2)
            formatted_items.append(f"{key}: {value_list}")
        except AttributeError:
            formatted_items.append(f"{key}: {round(value, 2)}")

    formatted_string = ' | '.join(formatted_items)
    return formatted_string


def save(path, json_rollout):
    """Saves trajectory as an HTML text file."""
    from etils import epath
    path = epath.Path(path)
    if not path.parent.exists():
        path.parent.mkdir(parents=True)
    path.write_text(json_rollout)


@struct.dataclass
class State:
    pipeline_state: base.State
    q_des: jnp.ndarray


@struct.dataclass
class Params:
    substeps: jnp.ndarray
    sys: base.System
    

def _cost(params, state, action):
    pipeline_state = state.pipeline_state
    x_i = pipeline_state.x.vmap().do(
        base.Transform.create(pos=params.sys.link.inertia.transform.pos)
    )
    boxpos = x_i.pos[box_idx]
    eepos = x_i.pos[ee_arm_idx]
    goalpos = x_i.pos[goal_idx][:2]

    rot_mat = math.quat_to_3x3(x_i.rot[ee_arm_idx])
    ee_to_goal = goalpos - eepos[:2]
    box_to_goal = goalpos - boxpos[:2]

    # dot product of ee_yaxis with ee_to_goal
    norm_ee_to_goal = ee_to_goal / math.safe_norm(ee_to_goal)
    cost_orn = jnp.abs(jnp.dot(rot_mat[:2, 1], norm_ee_to_goal))

    norm_box_to_goal = box_to_goal / math.safe_norm(box_to_goal)
    target_ee_xaxis = jnp.concatenate([norm_box_to_goal, jnp.array([-5.0])])
    norm_target_ee_xaxis = target_ee_xaxis / math.safe_norm(target_ee_xaxis)
    cost_down = (1-jnp.dot(rot_mat[:3, 0], norm_target_ee_xaxis))
    # cost_down = 0.5 * jnp.abs(rot_mat[2, 0]+1)

    # Radius cost
    box_dist = math.safe_norm(box_to_goal)
    ee_dist = math.safe_norm(ee_to_goal)
    cost_radius = jnp.where(ee_dist <= (box_dist+0.06), 15.0, 0.0)

    # cost_z = 1.0 * jnp.abs(eepos[2] - 0.09)
    cost_z = jnp.abs(eepos[2] - 0.075)
    cost_near = math.safe_norm((boxpos - eepos)[:2])
    cost_dist = math.safe_norm(boxpos[:2] - goalpos)
    cost_ctrl = math.safe_norm(action)

    cm = cost_dist

    # Weight all costs
    alpha = 1 / (1 + 2.0 * jnp.abs(cost_down + cost_orn))
    cost_orn = 3.0 * cost_orn
    cost_down = 3.0 * cost_down
    cost_radius = 1.0 * cost_radius
    cost_z = 1.0 * cost_z * alpha
    cost_near = 2.0 * cost_near * alpha
    cost_dist = 20.0 * cost_dist * alpha
    cost_ctrl = 0.1 * cost_ctrl

    total_cost = cost_ctrl + cost_z +  cost_near + cost_dist + cost_radius + cost_orn + cost_down
    info = {"cm": cm, "cost": total_cost, "cost_orn": cost_orn, "cost_down": cost_down, "cost_radius": cost_radius, "cost_z": cost_z, "cost_near": cost_near, "cost_dist": cost_dist, "cost_ctrl": cost_ctrl, "alpha": alpha}
    # return cost_z + cost_dist + cost_near + cost_ctrl + cost_radius + cost_orn + cost_down
    return total_cost, info


def cost(params, state, action, time_step: int):
    total_cost, info = _cost(params, state, action)
    return total_cost


# def dynamics(params: brax.System, state: State, action, time_step):

#     def loop_body(_, args):
#         state, = args
#         q_des = state.q_des + action * params.dt  # todo: clip to max angles?
#         pipeline_state = pipeline.step(params, state.pipeline_state, q_des)
#         return State(pipeline_state=pipeline_state, q_des=q_des),

#     state, = jax.lax.fori_loop(0, substeps, loop_body, (state,))  # TODO: substep changes...

#     return state

def dynamics(params: Params, state: State, action, time_step):

    def loop_cond(args):
        i, _ = args
        return i < params.substeps
        
    
    def loop_body(args):
        i, state = args
        q_des = state.q_des + action * params.sys.dt  # todo: clip to max angles?
        pipeline_state = pipeline.step(params.sys, state.pipeline_state, q_des)
        return i+1, State(pipeline_state=pipeline_state, q_des=q_des)

    i, state = jax.lax.while_loop(loop_cond, loop_body, (0, state))

    return state


from trajax.optimizers import cem, random_shooting
from functools import partial
jit_cem = jax.jit(partial(cem, cost, dynamics, hyperparams=cem_hyperparams))
jit_cost = jax.jit(_cost)

In [12]:
# Jit env functions
jit_env_reset = jax.jit(pipeline.init)
jit_env_step = jax.jit(pipeline.step)
rng = jax.random.PRNGKey(seed=1)
with timer("jit[reset]", log_level=100):
    boxpos_home = jnp.array([0.35, 0.0, 0.051])
    goalpos = boxpos_home[:2] + jnp.array([-0.1, 0.45])
    qpos = m.init_q.at[9].set(jnp.pi/2)
    qpos = qpos.at[0:5].set(jnp.concatenate([boxpos_home, goalpos]))
    pipeline_state = jit_env_reset(m, qpos, jnp.zeros(m.qd_size()))

[97m[15022][MainThread               ][tracer              ][jit[reset]          ] Elapsed: 0.00530242919921875[0m


In [13]:
init_state = State(pipeline_state=pipeline_state, q_des=pipeline_state.q[m.actuator.q_id])
params = Params(sys=m, substeps=substeps)
init_controls = jnp.zeros((horizon, m.act_size()))
with timer("jit[cem]", log_level=100):
    X, mean, obj = jit_cem(params, init_state, init_controls, control_low, control_high, jax.random.PRNGKey(0))
with timer("eval[cem]", log_level=100):
    _ = jit_cem(params, init_state, init_controls, control_low, control_high, jax.random.PRNGKey(0))

[97m[15022][MainThread               ][tracer              ][jit[cem]            ] Elapsed: 50.53440618515015[0m
[97m[15022][MainThread               ][tracer              ][eval[cem]           ] Elapsed: 2.300173759460449[0m


In [25]:
# Rollouts (CEM)
rollout = []
pipeline_state = jit_env_reset(m, qpos, jnp.zeros(m.qd_size()))
params = Params(sys=m, substeps=substeps)
next_q_des = pipeline_state.q[m.actuator.q_id]
_, next_controls, _ = jit_cem(params, init_state, init_controls, control_low, control_high, jax.random.PRNGKey(0))
pbar = tqdm.tqdm(range(timesteps), desc=f"Episode")
for i in pbar:
    next_state = State(pipeline_state=pipeline_state, q_des=next_q_des)
    X, mean, obj = jit_cem(params, next_state, next_controls, control_low, control_high, jax.random.PRNGKey(0))
    total_cost, info = jit_cost(params, next_state, mean[0])
    pbar.set_postfix_str(format_info(info))
    # print("")  # Uncomment to print history.
    for j in range(substeps):
        next_q_des = next_q_des + mean[0] * m.dt  # todo: clip to max angles?
        pipeline_state = jit_env_step(m, pipeline_state, next_q_des)
        rollout.append(pipeline_state)
    next_controls = jnp.vstack((mean[1:], jnp.zeros((1, m.act_size()))))

Episode: 100%|██████████████████████████████████████████████████| 17/17 [00:27<00:00,  1.61s/it, alpha: 0.95 | cm: 0.01 | cost: 0.63 | cost_ctrl: 0.03 | cost_dist: 0.27 | cost_down: 0.07 | cost_near: 0.21 | cost_orn: 0.0 | cost_radius: 0.0 | cost_z: 0.04]


In [18]:
rollout_json = html.render(m, rollout)
save("./vx300s_render.html", rollout_json)
HTML(rollout_json)