In [1]:
import os 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import k3d
from matplotlib import cm, colors
import torch
import numpy as np

In [2]:
def visualize_pointcloud(point_cloud, point_size=0.05, colors=None, flip_axes=False, name='point_cloud'):
    plot = k3d.plot(name=name, grid_visible=False)
    if flip_axes:
        point_cloud[:, 2] = point_cloud[:, 2] * -1
        point_cloud[:, [0, 1, 2]] = point_cloud[:, [0, 2, 1]]
    plt_points = k3d.points(positions=point_cloud.astype(np.float32), point_size=point_size, colors=colors if colors is not None else [], color=0xd0d0d0)
    plot += plt_points
    plt_points.shader = '3d'
    return plot.display()


In [3]:
def plot(coord, prediction):
    point_labels = (prediction - min(prediction)) / (max(prediction) - min(prediction))
    point_colors = cm.get_cmap('hsv')(point_labels)[:, :3]
    point_colors = np.sum((point_colors * 255).astype(int) * [255*255, 255, 1], axis=1)
    visualize_pointcloud(coord, colors=point_colors, point_size=0.025, flip_axes=True)

In [13]:
def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
    if (y.is_cuda):
        return new_y.cuda()
    return new_y

def infer(model, item):
    seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
               'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],
               'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],
               'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
    seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}
    for cat in seg_classes.keys():
        for label in seg_classes[cat]:
            seg_label_to_cat[label] = cat

    with torch.no_grad():
        seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}

        for cat in seg_classes.keys():
            for label in seg_classes[cat]:
                seg_label_to_cat[label] = cat

        model = model.eval()
        (points, label, target) = item
        points = torch.from_numpy(points).unsqueeze(0)
        label = torch.from_numpy(label).unsqueeze(0)
        target = torch.from_numpy(target).unsqueeze(0)
        
        cur_batch_size, NUM_POINT, _ = points.size()
        points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
        seg_pred = model(torch.cat([points, to_categorical(label, num_category).repeat(1, points.shape[1], 1)], -1))
        cur_pred_val = seg_pred.cpu().data.numpy()
        cur_pred_val_logits = cur_pred_val
        cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
        target = target.cpu().data.numpy()

        for i in range(cur_batch_size):
            cat = seg_label_to_cat[target[i, 0]]
            logits = cur_pred_val_logits[i, :, :]
            cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]

        return points, cur_pred_val, target

In [14]:
from dataset import PartNormalDataset
from models.Hengshuang.model import PointTransformerSeg
from easydict import EasyDict as edict

model_path = "/home/nazmicancalik/workspace/pc-transformers-with-convolution/Point-Transformers/log/partseg/Hengshuang/best_model.pth"
root = "/home/nazmicancalik/workspace/pc-transformers-with-convolution/Point-Transformers/data/shapenetcore_partanno_segmentation_benchmark_v0_normal/"

num_point = 1024
normal = True
batch_size = 16
num_category = 16

TEST_DATASET = PartNormalDataset(root=root, npoints=num_point, split='test', normal_channel=normal)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=batch_size, shuffle=False, num_workers=10)

num_class = 50
num_part = num_class

# Set config
cfg = edict()
cfg.num_point = 1024
cfg.num_class = num_class
cfg.input_dim = 22 # with normals
cfg.model = edict()
cfg.model.nblocks = 4
cfg.model.nneighbor = 16
cfg.model.transformer_dim = 512




model = PointTransformerSeg(cfg).cuda()
checkpoint = torch.load(model_path)
state_dict = checkpoint["model_state_dict"]
model.load_state_dict(state_dict, strict=True)

<All keys matched successfully>

In [73]:
def random_shuffle_region(arr):
    region = arr[250:300]
    np.random.shuffle(region)
    arr[250:300] = region
    return arr

item_index = 1500
item = TEST_DATASET[item_index]
points, pred, label = infer(model,item)

coordinates = np.asarray(points[0][:,:3].cpu())
pred = pred[0]
#pred = random_shuffle_region(pred)
label = label[0]

In [77]:
plot(coordinates, label)