In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import time
import csv

import cv2
import numpy as np
from scipy.spatial.distance import cdist

from tqdm import tqdm

import os, os.path as osp
from pathlib import Path

import matplotlib.pyplot as plt

import trimesh

# Lib

In [None]:
valid_classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39] 
print(f'{len(valid_classes)} classes')
# NYU labels
def create_color_palette():
    colors =  [
       (0, 0, 0), #index=0
       (174, 199, 232),  # 1.wall
       (152, 223, 138),  # 2.floor
       (31, 119, 180),   # 3.cabinet
       (255, 187, 120),  # 4.bed
       (188, 189, 34),   # 5.chair
       (140, 86, 75),    # 6.sofa
       (255, 152, 150),  # 7.table
       (214, 39, 40),    # 8.door
       (197, 176, 213),  # 9.window
       (148, 103, 189),  # 10.bookshelf
       (196, 156, 148),  # 11.picture
       (23, 190, 207),   # 12.counter
       (178, 76, 76),  
       (247, 182, 210),  # 14.desk
       (66, 188, 102), 
       (219, 219, 141),  # 16.curtain
       (140, 57, 197), 
       (202, 185, 52), 
       (51, 176, 203), 
       (200, 54, 131), 
       (92, 193, 61),  
       (78, 71, 183),  
       (172, 114, 82), 
       (255, 127, 14),   # 24.refrigerator
       (91, 163, 138), 
       (153, 98, 156), 
       (140, 153, 101),
       (158, 218, 229),  # 28.shower curtain
       (100, 125, 154),
       (178, 127, 135),
       (120, 185, 128),
       (146, 111, 194),
       (44, 160, 44),    # 33.toilet
       (112, 128, 144),  # 34.sink
       (96, 207, 209), 
       (227, 119, 194),  # 36.bathtub
       (213, 92, 176), 
       (94, 106, 211), 
       (82, 84, 163),    # 39.otherfurn
       (100, 85, 144) #index=40
    ]
    return colors

# map scannet -> nyu40
def map_labels(arr, label_mapping):
    mapped = np.copy(arr)
    for k,v in label_mapping.items():
        mapped[arr == k] = v
    return mapped.astype(np.uint8)

# if string s represents an int
def represents_int(s):
    try: 
        int(s)
        return True
    except ValueError:
        return False

# read the TSV file
def read_label_mapping(filename, label_from='id', label_to='nyu40id'):
    assert os.path.isfile(filename)
    mapping = dict()
    with open(filename) as csvfile:
        reader = csv.DictReader(csvfile, delimiter='\t')
        for row in reader:
            mapping[row[label_from]] = int(row[label_to])
    # if ints convert 
    if represents_int(list(mapping.keys())[0]):
        mapping = {int(k):v for k,v in mapping.items()}
    return mapping

In [None]:
DATA_DIR = Path('/mnt/data/scannet/scans/')

In [None]:
voxel_size = 0.05

scan_id = 'scene0000_01'
scan_dir = DATA_DIR / scan_id
label_file = DATA_DIR.parent / 'scannetv2-labels.combined.tsv'

input_file = f'{scan_id}_vh_clean_2.ply' 
gt_file = f'{scan_id}_vh_clean_2.labels.ply' 

input_mesh = trimesh.load(scan_dir / input_file)
gt_mesh = trimesh.load(scan_dir / gt_file)
input_mesh, gt_mesh

## input

In [None]:
# all the same
print(input_mesh.bounding_box.extents)
print(input_mesh.extents)
print(gt_mesh.extents)

In [None]:
input_mesh.show()

In [None]:
input_grid = input_mesh.voxelized(pitch=voxel_size) 
print(input_grid.matrix.shape)
occ_vol, total_vol = input_grid.matrix.sum(),  np.prod(input_grid.matrix.shape)
occ_frac = occ_vol / total_vol
print(f'Occupied: {occ_vol} / {total_vol} = {occ_frac:2f}')

occ_vol = input_grid.volume
total_vol = np.product(input_grid.matrix.shape)*(voxel_size**3)
print(f'Occupied volume: {occ_vol} / {total_vol} = {occ_vol/total_vol:2f}')

In [None]:
# another way of displaying, actually constructs the mesh.
# much heavier
input_boxes = input_grid.as_boxes()
input_boxes.show()

In [None]:
# show the grid directly, faster or does the same thing as "as_boxes"?
input_grid.show()

## read GT xyz, rgb and labels

In [None]:
from trimesh.exchange.ply import parse_header, ply_binary

with open(scan_dir / gt_file, 'rb') as f:
    elements, is_ascii, image_name = parse_header(f)
    ply_binary(elements, f)
    
scannet_labels = elements['vertex']['data']['label']
print('ScanNet Labels:', scannet_labels.shape, 'Range:', scannet_labels.min(), scannet_labels.max(), 
      'Unique:', len(np.unique(scannet_labels)))

plt.hist(scannet_labels, bins=40)
plt.show()

In [None]:
x, y, z = elements['vertex']['data']['x'], elements['vertex']['data']['y'], elements['vertex']['data']['z']
r, g, b = elements['vertex']['data']['red'], elements['vertex']['data']['green'], elements['vertex']['data']['blue']

gt_vertices = np.stack((x, y, z), axis=-1)
gt_rgb = np.stack((r, g, b), axis=-1)

print(gt_vertices.shape, gt_rgb.shape)

In [None]:
# map scannet labels to NYU40 labels
mapping = read_label_mapping(label_file)
nyu_labels = map_labels(scannet_labels, mapping)
print('NYU Labels:', nyu_labels.shape, 'range:', nyu_labels.min(), nyu_labels.max(), 'unique:', len(np.unique(nyu_labels)))

plt.hist(nyu_labels, bins=40)
plt.show()

In [None]:
nyu_labels = scannet_labels

In [None]:
# keep only the required labels, rest are 0
vertex_labels = np.array([l if l in valid_classes else 0 for l in nyu_labels.tolist()], dtype=np.uint8)
print('Selected NYU Labels:', vertex_labels.shape, 'range:', vertex_labels.min(), vertex_labels.max(), 'unique:', len(np.unique(vertex_labels)))
print('Classes present:', np.unique(vertex_labels))

plt.hist(vertex_labels, bins=40)
plt.show()

## check if gt colors and labels are 1-1 mapping

In [None]:
print(gt_rgb.shape, scannet_labels.shape)
unique_labels = np.unique(scannet_labels)

for label in unique_labels:
    # all colors for this label
    c = gt_rgb[scannet_labels == label]
    unique_c = np.unique(c, axis=0)
    print(label, unique_c.shape)

In [None]:
# centers of filled voxels
centers = input_grid.points
# convert points to indices
indices = input_grid.points_to_indices(centers)
# or get indices directly?
# indices = input_grid.sparse_indices
# or convert points to indices ourself - almost same
# indices = np.floor((centers - centers.min(axis=0)) / voxel_size).astype(np.uint8)
# vertices of the gt mesh
vertices = gt_vertices

print('centers', centers.shape, 'indices', indices.shape, 'vertices', vertices.shape)

print('Center range', centers.min(axis=0), centers.max(axis=0))
print('Vertex range', vertices.min(axis=0), vertices.max(axis=0))
print('Index range', indices.min(axis=0), indices.max(axis=0))

In [None]:
# assign labels and colors to grid
pairs = list(zip(centers, indices))

label_grid = np.zeros_like(input_grid.matrix, dtype=np.uint8)
print('label grid', label_grid.shape)
colors = np.zeros(label_grid.shape + (3,), dtype=np.uint8)
palette = create_color_palette()

method = 'nearest'
# method = 'voting'
skipped = 0

for center, ndx in tqdm(pairs):
    if method == 'nearest':
        # distance from this voxel center to all vertices
        dist = cdist(np.expand_dims(center, 0), vertices).flatten()
        # closest vertex
        closest_vtx_ndx = dist.argmin()
        # label of this vertex
        label = vertex_labels[closest_vtx_ndx]
    elif method == 'voting':
        # find indices all vertices within this voxel
        low, high = center - voxel_size, center + voxel_size
        vtx_in_voxel = np.all(np.logical_and((vertices >= low), (vertices <= high)), axis=1)
        # labels of these vertices
        labels = vertex_labels[vtx_in_voxel]
        # most common label
        try:
            label = np.bincount(labels).argmax()
        except ValueError:
            label = None
            skipped += 1
    
    # assign to label and color grid
    if label is not None and label in valid_classes:
        label_grid[ndx[0], ndx[1], ndx[2]] = label
        colors[ndx[0], ndx[1], ndx[2]] = palette[label]

print(f'Found labels for {len(centers) - skipped}/{len(centers)} centers')
        
# colors of voxel centers
center_colors = colors[indices[:, 0], indices[:, 1], indices[:, 2]]

In [None]:
unique_colors = np.unique(center_colors, axis=0)
print('Unique colors:', unique_colors.shape, unique_colors)

In [None]:
plt.hist(label_grid.reshape((-1)), bins=40, range=(1, 40))

## save point cloud to file

In [None]:
# add alpha channel
alpha = np.ones((len(centers), 1), dtype=np.uint8) * 255
pc_colors = np.concatenate((center_colors, alpha), axis=-1)

pc = trimesh.points.PointCloud(vertices=centers, colors=pc_colors)
_ = pc.export(scan_dir / f'{scan_id}_gt_voxelcenters.ply')