In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import glob
import pickle
import jax
import jax.numpy as jnp
import bayes3d as b
import genjax
import matplotlib.pyplot as plt

In [3]:
b.setup_visualizer()

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7000/static/


In [5]:
with open('../data.pkl', 'rb') as f:
    data = pickle.load(f)

camera_image_1 = data["init"][0]
categories_on_table = data["init"][1]
target_category = data["init"][2]
table_info = data["init"][3] # table pose and dimensions
n_objects = 5

X_WT = b.t3d.pybullet_pose_to_transform(table_info[0])
X_WC = b.t3d.pybullet_pose_to_transform(camera_image_1["camera_pose"])
X_CT = b.t3d.inverse_pose(X_WC) @ X_WT

def image_to_rgbd(camera_image_1):
    K = camera_image_1['camera_matrix']
    rgb = camera_image_1['rgbPixels']
    depth = camera_image_1['depthPixels']
    camera_pose = camera_image_1['camera_pose']
    camera_pose = b.t3d.pybullet_pose_to_transform(camera_pose)
    fx, fy, cx, cy = K[0,0],K[1,1],K[0,2],K[1,2]
    h,w = depth.shape
    rgbd_original = b.RGBD(rgb, depth, camera_pose, b.Intrinsics(h,w,fx,fy,cx,cy,0.001,10000.0))
    return rgbd_original

rgbd_original = image_to_rgbd(camera_image_1)
print("Got rgbd_original")

scaling_factor = 0.2
rgbd_scaled_down = b.RGBD.scale_rgbd(rgbd_original, scaling_factor)
print("Got rgb immage scaled down.")

b.clear_visualizer()
b.show_cloud("1", b.unproject_depth(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics).reshape(-1,3))

Got rgbd_original
Got rgb immage scaled down.


In [6]:
model_dir = os.path.join(os.path.abspath('..'), 'bayes3d/assets/bop/ycbv/models')
model_dir

'/home/georgematheos/bayes3d_genjax/bayes3d/assets/bop/ycbv/models'

In [7]:
b.setup_renderer(rgbd_scaled_down.intrinsics)

# os.path.join(os.path.abspath(''), 'bayes3d/assets/bop/ycbv/models')# 
model_dir = os.path.join(os.path.abspath('../..'), 'bayes3d/assets/bop/ycbv/models')
mesh_path = os.path.join(model_dir,"obj_" + "{}".format(13+1).rjust(6, '0') + ".ply")
ycb_filenames = glob.glob(os.path.join(model_dir, "*.ply"))
ycb_index_order = [int(s.split("/")[-1].split("_")[-1].split(".")[0]) for s in ycb_filenames]
sorted_ycb_filenames = [s for _,s in sorted(zip(ycb_index_order, ycb_filenames))]

relevant_objects = [any(x in name for x in categories_on_table) for (i, name) in enumerate(b.utils.ycb_loader.MODEL_NAMES)]
relevant_object_names = [b.utils.ycb_loader.MODEL_NAMES[i] for i in range(len(b.utils.ycb_loader.MODEL_NAMES)) if relevant_objects[i]]
filtered_filenames = [sorted_ycb_filenames[i] for i in range(len(sorted_ycb_filenames)) if relevant_objects[i]]

table_dims = table_info[1:]
table_mesh = b.utils.make_cuboid_mesh(table_dims)
b.RENDERER.add_mesh(table_mesh, "table")
print("Added table mesh.")

pillar_mesh = b.utils.make_cuboid_mesh(jnp.array([0.02, 0.02, 0.5]))
b.RENDERER.add_mesh(pillar_mesh, "pillar")
print("Added pillar mesh.")

for model_path in filtered_filenames:
    b.RENDERER.add_mesh_from_file(model_path, scaling_factor=1.0/1000.0)
    print(f"Added mesh at path {model_path}.")

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)
Added table mesh.
Added pillar mesh.
Added mesh at path /home/georgematheos/bayes3d/assets/bop/ycbv/models/obj_000002.ply.
Added mesh at path /home/georgematheos/bayes3d/assets/bop/ycbv/models/obj_000003.ply.
Added mesh at path /home/georgematheos/bayes3d/assets/bop/ycbv/models/obj_000011.ply.
Added mesh at path /home/georgematheos/bayes3d/assets/bop/ycbv/models/obj_000013.ply.
Added mesh at path /home/georgematheos/bayes3d/assets/bop/ycbv/models/obj_000021.ply.


In [23]:
relevant_object_names

['003_cracker_box',
 '004_sugar_box',
 '019_pitcher_base',
 '024_bowl',
 '061_foam_brick']

In [10]:
from src.model import model, viz_trace_meshcat
importance_jit = jax.jit(model.importance)
key = jax.random.PRNGKey(0)

In [11]:
obs_img = b.unproject_depth_jit(rgbd_scaled_down.depth, rgbd_scaled_down.intrinsics)

trace, weight = importance_jit(key, genjax.choice_map({
    "parent_0": -1,
    "parent_1": 0,
    "id_0": jnp.int32(0),
    "camera_pose": jnp.eye(4),
    "root_pose_0": X_CT,
    "face_parent_1": 2,
    "face_child_1": 3,
    "image": obs_img,
    "variance": 0.02,
    "outlier_prob": 0.0001,
    "contact_params_1": jnp.array([0.0, 0.0, 0.0])
}), (
    jnp.arange(1),
    jnp.arange(5),
    jnp.array([-jnp.ones(3)*100.0, jnp.ones(3)*100.0]),
    jnp.array([jnp.array([-12.0, -12.0, -jnp.pi]), jnp.array([12.0, 12.0, jnp.pi])]),
    b.RENDERER.model_box_dims
))

viz_trace_meshcat(trace)

In [12]:
grid_params = [
    (0.65, jnp.pi, (30,30,15)),
    (0.2, jnp.pi, (15,15,15)),
    (0.1, jnp.pi, (15,15,15)),
    (0.05, jnp.pi/3, (15,15,15)),
    # (0.02, jnp.pi, (9,9,51)),
    # (0.01, jnp.pi/5, (15,15,15)),
    # (0.05, 0.0, (31,31,1))
]
contact_param_gridding_schedule = [
    b.utils.make_translation_grid_enumeration_3d(
        -x, -x, -ang,
        x, x, ang,
        *nums
    )
    for (x,ang,nums) in grid_params
]

def _c2f(key, tr, object_idx, contact_param_gridding_schedule):
    address = f"contact_params_{object_idx}"
    print(address)
    updater = jax.vmap(lambda trace, v: trace.update(
        key, genjax.choice_map({address: v})
    )[0].get_score(), in_axes=(None, 0))
    cp = tr[address]
    for cp_grid in contact_param_gridding_schedule:
        cps = cp + cp_grid
        scores = updater(tr, cps)
        # key, _ = jax.random.split(key)
        # i = jax.random.randint(key, (), 0, len(cp_grid) - 1)
        cp = cps[scores.argmax()]
        # cp = cps[i]
    potential_trace = tr.update(key, genjax.choice_map({address: cp}))[0]
    return potential_trace
    

c2f = jax.jit(_c2f, static_argnums=[2])

In [15]:
from src.inference import add_object_jit

In [19]:
# id = 2: cracker box
tr2 = c2f(key, add_object_jit(trace, key, 2, 0, 2, 3), 1, contact_param_gridding_schedule)
viz_trace_meshcat(tr2)

contact_params_1


In [25]:
# id = 4: pitcher base
tr3 = c2f(key, add_object_jit(tr2, key, 4, 0, 2, 3), 2, contact_param_gridding_schedule)
viz_trace_meshcat(tr3)

In [26]:
# id = 3: sugar box
tr4 = c2f(key, add_object_jit(tr3, key, 3, 0, 2, 3), 3, contact_param_gridding_schedule)
viz_trace_meshcat(tr4)

contact_params_3


In [28]:
# id = 5: bowl
tr5 = c2f(key, add_object_jit(tr4, key, 5, 0, 2, 3), 4, contact_param_gridding_schedule)
viz_trace_meshcat(tr5)

In [29]:
# id = 1: pillar
tr6 = c2f(key, add_object_jit(tr5, key, 1, 0, 2, 3), 5, contact_param_gridding_schedule)
tr7 = c2f(key, add_object_jit(tr6, key, 1, 0, 2, 3), 6, contact_param_gridding_schedule)
viz_trace_meshcat(tr7)

contact_params_5
contact_params_6
