In [6]:
import torch
import torch.nn.functional as F
from models import E_3d, PE, BlockMLP
from kornia.utils.grid import create_meshgrid3d
from einops import rearrange, reduce, repeat
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import time
import trimesh
import numpy as np
import mcubes
from metrics import iou


device = 'cuda' if torch.cuda.is_available() else 'cpu'

input_size = [512 for _ in range(3)]
patch_size = [16 for _ in range(3)]
n_scales = 4
n_layers = 2
n_hidden = 8
exp_name = f'pyramid{input_size[0]}_{n_scales}scale'

n_freq = 5
P = torch.cat([E_3d*2**i for i in range(n_freq)], 1)
pe = PE(P).to(device)

xyz = create_meshgrid3d(patch_size[2], patch_size[1], patch_size[0], device=device)

occ_gt = np.unpackbits(np.load(f'occupancies/pyramid_{input_size[0]}.npy')).reshape(*input_size).astype(bool)
mesh_gt = trimesh.load('meshes/pyramid.obj', force='mesh', skip_materials=True)
# same preprocessing as in preprocess_mesh.py
bbox = np.amax(mesh_gt.vertices, 0)-np.amin(mesh_gt.vertices, 0)
mesh_whl = bbox/2
mesh_gt.vertices -= np.amax(mesh_gt.vertices, 0)-mesh_whl # center the mesh
mesh_whl *= 1.02

specified material (default)  not loaded!


In [7]:
t = time.time()
for j in reversed(range(n_scales)):
    final_act = 'sigmoid' if j==n_scales-1 else 'sin'
    nd = input_size[2]//(patch_size[2]*2**j)
    nh = input_size[1]//(patch_size[1]*2**j)
    nw = input_size[0]//(patch_size[0]*2**j)
    ckpt = torch.load(f'ckpts/{exp_name}/l{j}.ckpt', map_location=torch.device('cpu'))
    active_blocks = ckpt['active_blocks']
    n_blocks = int(active_blocks.sum())
    blockmlp = BlockMLP(n_blocks=n_blocks,
                        n_in=pe.out_dim, n_out=1,
                        n_layers=n_layers,
                        n_hidden=n_hidden,
                        final_act=final_act).to(device=device)
    blockmlp.load_state_dict(ckpt, strict=False)

    xyz_ = repeat(xyz, '1 pd ph pw c -> n (pd ph pw) c', n=n_blocks)
    with torch.no_grad():
        occ_pred_ = blockmlp(xyz_, b_chunks=512, **{'pe': pe, 'to_cpu': True})
        if j <= n_scales-2:
            occ_pred_ *= ckpt['scales']
    torch.cuda.synchronize()

    occ_pred = torch.zeros(nd*nh*nw, np.prod(patch_size), 1)
    occ_pred[active_blocks] = occ_pred_
    occ_pred = rearrange(occ_pred,
                         '(nd nh nw) (pd ph pw) c -> (nd pd) (nh ph) (nw pw) c',
                         nd=nd, nh=nh, nw=nw, pd=patch_size[2], ph=patch_size[1], pw=patch_size[0])
    if j <= n_scales-2:
        occ_pred += I_j_u
        occ_pred = np.clip(occ_pred, 0, 1)

    if j > 0:
        I_j_u = F.interpolate(rearrange(occ_pred, 'd h w c -> 1 c d h w'),
                              mode='trilinear',
                              scale_factor=2,
                              align_corners=True)
        I_j_u = rearrange(I_j_u, '1 c d h w -> d h w c')
torch.cuda.empty_cache()

occ_pred = occ_pred.numpy()[..., 0]
print(f'total time {time.time()-t:.3f} s')
print(f'IoU {iou(occ_pred, occ_gt):.6f}')
del occ_gt, I_j_u, ckpt, active_blocks, blockmlp, P, pe, xyz, xyz_

total time 4.362 s
IoU 0.999046


In [8]:
# create prediction mesh using marching cubes
vertices, triangles = mcubes.marching_cubes(occ_pred, 0.5)
vertices -= input_size[0]/2
vertices /= input_size[0]/2
vertices = vertices[:, [1, 0, 2]] # switch axes
vertices *= mesh_whl

print(f'mesh contains {len(vertices)} vertices and {len(triangles)} triangles')

# # compute chamfer distance (not yet...) use surface sample points!!
# with torch.no_grad():
#     chamfer_l1 = \
#         chamfer_distance(torch.FloatTensor(mesh_gt.vertices).unsqueeze(0).cuda(),
#                          torch.FloatTensor(vertices).unsqueeze(0).cuda(),
#                          squared=False)
# chamfer_l1 = chamfer_l1.item()*1e4/mesh_whl.max()
# print(f'Chamfer L1 {chamfer_l1:.4f}e-4')
# torch.cuda.empty_cache()

mesh contains 3293724 vertices and 6644160 triangles


# Visualize block decomposition

In [9]:
from collections import defaultdict
m = defaultdict(list)

for l in reversed(range(n_scales)):
    rw, rh, rd = patch_size[0]*2**l, patch_size[1]*2**l, patch_size[2]*2**l
    ckpt = torch.load(f'ckpts/{exp_name}/l{l}.ckpt', map_location=torch.device('cpu'))
    training_blocks = np.ones((input_size[2]//rd, input_size[1]//rh, input_size[0]//rw), bool)
    active_blocks = ckpt['active_blocks'].numpy().reshape(*training_blocks.shape)
    training_blocks[active_blocks] = 0 # converged

    for k in range(training_blocks.shape[0]):
        for j in range(training_blocks.shape[1]):
            for i in range(training_blocks.shape[2]):
                if not training_blocks[k, j, i]:
                    m[f'x{l}'] += [(np.array([0, 0, 1, 1, 0, 0, 1, 1])+i)*rw]
                    m[f'y{l}'] += [(np.array([0, 1, 1, 0, 0, 1, 1, 0])+j)*rh]
                    m[f'z{l}'] += [(np.array([0, 0, 0, 0, 1, 1, 1, 1])+k)*rd]
                    m[f'i{l}'] += [np.array([7, 0, 0, 0, 4, 4, 6, 6, 4, 0, 3, 2])+len(m[f'i{l}'])*8]
                    m[f'j{l}'] += [np.array([3, 4, 1, 2, 5, 6, 5, 2, 0, 1, 6, 3])+len(m[f'j{l}'])*8]
                    m[f'k{l}'] += [np.array([0, 7, 2, 3, 6, 7, 1, 1, 5, 5, 7, 6])+len(m[f'k{l}'])*8]

In [None]:
COLORS = ['red', 'green', 'blue', 'cyan', 'magenta'] # colors for each scale

fig = go.Figure()

if len(vertices) < 3e6:
    fig.add_trace(
        go.Mesh3d(
            x=vertices[:, 2],
            y=vertices[:, 0],
            z=vertices[:, 1],
            i=triangles[:, 0],
            j=triangles[:, 1],
            k=triangles[:, 2],
            color='lightgray',
            name='mesh',
            showlegend=True,
        )
    )

for l in reversed(range(n_scales)):
    fig.add_trace(
        go.Mesh3d(
            x=(np.concatenate(m[f'x{l}'])-input_size[0]/2)*2*mesh_whl[2]/input_size[0],
            y=(np.concatenate(m[f'y{l}'])-input_size[1]/2)*2*mesh_whl[0]/input_size[1],
            z=(np.concatenate(m[f'z{l}'])-input_size[2]/2)*2*mesh_whl[1]/input_size[2],
            i=np.concatenate(m[f'i{l}']),
            j=np.concatenate(m[f'j{l}']),
            k=np.concatenate(m[f'k{l}']),
            color=COLORS[l],
            name=f'scale {l}',
            showlegend=True,
            opacity=0.2
        )
    )

fig.update_layout(
    scene_camera=dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=1.5, y=1.5, z=1.5)
    ),
    scene_dragmode='orbit',
    title={
        'text': exp_name,
        'y': 0.9,
        'x': 0.5}
)

fig.show()