In [145]:
import torch
from data.shapenet import ShapeNet
from model.vertix_model import VertixModel
import tqdm
from util.visualization import visualize_pointcloud, visualize_mesh
import random
%load_ext autoreload
%autoreload 2
import numpy as np
from scipy.spatial import distance_matrix

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [146]:
num_vertices = 100

In [147]:
config = {
    'experiment_name': 'vertix_hungarian_train',
    'device': 'cuda:0',  
    'is_overfit': False,
    'batch_size': 8,
    'resume_ckpt': None,
    'learning_rate': 1e-3,
    'max_epochs': 10000,
    'print_every_n': 100,
    'validate_every_n': 5,
    'sdf_path': 'data/shapenet_dim32_sdf',
    'meshes_path': 'data/shapenet_reduced',
    'class_mapping': 'data/shape_info.json',
    'split': 'train',
    'num_vertices': num_vertices,
    'feature_size': 512,
    'num_trajectories': 2
}

In [148]:
dataset = ShapeNet(sdf_path=config["sdf_path"],
                         meshes_path=config["meshes_path"],
                         class_mapping=config["class_mapping"],
                         split = "val", threshold=config["num_vertices"], num_trajectories=config["num_trajectories"])


In [149]:
dataset.filter_data()

Length of dataset: 32304
Filtering data ..


100%|███████████████████████████████████| 32304/32304 [00:06<00:00, 4745.01it/s]

Length of dataset: 5474





In [8]:
from training import vertix_hungarian_train

vertix_hungarian_train.main(config)

Device: cuda:0
Length of dataset: 153528
Filtering data ..


  2%|▊                                   | 3540/153528 [00:12<08:42, 287.18it/s]


KeyboardInterrupt: 

In [186]:
from inference.inference_vertix import InferenceHandlerVertixModel

# create a handler for inference using a trained checkpoint
inferer = InferenceHandlerVertixModel('runs/vertix_train_hungarian/model_best.ckpt', config["num_vertices"])

In [187]:
x_indices = []
y_indices = []

graph = np.ones((1,config["num_vertices"], config["num_vertices"],1))

for i in range(config["num_vertices"]):
    for j in range(config["num_vertices"]):
        x_indices.append(i)
        y_indices.append(j)

In [188]:
random_sample = random.randint(0,len(dataset))

In [189]:
input_sdf, target_vertices, mask, target_edges, edges_adj = dataset[random_sample]

In [190]:
faces = []

for i in range(num_vertices):
    for j in range(num_vertices):
        for k in range(num_vertices):
            if target_edges[i][j] and target_edges[j][k]:
                faces.append(np.array([i,j,k]).reshape(1,-1))

In [191]:
faces = np.concatenate(faces,0)

In [192]:
visualize_pointcloud(target_vertices, point_size=0.01)

Output()

In [193]:
output_pointcloud = inferer.infer_single(input_sdf)

In [194]:
from scipy.optimize import linear_sum_assignment

In [195]:
visualize_mesh(target_vertices, faces)

Output()

In [196]:
visualize_pointcloud(output_pointcloud,point_size=0.01)

Output()

In [197]:
faces = []

for i in range(num_vertices):
    for j in range(num_vertices):
        for k in range(num_vertices):
            if target_edges[i][j] and target_edges[j][k]:
                faces.append(np.array([i,j,k]).reshape(1,-1))

In [198]:
cost = distance_matrix(output_pointcloud, target_vertices[:int(sum(mask))])
                        
vertix_idx, target_idx = linear_sum_assignment(cost)

In [199]:
target_size = int(sum(mask))

In [200]:
matched_edges = np.zeros((num_vertices,num_vertices))

In [201]:
for i in range(target_size):
    for j in range(target_size):
        curr_v_1 = vertix_idx[i]
        curr_t_1 = target_idx[i]
        curr_v_2 = vertix_idx[j]
        curr_t_2 = target_idx[j]
        matched_edges[curr_v_1,curr_v_2] = target_edges[curr_t_1,curr_t_2]

In [202]:
faces = []

for i in range(num_vertices):
    for j in range(num_vertices):
        for k in range(num_vertices):
            if matched_edges[i][j] and matched_edges[j][k]:
                faces.append(np.array([i,j,k]).reshape(1,-1))
faces = np.concatenate(faces,0)

In [203]:
visualize_mesh(output_pointcloud,faces)

Output()