In [84]:
import torch
from data.shapenet import ShapeNet
from model.vertix_model import VertixModel
import tqdm
from util.visualization import visualize_pointcloud, visualize_mesh
import random
%load_ext autoreload
%autoreload 2
import numpy as np

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
num_vertices = 150

In [5]:
config = {
    'experiment_name': 'vertix_edge_overfitting',
    'device': 'cuda:0',  
    'is_overfit': True,
    'batch_size': 8,
    'resume_ckpt': None,
    'learning_rate': 0.0005,
    'max_epochs': 1000,
    'print_every_n': 10,
    'validate_every_n': 50,
    'sdf_path': 'data/shapenet_dim32_sdf',
    'meshes_path': 'data/shapenet_reduced',
    'class_mapping': 'data/shape_info.json',
    'split': 'overfit',
    'num_vertices': num_vertices,
    'feature_size': 512,
    'num_trajectories': 2
}

In [6]:
dataset = ShapeNet(sdf_path=config["sdf_path"],
                         meshes_path=config["meshes_path"],
                         class_mapping=config["class_mapping"],
                         split = "overfit", threshold=config["num_vertices"], num_trajectories=config["num_trajectories"])


In [7]:
dataset.filter_data()

Length of dataset: 91
Filtering data ..


100%|█████████████████████████████████████████| 91/91 [00:00<00:00, 3655.36it/s]

Length of dataset: 19





In [6]:
from training import vertix_edge_train

vertix_edge_train.main(config)

Device: cuda:0
Length of dataset: 20
Filtering data ..


100%|█████████████████████████████████████████| 20/20 [00:00<00:00, 1877.36it/s]


Length of dataset: 7


100%|██████████████████████████████████████████| 7/7 [00:00<00:00, 89512.59it/s]


Class 02691156 has 7 shapes
Length of dataset: 20
Filtering data ..


100%|█████████████████████████████████████████| 20/20 [00:00<00:00, 2727.03it/s]


Length of dataset: 7


100%|█████████████████████████████████████████| 7/7 [00:00<00:00, 169711.72it/s]

Class 02691156 has 7 shapes





NameError: name 'mask' is not defined

In [8]:
from inference.inference_vertix_edge import InferenceHandlerVertixEdgeModel

# create a handler for inference using a trained checkpoint
inferer = InferenceHandlerVertixEdgeModel('runs/vertix_edge_training/model_best.ckpt', config["num_vertices"], config["feature_size"])

In [9]:
x_indices = []
y_indices = []

graph = np.ones((1,config["num_vertices"], config["num_vertices"],1))

for i in range(config["num_vertices"]):
    for j in range(config["num_vertices"]):
        x_indices.append(i)
        y_indices.append(j)

In [10]:
random_sample = random.randint(0,len(dataset))

In [11]:
random_input = dataset[random_sample]["input_sdf"]
gt_pointcloud = dataset[random_sample]["target_vertices"]
gt_mask = dataset[random_sample]["input_mask"]
gt_edges = dataset[random_sample]["target_edges"]
#faces = dataset[random_sample]["faces"]

In [12]:
output_pointcloud, edges = inferer.infer_single(random_input, x_indices,y_indices,graph)

In [68]:
from scipy.optimize import linear_sum_assignment

In [75]:
cost = np.zeros((int(gt_mask.sum()),int(gt_mask.sum())))

In [76]:
for i in range(int(gt_mask.sum())):
    for j in range(int(gt_mask.sum())):
        distance = np.linalg.norm(output_pointcloud[i] - gt_pointcloud[j])
        cost[i][j] = distance
            

In [72]:
from scipy.optimize import linear_sum_assignment

In [77]:
row_ind, col_ind = linear_sum_assignment(cost)

In [83]:
gt_pointcloud[col_ind].shape

(100, 3)

In [35]:
gt_mask.sum()

100.0

In [65]:
output = algorithm.find_matching(G, matching_type = 'min', return_type = 'list')

AttributeError: 'bool' object has no attribute 'vertices'

{'0': {'0': 0.34783852,
  '1': 0.33254936,
  '2': 0.34757203,
  '3': 0.3323322,
  '4': 0.5435124,
  '5': 0.53385687,
  '6': 0.54334193,
  '7': 0.5337217,
  '8': 0.32589656,
  '9': 0.35578492,
  '10': 0.34111634,
  '11': 0.3637867,
  '12': 0.32392582,
  '13': 0.34661484,
  '14': 0.32392594,
  '15': 0.35603637,
  '16': 0.34111485,
  '17': 0.35577604,
  '18': 0.34090313,
  '19': 0.54879504,
  '20': 0.53923416,
  '21': 0.54862624,
  '22': 0.5391002,
  '23': 0.41305563,
  '24': 0.41305754,
  '25': 0.45349914,
  '26': 0.45324326,
  '27': 0.40382075,
  '28': 0.44510418,
  '29': 0.40382358,
  '30': 0.44484425,
  '31': 0.36100942,
  '32': 0.3884423,
  '33': 0.3748131,
  '34': 0.38820374,
  '35': 0.37462047,
  '36': 0.37480605,
  '37': 0.38216764,
  '38': 0.57035214,
  '39': 0.5611586,
  '40': 0.5701897,
  '41': 0.56103,
  '42': 0.36100844,
  '43': 0.35923135,
  '44': 0.3932238,
  '45': 0.3955212,
  '46': 0.38215843,
  '47': 0.3798148,
  '48': 0.3932066,
  '49': 0.3475421,
  '50': 0.3322393,
  '