In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import csv

import cv2
import numpy as np
from scipy.spatial.distance import cdist

from tqdm import tqdm

import os, os.path as osp
from pathlib import Path

import matplotlib.pyplot as plt

import trimesh
from datasets.scannet.common import load_ply

# Lib

In [None]:
DATA_DIR = Path('/mnt/data/scannet/scans/')

## read GT vertices, rgb and labels

In [None]:
voxel_size = 0.05

scan_id = 'scene0001_00'
scan_dir = DATA_DIR / scan_id
label_file = DATA_DIR.parent / 'scannetv2-labels.combined.tsv'

input_file = f'{scan_id}_vh_clean_2.ply' 
gt_file = f'{scan_id}_vh_clean_2.labels.ply' 

input_mesh = trimesh.load(scan_dir / input_file)
gt_mesh = trimesh.load(scan_dir / gt_file)
input_mesh, gt_mesh

In [None]:
_, rgb, _ = load_ply(scan_dir / input_file)
vertices, _, labels = load_ply(scan_dir / gt_file, read_label=True)
print(vertices.shape, rgb.shape, labels.shape, labels.dtype)

## input

In [None]:
# all the same
print(input_mesh.extents)
minvertex = vertices.min(axis=0)
maxvertex = vertices.max(axis=0)
print('min', minvertex, 'max', maxvertex, 'range', maxvertex-minvertex)

In [None]:
input_mesh.show()

In [None]:
input_grid = input_mesh.voxelized(pitch=voxel_size) 

In [None]:
print('Grid size', input_grid.matrix.shape)
print('Start voxel location', input_grid.translation)
print('Offset', input_grid.translation / voxel_size)

occ_vol, total_vol = input_grid.matrix.sum(),  np.prod(input_grid.matrix.shape)
occ_frac = occ_vol / total_vol
print(f'Occupied: {occ_vol} / {total_vol} = {occ_frac:2f}')

In [None]:
# show the grid directly, faster or does the same thing as "as_boxes"?
input_grid.show()

In [None]:
# centers of filled voxels
centers = input_grid.points
# convert points to indices
# indices = input_grid.points_to_indices(centers)
# or get indices directly?
indices = input_grid.sparse_indices

# vertices of the gt mesh
# vertices = input_mesh.vertices

print('centers', centers.shape, 'indices', indices.shape, 'vertices', vertices.shape)

print('Center range', centers.min(axis=0), centers.max(axis=0))
print('Index range', indices.min(axis=0), indices.max(axis=0))
print('Vertex range', vertices.min(axis=0), vertices.max(axis=0))

In [None]:
# assign labels and colors to grid
pairs = list(zip(centers, indices))

label_grid = np.zeros_like(input_grid.matrix, dtype=np.int16)
rgb_grid = np.zeros(label_grid.shape + (3,), dtype=np.uint8)

method = 'nearest'
# method = 'voting'
skipped = 0

for center, ndx in tqdm(pairs):
    if method == 'nearest':
        # distance from this voxel center to all vertices
        dist = cdist(np.expand_dims(center, 0), vertices).flatten()
        # closest vertex
        closest_vtx_ndx = dist.argmin()
        # label of this vertex
        voxel_label = labels[closest_vtx_ndx]
        voxel_rgb = rgb[closest_vtx_ndx]
    elif method == 'voting':
        # find indices all vertices within this voxel
        low, high = center - voxel_size, center + voxel_size
        vtx_in_voxel = np.all(np.logical_and((vertices >= low), (vertices <= high)), axis=1)
        # labels of these vertices
        labels = vertex_labels[vtx_in_voxel]
        # most common label
        try:
            label = np.bincount(labels).argmax()
        except ValueError:
            label = None
            skipped += 1
    
    # assign to label and color grid
    if label is not None:
        label_grid[ndx[0], ndx[1], ndx[2]] = voxel_label
        rgb_grid[ndx[0], ndx[1], ndx[2]] = voxel_rgb
        
print(f'Found labels for {len(centers) - skipped}/{len(centers)} centers')
        
# colors of voxel centers
center_colors = rgb_grid[indices[:, 0], indices[:, 1], indices[:, 2]]

In [None]:
unique_colors = np.unique(center_colors, axis=0)
print('Unique colors:', unique_colors.shape, unique_colors)

In [None]:
plt.hist(label_grid.reshape((-1)), bins=40, range=(1, 40))

## save voxel centers point cloud to file

In [None]:
# add alpha channel
alpha = np.ones((len(centers), 1), dtype=np.uint8) * 255
pc_colors = np.concatenate((center_colors, alpha), axis=-1)

pc = trimesh.points.PointCloud(vertices=centers, colors=pc_colors)
out_file = f'{scan_id}_gt_voxelcenters.ply'
print(f'Saving to: {out_file}')

_ = pc.export(scan_dir / out_file)

## save occ grid to file

In [None]:
x, y = input_grid.matrix, label_grid
print(x.shape, x.dtype, y.shape, y.dtype, rgb_grid.shape, rgb_grid.dtype)
out_file = f'{scan_id}_occ_grid.pth'

data = {'x': x, 'y': y, 'rgb': rgb_grid, 'start_ndx': input_grid.translation / voxel_size, 'translation': input_grid.translation}
torch.save(data, scan_dir / out_file)

In [None]:
data = torch.load(scan_dir / out_file)
x, y, rgb = data['x'], data['y'], data['rgb']
print(x.shape, x.dtype, y.shape, y.dtype, rgb.shape, rgb.dtype)