In [7]:
import sys
sys.path.append('src/utils')
import mesh_tools as mt
from SDF_CNN import CNN_3d_multiple_split
from CNN_to_PoNQ_or_lite import CNN_to_PoNQ
from ABC_dataset import make_mask_close
import torch
from meshplot import plot
import igl
from tqdm import tqdm

In [8]:
state_dict = 'data/pretrained_PoNQ_ABC.pt'
example_mesh = 'data/bunny.obj'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [9]:
model = CNN_3d_multiple_split(device=device)
model.load_state_dict(torch.load(state_dict, map_location=device))
model.to(device)
model.eval()

CNN_3d_multiple_split(
  (encoder): Sequential(
    (0): Conv3d(1, 128, kernel_size=(2, 2, 2), stride=(1, 1, 1))
    (1): LeakyReLU(negative_slope=0.01, inplace=True)
    (2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): LeakyReLU(negative_slope=0.01, inplace=True)
    (4): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (5): LeakyReLU(negative_slope=0.01, inplace=True)
    (6): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (7): LeakyReLU(negative_slope=0.01, inplace=True)
    (8): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (9): LeakyReLU(negative_slope=0.01, inplace=True)
    (10): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (11): LeakyReLU(negative_slope=0.01, inplace=True)
  )
  (decoder_points): SDF_decoder(
    (decoder): Sequential(
      (0): resnet_block(
        (conv_1): Conv3d(128, 128, kernel_

In [10]:
v, f = igl.read_triangle_mesh(example_mesh)
v = 2*mt.NDCnormalize(v)
grid_n = 65

points = mt.mesh_grid(grid_n, True)
sdf = igl.signed_distance(points, v, f)[0].reshape(grid_n, grid_n, grid_n)

tensor_sdf = torch.tensor(sdf, dtype=torch.float32, device=device)[None, None, ...]
tensor_mask = torch.tensor(make_mask_close(sdf, grid_n), dtype=torch.bool, device=device).flatten()[None, ...]

In [11]:
ponq = CNN_to_PoNQ(model, tensor_sdf, grid_n, tensor_mask, device=device, subd=1)

In [None]:
ponq.get_vstars()[0].shape

In [12]:
2/32

0.0625

In [None]:
plot(*ponq.min_cut_surface(grid_n))

In [None]:
plot(*mt.mesh_from_voxels(sdf))

In [49]:
stride 


16

In [54]:
import numpy as np 
grid_n = 65
kernel_size = 33
stride = (kernel_size-1)//2

patch_grid = torch.tensor(mt.mesh_grid(grid_n-1, False), dtype=torch.int)
patch_grid = patch_grid.reshape(grid_n-1, grid_n-1, grid_n-1, 3).permute((3, 0, 1, 2))
print(patch_grid.shape)
patch_grid = patch_grid.unfold(3, kernel_size-1, stride).unfold(2, kernel_size-1, stride).unfold(1, kernel_size-1, stride)
print(patch_grid.shape)
patch_grid = patch_grid.reshape(3, -1, (kernel_size-1), (kernel_size-1), (kernel_size-1))
patch_grid = patch_grid.reshape(3, -1, (kernel_size-1)**3).permute((1, 2, 0))

full_grid = np.zeros((grid_n-1,grid_n-1,grid_n-1))
for patch in patch_grid:
    full_grid[patch[:, 0], patch[:, 1], patch[:, 2]] += 1
np.unique(full_grid)

torch.Size([3, 64, 64, 64])
torch.Size([3, 3, 3, 3, 32, 32, 32])


array([1., 2., 4., 8.])

In [136]:
patch_mask = tensor_mask.reshape(1, 1, grid_n-1, grid_n-1,grid_n-1)

print(patch_mask.shape)
patch_mask = patch_mask.unfold(2, kernel_size-1, stride).unfold(3, kernel_size-1, stride).unfold(4, kernel_size-1, stride)
patch_mask.shape


torch.Size([1, 1, 64, 64, 64])


torch.Size([1, 1, 3, 3, 3, 32, 32, 32])

torch.Size([1, 1, 3, 64, 64, 32])

### Large tensor

In [139]:
kernel_size = 33
stride = (kernel_size-1)//2

with torch.no_grad():
    x = tensor_sdf.clone()

    patches = x.unfold(2, kernel_size, stride).unfold(3, kernel_size, stride).unfold(4, kernel_size, stride)
    patches = patches.reshape(1, -1, kernel_size, kernel_size, kernel_size).permute(1, 0, 2, 3, 4)

    patch_mask = tensor_mask.reshape(1, 1, grid_n-1, grid_n-1,grid_n-1)
    patch_mask = patch_mask.unfold(2, kernel_size-1, stride).unfold(3, kernel_size-1, stride).unfold(4, kernel_size-1, stride)
    patch_mask = patch_mask.reshape(1, -1,  kernel_size-1,  kernel_size-1,  kernel_size-1).permute(1, 0, 2, 3, 4)
    patch_mask = patch_mask.reshape(-1, (kernel_size-1)**3)

    patch_grid = torch.tensor(mt.mesh_grid(grid_n-1, True)*(grid_n-1)/grid_n, dtype=torch.float32)
    patch_grid = patch_grid.reshape(grid_n-1, grid_n-1, grid_n-1, 3).permute((3, 0, 1, 2))
    patch_grid = patch_grid.unfold(1, kernel_size-1, stride).unfold(2, kernel_size-1, stride).unfold(3, kernel_size-1, stride)
    patch_grid = patch_grid.reshape(3, -1, (kernel_size-1), (kernel_size-1), (kernel_size-1))
    patch_grid = patch_grid.reshape(3, -1, (kernel_size-1)**3).permute((1, 2, 0))
    i=0
    model.change_grid_size(kernel_size)
    model.decoder_vstars.scale = grid_n
    model.decoder_points.scale = grid_n
    all_vstars = []
    all_mean_normals = []
    all_quadrics = []
    test_POINTS=[]
    for i in tqdm(range(len(patches))):
        model.grid = patch_grid[i].to(device)
        # if i==len(patches)-1:
        #     stride = kernel_size-1
        _, predicted_vstars, predicted_mean_normals, predicted_quadrics, predicted_bool = model(patches[None, i]*(grid_n-1)/32)
        print(predicted_vstars.shape)
        start = stride//2
        end = kernel_size-1-stride//2
        final_mask = (predicted_bool*patch_mask[i]).reshape(kernel_size-1, kernel_size-1, kernel_size-1)[start:end, start:end, start:end]>.5
        predicted_vstars = predicted_vstars.reshape(kernel_size-1, kernel_size-1, kernel_size-1, 4, 3)[start:end, start:end, start:end][final_mask]
        predicted_mean_normals = predicted_mean_normals.reshape(kernel_size-1, kernel_size-1, kernel_size-1, 4, 3)[start:end, start:end, start:end][final_mask]
        predicted_quadrics = predicted_quadrics.reshape(kernel_size-1, kernel_size-1, kernel_size-1, 4, 3, 3)[start:end, start:end, start:end][final_mask]
        
        all_vstars.append(predicted_vstars)
        all_mean_normals.append(predicted_mean_normals)
        all_quadrics.append(predicted_quadrics)
        torch.cuda.empty_cache()
        test_POINTS.append(model.grid.reshape(kernel_size-1, kernel_size-1,kernel_size-1, 3)[start:end, start:end, start:end].reshape(-1, 3))
        # if len(predicted_vstars)>0:
        #     print(i)
        #     break

    all_vstars = torch.cat(all_vstars)
    all_mean_normals = torch.cat(all_mean_normals)
    all_quadrics = torch.cat(all_quadrics)
    test_POINTS = torch.cat(test_POINTS)
# # # ...

  4%|▎         | 1/27 [00:01<00:44,  1.72s/it]

torch.Size([1, 32768, 4, 3])


  7%|▋         | 2/27 [00:03<00:44,  1.78s/it]

torch.Size([1, 32768, 4, 3])


 11%|█         | 3/27 [00:05<00:42,  1.75s/it]

torch.Size([1, 32768, 4, 3])


 15%|█▍        | 4/27 [00:06<00:39,  1.73s/it]

torch.Size([1, 32768, 4, 3])


 19%|█▊        | 5/27 [00:08<00:37,  1.71s/it]

torch.Size([1, 32768, 4, 3])


 22%|██▏       | 6/27 [00:10<00:35,  1.71s/it]

torch.Size([1, 32768, 4, 3])


 26%|██▌       | 7/27 [00:12<00:35,  1.77s/it]

torch.Size([1, 32768, 4, 3])


 30%|██▉       | 8/27 [00:13<00:32,  1.73s/it]

torch.Size([1, 32768, 4, 3])


 33%|███▎      | 9/27 [00:15<00:31,  1.76s/it]

torch.Size([1, 32768, 4, 3])


 37%|███▋      | 10/27 [00:17<00:29,  1.74s/it]

torch.Size([1, 32768, 4, 3])


 41%|████      | 11/27 [00:19<00:28,  1.76s/it]

torch.Size([1, 32768, 4, 3])


 44%|████▍     | 12/27 [00:20<00:25,  1.73s/it]

torch.Size([1, 32768, 4, 3])


 48%|████▊     | 13/27 [00:22<00:24,  1.72s/it]

torch.Size([1, 32768, 4, 3])


 52%|█████▏    | 14/27 [00:24<00:21,  1.66s/it]

torch.Size([1, 32768, 4, 3])


 56%|█████▌    | 15/27 [00:25<00:20,  1.70s/it]

torch.Size([1, 32768, 4, 3])


 59%|█████▉    | 16/27 [00:27<00:18,  1.67s/it]

torch.Size([1, 32768, 4, 3])


 63%|██████▎   | 17/27 [00:29<00:16,  1.67s/it]

torch.Size([1, 32768, 4, 3])


 67%|██████▋   | 18/27 [00:30<00:15,  1.70s/it]

torch.Size([1, 32768, 4, 3])


 70%|███████   | 19/27 [00:32<00:13,  1.72s/it]

torch.Size([1, 32768, 4, 3])


 74%|███████▍  | 20/27 [00:34<00:12,  1.72s/it]

torch.Size([1, 32768, 4, 3])


 78%|███████▊  | 21/27 [00:36<00:10,  1.78s/it]

torch.Size([1, 32768, 4, 3])


 81%|████████▏ | 22/27 [00:38<00:08,  1.76s/it]

torch.Size([1, 32768, 4, 3])


 85%|████████▌ | 23/27 [00:39<00:06,  1.75s/it]

torch.Size([1, 32768, 4, 3])


 89%|████████▉ | 24/27 [00:41<00:05,  1.73s/it]

torch.Size([1, 32768, 4, 3])


 93%|█████████▎| 25/27 [00:43<00:03,  1.72s/it]

torch.Size([1, 32768, 4, 3])


 96%|█████████▋| 26/27 [00:44<00:01,  1.73s/it]

torch.Size([1, 32768, 4, 3])


100%|██████████| 27/27 [00:46<00:00,  1.73s/it]

torch.Size([1, 32768, 4, 3])





In [119]:
with torch.no_grad():
    model.change_grid_size(grid_n)
    _, gt_predicted_vstars, _, _, _ = model(tensor_sdf*(grid_n-1)/32)
gt_predicted_vstars = gt_predicted_vstars.reshape(grid_n-1, grid_n-1, grid_n-1, 4, 3)

In [124]:
predicted_vstars = predicted_vstars.reshape(32, 32, 32, 4, 3)

In [125]:
gt_predicted_vstars.shape

torch.Size([64, 64, 64, 4, 3])

In [126]:
gt_predicted_vstars.shape

torch.Size([64, 64, 64, 4, 3])

In [74]:
all_vstars.mean(-2).shape, ponq.get_vstars()[0].shape

(torch.Size([398, 3]), torch.Size([5318, 3]))

In [75]:
mt.meshplot_add_points?

[0;31mSignature:[0m [0mmt[0m[0;34m.[0m[0mmeshplot_add_points[0m[0;34m([0m[0mmp[0m[0;34m,[0m [0mpoints[0m[0;34m,[0m [0msize[0m[0;34m=[0m[0;36m0.04[0m[0;34m,[0m [0mc[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      ~/Documents/These/PoNQ/src/utils/mesh_tools.py
[0;31mType:[0m      function


In [144]:
mp = plot(ponq.get_vstars()[0].cpu().detach().numpy(), shading={'point_size':.051, 'point_color': 'black'})
# mt.meshplot_add_points(mp, all_vstars.mean(-2).cpu().detach().numpy())

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0003400…

In [145]:
plot(all_vstars.mean(-2).cpu().detach().numpy(), shading={'point_size':.051})

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0002218…

<meshplot.Viewer.Viewer at 0x15a86ac10>

### H5py check

In [None]:
import h5py
grid_n=33
file = h5py.File('/data/nmaruani/DATASETS/gt_Quadrics/00000002.hdf5')
# original SDF is in [-0.5, 0.5]^3
sdf0 = 2 * file['{}_sdf'.format(grid_n-1)][:][None, :]

In [None]:
v, f = igl.read_triangle_mesh('/data/nmaruani/DATASETS/ABC/00000002/model.obj')
v = 2*mt.NDCnormalize(v)
points = mt.mesh_grid(grid_n, True)
sdf = igl.signed_distance(points, v, f)[0].reshape(grid_n, grid_n, grid_n)

In [None]:
sdf.shape

In [None]:
import numpy as np
grid_p = 3
mt.mesh_grid(grid_p-1, True)*(grid_p-2)/(grid_p-1)