In [115]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import orbax
import optax
import pandas as pd
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *

from scipy.spatial import cKDTree
import pickle

In [6]:
df_succ = pd.read_csv('./data/003_cracker_box/succ.csv')
df_fail = pd.read_csv('./data/003_cracker_box/fail.csv')
with open("./sdf_world/assets/object"+'/info.pkl', 'rb') as f:
    obj_data = pickle.load(f)
scale_to_norm = obj_data["scale_to_norm"]

In [7]:
columns = ["x", "y", "z", "d", "a1", "a2", "a3", "g1", "g2", "g3", "depth", "width"]

In [41]:
succ_points = df_succ[["x", "y", "z"]].to_numpy()
rand_points = np.random.uniform(-1, 1, [50000,3])
noise = np.random.normal(scale=0.1, size=[len(succ_points),3])
df_rand = pd.DataFrame(rand_points, columns=["x", "y", "z"])
df_perturb = pd.DataFrame(succ_points+noise, columns=["x", "y", "z"])

In [42]:
df_succ["succ"] = 1
df_succ["d"] = 0.
df_rand["succ"] = 0
df_fail["succ"] = 0
df_perturb["succ"] = 0

In [62]:
succ_point_tree = cKDTree(succ_points)
df = pd.concat([df_succ, df_rand, df_fail, df_perturb])
df.reset_index(inplace=True, drop=True)

for i, row in df.iterrows():
    if row["succ"] == 1: continue
    point = row[["x", "y", "z"]].to_numpy()
    d, idx = succ_point_tree.query(point, k=1, p=2)
    df.loc[i, "d"] = d

In [66]:
def rot6d_to_qtn(rot6d):
    z, y = rot6d[:3], rot6d[3:6]
    x = jnp.cross(y, z)
    R = jnp.vstack([x, y, z]).T
    return SO3.from_matrix(R).parameters()
rot6d_to_qtn_batch = jax.vmap(rot6d_to_qtn)

In [73]:
rot6ds = jnp.array(df[["a1", "a2", "a3", "g1", "g2", "g3"]].to_numpy())
qtns = rot6d_to_qtn_batch(rot6ds)
is_succ = df['succ'].to_numpy()[:,None]
d = df["d"].to_numpy()[:,None]
inputs = df[["x", "y", "z"]].to_numpy()
labels = np.asarray(jnp.hstack([is_succ, d, qtns]))

In [74]:
from sdf_world.dataset import NumpyDataset, numpy_collate
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

batch_size = 128
grasp_dataset = NumpyDataset(inputs, labels)
train_dataset, val_dataset = train_test_split(grasp_dataset, train_size=0.9, shuffle=True)
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, collate_fn=numpy_collate)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, collate_fn=numpy_collate)

In [81]:
from flax import linen as nn

class GraspNet(nn.Module):
    hidden_dim: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        logit = nn.Dense(features=6)(x)
        return logit

In [82]:
def rot_loss_qtn(qtn_pred, qtn_true):
    normalization = lambda x: x/safe_2norm(x)
    qtn_loss = lambda x, y: 1 - jnp.abs(x@y)
    flip_z180 = lambda qtn: (SO3(qtn) @ SO3.from_z_radians(jnp.pi)).parameters()

    qtn_pred_norm = jax.vmap(normalization)(qtn_pred)
    qtn_true_flip = jax.vmap(flip_z180)(qtn_true)
    loss1 = jax.vmap(qtn_loss)(qtn_pred_norm, qtn_true)
    loss2 = jax.vmap(qtn_loss)(qtn_pred_norm, qtn_true_flip)
    return jnp.minimum(loss1, loss2)

def grasp_config_loss_fn(pred, label):
    _, d_pred, qtn_pred = pred[:,0], pred[:,1], pred[:,2:]
    succ, d_true, qtn_true = label[:,0], label[:,1], label[:,2:]

    num_succ_samples = succ.sum()
    no_succ = num_succ_samples==0.
    den = jnp.where(no_succ, 1., num_succ_samples)
    losses_rot = jnp.where(succ, rot_loss_qtn(qtn_pred, jnp.nan_to_num(qtn_true)), 0.).sum()
    loss_rot = jnp.where(no_succ, 0., losses_rot/den)
    return loss_rot

def grasp_loss_qtn_fn(state:TrainState, params, batch):
    x, y = batch
    pred = state.apply_fn(params, x).squeeze()
    p_pred, d_pred = pred[:,0], pred[:,1]
    p_true, d_true = y[:,0], y[:,1]
    loss_p = optax.sigmoid_binary_cross_entropy(p_pred, p_true).mean()
    loss_d = optax.l2_loss(d_pred, d_true).mean()
    loss_grasp = grasp_config_loss_fn(pred, y)
    return loss_p + loss_d + loss_grasp

@jax.jit
def train_step(state:TrainState, batch):
    losses, grads = jax.value_and_grad(grasp_loss_qtn_fn, argnums=1)(state, state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, losses

In [83]:
grasp_net = GraspNet(32)
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key1, (3,)) # Dummy input data
params = grasp_net.init(key2, x) # Initialization call
tx = optax.adam(learning_rate=0.001)
state = TrainState.create(apply_fn=grasp_net.apply, params=params, tx=tx)

In [84]:
for epoch in range(100):
    for i, batch in enumerate(train_loader):
        state, losses = train_step(state, batch)    
    print(f"{epoch} : losses{losses.item()}")

0 : losses0.16402554512023926
1 : losses0.20534221827983856
2 : losses0.10707911103963852
3 : losses0.07724257558584213
4 : losses0.1458583027124405
5 : losses0.046451959758996964
6 : losses0.04527512565255165
7 : losses0.06805877387523651
8 : losses0.08287391811609268
9 : losses0.13976643979549408
10 : losses0.01128773856908083
11 : losses0.22251591086387634
12 : losses0.07920899987220764
13 : losses0.10807240754365921
14 : losses0.03861228749155998
15 : losses0.09120075404644012
16 : losses0.06354366987943649
17 : losses0.035901132971048355
18 : losses0.02986280992627144
19 : losses0.019774653017520905
20 : losses0.07369060814380646
21 : losses0.030681263655424118
22 : losses0.06776847690343857
23 : losses0.007744227070361376
24 : losses0.026766926050186157
25 : losses0.09204673022031784
26 : losses0.0461374893784523
27 : losses0.0549372136592865
28 : losses0.03419981151819229
29 : losses0.02735881507396698
30 : losses0.06729982793331146
31 : losses0.0632430911064148
32 : losses0.019

In [118]:
#save
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
ckpt = {
    "params": state.params,
    "hidden_dim": 10,
    "scale_to_norm": scale_to_norm
}
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('model/grasp_net_prob_dist', ckpt, save_args=save_args)

In [85]:
import trimesh
import open3d as o3d
def to_pointcloud(points:np.ndarray):
    points = o3d.utility.Vector3dVector(points)
    return o3d.geometry.PointCloud(points)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [86]:
grasp_fn = lambda x: grasp_net.apply(state.params, x)

In [87]:
mesh_path = "sdf_world/assets/object/norm_mesh.obj"
mesh = trimesh.load_mesh(mesh_path)

In [104]:
# eval_points = mesh.as_open3d.sample_points_uniformly(10000).point #surface points
eval_points = np.random.uniform(-1, 1, size=(100000, 3))

In [107]:
preds = grasp_fn(eval_points)

In [108]:
qual, d, qtns = preds[:,0], preds[:,1], preds[:,2:]

In [112]:
v = d # - jnp.exp(qual) # + d
thres = 0.05
bool_idx = v < thres   # jnp.sign(qual) > 0 #d < 0.03 # jnp.sign(qual) == 1
fail_bool_idx = v > thres #jnp.sign(qual) == -1

In [91]:
def get_yaxis_from_qtn(qtn):
    return SO3(qtn).normalize().as_matrix()[:,1]
def get_zaxis_from_qtn(qtn):
    return SO3(qtn).normalize().as_matrix()[:,2]

In [110]:
succ_points = jnp.array(eval_points)[bool_idx]
succ_pc = to_pointcloud(succ_points)

In [113]:
o3d.visualization.draw_geometries([succ_pc])