In [1]:
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3

from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

from network import *
from loss import *

In [3]:
world = SDFWorld()
world.show_in_jupyter()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7003/static/


In [None]:
panda_model = RobotModel(PANDA_URDF, PANDA_PACKAGE)

In [4]:
panda = Robot(world.vis, "panda", panda_model, alpha=0.5)
frame = Frame(world.vis, "frame", 0.1)
elbow = Sphere(world.vis, "elbow", 0.1, "red", alpha=0.5)

In [5]:
fks = panda_model.fk_fn(panda_model.neutral)
p_shoulder = fks[1][-3:]
p_elbow = fks[4][-3:]
upper_arm_len = jnp.linalg.norm(p_elbow-p_shoulder)
sphere_points = fibonacci_sphere(1000)
shoulder_sphere_points = sphere_points * upper_arm_len + p_shoulder
ws_lb, ws_ub = [-1, -1, -0.5], [1, 1, 1.5]

In [6]:
#generator
def generate_random_joints(num_samples):
    return jnp.array(np.random.uniform(
        panda_model.lb[:7], panda_model.ub[:7], size=(num_samples,7)))
def generate_successful_sample(q):
    gripper_joints = jnp.full((2), 0.04)
    qnew = jnp.hstack([q, gripper_joints])
    fks = panda_model.fk_fn(qnew)
    ee_pose = fks[-1]
    p_elbow = fks[4][-3:]
    #rtp = to_spherical_coord(p_elbow - p_shoulder)
    return jnp.hstack([ee_pose, p_elbow])

In [None]:
from flax import linen as nn
class Encoder(nn.Module):
    hidden_dim: int
    latent_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        mean = nn.Dense(features=self.latent_dim)(x)
        stddev = nn.Dense(features=self.latent_dim)(x)
        return mean, stddev
class Decoder(nn.Module):
    hidden_dim: int
    out_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.out_dim)(x)
        return x
    
class CVAE(nn.Module):
    cond_dim: int
    hidden_dim: int
    latent_dim: int
    out_dim: int
    def setup(self):
        self.encoder = Encoder(hidden_dim=self.hidden_dim, latent_dim=self.latent_dim)
        self.decoder = Decoder(hidden_dim=self.hidden_dim, out_dim=self.out_dim)
    
    def __call__(self, x): # x: cond
        x = x.reshape(-1, x.shape[-1])
        cond = x[:,:self.cond_dim]
        mean, stddev = self.encoder(x)
        z = mean + stddev * np.random.normal(size=mean.shape)
        x_hat = self.decoder(jnp.hstack([cond, z]))
        return x_hat, mean, stddev

In [146]:
cvae = CVAE(7, 32, 3, 3)
params = cvae.init(key2, jnp.zeros((10,)))
tx = optax.adam(learning_rate=0.001)
state = TrainState.create(
    apply_fn=cvae.apply,
    params=params,
    tx=tx)

In [147]:
def mse_loss(x1, x2):
    return jnp.mean(jnp.square(x1 - x2), axis=-1)
def kl_gaussian(mean, var):
    return 0.5 * jnp.sum(-jnp.log(var) - 1.0 + var + jnp.square(mean), axis=-1)

def loss_fn(params, batch):
    output = cvae.apply(params, batch)
    x = batch[:, -3:]
    x_hat, mean, stddev = output
    log_likelihood = -mse_loss(x, x_hat)
    kl = kl_gaussian(mean, jnp.square(stddev))
    elbo = log_likelihood - kl
    return -jnp.mean(elbo)

@jax.jit
def update(state:TrainState, batch):
    losses, grads = jax.value_and_grad(loss_fn)(params, batch)
    state = state.apply_gradients(grads=grads)
    return state, losses

In [148]:
num_batch = 128
epochs = 1000
for i in range(epochs):
    qs = generate_random_joints(num_batch)
    xs = jax.vmap(generate_successful_sample)(qs)
    state, loss = update(state, xs)
    if i % 100 == 0:
        print(f"epoch{i}  loss: {loss.item()}")

epoch0  loss: 4.974002838134766
epoch100  loss: 5.00502872467041
epoch200  loss: 4.812084197998047
epoch300  loss: 4.6000471115112305
epoch400  loss: 4.47087287902832
epoch500  loss: 4.6234822273254395
epoch600  loss: 4.427453994750977
epoch700  loss: 4.841214179992676
epoch800  loss: 4.605772018432617
epoch900  loss: 4.870190620422363


In [44]:
from jax import random
hp_enc = Hyperparam()
hp_enc.dims = [10, 32, 32, 32, 4]
hp_enc.lr = 0.001
hp_dec = Hyperparam()
hp_dec.dims = [9, 32, 32, 32, 3]
hp_dec.lr = 0.001
hp_dec.batch_size = 128

model_names = ["enc", "dec"]
hps = [hp_enc, hp_dec]
model_vae = {}
for i, name in enumerate(model_names):
    model = {}
    hp = hps[i]
    key1, key2 = random.split(random.PRNGKey(0))
    x = random.normal(key1, (hp.dims[0],))
    model["network"] = get_mlp(hps[i])
    model["params"] = model["network"].init(key2, x)
    model["tx"] = optax.adam(learning_rate=hp.lr)
    model["state"] = TrainState.create(
        apply_fn=model["network"].apply,
        params=model["params"],
        tx=model["tx"])
    model_vae[name] = model

In [45]:
qs = generate_random_joints(100)
xs = jax.vmap(generate_successful_sample)(qs)

In [46]:
conds = xs[:,:7]
enc_params = model_vae["enc"]["params"]
dec_params = model_vae["dec"]["params"]
enc_out = model_vae["enc"]["state"].apply_fn(enc_params, xs)
means = enc_out[:,:2]
stddevs = jnp.exp(enc_out[:,2:])
z = means + stddevs * np.random.normal(size=means.shape)
dec_out = model_vae["dec"]["state"].apply_fn(dec_params, jnp.hstack([conds, z]))

In [59]:
def mse_loss(x1, x2):
    return jnp.mean(jnp.square(x1 - x2), axis=-1)
def kl_gaussian(mean, var):
    return 0.5 * jnp.sum(-jnp.log(var) - 1.0 + var + jnp.square(mean), axis=-1)

def loss():
    log_likelihood = -mse_loss(xs[:,-3:], dec_out)
    kl = kl_gaussian(means, jnp.square(stddevs))
    elbo = log_likelihood - kl
    return -jnp.mean(elbo)

In [57]:
kl

Array([0.00207544, 0.00192553, 0.03571136, 0.0057008 , 0.00998244,
       0.00864025, 0.09721689, 0.00386007, 0.06744651, 0.01449117,
       0.00438253, 0.01051119, 0.02535182, 0.00685101, 0.03107384,
       0.05049649, 0.07380195, 0.00869743, 0.00584108, 0.00276431,
       0.00452977, 0.18243921, 0.0721527 , 0.00126294, 0.00306005,
       0.00853154, 0.0677243 , 0.00424208, 0.07571533, 0.03864774,
       0.00751176, 0.00186073, 0.05415118, 0.00357051, 0.12922677,
       0.00134353, 0.01685642, 0.09726626, 0.00486335, 0.004632  ,
       0.01597006, 0.01502506, 0.03697343, 0.00996438, 0.00367836,
       0.02304477, 0.01161906, 0.01869755, 0.01043234, 0.0188982 ,
       0.05332375, 0.01749801, 0.05166814, 0.00283588, 0.00662553,
       0.01371285, 0.1004499 , 0.00417402, 0.00635153, 0.00140335,
       0.02027686, 0.00314885, 0.00409268, 0.01013746, 0.08053275,
       0.00172006, 0.00717903, 0.01305985, 0.0303546 , 0.00841585,
       0.00737336, 0.02665936, 0.00212327, 0.04401578, 0.00537

In [195]:
q = generate_random_joints(1)[0]
x_sample = generate_successful_sample(q)
ee_pose = x_sample[:7]
frame.set_pose(SE3(ee_pose))
panda.set_joint_angles(jnp.hstack([q, 0.04, 0.04]))

In [196]:
cond = pd.DataFrame(
    np.tile(ee_pose, 100).reshape(-1,7),
    columns=cond_columns)
samples = flow.sample(1, conditions=cond)

In [197]:
indices = jnp.exp(flow.log_prob(samples)) > 0.5
tp_elbows = samples.loc[:,data_columns].to_numpy()[indices]
#tp_elbows = samples.loc[:,data_columns].to_numpy()
rtp_elbows = jnp.hstack([jnp.ones((len(tp_elbows),1))*upper_arm_len, tp_elbows])
p_elbows = jax.vmap(to_cartesian_coord)(rtp_elbows) + p_shoulder
pc = PointCloud(world.vis, "pc", p_elbows, color="blue")

In [194]:
del pc

In [52]:
frame = Frame(world.vis, "ee")

In [13]:
q = generate_random_joints(1)[0]
x_succ = generate_successful_sample(q)

Array([ 0.3105307 ,  0.942123  ,  0.01204117,  0.12581576,  0.49856573,
       -0.12538093,  0.8066797 ,  0.33753774,  0.13996994], dtype=float32)