In [None]:
import numpy as np
import utils
import matplotlib.pyplot as plt
%matplotlib inline
import torch

In [None]:
batch_size = 256
log_folder = "logs/" # folder path to save the results
latent_size = 128 # bottleneck size of the Autoencoder model

category = "Chair"
n_points = 2048

In [None]:
from data.load_dataset import get_dataset
from torch.utils.data import TensorDataset, DataLoader

test_set = get_dataset(category, "test", n_points)

part_count = int(test_set.max())

print("Test set shape :" + str(test_set.shape))
print("Number of points : " + str(n_points))
print("Part count : " + str(part_count))

test_tensor = torch.from_numpy(test_set).float()

test_loader = DataLoader(dataset=test_tensor, batch_size=batch_size, shuffle=True, pin_memory=True)

In [None]:
model = torch.load(log_folder + "model_save")
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
model.eval()
model

In [None]:
def reconstruct(data):
    with torch.no_grad():
        points = data[:,:,0:3].to(device)
        _, decoded = model(points)
        seg_results , _ = model(decoded)
        seg_labels = seg_results.argmax(dim=2,keepdim=True)
       
        output = torch.cat([decoded, seg_labels.float()],2)
        return output.cpu().detach().numpy()

In [None]:
def segmentall(pc):
    
    t_data = torch.cat([pc, torch.zeros([pc.shape[0],n_points,1]).to(pc.device)],2)

    seg_results, output = model(t_data.to(device))
        
    seg_labels = seg_results.argmax(dim=2,keepdim=True).squeeze()
        
    t_data[:,:,3] = seg_labels
    
    return t_data.cpu().detach().numpy()

In [None]:
test_samples = next(iter(test_loader)) # random samples
test_output = reconstruct(test_samples)
utils.plotPC([test_samples.numpy(),test_output])

In [None]:
sample1 = 0
sample2 = 1
with torch.no_grad():
    
    point_features = model.get_point_features(test_samples.to(device))
    seg_results = model.segment(point_features)
    part_features = model.get_part_features(point_features, seg_results)
    global_feature = torch.max(part_features, 1)[0]
    
    latent_interpolation = utils.interpolateArray(global_feature[sample1],global_feature[sample2],9)
    decoded = model.decode(latent_interpolation)
    segmented = segmentall(decoded)
utils.plotPC(segmented)