# Pipeline Simulation - Test

In [1]:
import numpy as np
import torch
from feature_detection import Net
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from _datasets import ObjectPointCloudDataset

## Sanity Check

### Label Assignation

In [2]:
import os
import re

names = os.listdir('ycb')

labels = {}

for name in names:
    id = int(name[:3])
    
    labels[id] = name

{71: '071_nine_hole_peg_test.stl', 38: '038_padlock.stl', 65: '065-b_cups.stl', 6: '006_mustard_bottle.stl', 73: '073-g_lego_duplo.stl', 32: '032_knife.stl', 4: '004_sugar_box.stl', 9: '009_gelatin_box.stl', 70: '070-b_colored_wood_blocks.stl', 30: '030_fork.stl', 5: '005_tomato_soup_can.stl', 2: '002_master_chef_can.stl', 31: '031_spoon.stl', 33: '033_spatula.stl', 11: '011_banana.stl', 72: '072-c_toy_airplane.stl', 48: '048_hammer.stl', 55: '055_baseball.stl', 8: '008_pudding_box.stl', 62: '062_dice.stl', 40: '040_large_marker.stl', 57: '057_racquetball.stl', 59: '059_chain.stl', 58: '058_golf_ball.stl', 3: '003_cracker_box.stl', 43: '043_phillips_screwdriver.stl', 21: '021_bleach_cleanser.stl', 24: '024_bowl.stl', 63: '063-a_marbles.stl', 50: '050_medium_clamp.stl', 77: '077_rubiks_cube.stl', 26: '026_sponge.stl', 51: '051_large_clamp.stl', 42: '042_adjustable_wrench.stl', 22: '022_windex_bottle.stl', 19: '019_pitcher_base.stl', 17: '017_orange.stl', 29: '029_plate.stl', 54: '054_so


### Models Instantiation

In [16]:
# Load the model from pt files
model_paths = [path for path in os.listdir('models')]
models = []
for path in model_paths:
    model = Net()
    model.to('cuda')
    model.load_state_dict(torch.load('models/'+path)['model_state_dict'])
    model.eval()
    models.append(model)

# Loading point clouds from nympy file
pc = np.load("sanity_check.npy")
data = Data(x = torch.tensor(pc, dtype=torch.float))
loader = DataLoader(dataset=[data], batch_size=1)

for i in range(len(models)):
    for d in loader:
        d = d.to('cuda')
        result = models[i](d)[0].max(1)[1]

    print(f"{model_paths[i]} => ", result, " : ", labels[result.item()])

model_64_0.9912256773958902.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_37_0.9930441898527005.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_11_0.9883160574649936.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_61_0.9912711402073104.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_41_0.9903618839789052.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_90_0.9905891980360065.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_28_0.9886797599563557.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_52_0.9916348426986725.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_23_0.988907074013457.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_96_0.9913166030187307.pt =>  tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl
model_88_0.9863156937625023.pt =>

In [4]:
import open3d as o3d

def visualize_point_cloud(pc):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pc)
    o3d.visualization.draw_geometries([pcd])

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [10]:
model = Net()
model.to('cuda')
model.load_state_dict(torch.load('models/model.pt')['model_state_dict'])
model.eval()
models.append(model)


# Loading test dataset frompt files
test_dataset = ObjectPointCloudDataset(root = '.', 
                                    chunk = (87984, 109980), 
                                    sample_count = 512,
                                    output_name = 'test'
                                    )

loader = DataLoader(test_dataset,
                    batch_size=1)

point_clouds = []
label = []

test_num = 0
for d in loader:
    test_num += 1
    d = d.to('cuda')
    result = models[i](d)[0].max(1)[1]

    print(result, " : ", labels[result.item()], " | ", d.y)
    print(d.x.dtype)

    pc = d.x.cpu().numpy()

    point_clouds.append(pc)
    label.append(labels[result.item()])

    if test_num > 9:
        break

done loading
tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl  |  tensor([73], device='cuda:0', dtype=torch.int8)
torch.float32
tensor([72], device='cuda:0')  :  072-c_toy_airplane.stl  |  tensor([72], device='cuda:0', dtype=torch.int8)
torch.float32
tensor([62], device='cuda:0')  :  062_dice.stl  |  tensor([62], device='cuda:0', dtype=torch.int8)
torch.float32
tensor([38], device='cuda:0')  :  038_padlock.stl  |  tensor([38], device='cuda:0', dtype=torch.int8)
torch.float32
tensor([57], device='cuda:0')  :  057_racquetball.stl  |  tensor([57], device='cuda:0', dtype=torch.int8)
torch.float32
tensor([65], device='cuda:0')  :  065-b_cups.stl  |  tensor([65], device='cuda:0', dtype=torch.int8)
torch.float32
tensor([73], device='cuda:0')  :  073-g_lego_duplo.stl  |  tensor([73], device='cuda:0', dtype=torch.int8)
torch.float32
tensor([38], device='cuda:0')  :  038_padlock.stl  |  tensor([38], device='cuda:0', dtype=torch.int8)
torch.float32
tensor([63], device='cuda:0')  :  063-a_ma

In [7]:
for i in range(len(point_clouds)):
    np.save(f"pcs/{label[i]}_{i}", point_clouds[i])

In [None]:
files = os.listdir('../dataset/v4/')
labels = set()

for file in files[:8794]:
    labels.add(file[:3])

print(labels)

test = set()
for file in files[8794:]:
    test.add(file[:3])

print(test)


In [None]:
repetitions = {}

for label in files:
    try:
        repetitions[label[:3]] += 1
    except:
        repetitions[label[:3]] = 1

print(repetitions)

In [None]:
# %matplotlib widget
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

def plot_point_cloud(mat):
    # Create a 3D figure
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')

    # Define the x, y, z coordinates of the point cloud
    x = mat[:, 0]
    y = mat[:, 1]
    z = mat[:, 2]

    # Plot the point cloud data
    ax.scatter(x, y, z, s=1)

    # Set the axis labels
    ax.set_xlabel('X Label')
    ax.set_ylabel('Y Label')
    ax.set_zlabel('Z Label')

    ax.scatter(0, 0, 0, s=10)

    ax.set_xlim(-0.3, 0.3)
    ax.set_ylim(-0.3, 0.3)
    ax.set_zlim(-0.3, 0.3)

    # Show the plot
    plt.show()
    # print

plot_point_cloud(pc)
