In [1]:
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

from network import *
from loss import *

In [2]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
frame = Frame(world.vis, "frame", 0.1)
elbow = Sphere(world.vis, "elbow", 0.1, "red", alpha=0.5)

concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts


In [4]:
fks = panda_model.fk_fn(panda_model.neutral)
p_shoulder = fks[1][-3:]
p_elbow = fks[4][-3:]
upper_arm_len = jnp.linalg.norm(p_elbow-p_shoulder)
sphere_points = fibonacci_sphere(1000)
#sphere_points_rtp = jax.vmap(to_spherical_coord)(sphere_points)
shoulder_sphere_points = sphere_points * upper_arm_len + p_shoulder
ws_lb, ws_ub = [-1, -1, -0.5], [1, 1, 1.5]

In [403]:
pc = PointCloud(world.vis, "pc", np.array(shoulder_sphere_points).astype(np.float64), color="blue")

In [5]:
from scipy.spatial.transform import Rotation
def get_random_samples(num_samples):
    xyz = np.random.uniform(ws_lb, ws_ub, (num_samples, 3))
    qtns_xyzw = Rotation.random(num_samples).as_quat()
    qtns = qtns_xyzw[:,[3,0,1,2]]
    indices = np.random.randint(0, 1000, size=num_samples)
    p_elbow = shoulder_sphere_points[indices]
    #tp = sphere_points_rtp[indices, 1:]
    return jnp.hstack([qtns, xyz, p_elbow])

def generate_successful_sample(q):
    gripper_joints = jnp.full((2), 0.04)
    qnew = jnp.hstack([q, gripper_joints])
    fks = panda_model.fk_fn(qnew)
    ee_pose = fks[-1]
    p_elbow = fks[4][-3:]
    #rtp = to_spherical_coord(p_elbow - p_shoulder)
    return jnp.hstack([ee_pose, p_elbow])

def generate_random_joints(num_samples):
    return jnp.array(np.random.uniform(
        panda_model.lb[:7], panda_model.ub[:7], size=(num_samples,7)))

@jax.jit
def get_batch_samples(qs):
    x_succ = jax.vmap(generate_successful_sample)(qs) #true data
    #tp_noise = np.random.normal(size=(qs.shape[0],2)) * jnp.pi/10
    p_noise = np.random.normal(size=(qs.shape[0],3)) * 0.05
    x_fail1 = get_random_samples(qs.shape[0])
    x_fail2 = x_succ.at[:,7:].set(x_succ[:,7:] + p_noise)
    x_fail3 = x_succ.at[:,4:7].set(x_succ[:,4:7] + p_noise)
    x_fail = jnp.vstack([x_fail1, x_fail2, x_fail3])
    return x_succ, x_fail

In [505]:
qs = generate_random_joints(100)
x_succ, x_fail = get_batch_samples(qs)

In [524]:
i=100
x_sample = x_fail

In [529]:
frame.set_pose(SE3(x_sample[i][:7]))
elbow.set_translate(x_sample[i][7:])
print(x_sample[i])
i+= 1

[-0.23737708 -0.7212028  -0.59818256 -0.25631255  0.30037656 -0.19551739
 -0.10780752 -0.04035843 -0.18666959  0.5045196 ]


In [12]:
from jax import random
hp = Hyperparam()
hp.dims = [10, 32, 32, 32, 32, 1]
hp.lr = 0.001
hp.batch_size = 128

model = get_mlp(hp)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (hp.dims[0],))
params = model.init(key2, x)
tx = optax.adam(learning_rate=hp.lr)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [13]:
def loss(state, params, x_succ, x_fail):
    logits_succ = state.apply_fn(params, x_succ).flatten()
    loss_succ = optax.sigmoid_binary_cross_entropy(
        logits_succ, jnp.ones_like(logits_succ)).mean()
    logits_fail = state.apply_fn(params, x_fail).flatten()
    loss_fail = optax.sigmoid_binary_cross_entropy(
        logits_fail, jnp.zeros_like(logits_fail)).mean()
    return loss_succ + loss_fail, (loss_succ, loss_fail)
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)

@jax.jit
def training_step(state, x_succ, x_fail):
    (loss, (loss_succ, loss_fail)), grads = grad_fn(state, state.params, x_succ, x_fail)
    state = state.apply_gradients(grads=grads)
    return state, loss, loss_succ, loss_fail

In [14]:
for epoch in range(50000):
    qs = qs = generate_random_joints(1000)
    x_succ, x_fail = get_batch_samples(qs)
    state, loss, loss_succ, loss_fail = training_step(state, x_succ, x_fail)
    if epoch % 100 == 0:
        print(f"epoch:{epoch}, loss:{loss}")
        print(f"loss_succ{loss_succ}, loss_fail:{loss_fail}")

epoch:0, loss:1.3843345642089844
loss_succ0.6789218783378601, loss_fail:0.705412745475769
epoch:100, loss:1.2167222499847412
loss_succ0.5869525671005249, loss_fail:0.6297696232795715
epoch:200, loss:1.159193992614746
loss_succ0.5281186103820801, loss_fail:0.631075382232666
epoch:300, loss:1.1379543542861938
loss_succ0.5286485552787781, loss_fail:0.6093057990074158
epoch:400, loss:1.1220933198928833
loss_succ0.5221401453018188, loss_fail:0.5999531745910645
epoch:500, loss:1.1145291328430176
loss_succ0.5236716270446777, loss_fail:0.5908575654029846
epoch:600, loss:1.1061792373657227
loss_succ0.538483738899231, loss_fail:0.5676954984664917
epoch:700, loss:1.0903346538543701
loss_succ0.5103625059127808, loss_fail:0.5799721479415894
epoch:800, loss:1.0923676490783691
loss_succ0.49548324942588806, loss_fail:0.5968843698501587
epoch:900, loss:1.081726312637329
loss_succ0.4826229214668274, loss_fail:0.5991033315658569
epoch:1000, loss:1.0750495195388794
loss_succ0.4912102222442627, loss_fail:0

KeyboardInterrupt: 

In [15]:
trained_param = state.params

In [23]:
qs = np.random.uniform(panda_model.lb[:7], panda_model.ub[:7], size=(128,7))
x_succ, x_fail = get_batch_samples(qs)

In [27]:
bools = nn.sigmoid(model.apply(trained_param, x_fail)) > 0.5
jnp.arange(len(x_fail))[bools.flatten()]

Array([ 20,  34,  76,  85, 100, 118, 125, 128, 129, 130, 132, 133, 137,
       139, 141, 142, 144, 145, 146, 147, 148, 149, 150, 151, 152, 154,
       155, 156, 157, 159, 164, 165, 166, 171, 172, 173, 176, 177, 181,
       184, 187, 188, 189, 190, 191, 193, 195, 196, 198, 200, 204, 208,
       209, 212, 213, 215, 216, 219, 220, 222, 225, 229, 231, 232, 233,
       234, 236, 237, 238, 239, 244, 249, 251, 252, 254, 255, 256, 257,
       258, 259, 260, 261, 263, 265, 266, 268, 272, 273, 274, 276, 278,
       280, 283, 288, 289, 293, 296, 297, 299, 300, 303, 305, 306, 312,
       313, 314, 315, 317, 318, 320, 321, 322, 324, 326, 329, 334, 336,
       337, 338, 340, 343, 344, 345, 346, 347, 351, 352, 353, 354, 356,
       360, 361, 362, 364, 367, 368, 369, 371, 375, 376, 377, 379, 381,
       382, 383], dtype=int32)

In [45]:
def feasibility(ee_pose, tp_elbow):
    ee_posevec = ee_pose.parameters()
    return model.apply(trained_param, jnp.hstack([ee_posevec, tp_elbow]))

# problem
qrand = np.random.uniform(panda_model.lb, panda_model.ub)
pose_rand = SE3(panda_model.fk_fn(qrand)[-1])

#vis
frame.set_pose(pose_rand)
panda.set_joint_angles(qrand)

In [46]:
sphere_samples = fibonacci_sphere(1000)
sphere_shoulder = sphere_samples * upper_arm_len + p_shoulder

logits = jax.vmap(feasibility, in_axes=(None, 0))(pose_rand, sphere_shoulder)
indices = np.arange(len(logits))[logits.flatten() > 1.]
elbow_points = np.array(sphere_shoulder[indices], dtype=np.float64)
nn.sigmoid(logits[indices].flatten())

#vis
colors = np.tile(Colors.read("blue", return_rgb=True), len(elbow_points)).reshape(-1, 3)
pc = PointCloud(world.vis, "pc", elbow_points, color="blue")
# world.vis["pc"].set_object(
#     g.PointsGeometry(elbow_points.T, colors.T),
#     g.PointsMaterial(size=0.05)
# )
pc.load()

In [333]:
indices

array([173, 178, 186, 195, 199, 207, 228, 241, 250, 262, 275, 283, 284,
       296, 317, 330, 338, 339, 351, 364, 372, 385, 396, 406, 417, 419,
       427, 430, 438, 440, 441, 443, 451, 453, 459, 461, 472, 474, 475,
       480, 482, 493, 495, 498, 504, 506, 508, 509, 514, 516, 517, 522,
       527, 529, 530, 532, 535, 537, 538, 543, 548, 550, 551, 553, 556,
       558, 559, 561, 563, 564, 566, 571, 572, 579, 582, 584, 585, 587,
       592, 593, 595, 597, 598, 600, 605, 606, 608, 613, 616, 618, 619,
       621, 626, 627, 629, 634, 639, 640, 642, 647, 648, 650, 652, 655,
       660, 661, 663, 668, 673, 674, 676, 681, 682, 689, 695, 697, 702,
       707, 710, 716, 723, 731, 736, 737, 744, 752, 757, 765, 778, 786,
       799, 807, 820, 833, 841, 854, 875, 896, 909, 970, 983, 988, 996])

In [155]:
world.vis["pc"].delete()

In [72]:
save("elbow_feas_net_euclid.pth", state, hp, force=True)

In [73]:
feas_fn = get_mlp_by_path("elbow_feas_net_euclid.pth")

In [61]:
from sdf_world.sdf_world import *
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7004/static/


In [62]:
world.show_in_jupyter()

In [63]:
#panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model)

In [64]:
frame = Frame(world.vis, "frame", 0.2)
shoulder = Sphere(world.vis, "shoulder", 0.1, "red", 0.5)
elbow = Sphere(world.vis, "elbow", 0.1, "red", 0.5)
wrist = Sphere(world.vis, "wrist", 0.1, "red", 0.5)

In [65]:
#predefine
fks = panda_model.fk_fn(panda_model.neutral)
ee_pose = SE3(fks[-1])
# wrist_pos = panda_model.fk_fn(panda_model.neutral)[7][-3:]
# wrist_wrt_ee = ee_pose.inverse().apply(wrist_pos)
upper_arm_len = jnp.linalg.norm(fks[3][-3:] - fks[1][-3:])
p_shoulder = fks[1][-3:]
ws_lb = [-0.8, -0.8, -0.3]
ws_ub = [0.8, 0.8, 1.3]

In [153]:
sphere_points = fibonacci_sphere() * upper_arm_len + p_shoulder
# elbow_params = np.random.uniform([0., -np.pi], [np.pi, np.pi], size=(100,2))
elbow_points = jnp.array(sphere_points)
def feasibility_logit_fn(ee_posevec, elbow_param):
    return feas_fn(jnp.hstack([ee_posevec, elbow_param]))
@jax.jit
def ik(ee_pose):
    rot6d = ee_pose.as_matrix()[:3,:2].T.flatten()
    p_ee = ee_pose.translation()
    ee_posevec = jnp.hstack([rot6d, p_ee])
    logits = jax.vmap(feasibility_logit_fn, in_axes=(None, 0))(ee_posevec, elbow_points)
    return logits

In [154]:
cond = (ik(ee_pose) > 1.).flatten()
indices = jnp.arange(len(cond))[cond]
elbows = elbow_points[indices]

In [291]:
def ee_error(q, ee_pose_des):
    fks = panda_model.fk_fn(q)
    ee_pose_curr = SE3(fks[-1])
    ee_err_vec = (ee_pose_curr.inverse() @ ee_pose_des).log()
    ee_err = safe_2norm(ee_err_vec)
    return ee_err
def elbow_err(q, p_elbow_des):
    fks = panda_model.fk_fn(q)
    p_elbow_curr = fks[4][-3:]
    elbow_err = safe_2norm(p_elbow_curr - p_elbow_des)
    return elbow_err
err_grad_fn = jax.grad(ee_error)
elbowerr_grad_fn = jax.grad(elbow_err)

In [310]:
def get_jacobian(q):
    fks = panda_model.fk_fn(q)
    pos_jac = []
    rot_jac = []
    p_ee = fks[-1][-3:]
    for i in range(1, 8):
        p_frame = fks[i][-3:]
        rot_axis = SE3(fks[i]).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        pos_jac.append(lin_vel)
        rot_jac.append(rot_axis)
        
    pos_jac = jnp.vstack(pos_jac).T
    rot_jac = jnp.vstack(rot_jac).T
    return jnp.vstack([pos_jac, rot_jac])

In [156]:
p_elbow = elbows[1]
elbow.set_translate(p_elbow)

In [158]:
q = panda_model.neutral
panda.set_joint_angles(q)

In [401]:
ee_grad = err_grad_fn(q, pose_rand)
elbow_grad = elbowerr_grad_fn(q, p_elbow)

jac = get_jacobian(q)
ns_proj = jnp.eye(7) - jac.T@jac
q_delta = ee_grad + jnp.hstack([ns_proj@elbow_grad[:-2],0,0])
q = q - q_delta*0.2
panda.set_joint_angles(q)

In [183]:
q_grad

Array([ 0.07283606,  0.00364066,  0.0863461 , -0.08738822, -0.08824547,
        0.10718858,  0.20078273,  0.        ,  0.        ], dtype=float32)

In [65]:
far_points = farthest_point_sampling(points, 5)

In [99]:
p_elbow = far_points[0]

In [76]:
panda = Robot(world.vis, "panda", panda_model)
panda.reduce_dim([7,8], [0.04, 0.04])
frame_elbow = Frame(world.vis, "elbow_frame", 0.2)
elbow = Sphere(world.vis, "elbow", 0.1, "red", 0.5)

In [101]:
frame_1 = Frame(world.vis, "frame1", 0.2)

In [81]:
panda.set_joint_angles(panda_model.neutral[:7])

In [107]:
fks = panda_model.fk_fn(panda_model.neutral)
pose2 = SE3(fks[2])
pose5 = SE3(fks[5])
frame_1.set_pose(pose2)
frame_elbow.set_pose(pose5)

In [120]:
diff = (pose5.translation() - pose2.translation())
y_pose2 = pose2.as_matrix()[:3,1]

In [144]:
p_elbow = (diff @ y_pose2) * y_pose2 + pose2.translation()
ee_pose = SE3(fks[-1])
alpha_len = jnp.linalg.norm(p_elbow - fks[-6][-3:])
beta_len = jnp.linalg.norm(ee_pose.translation() - fks[-6][-3:])

In [132]:
frame.set_pose(ee_pose)

In [187]:
# given p_elbow, ee_pose, can we solve ik?
v = ee_pose.translation() - p_elbow
c = jnp.linalg.norm(v)
r = jnp.sqrt((beta_len**2 + c**2 - alpha_len**2)/(2*c**2))

In [190]:
# z_ee = ee_pose.as_matrix()[:3, 2]
# up_vec = jnp.cross(v, jnp.cross(v, z_ee))
# up_vec = up_vec/jnp.linalg.norm(up_vec)
# up_mag = jnp.sqrt(alpha_len**2 - (c*r)**2)
p_wrist = p_elbow + r*v #+up_vec*up_mag

In [191]:
wrist.set_translate(p_wrist)

In [158]:
v

Array([ 6.1263674e-01,  1.1562403e-08, -1.7700875e-01], dtype=float32)

In [139]:
wrist.set_translate(fks[-6][-3:])

In [129]:
elbow.set_translate(p_elbow)

In [112]:
diff

Array([4.6650025e-01, 5.9267933e-09, 3.9849854e-01], dtype=float32)

In [80]:
panda.set_joint_angles(jnp.array([0,-0.5,0,0,0,0,0]))

In [70]:
elbow.set_translate(far_points[4])