In [None]:
import sys
sys.path.append('src/utils')
import mesh_tools as mt
from SDF_CNN import CNN_3d_multiple_split
from CNN_to_PoNQ_or_lite import CNN_to_PoNQ
from ABC_dataset import make_mask_close
import torch
from meshplot import plot
import igl
from tqdm import tqdm

In [None]:
state_dict = 'data/pretrained_PoNQ_ABC.pt'
example_mesh = 'data/bunny.obj'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
model = CNN_3d_multiple_split(device=device)
model.load_state_dict(torch.load(state_dict, map_location=device))
model.to(device)
model.eval()

In [None]:
v, f = igl.read_triangle_mesh(example_mesh)
v = 2*mt.NDCnormalize(v)
grid_n = 65

points = mt.mesh_grid(grid_n, True)
sdf = igl.signed_distance(points, v, f)[0].reshape(grid_n, grid_n, grid_n)

tensor_sdf = torch.tensor(sdf, dtype=torch.float32, device=device)[None, None, ...]
tensor_mask = torch.tensor(make_mask_close(sdf, grid_n), dtype=torch.bool, device=device).flatten()[None, ...]

In [None]:
ponq = CNN_to_PoNQ(model, tensor_sdf, grid_n, tensor_mask, device=device, subd=1)

In [None]:
ponq.get_vstars()[0].shape

In [None]:
plot(*ponq.min_cut_surface(grid_n))

In [None]:
plot(*mt.mesh_from_voxels(sdf))

### Large tensor

In [None]:
kernel_size = 33
stride = (kernel_size-1)//2

with torch.no_grad():
    x = tensor_sdf.clone()

    patches = x.unfold(4, kernel_size, stride).unfold(3, kernel_size, stride).unfold(2, kernel_size, stride)
    patches = patches.reshape(1, -1, kernel_size, kernel_size, kernel_size).permute(1, 0, 2, 3, 4)

    patch_mask = tensor_mask.reshape(1, 1, grid_n-1, grid_n-1,grid_n-1)
    patch_mask = patch_mask.unfold(4, kernel_size-1, stride).unfold(3, kernel_size-1, stride).unfold(2, kernel_size-1, stride)
    patch_mask = patch_mask.reshape(1, -1,  kernel_size-1,  kernel_size-1,  kernel_size-1).permute(1, 0, 2, 3, 4)
    patch_mask = patch_mask.reshape(-1, (kernel_size-1)**3)

    patch_grid = torch.tensor(mt.mesh_grid(grid_n-1, True)*(grid_n-1)/grid_n, dtype=torch.float32)
    patch_grid = patch_grid.reshape(grid_n-1, grid_n-1, grid_n-1, 3).permute((3, 0, 1, 2))
    patch_grid = patch_grid.unfold(3, kernel_size-1, stride).unfold(2, kernel_size-1, stride).unfold(1, kernel_size-1, stride)
    patch_grid = patch_grid.reshape(3, -1, (kernel_size-1), (kernel_size-1), (kernel_size-1))
    patch_grid = patch_grid.reshape(3, -1, (kernel_size-1)**3).permute((1, 2, 0))
    i=0
    model.change_grid_size(kernel_size)
    model.decoder_vstars.scale = grid_n
    model.decoder_points.scale = grid_n
    all_vstars = []
    all_mean_normals = []
    all_quadrics = []
    test_POINTS=[]
    for i in tqdm(range(len(patches))):
        model.grid = patch_grid[i].to(device)
        # if i==len(patches)-1:
        #     stride = kernel_size-1
        _, predicted_vstars, predicted_mean_normals, predicted_quadrics, predicted_bool = model(patches[None, i]*(grid_n-1)/32)
        start = stride//2
        end = kernel_size-1-stride//2
        final_mask = (predicted_bool*patch_mask[i]).reshape(kernel_size-1, kernel_size-1, kernel_size-1)[start:end, start:end, start:end]>.5
        predicted_vstars = predicted_vstars.reshape(kernel_size-1, kernel_size-1, kernel_size-1, 4, 3)[start:end, start:end, start:end][final_mask]
        predicted_mean_normals = predicted_mean_normals.reshape(kernel_size-1, kernel_size-1, kernel_size-1, 4, 3)[start:end, start:end, start:end][final_mask]
        predicted_quadrics = predicted_quadrics.reshape(kernel_size-1, kernel_size-1, kernel_size-1, 4, 3, 3)[start:end, start:end, start:end][final_mask]
        
        all_vstars.append(predicted_vstars)
        all_mean_normals.append(predicted_mean_normals)
        all_quadrics.append(predicted_quadrics)
        torch.cuda.empty_cache()
        test_POINTS.append(model.grid.reshape(kernel_size-1, kernel_size-1,kernel_size-1, 3)[start:end, start:end, start:end].reshape(-1, 3))

    all_vstars = torch.cat(all_vstars)
    all_mean_normals = torch.cat(all_mean_normals)
    all_quadrics = torch.cat(all_quadrics)
    test_POINTS = torch.cat(test_POINTS)
# # # ...

In [None]:
all_vstars.shape

In [None]:
ponq.get_vstars()[0].min()

In [None]:
all_vstars.mean(-2).min()

In [None]:
all_vstars.mean(-2).shape, ponq.get_vstars()[0].shape

In [None]:
plot(ponq.get_vstars()[0].cpu().detach().numpy(), shading={'point_size':.051})

In [None]:
plot(all_vstars.mean(-2).cpu().detach().numpy(), shading={'point_size':.051})

### H5py check

In [None]:
import h5py
grid_n=33
file = h5py.File('/data/nmaruani/DATASETS/gt_Quadrics/00000002.hdf5')
# original SDF is in [-0.5, 0.5]^3
sdf0 = 2 * file['{}_sdf'.format(grid_n-1)][:][None, :]

In [None]:
v, f = igl.read_triangle_mesh('/data/nmaruani/DATASETS/ABC/00000002/model.obj')
v = 2*mt.NDCnormalize(v)
points = mt.mesh_grid(grid_n, True)
sdf = igl.signed_distance(points, v, f)[0].reshape(grid_n, grid_n, grid_n)

In [None]:
sdf.shape

In [None]:
import numpy as np
grid_p = 3
mt.mesh_grid(grid_p-1, True)*(grid_p-2)/(grid_p-1)