In [None]:
import matplotlib.pyplot as plt
import torchvision
from pytorch3d.ops.marching_cubes import marching_cubes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures import Meshes

from cob3d_dataset import COB3D
from torch_utils import *
from fcon_model import FCON
import vis


### Initialize the dataset

In [None]:
device = 'cuda' if torch.cuda.is_available() else "cpu"
root = '/tmp/cob3d/v2  ## Change this to point to where you have the data downloaded!'

dset = COB3D.load(root, target_scale=800)

### Load a scene and visualize inputs

In [None]:
scene_id = '26ac1594-3d5e-a3ee-3278-d3a1e4b11d19'

batch = to_torch(dset[scene_id], recursive=True, device=device)

plt.figure(figsize=(12, 4))
plt.subplot(131)
vis.plot_rgb(batch['rgb'])
plt.title("RGB:")

plt.subplot(132)
vis.plot_rgb(batch['rgb'])
for m in batch['masks']:
    vis.plot_mask(m, edgecolor='w')
plt.title("Instance Masks:")

plt.subplot(133)    
plt.imshow(to_np(batch['depth_map']))
plt.title("Depth Map:")

print(f'Scene ID: {scene_id}')

### Visualize GT in 3D

In [None]:
point_map = depth2cloud(batch['depth_map'], batch['intrinsic']).permute(2,0,1)

voxels = batch['voxel_grid']['voxels']
extents = batch['voxel_grid']['extents']
cam_from_obj = batch['obj_poses']['poses']
scales = batch['obj_poses']['scales']

voxel_shape = voxels.shape[1:]
pts_normed = torch.stack(torch.meshgrid(*[torch.linspace(-1, 1, s, device=device) for s in voxel_shape], indexing='ij'), dim=-1)
gt_pts_obj = (extents / scales)[:, None, None, None] * pts_normed
gt_pts_cam = transform_points(cam_from_obj[:, None, None, None], gt_pts_obj)

builder = vis.SceneBuilder.from_point_map(point_map, batch['rgb'])
cmap = plt.get_cmap("hsv")
for i, (p, v) in enumerate(zip(gt_pts_cam, voxels)):
    color = cmap(i / len(gt_pts_cam))[0:3]
    builder.add_points(p[v][::10], color=color) # only plot every 10th point to make the rendering faster
builder.show()

### Load the pretrained F-CON model

In [None]:
ckpt = torch.load('checkpoints/fcon.pt', map_location='cpu')
model = FCON(n_depth_bins=96, patch_size=64).to(device)
model.load_state_dict(ckpt)

### Do inference with the model and visualize predictions

In [None]:
out = model.predict(
    batch['rgb'],
    batch['intrinsic'],
    point_map,
    batch['boxes'],
    batch['masks'],
    batch['near_plane'],
    batch['far_plane'],
)
probs = out['logits'].float().sigmoid()
voxel_centers_cam = out['grid_centers']

verts_idx_lst, faces_lst = marching_cubes(probs, 0.5, return_local_coords=False)
verts_cam_lst = [interp3d(p.permute(3, 0, 1, 2), v.flip(-1)) for p, v in zip(voxel_centers_cam, verts_idx_lst)]
pred_pts_cam = sample_points_from_meshes(Meshes(verts_cam_lst, faces_lst), 4096) 

builder = vis.SceneBuilder.from_point_map(point_map, batch['rgb'])
for i, p in enumerate(pred_pts_cam):
    color = cmap(i / len(pred_pts_cam))[0:3]
    builder.add_points(p, color=color)

builder.show()