In [1]:
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

from network import *
from loss import *

In [2]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
frame = Frame(world.vis, "frame", 0.1)
elbow = Sphere(world.vis, "elbow", 0.1, "red", alpha=0.5)
wrist = Sphere(world.vis, "elbow", 0.07, "red", alpha=0.5)

concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts


In [4]:
fks = panda_model.fk_fn(panda_model.neutral)
p_shoulder = fks[1][-3:]
p_elbow = fks[4][-3:]
p_wrist = fks[5][-3:]
p_ee = fks[-1][-3:]
upper_arm_len = jnp.linalg.norm(p_elbow-p_shoulder)
hand_len = jnp.linalg.norm(p_wrist - p_ee)
lower_arm_len = jnp.linalg.norm(p_elbow-p_wrist)
sphere_points = fibonacci_sphere(1000)
#sphere_points_rtp = jax.vmap(to_spherical_coord)(sphere_points)
shoulder_sphere_points = sphere_points * upper_arm_len + p_shoulder
ee_sphere_points = sphere_points * hand_len
ws_lb, ws_ub = [-1, -1, -0.5], [1, 1, 1.5]
print(lower_arm_len)

0.3927623


In [403]:
pc = PointCloud(world.vis, "pc", np.array(shoulder_sphere_points).astype(np.float64), color="blue")

In [5]:
from scipy.spatial.transform import Rotation
def get_random_samples(num_samples):
    xyz = np.random.uniform(ws_lb, ws_ub, (num_samples, 3))
    qtns_xyzw = Rotation.random(num_samples).as_quat()
    qtns = qtns_xyzw[:,[3,0,1,2]]
    indices = np.random.randint(0, 1000, size=num_samples)
    p_elbow = shoulder_sphere_points[indices]
    indices = np.random.randint(0, 1000, size=num_samples)
    p_wrist = ee_sphere_points[indices] + xyz
    #tp = sphere_points_rtp[indices, 1:]
    return jnp.hstack([qtns, xyz, p_elbow, p_wrist]) #p_wrist

def generate_successful_sample(q):
    gripper_joints = jnp.full((2), 0.04)
    qnew = jnp.hstack([q, gripper_joints])
    fks = panda_model.fk_fn(qnew)
    ee_pose = fks[-1]
    p_elbow = fks[4][-3:]
    p_wrist = fks[5][-3:]
    #rtp = to_spherical_coord(p_elbow - p_shoulder)
    return jnp.hstack([ee_pose, p_elbow, p_wrist]) #p_wrist

def generate_random_joints(num_samples):
    return jnp.array(np.random.uniform(
        panda_model.lb[:7], panda_model.ub[:7], size=(num_samples,7)))

@jax.jit
def get_batch_samples(qs):
    x_succ = jax.vmap(generate_successful_sample)(qs) #true data
    #tp_noise = np.random.normal(size=(qs.shape[0],2)) * jnp.pi/10
    p_noise = np.random.normal(size=(qs.shape[0],6)) * 0.05
    #x_fail1 = get_random_samples(qs.shape[0])
    x_fail2 = x_succ.at[:,7:].set(x_succ[:,7:] + p_noise)
    
    indices = np.random.randint(0, 1000, size=qs.shape[0])
    p_shoulder_fail = shoulder_sphere_points[indices]
    indices = np.random.randint(0, 1000, size=qs.shape[0])
    p_wrist_fail = ee_sphere_points[indices] + x_succ[:,4:7]
    
    x_fail3 = x_succ.at[:,7:10].set(p_shoulder_fail)
    x_fail3 = x_fail3.at[:,10:13].set(p_wrist_fail)
    # x_fail3 = x_succ.at[:,10:13].set(x_succ[:,10:13] + p_noise)
    # x_fail3 = x_succ.at[:,4:7].set(x_succ[:,4:7] + p_noise)
    x_fail = jnp.vstack([x_fail2, x_fail3]) #x_fail3
    return x_succ, x_fail

In [56]:
qs = generate_random_joints(100)
x_succ, x_fail = get_batch_samples(qs)

In [6]:
from jax import random
hp = Hyperparam()
hp.dims = [13, 64, 64, 64, 64, 1]
hp.lr = 0.001
hp.batch_size = 128

model = get_mlp(hp)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (hp.dims[0],))
params = model.init(key2, x)
tx = optax.adam(learning_rate=hp.lr)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [7]:
def loss(state, params, x_succ, x_fail):
    logits_succ = state.apply_fn(params, x_succ).flatten()
    loss_succ = optax.sigmoid_binary_cross_entropy(
        logits_succ, jnp.ones_like(logits_succ)).mean()
    logits_fail = state.apply_fn(params, x_fail).flatten()
    loss_fail = optax.sigmoid_binary_cross_entropy(
        logits_fail, jnp.zeros_like(logits_fail)).mean()
    return loss_succ + loss_fail, (loss_succ, loss_fail)
grad_fn = jax.value_and_grad(loss, argnums=1, has_aux=True)

@jax.jit
def training_step(state, x_succ, x_fail):
    (loss, (loss_succ, loss_fail)), grads = grad_fn(state, state.params, x_succ, x_fail)
    state = state.apply_gradients(grads=grads)
    return state, loss, loss_succ, loss_fail

In [10]:
for epoch in range(10000):
    qs = qs = generate_random_joints(500)
    x_succ, x_fail = get_batch_samples(qs)
    state, loss, loss_succ, loss_fail = training_step(state, x_succ, x_fail)
    if epoch % 100 == 0:
        print(f"epoch:{epoch}, loss:{loss}")
        print(f"loss_succ{loss_succ}, loss_fail:{loss_fail}")

epoch:0, loss:0.09222819656133652
loss_succ0.026334576308727264, loss_fail:0.06589362025260925
epoch:100, loss:0.10091309994459152
loss_succ0.030481835827231407, loss_fail:0.07043126225471497
epoch:200, loss:0.08229967951774597
loss_succ0.024466127157211304, loss_fail:0.057833556085824966
epoch:300, loss:0.06659894436597824
loss_succ0.025636853650212288, loss_fail:0.0409620925784111
epoch:400, loss:0.060917988419532776
loss_succ0.021194806322455406, loss_fail:0.03972318395972252
epoch:500, loss:0.08779995143413544
loss_succ0.0228993222117424, loss_fail:0.06490062922239304
epoch:600, loss:0.08732685446739197
loss_succ0.018536318093538284, loss_fail:0.06879053264856339
epoch:700, loss:0.08431386202573776
loss_succ0.025149019435048103, loss_fail:0.05916484072804451
epoch:800, loss:0.10588708519935608
loss_succ0.015997812151908875, loss_fail:0.0898892730474472
epoch:900, loss:0.09114277362823486
loss_succ0.0227863360196352, loss_fail:0.06835643947124481
epoch:1000, loss:0.08893390744924545

In [11]:
trained_param = state.params

In [12]:
qs = np.random.uniform(panda_model.lb[:7], panda_model.ub[:7], size=(128,7))
x_succ, x_fail = get_batch_samples(qs)

In [13]:
bools = nn.sigmoid(model.apply(trained_param, x_fail)) > 0.5
jnp.arange(len(x_fail))[bools.flatten()]

Array([  1,   8,   9,  33,  53, 120, 123], dtype=int32)

In [102]:
# def feasibility(ee_pose, p_elbow, p_wrist):
#     ee_posevec = ee_pose.parameters()
#     return model.apply(trained_param, jnp.hstack([ee_posevec, p_elbow, p_wrist]))

# problem
qrand = np.random.uniform(panda_model.lb, panda_model.ub)
fks = panda_model.fk_fn(qrand)
pose_rand = SE3(fks[-1])
p_wrist = fks[5][-3:]

#vis
frame.set_pose(pose_rand)
panda.set_joint_angles(qrand)

In [56]:

# sphere_samples = fibonacci_sphere(1000)
# sphere_shoulder = sphere_samples * upper_arm_len + p_shoulder

logits = jax.vmap(feasibility, in_axes=(None, 0, None))(
    pose_rand, shoulder_sphere_points, p_wrist)
indices = np.arange(len(logits))[logits.flatten() > 0.]
elbow_points = np.array(shoulder_sphere_points[indices], dtype=np.float64)
elbow.set_translate(p_wrist)

# cond
#lower_arm_err = jnp.abs(jnp.linalg.norm(elbow_points - p_wrist, axis=-1) - lower_arm_len)
#elbow_points = elbow_points[lower_arm_err < 0.005]
nn.sigmoid(logits[indices].flatten())

#vis
colors = np.tile(Colors.read("blue", return_rgb=True), len(elbow_points)).reshape(-1, 3)
pc = PointCloud(world.vis, "pc", elbow_points, color="blue")

In [54]:
del pc

In [57]:
from sdf_world.nlp import *

In [113]:
prob = NLP()
lb = jnp.hstack([panda_model.lb[:7], ws_lb, ws_lb])
ub = jnp.hstack([panda_model.ub[:7], ws_ub, ws_ub])
prob.add_var("x", 13, lb, ub)

In [114]:
def error_radius(x):
    p_ee = x[4:7]
    p_elbow = x[7:10]
    p_wrist = x[10:13]
    err_upper = safe_2norm(p_shoulder - p_elbow) - upper_arm_len
    err_lower = safe_2norm(p_wrist - p_ee) - hand_len
    err_mid = safe_2norm(p_elbow - p_wrist) - lower_arm_len
    return jnp.abs(err_upper) + jnp.abs(err_lower) + jnp.abs(err_mid)

# def error_wrist_radius(x):
#     p_ee = x[4:7]
#     p_wrist = x[10:13]
#     err_wrist = safe_2norm(p_wrist - p_ee) - lower_arm_len
#     return jnp.abs(err_wrist)
# def error_mid_radius(x):

def feasibility(x):
    ee_pose = SE3(x[:7]).normalize()
    x = jnp.hstack([ee_pose.parameters(), x[7:]])
    logit = model.apply(trained_param, x)
    return - logit
def error_ee(x):
    ee_pose_curr = SE3(x[:7]).normalize()
    return safe_2norm((pose_rand.inverse()@ee_pose_curr).log())

prob.add_con("elbow_radius", 1, ["x"], error_radius, upper=0.01)
#prob.add_con("wrist_radius", 1, ["x"], error_wrist_radius, upper=0.01)
prob.add_con("fk_feas", 1, ["x"], feasibility, lower=0.)
prob.add_objective(error_ee)

In [115]:
prob.build()

In [119]:
x0 = np.random.uniform(lb, ub)

In [120]:
xsol, info = prob.solve(x0, viol_tol=0.001)

This is Ipopt version 3.14.10, running with linear solver MUMPS 5.2.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:       26
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:       13
                     variables with only lower bounds:        0
                variables with lower and upper bounds:       13
                     variables with only upper bounds:        0
Total number of equality constraints.................:        0
Total number of inequality constraints...............:        2
        inequality constraints with only lower bounds:        1
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        1

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  2.4088335e+00 4.52e+00 9.96e-01   0.0 0.00e+00    -  0.00e+00 0.00e+00 

In [144]:
x = xsol
p_ee = x[4:7]
p_elbow = x[7:10]
p_wrist = x[10:13]
err_upper = safe_2norm(p_shoulder - p_elbow) - upper_arm_len
err_lower = safe_2norm(p_wrist - p_ee) - hand_len
err_mid = safe_2norm(p_elbow - p_wrist) - lower_arm_len
err_upper, err_lower, err_mid

(Array(0.00592175, dtype=float32),
 Array(-0.00012724, dtype=float32),
 Array(0.00338504, dtype=float32))

In [143]:
safe_2norm(p_wrist - p_ee)-hand_len

Array(-0.00012724, dtype=float32)

In [142]:
hand_len

Array(0.2295386, dtype=float32)

In [134]:
jnp.linalg.norm(ee_pose.translation() - p_wrist)

Array(0.22941135, dtype=float32)

In [128]:
safe_2norm(p_wrist - p_ee) - hand_len

Array(0.91949075, dtype=float32)

In [123]:
feasibility(xsol)

Array([16.513062], dtype=float32)

In [135]:
error_radius(xsol)

Array(0.00943403, dtype=float32)

In [130]:
ee_pose = SE3(xsol[:7]).normalize()
p_elbow = xsol[7:10]
p_wrist = xsol[10:]
frame.set_pose(ee_pose)
elbow.set_translate(p_elbow)
wrist.set_translate(p_wrist)

In [86]:
elbow.set_translate([0,0,0])

In [99]:
error_elbow_radius(xsol)

Array(0.00091258, dtype=float32)

In [100]:
error_wrist_radius(xsol)

Array(0.00994909, dtype=float32)

In [98]:
feasibility(xsol)

Array([27.751844], dtype=float32)

In [91]:
elbow = Sphere(world.vis, "elbow", 0.07, "blue", 0.5)

In [79]:
wrist = Sphere(world.vis, "wrist", 0.07, "red", 0.5)

In [165]:
lower_arm_err<0.001

Array([False, False, False, False, False, False, False, False, False,
       False,  True, False, False], dtype=bool)

In [163]:
lower_arm_err.min()

Array(8.234382e-05, dtype=float32)

In [155]:
lower_arm_lens.max()

Array(0.00815082, dtype=float32)

In [155]:
world.vis["pc"].delete()

In [72]:
save("elbow_feas_net_euclid.pth", state, hp, force=True)

In [73]:
feas_fn = get_mlp_by_path("elbow_feas_net_euclid.pth")

In [61]:
from sdf_world.sdf_world import *
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7004/static/


In [62]:
world.show_in_jupyter()

In [63]:
#panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model)

In [64]:
frame = Frame(world.vis, "frame", 0.2)
shoulder = Sphere(world.vis, "shoulder", 0.1, "red", 0.5)
elbow = Sphere(world.vis, "elbow", 0.1, "red", 0.5)
wrist = Sphere(world.vis, "wrist", 0.1, "red", 0.5)

In [65]:
#predefine
fks = panda_model.fk_fn(panda_model.neutral)
ee_pose = SE3(fks[-1])
# wrist_pos = panda_model.fk_fn(panda_model.neutral)[7][-3:]
# wrist_wrt_ee = ee_pose.inverse().apply(wrist_pos)
upper_arm_len = jnp.linalg.norm(fks[3][-3:] - fks[1][-3:])
p_shoulder = fks[1][-3:]
ws_lb = [-0.8, -0.8, -0.3]
ws_ub = [0.8, 0.8, 1.3]

In [153]:
sphere_points = fibonacci_sphere() * upper_arm_len + p_shoulder
# elbow_params = np.random.uniform([0., -np.pi], [np.pi, np.pi], size=(100,2))
elbow_points = jnp.array(sphere_points)
def feasibility_logit_fn(ee_posevec, elbow_param):
    return feas_fn(jnp.hstack([ee_posevec, elbow_param]))
@jax.jit
def ik(ee_pose):
    rot6d = ee_pose.as_matrix()[:3,:2].T.flatten()
    p_ee = ee_pose.translation()
    ee_posevec = jnp.hstack([rot6d, p_ee])
    logits = jax.vmap(feasibility_logit_fn, in_axes=(None, 0))(ee_posevec, elbow_points)
    return logits

In [154]:
cond = (ik(ee_pose) > 1.).flatten()
indices = jnp.arange(len(cond))[cond]
elbows = elbow_points[indices]

In [291]:
def ee_error(q, ee_pose_des):
    fks = panda_model.fk_fn(q)
    ee_pose_curr = SE3(fks[-1])
    ee_err_vec = (ee_pose_curr.inverse() @ ee_pose_des).log()
    ee_err = safe_2norm(ee_err_vec)
    return ee_err
def elbow_err(q, p_elbow_des):
    fks = panda_model.fk_fn(q)
    p_elbow_curr = fks[4][-3:]
    elbow_err = safe_2norm(p_elbow_curr - p_elbow_des)
    return elbow_err
err_grad_fn = jax.grad(ee_error)
elbowerr_grad_fn = jax.grad(elbow_err)

In [310]:
def get_jacobian(q):
    fks = panda_model.fk_fn(q)
    pos_jac = []
    rot_jac = []
    p_ee = fks[-1][-3:]
    for i in range(1, 8):
        p_frame = fks[i][-3:]
        rot_axis = SE3(fks[i]).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        pos_jac.append(lin_vel)
        rot_jac.append(rot_axis)
        
    pos_jac = jnp.vstack(pos_jac).T
    rot_jac = jnp.vstack(rot_jac).T
    return jnp.vstack([pos_jac, rot_jac])

In [156]:
p_elbow = elbows[1]
elbow.set_translate(p_elbow)

In [158]:
q = panda_model.neutral
panda.set_joint_angles(q)

In [401]:
ee_grad = err_grad_fn(q, pose_rand)
elbow_grad = elbowerr_grad_fn(q, p_elbow)

jac = get_jacobian(q)
ns_proj = jnp.eye(7) - jac.T@jac
q_delta = ee_grad + jnp.hstack([ns_proj@elbow_grad[:-2],0,0])
q = q - q_delta*0.2
panda.set_joint_angles(q)

In [183]:
q_grad

Array([ 0.07283606,  0.00364066,  0.0863461 , -0.08738822, -0.08824547,
        0.10718858,  0.20078273,  0.        ,  0.        ], dtype=float32)

In [65]:
far_points = farthest_point_sampling(points, 5)

In [99]:
p_elbow = far_points[0]

In [76]:
panda = Robot(world.vis, "panda", panda_model)
panda.reduce_dim([7,8], [0.04, 0.04])
frame_elbow = Frame(world.vis, "elbow_frame", 0.2)
elbow = Sphere(world.vis, "elbow", 0.1, "red", 0.5)

In [101]:
frame_1 = Frame(world.vis, "frame1", 0.2)

In [81]:
panda.set_joint_angles(panda_model.neutral[:7])

In [107]:
fks = panda_model.fk_fn(panda_model.neutral)
pose2 = SE3(fks[2])
pose5 = SE3(fks[5])
frame_1.set_pose(pose2)
frame_elbow.set_pose(pose5)

In [120]:
diff = (pose5.translation() - pose2.translation())
y_pose2 = pose2.as_matrix()[:3,1]

In [144]:
p_elbow = (diff @ y_pose2) * y_pose2 + pose2.translation()
ee_pose = SE3(fks[-1])
alpha_len = jnp.linalg.norm(p_elbow - fks[-6][-3:])
beta_len = jnp.linalg.norm(ee_pose.translation() - fks[-6][-3:])

In [132]:
frame.set_pose(ee_pose)

In [187]:
# given p_elbow, ee_pose, can we solve ik?
v = ee_pose.translation() - p_elbow
c = jnp.linalg.norm(v)
r = jnp.sqrt((beta_len**2 + c**2 - alpha_len**2)/(2*c**2))

In [190]:
# z_ee = ee_pose.as_matrix()[:3, 2]
# up_vec = jnp.cross(v, jnp.cross(v, z_ee))
# up_vec = up_vec/jnp.linalg.norm(up_vec)
# up_mag = jnp.sqrt(alpha_len**2 - (c*r)**2)
p_wrist = p_elbow + r*v #+up_vec*up_mag

In [191]:
wrist.set_translate(p_wrist)

In [158]:
v

Array([ 6.1263674e-01,  1.1562403e-08, -1.7700875e-01], dtype=float32)

In [139]:
wrist.set_translate(fks[-6][-3:])

In [129]:
elbow.set_translate(p_elbow)

In [112]:
diff

Array([4.6650025e-01, 5.9267933e-09, 3.9849854e-01], dtype=float32)

In [80]:
panda.set_joint_angles(jnp.array([0,-0.5,0,0,0,0,0]))

In [70]:
elbow.set_translate(far_points[4])