In [2]:
import meshcat
import meshcat.geometry as g
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *

In [1]:
from torch.utils.tensorboard import SummaryWriter

In [3]:
from network import *
from loss import *
from train import *

In [4]:
world = SDFWorld()
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts


In [4]:
world.show_in_jupyter()

In [5]:
panda = Robot(world.vis, "panda", panda_model)

In [6]:
def batch_fk(qs):
    gripper_joints = jnp.full((qs.shape[0], 2), 0.04)
    qs = jnp.hstack([qs, gripper_joints])
    return jax.vmap(panda_model.fk_fn)(qs)

In [7]:
from jax import random
hp = Hyperparam()
hp.dims = [10, 10, 10, 1]
hp.lr = 0.001
hp.batch_size = 128

model = get_mlp(hp)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (hp.dims[0],))
params = model.init(key2, x)
tx = optax.adam(learning_rate=hp.lr)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [11]:
def get_ee_and_elbowxyz(q):
    gripper_joints = jnp.full((2), 0.04)
    qnew = jnp.hstack([q, gripper_joints])
    posevecs = panda_model.fk_fn(qnew)
    elbow_xyz = posevecs[4][-3:]
    ee_pose = posevecs[-1]
    return jnp.hstack([ee_pose, elbow_xyz])
@jax.jit
def get_batch_samples(qs):
    x_succ = jax.vmap(get_ee_and_elbowxyz)(qs)
    random_xyz = np.random.uniform([-1,-1,0.5],[1,1,1.5],size=(qs.shape[0],3))
    x_fail = x_succ.at[:,-3:].set(random_xyz)
    return x_succ, x_fail

def loss(state, params, x_succ, x_fail):
    logits_succ = state.apply_fn(params, x_succ).flatten()
    loss_succ = optax.sigmoid_binary_cross_entropy(
        logits_succ, jnp.ones_like(logits_succ)).mean()
    logits_fail = state.apply_fn(params, x_fail).flatten()
    loss_fail = optax.sigmoid_binary_cross_entropy(
        logits_fail, jnp.zeros_like(logits_fail)).mean()
    return loss_succ + loss_fail
grad_fn = jax.value_and_grad(loss, argnums=1)
@jax.jit
def training_step(state, x_succ, x_fail):
    loss, grads = grad_fn(state, state.params, x_succ, x_fail)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [12]:
for epoch in range(1000):
    qs = np.random.uniform(panda_model.lb[:7], panda_model.ub[:7], size=(10000,7))
    x_succ, x_fail = get_batch_samples(qs)
    state, loss = training_step(state, x_succ, x_fail)
    print(f"loss:{loss}")

loss:0.49723729491233826
loss:0.4928097426891327
loss:0.4915361702442169
loss:0.48887109756469727
loss:0.4832706153392792
loss:0.4755799472332001
loss:0.4750351905822754
loss:0.4740275740623474
loss:0.4671279788017273
loss:0.46403226256370544
loss:0.4613837003707886
loss:0.4561329185962677
loss:0.4579515755176544
loss:0.449409157037735
loss:0.44644567370414734
loss:0.44546815752983093
loss:0.44219475984573364
loss:0.43885916471481323
loss:0.4329032599925995
loss:0.43118128180503845
loss:0.4298625588417053
loss:0.42395535111427307
loss:0.4227527678012848
loss:0.4192536175251007
loss:0.4170992970466614
loss:0.4127121567726135
loss:0.4085482060909271
loss:0.40597057342529297
loss:0.4039827883243561
loss:0.39887911081314087
loss:0.396720290184021
loss:0.39366087317466736
loss:0.39136964082717896
loss:0.38900500535964966
loss:0.3871444761753082
loss:0.3815840184688568
loss:0.38070148229599
loss:0.38116952776908875
loss:0.3762173056602478
loss:0.3736759126186371
loss:0.3703068494796753
loss:

In [79]:
trained_param = state.params

In [86]:
qs = np.random.uniform(panda_model.lb[:7], panda_model.ub[:7], size=(10000,7))
x_succ, x_fail = get_batch_samples(qs)
failed = nn.sigmoid(model.apply(trained_param, x_succ)) < 0.5

In [101]:
grad_fn = jax.jacrev(model.apply, argnums=1)

In [191]:
x = x_fail[2]

In [213]:
x_grad = grad_fn(trained_param, x)[0]
x = x + x_grad*0.001
print(model.apply(trained_param, x))

frame.set_pose(SE3(x[:7]).normalize())
sphere.set_translate(x[7:])
print(x[7:])

[21.614344]
[-0.21727435 -0.09344015 -0.03501084]


In [27]:
panda.set_joint_angles(jnp.zeros(9))

In [228]:
frame = Frame(world.vis, "debug", length=0.15)
sphere1 = Sphere(world.vis, "shoulder", 0.1, "red", 0.5)
sphere2 = Sphere(world.vis, "elbow", 0.1, "red", 0.5)
#box = Sphere(world.vis, "elbow", 0.1, "red", 0.5)
sphere3 = Sphere(world.vis, "wrist", 0.1, "red", 0.5)

In [247]:
#panda.set_joint_angles(panda_model.neutral)
panda.set_joint_angles(qrand)

In [281]:
fks = panda_model.fk_fn(panda_model.neutral)
ee_pose = SE3(fks[-1])
shoulder = fks[2][-3:]
elbow = fks[4][-3:]
wrist = fks[7][-3:]
upper_arm_length = jnp.linalg.norm(elbow - shoulder)
wrist_wrt_ee = ee_pose.inverse().apply(wrist)

In [None]:
qrand = np.random.uniform(panda_model.lb, panda_model.ub)
fks = panda_model.fk_fn(qrand)
ee_pose = SE3(fks[-1])

In [468]:
def get_sample(q, xyz, noise):
    fks = panda_model.fk_fn(q)
    wrist = fks[7][-3:]
    elbow = fks[4][-3:]
    x_succ = jnp.hstack([wrist, elbow])
    x_fail1 = jnp.hstack([wrist, xyz])
    x_fail2 = jnp.hstack([wrist, elbow+noise])
    x_fail = jnp.vstack([x_fail1, x_fail2])
    return x_succ, x_fail
get_sample_batch = jax.jit(jax.vmap(get_sample, in_axes=(0,0,0)))

# qs = np.random.uniform(panda_model.lb, panda_model.ub, size=(100,9))
# xyz_random = np.random.uniform([-1,-1, -0.5], [1, 1, 1.5], size=(100,3))
# x_succ, x_fail = get_sample_batch(qs, xyz_random)

In [469]:
#train
def loss(state, params, x_succ, x_fail):
    logits_succ = state.apply_fn(params, x_succ).flatten()
    loss_succ = optax.sigmoid_binary_cross_entropy(
        logits_succ, jnp.ones_like(logits_succ)).mean()
    logits_fail = state.apply_fn(params, x_fail).flatten()
    loss_fail = optax.sigmoid_binary_cross_entropy(
        logits_fail, jnp.zeros_like(logits_fail)).mean()
    return loss_succ + loss_fail
grad_fn = jax.value_and_grad(loss, argnums=1)
@jax.jit
def training_step(state, x_succ, x_fail):
    loss, grads = grad_fn(state, state.params, x_succ, x_fail)
    state = state.apply_gradients(grads=grads)
    return state, loss

In [470]:
# init train
from jax import random
hp = Hyperparam()
hp.dims = [6, 32, 32, 1]
hp.lr = 0.001
hp.batch_size = 128

model = get_mlp(hp)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (hp.dims[0],))
params = model.init(key2, x)
tx = optax.adam(learning_rate=hp.lr)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [487]:
# training loop
num_batch = 100
for epoch in range(1000):
    qs = np.random.uniform(panda_model.lb, panda_model.ub, size=(num_batch,9))
    xyz_random = np.random.uniform([-1,-1, -0.5], [1, 1, 1.5], size=(num_batch,3))
    noise = np.random.normal(size=(num_batch,3)) * 0.5
    x_succ, x_fail = get_sample_batch(qs, xyz_random, noise)
    state, loss = training_step(state, x_succ, x_fail)
    print(f"loss:{loss}")

loss:0.06978580355644226
loss:0.06231020390987396
loss:0.03874324634671211
loss:0.03303641453385353
loss:0.05097556486725807
loss:0.025167366489768028
loss:0.02388017624616623
loss:0.07452349364757538
loss:0.09798727184534073
loss:0.028322262689471245
loss:0.05155383422970772
loss:0.02694651298224926
loss:0.0403045117855072
loss:0.01377770584076643
loss:0.05661853775382042
loss:0.012205051258206367
loss:0.039094626903533936
loss:0.07685178518295288
loss:0.007445528171956539
loss:0.07751922309398651
loss:0.09573078155517578
loss:0.031236805021762848
loss:0.04464467614889145
loss:0.06775760650634766
loss:0.042640700936317444
loss:0.011574069038033485
loss:0.042792461812496185
loss:0.033325694501399994
loss:0.050440479069948196
loss:0.045409440994262695
loss:0.09014075249433517
loss:0.05087340250611305
loss:0.05347944423556328
loss:0.04404553025960922
loss:0.09530609101057053
loss:0.08137667179107666
loss:0.07459788024425507
loss:0.04642801359295845
loss:0.037995509803295135
loss:0.032501

In [490]:
trained_params = state.params

In [506]:
qs = np.random.uniform(panda_model.lb, panda_model.ub, size=(num_batch,9))
xyz_random = np.random.uniform([-1,-1, -0.5], [1, 1, 1.5], size=(num_batch,3))
noise = np.random.normal(size=(num_batch,3))*0.5
x_succ, x_fail = get_sample_batch(qs, xyz_random, noise)

In [646]:
#make problem
q = panda.get_random_config()
panda.set_joint_angles(q)
ee_pose = SE3(panda_model.fk_fn(panda.q)[-1])
frame.set_pose(ee_pose)

In [654]:
wrist = ee_pose.apply(wrist_wrt_ee)
sphere2.set_translate(wrist)

In [670]:
# optimize
@jax.jit
def objective(wrist, elbow):
    upper_arm_err = jnp.abs(safe_2norm(shoulder - elbow) - upper_arm_length)
    lower_arm_len = safe_2norm(elbow - wrist)
    lower_arm_bounded = jnp.clip(lower_arm_len,a_min=0.35, a_max=0.5)
    lower_arm_err = jnp.abs(lower_arm_bounded - lower_arm_len)
    elbow_inf = - model.apply(trained_params, x)[0]
    return upper_arm_err + lower_arm_err + elbow_inf*0.1

In [684]:
f_vg = jax.value_and_grad(objective, argnums=1)

In [700]:
save("wrist_elbow_feas", trained_state, hp, force=True)

NameError: name 'trained_state' is not defined

In [682]:
#initialize
%time
xyz = np.random.uniform([-1,-1,-0.5], [1,1,1.5], size=(100,3))
values = jax.vmap(objective, in_axes=(None,0))(wrist, xyz)
elbow = xyz[values.argmin()]
feasibility = nn.sigmoid(model.apply(trained_params, jnp.hstack([wrist, elbow]))[0])
sphere3.set_translate(elbow)
print(feasibility)

CPU times: user 2 µs, sys: 14 µs, total: 16 µs
Wall time: 32.9 µs
0.053368486


In [690]:
value, elbow_grad = f_vg(wrist, elbow)

In [699]:
elbow = elbow - elbow_grad * 0.01
feasibility = nn.sigmoid(model.apply(
    trained_params, jnp.hstack([wrist, elbow]))[0])
sphere3.set_translate(elbow)

print(feasibility)

0.9956827


0.13750944


In [680]:
elbow_true = panda_model.fk_fn(panda.q)[4][-3:]

In [681]:
nn.sigmoid(model.apply(trained_params, jnp.hstack([wrist, elbow_true]))[0])

Array(0.99658185, dtype=float32)

In [676]:
elbow

array([ 0.23603655, -0.24421876,  0.35114774])

In [552]:
x = jnp.vstack(x_fail)[0]

In [543]:
x = x_succ[0]

In [592]:
val_grad_fn = jax.value_and_grad(objective)

In [644]:
f, x_grad = val_grad_fn(x)
x = x - x_grad * 0.001

#visualize
wrist, elbow = x[:3], x[3:]
sphere2.set_translate(wrist)
sphere3.set_translate(elbow)
print(f"f:{f}")

f:-0.5479756593704224


In [391]:


#derived:
wrist = ee_pose.apply(wrist_wrt_ee)

# shoulder = fks[2][-3:]
# elbow = fks[4][-3:]
# wrist = fks[7][-3:]


sphere1.set_translate(shoulder)
sphere2.set_translate(wrist)
sphere3.set_translate(elbow)
frame.set_pose(ee_pose)
panda.set_joint_angles(qrand)

lower_arm_length = jnp.linalg.norm(wrist - elbow)
lower_arm_length

Array(0.62608165, dtype=float32)

Array(0.04690649, dtype=float32)

In [8]:
qs = jnp.zeros((100,7))
batch_fk(qs).shape

(100, 12, 7)

In [14]:
frame.set_pose(SE3(panda_model.fk_fn(panda_model.neutral)[-1]))

In [15]:
elbow_xyz = panda_model.fk_fn(panda_model.neutral)[4][-3:]
ee_pose = SE3(panda_model.fk_fn(panda_model.neutral)[-1])