In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jaxlie import SE3, SO3
import orbax
import optax
import pandas as pd

from flax.training import orbax_utils
from flax.training.train_state import TrainState
from sdf_world.sdf_world import *
from sdf_world.robots import *
from sdf_world.util import *
from sdf_world.network import *

from scipy.spatial import cKDTree
import pickle

In [2]:
obj_folder_path = "./sdf_world/assets/waffle_box/"
df_succ = pd.read_csv(obj_folder_path+'succ.csv')
df_fail = pd.read_csv(obj_folder_path+'fail.csv')
with open(obj_folder_path+'info.pkl', 'rb') as f:
    obj_data = pickle.load(f)
scale_to_norm = obj_data["scale_to_norm"]

#if non scaled, 
xyz_label = ["x", "y", "z"]
df_succ.loc[:, xyz_label] *= scale_to_norm
df_fail.loc[:, xyz_label] *= scale_to_norm

In [3]:
columns = ["x", "y", "z", "d", "a1", "a2", "a3", "g1", "g2", "g3", "depth", "width"]

In [4]:
succ_points = df_succ[["x", "y", "z"]].to_numpy()
rand_points = np.random.uniform(-1, 1, [50000,3])
noise = np.random.normal(scale=0.1, size=[len(succ_points),3])
df_rand = pd.DataFrame(rand_points, columns=["x", "y", "z"])
df_perturb = pd.DataFrame(succ_points+noise, columns=["x", "y", "z"])

In [5]:
df_succ["succ"] = 1
df_succ["d"] = 0.
df_rand["succ"] = 0
df_fail["succ"] = 0
df_perturb["succ"] = 0

df = pd.concat([df_succ, df_rand, df_fail, df_perturb])
df.reset_index(inplace=True, drop=True)

In [8]:
# succ_point_tree = cKDTree(succ_points)
# df = pd.concat([df_succ, df_rand, df_fail, df_perturb])
# df.reset_index(inplace=True, drop=True)

# calculate_distance
# for i, row in df.iterrows():
#     if row["succ"] == 1: continue
#     point = row[["x", "y", "z"]].to_numpy()
#     d, idx = succ_point_tree.query(point, k=1, p=2)
#     df.loc[i, "d"] = d

In [6]:
def rot6d_to_qtn(rot6d):
    z, y = rot6d[:3], rot6d[3:6]
    x = jnp.cross(y, z)
    R = jnp.vstack([x, y, z]).T
    return SO3.from_matrix(R).parameters()
rot6d_to_qtn_batch = jax.vmap(rot6d_to_qtn)

In [7]:
rot6ds = jnp.array(df[["a1", "a2", "a3", "g1", "g2", "g3"]].to_numpy())
qtns = rot6d_to_qtn_batch(rot6ds)
dw = df[["depth", "width"]].to_numpy()
is_succ = df['succ'].to_numpy()[:,None]
# d = df["d"].to_numpy()[:,None]
inputs = df[["x", "y", "z"]].to_numpy()
labels = np.asarray(jnp.hstack([is_succ, qtns, dw]))

In [8]:
import open3d as o3d
def to_pointcloud(xyz):
    return o3d.geometry.PointCloud(o3d.utility.Vector3dVector(xyz))

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [58]:
xyz = df_fail[["x", "y", "z"]].to_numpy()
frame = o3d.geometry.TriangleMesh.create_coordinate_frame(0.1)
o3d.visualization.draw_geometries([to_pointcloud(xyz), frame])

In [10]:
def rot_loss_qtn(qtn_pred, qtn_true):
    normalization = lambda x: x/safe_2norm(x)
    qtn_loss = lambda x, y: 1 - jnp.abs(x@y)
    flip_z180 = lambda qtn: (SO3(qtn) @ SO3.from_z_radians(jnp.pi)).parameters()

    qtn_pred_norm = jax.vmap(normalization)(qtn_pred)
    qtn_true_flip = jax.vmap(flip_z180)(qtn_true)
    loss1 = jax.vmap(qtn_loss)(qtn_pred_norm, qtn_true)
    loss2 = jax.vmap(qtn_loss)(qtn_pred_norm, qtn_true_flip)
    return jnp.minimum(loss1, loss2)

def grasp_config_loss_fn(pred, label):
    # _, d_pred, qtn_pred = pred[:,0], pred[:,1], pred[:,2:]
    # succ, d_true, qtn_true = label[:,0], label[:,1], label[:,2:]
    _, qtn_pred = pred[:,0], pred[:,1:5]
    succ, qtn_true = label[:,0], label[:,1:5]
    dw_pred, dw_true = pred[:,5:], label[:, 5:]
    
    num_succ_samples = succ.sum()
    no_succ = num_succ_samples == 0.
    den = jnp.where(no_succ, 1., num_succ_samples)
    
    losses_rot = jnp.where(succ, rot_loss_qtn(qtn_pred, jnp.nan_to_num(qtn_true)), 0.).sum()
    loss_rot = jnp.where(no_succ, 0., losses_rot/den)
    
    losses_dw = jnp.where(succ, optax.l2_loss(dw_pred, jnp.nan_to_num(dw_true)).sum(axis=-1), 0.).sum()
    loss_dw = jnp.where(no_succ, 0., losses_dw/den)
    return loss_rot + loss_dw

def grasp_loss_qtn_fn(state:TrainState, params, batch):
    x, y = batch
    pred = state.apply_fn(params, x).squeeze()
    p_pred = pred[:,0]
    p_true = y[:,0]
    loss_p = optax.sigmoid_binary_cross_entropy(p_pred, p_true).mean()
    # loss_d = optax.l2_loss(d_pred, d_true).mean()
    loss_grasp = grasp_config_loss_fn(pred, y)
    return loss_p + loss_grasp

@jax.jit
def train_step(state:TrainState, batch):
    losses, grads = jax.value_and_grad(grasp_loss_qtn_fn, argnums=1)(state, state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, losses

In [11]:
hidden_dim = 32
grasp_net = GraspNet(hidden_dim, 7)
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
x = jax.random.normal(key1, (3,)) # Dummy input data
params = grasp_net.init(key2, x) # Initialization call
tx = optax.adam(learning_rate=0.001)
state = TrainState.create(apply_fn=grasp_net.apply, params=params, tx=tx)

In [9]:
from sdf_world.dataset import NumpyDataset, numpy_collate
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

batch_size = 256
grasp_dataset = NumpyDataset(inputs, labels)
train_dataset, val_dataset = train_test_split(grasp_dataset, train_size=0.9, shuffle=True)
train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, collate_fn=numpy_collate)
val_loader = DataLoader(
    val_dataset, batch_size=batch_size, collate_fn=numpy_collate)

In [20]:
for epoch in range(100):
    for i, batch in enumerate(train_loader):
        state, losses = train_step(state, batch)    
    print(f"{epoch} : losses{losses.item()}")

0 : losses0.035095714032649994
1 : losses0.04895775392651558
2 : losses0.04885614663362503
3 : losses0.028917066752910614
4 : losses0.047990307211875916
5 : losses0.013853542506694794
6 : losses0.03493417054414749
7 : losses0.012565002776682377
8 : losses0.012770280241966248
9 : losses0.0390767864882946
10 : losses0.025458360090851784
11 : losses0.0360000841319561
12 : losses0.03757218271493912
13 : losses0.02851737290620804
14 : losses0.011066142469644547
15 : losses0.04934161901473999
16 : losses0.02899705059826374
17 : losses0.02733563631772995
18 : losses0.02459053508937359
19 : losses0.032926153391599655
20 : losses0.01940654218196869
21 : losses0.04499977082014084
22 : losses0.03145228326320648
23 : losses0.01725192554295063
24 : losses0.026518424972891808
25 : losses0.017685260623693466
26 : losses0.013019411824643612
27 : losses0.06600137799978256
28 : losses0.038713276386260986
29 : losses0.02998065948486328
30 : losses0.0476401187479496
31 : losses0.03572981804609299
32 : los

In [27]:
#save
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
ckpt = {
    "params": state.params,
    "hidden_dim": hidden_dim,
    "out_dim":7,
    "scale_to_norm": scale_to_norm,
    "distance?":False,
}
save_args = orbax_utils.save_args_from_target(ckpt)
orbax_checkpointer.save('model/grasp_net_waffle', ckpt, save_args=save_args)

In [21]:
import trimesh
import open3d as o3d
def to_pointcloud(points:np.ndarray):
    points = o3d.utility.Vector3dVector(points)
    return o3d.geometry.PointCloud(points)

In [22]:
grasp_fn = lambda x: grasp_net.apply(state.params, x)

In [23]:
mesh_path = "./sdf_world/assets/waffle_box/waffle_box_scale12.03.obj"
mesh = trimesh.load_mesh(mesh_path)
eval_points = mesh.as_open3d.sample_points_uniformly(10000).points #surface points
# eval_points = np.random.uniform(-1, 1, size=(100000, 3))
preds = grasp_fn(eval_points)
qual, qtns = preds[:,0], preds[:,1:5]
bool_idx = qual > 1 #jnp.sign(qual) == 1

In [24]:
def get_approach_dir(qtn):
    return SO3(qtn).normalize().as_matrix()[:,-1]

In [25]:
zs = jax.vmap(get_approach_dir)(qtns)
zs = np.array(zs).astype(float)[bool_idx]

succ_points = jnp.array(eval_points)[bool_idx]
succ_pc = to_pointcloud(succ_points)
succ_pc.normals = o3d.utility.Vector3dVector(zs)

In [26]:
o3d.visualization.draw_geometries([succ_pc])

In [133]:
# o3d.visualization.draw_geometries([
#     #to_pointcloud(df_succ.loc[:, xyz_label].to_numpy()),
#     to_pointcloud(df_fail.loc[:, xyz_label].to_numpy())
# ])
