# ** Semantic Segmentation Models**

In this notebook we will create a pipeline to perform semantic segmentation on point clouds of indoor spaces. This pipeline will incorporate a pretrained segmentation Point Net to get predictions for an input set of points. Then we will use open3d to search the point cloud space for 

In [None]:
import os
import re
from glob import glob
import time
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import MulticlassMatthewsCorrCoef
import open3d as o3
# from open3d import JVisualizer # For Colab Visualization
from open3d.web_visualizer import draw # for non Colab

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# TEMP for supressing pytorch user warnings
import warnings
warnings.filterwarnings("ignore")

In [None]:
# dataset
ROOT = r'C:\Users\itber\Documents\datasets\S3DIS\Stanford3dDataset_v1.2_Reduced_Partitioned_Aligned_Version_1m'

# feature selection hyperparameters
NUM_TRAIN_POINTS = 4096 # train/valid points
NUM_TEST_POINTS = 15000

BATCH_SIZE = 16

: 

In [None]:
CATEGORIES = {
    'ceiling'  : 0, 
    'floor'    : 1, 
    'wall'     : 2, 
    'beam'     : 3, 
    'column'   : 4, 
    'window'   : 5,
    'door'     : 6, 
    'table'    : 7, 
    'chair'    : 8, 
    'sofa'     : 9, 
    'bookcase' : 10, 
    'board'    : 11,
    'stairs'   : 12,
    'clutter'  : 13
}

# unique color map generated via
# https://mokole.com/palette.html
COLOR_MAP = {
    0  : (47, 79, 79),    # ceiling - darkslategray
    1  : (139, 69, 19),   # floor - saddlebrown
    2  : (34, 139, 34),   # wall - forestgreen
    3  : (75, 0, 130),    # beam - indigo
    4  : (255, 0, 0),     # column - red 
    5  : (255, 255, 0),   # window - yellow
    6  : (0, 255, 0),     # door - lime
    7  : (0, 255, 255),   # table - aqua
    8  : (0, 0, 255),     # chair - blue
    9  : (255, 0, 255),   # sofa - fuchsia
    10 : (238, 232, 170), # bookcase - palegoldenrod
    11 : (100, 149, 237), # board - cornflower
    12 : (255, 105, 180), # stairs - hotpink
    13 : (0, 0, 0)        # clutter - black
}

v_map_colors = np.vectorize(lambda x : COLOR_MAP[x])

NUM_CLASSES = len(CATEGORIES)

#### Get Datasets and Dataloaders

In [None]:
from torch.utils.data import DataLoader
from s3dis_dataset import S3DIS

# get datasets
s3dis_test = S3DIS(ROOT, area_nums='6', split='test', npoints=NUM_TEST_POINTS)

# get dataloaders
test_dataloader = DataLoader(s3dis_test, batch_size=BATCH_SIZE, shuffle=False)

Get an example and display it

In [None]:
points, targets = s3dis_test[10]

pcd = o3.geometry.PointCloud()
pcd.points = o3.utility.Vector3dVector(points)
pcd.colors = o3.utility.Vector3dVector(np.vstack(v_map_colors(targets)).T/255)

# draw(pcd)
o3.visualization.draw_plotly([pcd])

#### Get Seg Model

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
DEVICE

In [None]:
from point_net import PointNetSegHead

# get intitial model architecture
# model = PointNetSegHead(num_points=NUM_TRAIN_POINTS, m=NUM_CLASSES)


MODEL_PATH = 'trained_models/seg_focal/seg_model_60.pth'

model = PointNetSegHead(num_points=NUM_TEST_POINTS, m=NUM_CLASSES).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval();

### Detection pipeline

Now it's time to make the object detection pipeline. In appendix D of the Point Net paper, the authors choose a random point, find it's predicted class, then search for other predicted classes within a 0.2m radius, then if the resulting cluster contains more than 200 points, then the clusters bounding boxe is added to a list of proposals. We compute the average point score for each proposed object, by taking the total number of points assigned to the object divided by the total number of evaluated in the radius.

Implement the searching with open3d

In [None]:
# reshape the points into an Nx3 array
pcd_points = norm_points.permute(2, 0, 1).reshape(3, -1).to('cpu').T

# place them into a point cloud object
pcd = o3.geometry.PointCloud()
pcd.points = o3.utility.Vector3dVector(pcd_points)

# initialize KD tree object
pcd_tree = o3.geometry.KDTreeFlann(pcd)

# perform search over radius r = 0.2
[k, idx, a] = pcd_tree.search_radius_vector_3d(pcd.points[1500], 0.2)

In [None]:
def get_downsample_choices(points, npoints):
    if len(points) > npoints:
        choice = np.random.choice(len(points), npoints, replace=False)
    else:
        choice = np.random.choice(len(points), npoints, replace=True)

    return choice

In [None]:
predictions = pred_choice.reshape(-1).to('cpu') # Nx1
pcd_points = norm_points.permute(2, 0, 1).reshape(3, -1).to('cpu').T # Nx3

# downsample points
choice = get_downsample_choices(pcd_points, 1500)
pcd_points = pcd_points[choice]
predictions = predictions[choice]

# only obtain points for current category
pcd_points = pcd_points[predictions == 0, :]

# place them into a point cloud object
pcd = o3.geometry.PointCloud()
pcd.points = o3.utility.Vector3dVector(pcd_points)

# initialize KD tree object
pcd_tree = o3.geometry.KDTreeFlann(pcd)

# perform search over radius r = 0.2
[k, idx, a] = pcd_tree.search_radius_vector_3d(pcd.points[10], 0.2)

In [None]:
def get_stuff(predictions, points, cat, npoints=15000, radius=0.2, M=500):
    predictions = pred_choice.reshape(-1).to('cpu') # Nx1
    pcd_points = norm_points.permute(2, 0, 1).reshape(3, -1).to('cpu').T # Nx3

    # downsample points
    choice = np.random.choice(len(pcd_points), 15000, replace=False)
    pcd_points = pcd_points[choice]
    predictions = predictions[choice]

    # only obtain points for current category
    pcd_points = pcd_points[predictions == cat]

    # place them into a point cloud object
    pcd = o3.geometry.PointCloud()
    pcd.points = o3.utility.Vector3dVector(pcd_points)

    # initialize KD tree object
    pcd_tree = o3.geometry.KDTreeFlann(pcd)

    # perform M proposal searches over radius 
    p_idxs = np.random.choice(len(pcd_points), M, replace=False)
    for p in p_idxs:
        [k, idx, _] = pcd_tree.search_radius_vector_3d(pcd.points[p], radius=radius)
