In [23]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import orbax
import optax
import pandas as pd
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

In [2]:
df_succ = pd.read_csv('./data/003_cracker_box/succ.csv')
df_fail = pd.read_csv('./data/003_cracker_box/fail.csv')

In [3]:
columns = ["x", "y", "z", "d", "a1", "a2", "a3", "g1", "g2", "g3", "depth", "width"]

In [4]:
rand_points = np.random.uniform(-1, 1, [50000,3])
df_rand = pd.DataFrame(rand_points, columns=["x", "y", "z"])

In [5]:
df_succ["succ"] = 1
df_rand["succ"] = 0
df_fail["succ"] = 0

In [6]:
df = pd.concat([df_succ, df_rand, df_fail])

In [7]:
def rot6d_to_qtn(rot6d):
    z, y = rot6d[:3], rot6d[3:6]
    x = jnp.cross(y, z)
    R = jnp.vstack([x, y, z]).T
    return SO3.from_matrix(R).parameters()
rot6d_to_qtn_batch = jax.vmap(rot6d_to_qtn)

In [18]:
rot6ds = jnp.array(df[["a1", "a2", "a3", "g1", "g2", "g3"]].to_numpy())
qtns = rot6d_to_qtn_batch(rot6ds)
is_succ = df['succ'].to_numpy()[:,None]
inputs = df[["x", "y", "z"]].to_numpy()
labels = np.asarray(jnp.hstack([is_succ, qtns]))

In [20]:
from sdf_world.dataset import NumpyDataset, numpy_collate
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

batch_size = 128
grasp_dataset = NumpyDataset(inputs, labels)
train_dataset, val_dataset = train_test_split(grasp_dataset, train_size=0.9, shuffle=True)
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, collate_fn=numpy_collate)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, collate_fn=numpy_collate)

In [27]:
from flax import linen as nn

class GraspNet(nn.Module):
    hidden_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        logit = nn.Dense(features=5)(x)
        return logit

In [263]:
def rot_loss_qtn(qtn_pred, qtn_true):
    normalization = lambda x: x/safe_2norm(x)
    qtn_loss = lambda x, y: 1 - jnp.abs(x@y)
    flip_z180 = lambda qtn: (SO3(qtn) @ SO3.from_z_radians(jnp.pi)).parameters()

    qtn_pred_norm = jax.vmap(normalization)(qtn_pred)
    qtn_true_flip = jax.vmap(flip_z180)(qtn_true)
    loss1 = jax.vmap(qtn_loss)(qtn_pred_norm, qtn_true)
    loss2 = jax.vmap(qtn_loss)(qtn_pred_norm, qtn_true_flip)
    return jnp.minimum(loss1, loss2)

def grasp_config_loss_fn(pred, label):
    _, qtn_pred = pred[:,0], pred[:,1:5]
    succ, qtn_true = label[:,0], label[:,1:5]
    num_succ_samples = succ.sum()
    no_succ = num_succ_samples==0.
    den = jnp.where(no_succ, 1., num_succ_samples)
    losses_rot = jnp.where(succ, rot_loss_qtn(qtn_pred, jnp.nan_to_num(qtn_true)), 0.).sum()
    loss_rot = jnp.where(no_succ, 0., losses_rot/den)
    return loss_rot
def grasp_loss_qtn_fn(state:TrainState, params, batch):
    x, y = batch
    pred = state.apply_fn(params, x).squeeze()
    p_pred = pred[:,0]
    p_true = y[:,0]
    loss_p = optax.sigmoid_binary_cross_entropy(p_pred, p_true).mean()
    loss_grasp = grasp_config_loss_fn(pred, y)
    return loss_p + loss_grasp

@jax.jit
def train_step(state:TrainState, batch):
    losses, grads = jax.value_and_grad(grasp_loss_qtn_fn, argnums=1)(state, state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, losses

In [255]:
grasp_net = GraspNet(10)
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key1, (3,)) # Dummy input data
params = grasp_net.init(key2, x) # Initialization call
tx = optax.adam(learning_rate=0.001)
state = TrainState.create(apply_fn=grasp_net.apply, params=params, tx=tx)

In [285]:
for epoch in range(100):
    for i, batch in enumerate(train_loader):
        state, losses = train_step(state, batch)    
    print(f"{epoch} : losses{losses.item()}")

0 : losses0.03429203853011131
1 : losses0.015390394255518913
2 : losses0.08615678548812866
3 : losses0.016262270510196686
4 : losses0.022181129083037376
5 : losses0.03709304332733154
6 : losses0.026119589805603027
7 : losses0.040343768894672394
8 : losses0.047262948006391525
9 : losses0.03480619564652443
10 : losses0.0408630408346653
11 : losses0.009641844779253006
12 : losses0.016654474660754204
13 : losses0.055467333644628525
14 : losses0.05414344370365143
15 : losses0.04399104416370392
16 : losses0.04738979414105415
17 : losses0.05603944510221481
18 : losses0.04310968145728111
19 : losses0.07680554687976837
20 : losses0.03505518287420273
21 : losses0.024633269757032394
22 : losses0.035928063094615936
23 : losses0.04586592689156532


KeyboardInterrupt: 

In [288]:
#save
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
ckpt = {
    "params": state.params,
    "hidden_dim": 10
}
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('model/grasp_net', ckpt, save_args=save_args, )

In [266]:
import trimesh
import open3d as o3d
def to_pointcloud(points:np.ndarray):
    points = o3d.utility.Vector3dVector(points)
    return o3d.geometry.PointCloud(points)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [277]:
grasp_fn = lambda x: grasp_net.apply(state.params, x)

In [278]:
mesh_path = "sdf_world/assets/object/norm_mesh.obj"
mesh = trimesh.load_mesh(mesh_path)
surface_points = mesh.as_open3d.sample_points_uniformly(10000)

In [279]:
preds = grasp_fn(surface_points.points)

In [280]:
qual, qtns = preds[:,0], preds[:,1:5]

In [281]:
v = qual #- jnp.exp(qual) # + d
thres = 0.5
bool_idx = v > thres  # jnp.sign(qual) > 0 #d < 0.03 # jnp.sign(qual) == 1
fail_bool_idx = v < thres #jnp.sign(qual) == -1

In [282]:
def get_yaxis_from_qtn(qtn):
    return SO3(qtn).normalize().as_matrix()[:,1]
def get_zaxis_from_qtn(qtn):
    return SO3(qtn).normalize().as_matrix()[:,2]

In [283]:
succ_points = jnp.array(surface_points.points)[bool_idx]
succ_pc = to_pointcloud(succ_points)

In [284]:
o3d.visualization.draw_geometries([succ_pc])