In [1]:
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

from network import *
from loss import *

In [2]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


In [3]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
frame = Frame(world.vis, "frame", 0.1)
elbow = Sphere(world.vis, "elbow", 0.1, "red", alpha=0.5)

concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts
concatenating texture: may result in visual artifacts


In [4]:
fks = panda_model.fk_fn(panda_model.neutral)
p_shoulder = fks[1][-3:]
p_elbow = fks[4][-3:]
upper_arm_len = jnp.linalg.norm(p_elbow-p_shoulder)
sphere_points = fibonacci_sphere(1000)
shoulder_sphere_points = sphere_points * upper_arm_len + p_shoulder
ws_lb, ws_ub = [-1, -1, -0.5], [1, 1, 1.5]

In [5]:
from scipy.spatial.transform import Rotation
def get_random_samples(num_samples):
    xyz = np.random.uniform(ws_lb, ws_ub, (num_samples, 3))
    qtns_xyzw = Rotation.random(num_samples).as_quat()
    qtns = qtns_xyzw[:,[3,0,1,2]]
    indices = np.random.randint(0, 1000, size=num_samples)
    p_elbow = shoulder_sphere_points[indices]
    return jnp.hstack([qtns, xyz, p_elbow])

def generate_successful_sample(q):
    gripper_joints = jnp.full((2), 0.04)
    qnew = jnp.hstack([q, gripper_joints])
    fks = panda_model.fk_fn(qnew)
    ee_pose = fks[-1]
    p_elbow = fks[4][-3:]
    return jnp.hstack([ee_pose, p_elbow])

def generate_random_joints(num_samples):
    return jnp.array(np.random.uniform(
        panda_model.lb[:7], panda_model.ub[:7], size=(num_samples,7)))

@jax.jit
def get_batch_samples(qs):
    x_succ = jax.vmap(generate_successful_sample)(qs) #true data
    p_noise = np.random.normal(size=(qs.shape[0],3)) * 0.05
    x_fail1 = get_random_samples(qs.shape[0])
    # x_fail2 = x_succ.at[:,7:].set(x_succ[:,7:] + p_noise)
    # x_fail3 = x_succ.at[:,4:7].set(x_succ[:,4:7] + p_noise)
    x_fail = x_fail1 # jnp.vstack([x_fail1, x_fail2, x_fail3])
    return x_succ, x_fail

In [505]:
qs = generate_random_joints(100)
x_succ, x_fail = get_batch_samples(qs)

In [524]:
i=100
x_sample = x_fail

In [529]:
frame.set_pose(SE3(x_sample[i][:7]))
elbow.set_translate(x_sample[i][7:])
print(x_sample[i])
i+= 1

[-0.23737708 -0.7212028  -0.59818256 -0.25631255  0.30037656 -0.19551739
 -0.10780752 -0.04035843 -0.18666959  0.5045196 ]


In [15]:
from jax import random
hp_disc = Hyperparam()
hp_disc.dims = [10, 32, 32, 32, 32, 1]
hp_disc.lr = 0.001
hp_disc.batch_size = 128

model_disc = get_mlp(hp_disc)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (hp_disc.dims[0],))
params_disc = model_disc.init(key2, x)
tx_disc = optax.adam(learning_rate=hp_disc.lr)
state_disc = TrainState.create(apply_fn=model_disc.apply, params=params_disc, tx=tx_disc)

hp_gen = Hyperparam()
hp_gen.dims = [9, 32, 32, 32, 32, 3]
hp_gen.lr = 0.001
hp_gen.batch_size = 128

model_gen = get_mlp(hp_gen)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (hp_gen.dims[0],))
params_gen = model_gen.init(key2, x)
tx_gen = optax.adam(learning_rate=hp_gen.lr)
state_gen = TrainState.create(apply_fn=model_gen.apply, params=params_gen, tx=tx_gen)

In [8]:
qs = qs = generate_random_joints(100)
x_succ, x_fail = get_batch_samples(qs)

In [9]:
noise = np.random.normal(size=(100, 2))
inputs_fake = jnp.hstack([x_succ[:,:7], noise])
outputs_fake = state_gen.apply_fn(params_gen, inputs_fake)

In [20]:
x_fake = jnp.hstack([x_succ[:, :7], outputs_fake])
logits_fake = state_disc.apply_fn(params_disc, x_fake)

In [6]:
def loss_disc(state_disc, params_disc, 
              x_succ, x_fail, noise, state_gen, params_gen):
    # disc : real data
    logits_succ = state_disc.apply_fn(params_disc, x_succ).flatten()
    loss_succ = optax.sigmoid_binary_cross_entropy(
        logits_succ, jnp.ones_like(logits_succ)).mean()
    logits_fail = state_disc.apply_fn(params_disc, x_fail).flatten()
    loss_fail = optax.sigmoid_binary_cross_entropy(
        logits_fail, jnp.zeros_like(logits_fail)).mean()
    
    inputs_fake = jnp.hstack([x_succ[:,:7], noise])
    outputs_fake = state_gen.apply_fn(params_gen, inputs_fake)
    x_fake = jnp.hstack([x_succ[:, :7], outputs_fake])
    logits_fake = state_disc.apply_fn(params_disc, x_fake).flatten()
    loss_fake = optax.sigmoid_binary_cross_entropy(
        logits_fake, jnp.zeros_like(logits_fake)).mean()
    loss_disc = loss_succ + loss_fail + loss_fake
    return loss_disc, (loss_succ, loss_fail, loss_fake)

def loss_gen(state_gen, params_gen, 
             x_succ, noise, state_disc, params_disc):
    # gen:
    inputs_fake = jnp.hstack([x_succ[:,:7], noise])
    outputs_fake = state_gen.apply_fn(params_gen, inputs_fake)
    x_fake = jnp.hstack([x_succ[:, :7], outputs_fake])
    logits_fake = state_disc.apply_fn(params_disc, x_fake).flatten()
    loss_cheat = optax.sigmoid_binary_cross_entropy(
        logits_fake, jnp.ones_like(logits_fake)).mean()
    return loss_cheat
grad_disc_fn = jax.value_and_grad(loss_disc, argnums=1, has_aux=True)
grad_gen_fn = jax.value_and_grad(loss_gen, argnums=1)

@jax.jit
def training_step(states, x_succ, x_fail, noise):
    state_disc, state_gen = states
    (loss_disc, (loss_succ, loss_fail, loss_fake)), grads = \
        grad_disc_fn(state_disc, state_disc.params, x_succ, x_fail, noise,
                     state_gen, state_gen.params)
    state_disc = state_disc.apply_gradients(grads=grads)
    loss_cheat, grads = grad_gen_fn(state_gen, state_gen.params,
                                    x_succ, noise, state_disc, state_disc.params)
    state_gen = state_gen.apply_gradients(grads=grads)
    return (state_disc, state_gen), (loss_disc, loss_cheat, loss_succ, loss_fail, loss_fake)

In [7]:
states = (state_disc, state_gen)
num_batch = 1000
for epoch in range(50000):
    qs = qs = generate_random_joints(num_batch)
    x_succ, x_fail = get_batch_samples(qs)
    noise = np.random.normal(size=(num_batch, 2))
    states, losses = training_step(states, x_succ, x_fail, noise)
    loss_disc, loss_cheat, loss_succ, loss_fail, loss_fake = losses
    if epoch % 100 == 0:
        print(f"epoch:{epoch}, loss_disc:{loss_disc}, loss_cheat:{loss_cheat}")
        #print(f"loss_succ{loss_succ}, loss_fail:{loss_fail}")

NameError: name 'state_disc' is not defined

In [33]:
trained_param = state_disc.params

In [34]:
qs = np.random.uniform(panda_model.lb[:7], panda_model.ub[:7], size=(128,7))
x_succ, x_fail = get_batch_samples(qs)

In [35]:
bools = nn.sigmoid(model_disc.apply(trained_param, x_succ)) < 0.5
jnp.arange(len(x_succ))[bools.flatten()]

Array([  0,   7,  14,  17,  25,  27,  28,  29,  32,  43,  54,  55,  56,
        71,  77,  81,  85,  87,  90,  93,  94,  96, 100, 103, 106, 111,
       116, 121], dtype=int32)

In [41]:
def feasibility(ee_pose, tp_elbow):
    ee_posevec = ee_pose.parameters()
    return model_disc.apply(trained_param, jnp.hstack([ee_posevec, tp_elbow]))

# problem
qrand = np.random.uniform(panda_model.lb, panda_model.ub)
pose_rand = SE3(panda_model.fk_fn(qrand)[-1])

#vis
frame.set_pose(pose_rand)
panda.set_joint_angles(qrand)

In [42]:
sphere_samples = fibonacci_sphere(1000)
sphere_shoulder = sphere_samples * upper_arm_len + p_shoulder

logits = jax.vmap(feasibility, in_axes=(None, 0))(pose_rand, sphere_shoulder)
indices = np.arange(len(logits))[logits.flatten() > 1.]
elbow_points = np.array(sphere_shoulder[indices], dtype=np.float64)
nn.sigmoid(logits[indices].flatten())

#vis
colors = np.tile(Colors.read("blue", return_rgb=True), len(elbow_points)).reshape(-1, 3)
pc = PointCloud(world.vis, "pc", elbow_points, color="blue")
# world.vis["pc"].set_object(
#     g.PointsGeometry(elbow_points.T, colors.T),
#     g.PointsMaterial(size=0.05)
# )
pc.load()

In [43]:
indices

array([], dtype=int64)

In [155]:
world.vis["pc"].delete()

In [72]:
save("elbow_feas_net_euclid.pth", state, hp, force=True)

In [73]:
feas_fn = get_mlp_by_path("elbow_feas_net_euclid.pth")

In [61]:
from sdf_world.sdf_world import *
world = SDFWorld()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7004/static/


In [62]:
world.show_in_jupyter()

In [63]:
#panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)
panda = Robot(world.vis, "panda", panda_model)

In [64]:
frame = Frame(world.vis, "frame", 0.2)
shoulder = Sphere(world.vis, "shoulder", 0.1, "red", 0.5)
elbow = Sphere(world.vis, "elbow", 0.1, "red", 0.5)
wrist = Sphere(world.vis, "wrist", 0.1, "red", 0.5)

In [65]:
#predefine
fks = panda_model.fk_fn(panda_model.neutral)
ee_pose = SE3(fks[-1])
# wrist_pos = panda_model.fk_fn(panda_model.neutral)[7][-3:]
# wrist_wrt_ee = ee_pose.inverse().apply(wrist_pos)
upper_arm_len = jnp.linalg.norm(fks[3][-3:] - fks[1][-3:])
p_shoulder = fks[1][-3:]
ws_lb = [-0.8, -0.8, -0.3]
ws_ub = [0.8, 0.8, 1.3]

In [153]:
sphere_points = fibonacci_sphere() * upper_arm_len + p_shoulder
# elbow_params = np.random.uniform([0., -np.pi], [np.pi, np.pi], size=(100,2))
elbow_points = jnp.array(sphere_points)
def feasibility_logit_fn(ee_posevec, elbow_param):
    return feas_fn(jnp.hstack([ee_posevec, elbow_param]))
@jax.jit
def ik(ee_pose):
    rot6d = ee_pose.as_matrix()[:3,:2].T.flatten()
    p_ee = ee_pose.translation()
    ee_posevec = jnp.hstack([rot6d, p_ee])
    logits = jax.vmap(feasibility_logit_fn, in_axes=(None, 0))(ee_posevec, elbow_points)
    return logits

In [154]:
cond = (ik(ee_pose) > 1.).flatten()
indices = jnp.arange(len(cond))[cond]
elbows = elbow_points[indices]

In [291]:
def ee_error(q, ee_pose_des):
    fks = panda_model.fk_fn(q)
    ee_pose_curr = SE3(fks[-1])
    ee_err_vec = (ee_pose_curr.inverse() @ ee_pose_des).log()
    ee_err = safe_2norm(ee_err_vec)
    return ee_err
def elbow_err(q, p_elbow_des):
    fks = panda_model.fk_fn(q)
    p_elbow_curr = fks[4][-3:]
    elbow_err = safe_2norm(p_elbow_curr - p_elbow_des)
    return elbow_err
err_grad_fn = jax.grad(ee_error)
elbowerr_grad_fn = jax.grad(elbow_err)

In [310]:
def get_jacobian(q):
    fks = panda_model.fk_fn(q)
    pos_jac = []
    rot_jac = []
    p_ee = fks[-1][-3:]
    for i in range(1, 8):
        p_frame = fks[i][-3:]
        rot_axis = SE3(fks[i]).as_matrix()[:3, 2]
        lin_vel = jnp.cross(rot_axis, p_ee - p_frame)
        pos_jac.append(lin_vel)
        rot_jac.append(rot_axis)
        
    pos_jac = jnp.vstack(pos_jac).T
    rot_jac = jnp.vstack(rot_jac).T
    return jnp.vstack([pos_jac, rot_jac])

In [156]:
p_elbow = elbows[1]
elbow.set_translate(p_elbow)

In [158]:
q = panda_model.neutral
panda.set_joint_angles(q)

In [401]:
ee_grad = err_grad_fn(q, pose_rand)
elbow_grad = elbowerr_grad_fn(q, p_elbow)

jac = get_jacobian(q)
ns_proj = jnp.eye(7) - jac.T@jac
q_delta = ee_grad + jnp.hstack([ns_proj@elbow_grad[:-2],0,0])
q = q - q_delta*0.2
panda.set_joint_angles(q)

In [183]:
q_grad

Array([ 0.07283606,  0.00364066,  0.0863461 , -0.08738822, -0.08824547,
        0.10718858,  0.20078273,  0.        ,  0.        ], dtype=float32)

In [65]:
far_points = farthest_point_sampling(points, 5)

In [99]:
p_elbow = far_points[0]

In [76]:
panda = Robot(world.vis, "panda", panda_model)
panda.reduce_dim([7,8], [0.04, 0.04])
frame_elbow = Frame(world.vis, "elbow_frame", 0.2)
elbow = Sphere(world.vis, "elbow", 0.1, "red", 0.5)

In [101]:
frame_1 = Frame(world.vis, "frame1", 0.2)

In [81]:
panda.set_joint_angles(panda_model.neutral[:7])

In [107]:
fks = panda_model.fk_fn(panda_model.neutral)
pose2 = SE3(fks[2])
pose5 = SE3(fks[5])
frame_1.set_pose(pose2)
frame_elbow.set_pose(pose5)

In [120]:
diff = (pose5.translation() - pose2.translation())
y_pose2 = pose2.as_matrix()[:3,1]

In [144]:
p_elbow = (diff @ y_pose2) * y_pose2 + pose2.translation()
ee_pose = SE3(fks[-1])
alpha_len = jnp.linalg.norm(p_elbow - fks[-6][-3:])
beta_len = jnp.linalg.norm(ee_pose.translation() - fks[-6][-3:])

In [132]:
frame.set_pose(ee_pose)

In [187]:
# given p_elbow, ee_pose, can we solve ik?
v = ee_pose.translation() - p_elbow
c = jnp.linalg.norm(v)
r = jnp.sqrt((beta_len**2 + c**2 - alpha_len**2)/(2*c**2))

In [190]:
# z_ee = ee_pose.as_matrix()[:3, 2]
# up_vec = jnp.cross(v, jnp.cross(v, z_ee))
# up_vec = up_vec/jnp.linalg.norm(up_vec)
# up_mag = jnp.sqrt(alpha_len**2 - (c*r)**2)
p_wrist = p_elbow + r*v #+up_vec*up_mag

In [191]:
wrist.set_translate(p_wrist)

In [158]:
v

Array([ 6.1263674e-01,  1.1562403e-08, -1.7700875e-01], dtype=float32)

In [139]:
wrist.set_translate(fks[-6][-3:])

In [129]:
elbow.set_translate(p_elbow)

In [112]:
diff

Array([4.6650025e-01, 5.9267933e-09, 3.9849854e-01], dtype=float32)

In [80]:
panda.set_joint_angles(jnp.array([0,-0.5,0,0,0,0,0]))

In [70]:
elbow.set_translate(far_points[4])