In [1]:
import sys
# add_paths = ["/home/r2ci/rex", "/home/r2ci/brax", "/home/r2ci/trajax"]
add_paths = ["/home/r2ci/rex", "/home/r2ci/trajax"]
for p in add_paths:
    if p not in sys.path:
        sys.path.append(p)
print(sys.path)

['/home/r2ci/rex/notebooks', '', '/home/r2ci/catkin_ws/devel/lib/python3/dist-packages', '/home/r2ci/interbotix_ws/devel/lib/python3/dist-packages', '/opt/ros/noetic/lib/python3/dist-packages', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '/home/r2ci/.cache/pypoetry/virtualenvs/rex-lib-fAzIlxw_-py3.8/lib/python3.8/site-packages', '/home/r2ci/rex', '/home/r2ci/trajax']


In [2]:
from IPython.display import HTML, clear_output

import jax
import jax.numpy as jnp
from math import ceil
import os
import tqdm

import brax
from brax.io import html
from brax.io import mjcf
from brax.generalized import pipeline as gen_pipeline
from brax.positional import pipeline as pos_pipeline
from brax.spring import pipeline as spr_pipeline
from brax import base, math
from flax import struct

from rex.utils import timer

In [3]:
print("loading system")
m = mjcf.load('/home/r2ci/rex/envs/vx300s/assets/vx300s_cem_brax.xml')
print(f"degrees of freedom: {m.qd_size()}")

# Determine collision pairs
print("\nCOLLISIONS")
from brax.geometry.contact import _geom_pairs
for (geom_i, geom_j) in _geom_pairs(m):
    # print(geom_i.link_idx, geom_j.link_idx)
    name_i = m.link_names[geom_i.link_idx[0]] if geom_i.link_idx is not None else "world"
    name_j = m.link_names[geom_j.link_idx[0]] if geom_j.link_idx is not None else "world"
    print(f"collision pair: {name_i} --> {name_j}")

# Actuators
print("\nACTUATOR SIZE")
print(f"actuator size: {m.act_size()}")
q_id = m.actuator.q_id[:6]

# DOFS
print("\nDEGREES OF FREEDOM SIZE")
print(f"degrees of freedom: {m.qd_size()}")

# Select pipeline
pipeline = gen_pipeline

# Overwrite xml values
m = m.replace(dt=0.015, solver_maxls=10)  # generalized 
# m = m.replace(dt=0.01, 
#               baumgarte_erp=0.1,  # default=0.1
#               spring_inertia_scale=1.0, # default=0.0
#               spring_mass_scale=0.0, # default=0.0
#               vel_damping=0.0,  # default=0.0
#               ang_damping=-0.05,  # default=0.0
#               joint_scale_pos=0.5,  # default=0.5
#               joint_scale_ang=0.1,  # default=0.2
#              )

# Get sampling time (0.8s horizon needed)
total_time = 5
cem_dt = 0.15
substeps = ceil(cem_dt / m.dt)
timesteps = ceil(total_time / cem_dt)
assert cem_dt > m.dt
dt = cem_dt / substeps
m = m.replace(dt=dt)
horizon = 4
print(f"\nTIME")
print(f"cem_dt: {cem_dt}, brax_dt: {dt}, substeps: {substeps}, horizon_steps: {horizon}, horizon_t: {horizon * cem_dt}, t_final: {timesteps*cem_dt}")

control_high = 0.35 * 3.1416 * jnp.ones(q_id.shape[0])
control_low = -control_high

# Print relevant parameters
parameters_dict = {
    pos_pipeline: [
        'dt',
        'joint_scale_pos',
        'joint_scale_ang',
        'collide_scale',
        'ang_damping',  # shared with `brax.physics.spring`
        'vel_damping',  # shared with `brax.physics.spring`
        'baumgarte_erp',  # shared with `brax.physics.spring`
        'spring_mass_scale',  # shared with `brax.physics.spring`
        'spring_inertia_scale',  # shared with `brax.physics.spring`
        'constraint_ang_damping',  # shared with `brax.physics.spring`
        'elasticity',  # shared with `brax.physics.spring`
    ],
    spr_pipeline: [
        'dt',
        'constraint_stiffness',
        'constraint_limit_stiffness',
        'constraint_vel_damping',
        'ang_damping',  # shared with `brax.physics.positional`
        'vel_damping',  # shared with `brax.physics.positional`
        'baumgarte_erp',  # shared with `brax.physics.positional`
        'spring_mass_scale',  # shared with `brax.physics.positional`
        'spring_inertia_scale',  # shared with `brax.physics.positional`
        'constraint_ang_damping',  # shared with `brax.physics.positional`
        'elasticity',  # shared with `brax.physics.positional`
    ],
    gen_pipeline: [
        'dt',
        'matrix_inv_iterations',
        'solver_iterations',
        'solver_maxls',
    ]
    # The 'convex' parameter is not included due to its unknown usage.
}

print(f"\nPARAMETERS: {pipeline.__name__}")
for p in parameters_dict[pipeline]:
    try:
        print(f"{p}: {m.__getattribute__(p)}")
        continue
    except AttributeError:
        pass
    try:
        print(f"{p}: {m.link.__getattribute__(p)}")
        continue
    except AttributeError:
        pass
    try:
        print(f"{p}: {m.geoms[0].__getattribute__(p)}")
        continue
    except AttributeError:
        pass

loading system
degrees of freedom: 13

COLLISIONS
collision pair: box --> world
collision pair: vx300s/gripper_link --> world
collision pair: box --> vx300s/gripper_link

ACTUATOR SIZE
actuator size: 6

DEGREES OF FREEDOM SIZE
degrees of freedom: 13

TIME
cem_dt: 0.15, brax_dt: 0.015, substeps: 10, horizon_steps: 4, horizon_t: 0.6, t_final: 5.1

PARAMETERS: brax.generalized.pipeline
dt: 0.015
matrix_inv_iterations: 10
solver_iterations: 10
solver_maxls: 10


In [4]:
cem_hyperparams = {
      'sampling_smoothing': 0.,
      'evolution_smoothing': 0.1,
      'elite_portion': 0.1,
      'max_iter': 4,
      'num_samples': 150
}

# Get indices
ee_arm_idx = m.link_names.index("ee_link")
box_idx = m.link_names.index("box")
goal_idx = m.link_names.index("goal")


def format_info(info):
    formatted_items = []
    for key, value in info.items():
        try:
            value_list = value.tolist()
            if isinstance(value_list, list):
                value_list= [round(v, 2) for v in value_list]
            else:
                value_list = round(value_list, 2)
            formatted_items.append(f"{key}: {value_list}")
        except AttributeError:
            formatted_items.append(f"{key}: {round(value, 2)}")

    formatted_string = ' | '.join(formatted_items)
    return formatted_string


def save(path, json_rollout):
    """Saves trajectory as an HTML text file."""
    from etils import epath
    path = epath.Path(path)
    if not path.parent.exists():
        path.parent.mkdir(parents=True)
    path.write_text(json_rollout)


@struct.dataclass
class State:
    pipeline_state: base.State
    q_des: jnp.ndarray


@struct.dataclass
class Params:
    substeps: jnp.ndarray
    sys: base.System
    

def _cost(params, state, action):
    pipeline_state = state.pipeline_state
    sys = params.sys

    # Get indices
    ee_arm_idx = sys.link_names.index("ee_link")
    box_idx = sys.link_names.index("box")
    goal_idx = sys.link_names.index("goal")

    x_i = pipeline_state.x.vmap().do(
        base.Transform.create(pos=sys.link.inertia.transform.pos)
    )
    boxpos = x_i.pos[box_idx]
    eepos = x_i.pos[ee_arm_idx]
    goalpos = x_i.pos[goal_idx][:2]

    rot_mat = math.quat_to_3x3(x_i.rot[ee_arm_idx])
    ee_to_goal = goalpos - eepos[:2]
    box_to_goal = goalpos - boxpos[:2]

    # if dot(ee_yaxis (in global), ee_to_goal (in global))==0 --> ee_yaxis = perpedicular to box
    # ee_yaxis is parellel to gripper_bar axis
    norm_ee_to_goal = ee_to_goal / math.safe_norm(ee_to_goal)
    cost_orn = jnp.abs(jnp.dot(rot_mat[:2, 1], norm_ee_to_goal))

    # ee_xaxis points in -z if ee is oriented downward
    norm_box_to_goal = box_to_goal / math.safe_norm(box_to_goal)
    target_ee_xaxis = jnp.concatenate([norm_box_to_goal, jnp.array([-50.0])])  # making this more negative forces the ee to be pointing downward
    norm_target_ee_xaxis = target_ee_xaxis / math.safe_norm(target_ee_xaxis)
    cost_down = jnp.abs(rot_mat[:2, 0]).sum()
    # todo: not sure if this cost is correct for 180 degree rotations around 
    # cost_down = (1-jnp.dot(rot_mat[:3, 0], norm_target_ee_xaxis))  # Here, the dot is 1 if ee_xaxis == target_ee_axis
    # todo: Not sure
    # cost_down = 0.5 * jnp.abs(rot_mat[2, 0]+1)

    # Distances in xy-plane
    box_to_ee = (eepos - boxpos)[:2]
    box_to_goal = (goalpos - boxpos[:2])
    dist_box_to_ee = math.safe_norm(box_to_ee)
    dist_box_to_goal = math.safe_norm(box_to_goal)
    cost_align = (jnp.sum(box_to_ee * box_to_goal) / (dist_box_to_ee*dist_box_to_goal) + 1)

    # Radius cost
    box_dist = math.safe_norm(box_to_goal)
    ee_dist = math.safe_norm(ee_to_goal)
    cost_radius = jnp.where(ee_dist <= (box_dist+0.06), 15.0, 0.0)

    # Force cost
    cost_force = (jnp.abs(pipeline_state.qf_constraint).max() - 0.46) ** 2
    # cost_force = (pipeline_state.qf_constraint[2]-0.46) ** 2

    # cost_z = 1.0 * jnp.abs(eepos[2] - 0.09)
    # cost_z = jnp.abs(eepos[2] - 0.075)
    cost_z = jnp.abs(eepos[2]-0.05) 
    cost_near = math.safe_norm((boxpos - eepos)[:2])
    cost_dist = math.safe_norm(boxpos[:2] - goalpos)
    cost_ctrl = math.safe_norm(action)

    cm = cost_dist

    # Weight all costs
    alpha = 1 #/ (1 + 2.0 * jnp.abs(cost_down + cost_orn))
    cost_orn = 0.0 * cost_orn
    cost_down = 3.0 * cost_down
    cost_radius = 0.0 * cost_radius
    cost_force = 10.0 * cost_force
    cost_align = 2.0 * cost_align
    cost_z = 1.0 * cost_z * alpha
    cost_near = 10.0 * cost_near * alpha
    cost_dist = 50.0 * cost_dist * alpha
    cost_ctrl = 0.1 * cost_ctrl

    total_cost = cost_ctrl + cost_z + cost_near + cost_dist + cost_radius + cost_orn + cost_down + cost_align + cost_force
    # total_cost = cost_down + cost_orn + cost_z + cost_near
    info = {"cm": cm*100, "cost": total_cost, "cost_orn": cost_orn, "cost_force": cost_force, "cost_down": cost_down, "cost_radius": cost_radius, "cost_align": cost_align, "cost_z": cost_z, "cost_near": cost_near, "cost_dist": cost_dist, "cost_ctrl": cost_ctrl, "alpha": alpha}
    # return cost_z + cost_dist + cost_near + cost_ctrl + cost_radius + cost_orn + cost_down
    return total_cost, info

def cost(params, state, action, time_step: int):
    total_cost, info = _cost(params, state, action)
    return total_cost


def get_action(q_des):
    if m.act_size() == 12:
        action = jnp.concatenate([q_des, 0*jnp.ones(q_des.shape)])
    else:
        action = q_des
    return action
    

def dynamics(params: Params, state: State, action, time_step):

    def loop_cond(args):
        i, _ = args
        return i < params.substeps
        
    
    def loop_body(args):
        i, state = args
        q_des = state.q_des + action * params.sys.dt  # todo: clip to max angles?
        q_qd_des = get_action(q_des)
        pipeline_state = pipeline.step(params.sys, state.pipeline_state, q_qd_des)
        return i+1, State(pipeline_state=pipeline_state, q_des=q_des)

    i, state = jax.lax.while_loop(loop_cond, loop_body, (0, state))

    return state


def env_reset(_m, boxpos, boxyaw, goalpos, jpos):
    qpos_box_goal = jnp.concatenate([boxpos_home, boxyaw_home, goalpos])
    ndof_box_goal = qpos_box_goal.shape[0]
    qpos = _m.init_q.at[0:ndof_box_goal].set(qpos_box_goal)
    qpos = qpos.at[ndof_box_goal:-1].set(jpos)
    pipeline_state = pipeline.init(_m, qpos, jnp.zeros(m.qd_size()))
    return pipeline_state



    


from trajax.optimizers import cem, random_shooting
from functools import partial
jit_cem = jax.jit(partial(cem, cost, dynamics, hyperparams=cem_hyperparams))
jit_cost = jax.jit(_cost)

In [5]:
# Initialize (NO CONTROL)
jit_env_reset = jax.jit(env_reset)
jit_env_step = jax.jit(pipeline.step)
rng = jax.random.PRNGKey(seed=1)
with timer("jit[reset]", log_level=100):
    boxpos_home = jnp.array([0.35, 0.1, 0.051])
    boxyaw_home = jnp.array([jnp.pi/2])
    goalpos = boxpos_home[0:2] + jnp.array([-0.1, 0.45])
    jpos_home = jnp.array([0, 0, 0, -jnp.pi/4, jnp.pi/2, 0])
    pipeline_state = jit_env_reset(m, boxpos_home, boxyaw_home, goalpos, jpos_home)
with timer("eval[reset]", log_level=100):
    _ = jit_env_reset(m, boxpos_home, boxyaw_home, goalpos, jpos_home)
with timer("jit[step]", log_level=100):
    _ = jit_env_step(m, pipeline_state, 10 * jnp.sin(1 / 100) * jnp.ones(m.act_size()))
with timer("eval[step]", log_level=100):
    _ = jit_env_step(m, pipeline_state, 10 * jnp.sin(1 / 100) * jnp.ones(m.act_size()))

[97m[35696][MainThread               ][tracer              ][jit[reset]          ] Elapsed: 9.05883002281189[0m
[97m[35696][MainThread               ][tracer              ][eval[reset]         ] Elapsed: 0.001154184341430664[0m
[97m[35696][MainThread               ][tracer              ][jit[step]           ] Elapsed: 9.416565895080566[0m
[97m[35696][MainThread               ][tracer              ][eval[step]          ] Elapsed: 0.0024149417877197266[0m


In [32]:
# Get indices
# ee_arm_idx = m.link_names.index("ee_link")
# box_idx = m.link_names.index("box")
# goal_idx = m.link_names.index("goal")

# x_i = pipeline_state.x.vmap().do(
#     base.Transform.create(pos=m.link.inertia.transform.pos)
# )
# boxpos = x_i.pos[box_idx]
# eepos = x_i.pos[ee_arm_idx]
# goalpos = x_i.pos[goal_idx][:2]

# rot_mat = math.quat_to_3x3(x_i.rot[ee_arm_idx])
# total_cost, info = jit_cost(params, next_state, next_q_des)
# print(format_info(info))

12

In [48]:
# Rollouts (NO CONTROL)
rollout = []
pipeline_state = jit_env_reset(m, boxpos_home, boxyaw_home, goalpos, jpos_home)
params = Params(sys=m, substeps=substeps)
next_q_des = pipeline_state.q[m.actuator.q_id]
pbar = tqdm.tqdm(range(timesteps), desc=f"Episode")
for i in pbar:
    next_state = State(pipeline_state=pipeline_state, q_des=next_q_des)
    total_cost, info = jit_cost(params, next_state, next_q_des)
    pbar.set_postfix_str(format_info(info))
    # print("")  # Uncomment to print history.
    for j in range(substeps):
        mean = jnp.zeros((1, m.act_size()))
        mean = mean.at[0, 0].set(0.0)
        # next_q_des = next_q_des + mean[0] * m.dt  # todo: clip to max angles?
        pipeline_state = jit_env_step(m, pipeline_state, next_q_des)
        rollout.append(pipeline_state)

Episode: 100%|███████████████████████████████████████████████████████████████████| 34/34 [00:02<00:00, 15.88it/s, alpha: 1 | cm: 46.1 | cost: 29.34 | cost_align: 0.08 | cost_ctrl: 0.25 | cost_dist: 23.05 | cost_down: 3.0 | cost_force: 0.4 | cost_near: 2.37 | cost_orn: 0.0 | cost_radius: 0.0 | cost_z: 0.2]


In [61]:
rollout_json = html.render(m, rollout)
# save("./vx300s_render.html", rollout_json)
HTML(rollout_json)

In [6]:
# Initialize (CEM)
init_state = State(pipeline_state=pipeline_state, q_des=pipeline_state.q[q_id])
params = Params(sys=m, substeps=substeps)
init_controls = jnp.zeros((horizon, q_id.shape[0]))
with timer("jit[cem]", log_level=100):
    X, mean, obj = jit_cem(params, init_state, init_controls, control_low, control_high, jax.random.PRNGKey(0))
with timer("eval[cem]", log_level=100):
    _ = jit_cem(params, init_state, init_controls, control_low, control_high, jax.random.PRNGKey(0))

[97m[35696][MainThread               ][tracer              ][jit[cem]            ] Elapsed: 27.422181844711304[0m
[97m[35696][MainThread               ][tracer              ][eval[cem]           ] Elapsed: 0.5396618843078613[0m


In [7]:
# Rollouts (CEM)
rollout = []
pipeline_state = jit_env_reset(m, boxpos_home, boxyaw_home, goalpos, jpos_home)
params = Params(sys=m, substeps=substeps)
next_q_des = pipeline_state.q[q_id]
init_controls = jnp.zeros((horizon, q_id.shape[0]))
_, next_controls, _ = jit_cem(params, init_state, init_controls, control_low, control_high, jax.random.PRNGKey(0))
pbar = tqdm.tqdm(range(timesteps), desc=f"Episode")
for i in pbar:
    next_state = State(pipeline_state=pipeline_state, q_des=next_q_des)
    X, mean, obj = jit_cem(params, next_state, next_controls, control_low, control_high, jax.random.PRNGKey(i))
    total_cost, info = jit_cost(params, next_state, mean[0])
    pbar.set_postfix_str(format_info(info))
    # print("")  # Uncomment to print history.
    for j in range(substeps):
        next_q_des = next_q_des + mean[0] * m.dt  # todo: clip to max angles?
        q_qd_des = get_action(next_q_des)
        pipeline_state = jit_env_step(m, pipeline_state, q_qd_des)
        rollout.append(pipeline_state)
    next_controls = jnp.vstack((mean[1:], jnp.zeros((1, q_id.shape[0]))))

Episode: 100%|████████████████████████████████████████████████████████████████████| 34/34 [00:21<00:00,  1.61it/s, alpha: 1 | cm: 2.6 | cost: 3.64 | cost_align: 0.35 | cost_ctrl: 0.07 | cost_dist: 1.3 | cost_down: 0.13 | cost_force: 0.44 | cost_near: 1.33 | cost_orn: 0.0 | cost_radius: 0.0 | cost_z: 0.03]


In [32]:
# INVESTIGATE TIME SPEND IN ROLLOUTS vs NUM SAMPLES
NUM_SAMPLES = 400


import numpy as np
from trajax.optimizers import cem, random_shooting
from functools import partial
# jit_cem = jax.jit(partial(cem, cost, dynamics, hyperparams=cem_hyperparams))
# jit_cost = jax.jit(_cost)

# Params
params = Params(sys=m, substeps=substeps)

# 
# tmp_cost = partial(cost, params)

def _rollout(U, x0, *args):
    def dynamics_for_scan(x, ut):
        u, t = ut
        x_next = dynamics(params, x, u, t, *args)
        c_next = cost(params, x_next, u, t)
        return x_next, c_next

    all_c = jax.lax.scan(f=dynamics_for_scan, init=x0, xs=(U, np.arange(U.shape[0])))[1]
    return all_c

jit_vmap_rollout = jax.jit(jax.vmap(_rollout, in_axes=(0, None)))


init_state = State(pipeline_state=pipeline_state, q_des=pipeline_state.q[m.actuator.q_id])
init_controls = jnp.zeros((NUM_SAMPLES, horizon, m.act_size()))

with timer("jit[rollout]", log_level=100):
    _ = jit_vmap_rollout(init_controls, init_state)
for _ in range(5):
    with timer("eval[rollout]", log_level=100):
        _ = jit_vmap_rollout(init_controls, init_state)

[97m[20771][MainThread               ][tracer              ][jit[rollout]        ] Elapsed: 14.05599308013916[0m
[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.1441516876220703[0m
[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.1267073154449463[0m
[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.1290123462677002[0m
[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.12676024436950684[0m
[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.12762975692749023[0m


In [33]:
for _ in range(5):
    with timer("eval[rollout]", log_level=100):
        _ = jit_vmap_rollout(init_controls, init_state)

[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.12561988830566406[0m
[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.1270732879638672[0m
[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.12749695777893066[0m
[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.12566137313842773[0m
[97m[20771][MainThread               ][tracer              ][eval[rollout]       ] Elapsed: 0.1297166347503662[0m
