In [1]:
from sklearn.cluster import DBSCAN
import numpy as np
import pickle
import pdb
import time
import os
from tree_utils import flatten_scores, flatten_indices
import sys
from utils import *
import open3d as o3d
from open3d import JVisualizer
import yaml

In [2]:
def evaluate(inds):
    return np.mean(objectness_objects[inds]).item()
    
    
def segment(id_, eps_list, cloud, original_indices=None, aggr_func='min'):
    if not all(eps_list[i] > eps_list[i+1] for i in range(len(eps_list)-1)):
        raise ValueError('eps_list is not sorted in descending order')
    # pick the first threshold from the list
    max_eps = eps_list[0]
    #
    if original_indices is None: original_indices = np.arange(cloud.shape[0])
    if isinstance(original_indices, list): original_indices = np.array(original_indices)
    # spatial segmentation
    dbscan = DBSCAN(max_eps, min_samples=1).fit(cloud[original_indices,:])
    labels = dbscan.labels_
    # evaluate every segment
    indices, scores = [], []
    for unique_label in np.unique(labels):
        inds = original_indices[np.flatnonzero(labels == unique_label)]
        indices.append(inds.tolist())
        scores.append(evaluate(inds))
    # return if we are done
    if len(eps_list) == 1: return indices, scores
    # expand recursively
    final_indices, final_scores = [], []
    for i, (inds, score) in enumerate(zip(indices, scores)):
        # focus on this segment
        fine_indices, fine_scores = segment(id_, eps_list[1:], cloud, inds)
        # flatten scores to get the minimum (keep structure)
        flat_fine_scores = flatten_scores(fine_scores)
        if aggr_func == 'min':
            aggr_score = np.min(flat_fine_scores)
        elif aggr_func == 'avg':
            aggr_score = np.mean(flat_fine_scores)
        elif aggr_func == 'sum':
            aggr_score = np.sum(flat_fine_scores)
        elif aggr_func == 'wavg':
            # compute a weighted average (each score is weighted by the number of points)
            flat_fine_indices = flatten_indices(fine_indices)
            sum_count, sum_score = 0, 0.0
            for indices, score in zip(flat_fine_indices, flat_fine_scores):
                sum_count += len(indices)
                sum_score += len(indices)*score
            aggr_score = float(sum_score)/sum_count
        elif aggr_func == 'd2wavg':
            # compute a weighted average (each score is weighted by the number of points)
            flat_fine_indices = flatten_indices(fine_indices)
            sum_count, sum_score = 0, 0.0
            for indices, score in zip(flat_fine_indices, flat_fine_scores):
                squared_dists = np.sum(cloud[inds,:]**2, axis=1)
                sum_count += np.sum(squared_dists)
                sum_score += np.sum(squared_dists * score)
            aggr_score = float(sum_score)/sum_count

        # COMMENTING THIS OUT BECAUSE OF ADDING SUM AS AN AGGR FUNC
        # assert(aggr_score <= 1 and aggr_score >= 0)

        # if splitting is better
        if score < aggr_score:
            final_indices.append(fine_indices)
            final_scores.append(fine_scores)
        else: # otherwise
            final_indices.append(inds)
            final_scores.append(score)
    return final_indices, final_scores


def vis_instance_o3d():
    # visualization
    pcd_objects = o3d.geometry.PointCloud()
    colors = np.zeros((len(pts_velo_cs_objects), 4))
    max_instance = len(flat_indices)
    print(f"point cloud has {max_instance + 1} clusters")
    colors_instance = plt.get_cmap("tab20")(np.arange(len(flat_indices)) / (max_instance if max_instance > 0 else 1))

    for idx in range(len(flat_indices)):
        colors[flat_indices[idx]] = colors_instance[idx]

    pcd_objects.points = o3d.utility.Vector3dVector(pts_velo_cs_objects[:, :3])
    pcd_objects.colors = o3d.utility.Vector3dVector(colors[:, :3])

    pcd_background = o3d.geometry.PointCloud()
    pcd_background.points = o3d.utility.Vector3dVector(pts_velo_cs[background_mask, :3])
    pcd_background.paint_uniform_color([0.5, 0.5, 0.5])

#     o3d.visualization.draw_geometries([pcd_objects, pcd_background])
    visualizer = JVisualizer()
    visualizer.add_geometry(pcd_objects) # Ani: adds the colourful points (each color is a new segmented instance)
    visualizer.add_geometry(pcd_background) # Ani: adds the gray bg points (technically, these are not bg but actually known classes)
    visualizer.show()

    
def vis_4dpls():
    fg = o3d.geometry.PointCloud()
    colors = np.zeros((len(pts_velo_cs_objects), 4))
    max_cls = unk_label + 1 # [0, 11] for TS1
    colors_cls = plt.get_cmap("tab20")(np.arange(max_cls) / (max_cls if max_cls > 0 else 1))

    colors = colors_cls[labels_objects]

    visualizer = JVisualizer()
    fg.points = o3d.utility.Vector3dVector(pts_velo_cs_objects[:, :3])
    fg.colors = o3d.utility.Vector3dVector(colors[:, :3])
    visualizer.add_geometry(fg)

    bg = o3d.geometry.PointCloud()
    bg.points = o3d.utility.Vector3dVector(pts_velo_cs[background_mask, :3])
    bg.paint_uniform_color([0.5, 0.5, 0.5])
    visualizer.add_geometry(bg)
    visualizer.show()


def load_config(file, task_set):
    with open(file, 'r') as stream:
        doc = yaml.safe_load(stream)
#         all_labels = doc['task_set_map'][task_set]['labels']
#         learning_map_inv = doc['task_set_map'][task_set]['learning_map_inv']
        learning_map = doc['task_set_map'][task_set]['learning_map']
    learning_map_arr = np.zeros((np.max([k for k in learning_map.keys()]) + 1), dtype=np.int32)
    for k, v in learning_map.items():
        learning_map_arr[k] = v
    return learning_map_arr



### Visualize individual images

In [5]:
seq = '08'
# seq_prefix = '08_0'

config_file = '/project_data/ramanan/achakrav/4D-PLS/data/SemanticKitti/semantic-kitti.yaml'

task_set = 1
# write_dir = '/project_data/ramanan/achakrav/hu-segmentation/kitti_raw_ts1_segmented/'
write_dir = '/project_data/ramanan/achakrav/hu-segmentation/ts{}_trial/'.format(task_set)
if not os.path.exists(write_dir):
    os.makedirs(write_dir)

# scan_folder = '/media/data/dataset/kitti-odometry/dataset/sequences/' + seq + '/velodyne'
scan_folder = '/project_data/ramanan/achakrav/4D-PLS/data/SemanticKitti/sequences/' + seq + '/velodyne/'
scan_files = load_paths(scan_folder)

# objectness_folder = '/media/data/tmp/testsetobj'
# objectness_files_raw = load_paths(objectness_folder)
# objectness_files = [path for path in objectness_files_raw if seq_prefix in path]

# semantic_folder = '/media/data/tmp/testsetsem'
# semantic_files_raw = load_paths(semantic_folder)
# semantic_files = [path for path in semantic_files_raw if seq_prefix in path]

# objsem_folder = 'xieyuanlichen/tmp/seq08objsem'
objsem_folder = '/project_data/ramanan/achakrav/4D-PLS/val_preds_TS{}/val_preds/'.format(task_set)
objsem_files = load_paths(objsem_folder)

label_folder = '/project_data/ramanan/achakrav/4D-PLS/data/SemanticKitti/sequences/' + seq + '/labels/'
label_files = load_paths(label_folder)
label_files = [x for x in label_files if '.label' in x]

# for task set 1
if task_set == 1:
    class_strings = ["car", "truck", "person", "road", "sidewalk", "building", "fence", "vegetation", "terrain", "pole", "unknown"]
else:
    assert False

sem_file_mask = []
obj_file_mask = []
for idx, file in enumerate(objsem_files):
    if '_c' in file:
        obj_file_mask.append(idx)
    elif '_u' not in file and '_i' not in file and '_e' not in file and '_pots' not in file:
        sem_file_mask.append(idx)

objectness_files = objsem_files[obj_file_mask]
semantic_files = objsem_files[sem_file_mask]

# for scan_file in scan_files:
#     base_file = os.path.basename(scan_file)
#     idx_num = base_file.split('.')[0]
#     sem_file = '08_0{}.npy'.format(idx_num)
#     sem_file = os.path.join(objsem_folder, sem_file)
#     if sem_file not in objsem_files:
#         pdb.set_trace()


assert (len(semantic_files) == len(objectness_files))
assert (len(semantic_files) == len(scan_files))

learning_map = load_config(config_file, task_set)

for idx in tqdm(range(len(objectness_files))):
    # load scan
    # frame_idx = int(os.path.basename(semantic_files[idx]).replace('.npy', '').replace('08_', ''))
    # scan_file = scan_files[frame_idx]
    scan_file = scan_files[idx]
    pts_velo_cs = load_vertex(scan_file)
    pts_indexes = np.arange(len(pts_velo_cs))

    # load objectness
    objectness_file = objectness_files[idx]
    objectness = np.load(objectness_file)

    # labels
    label_file = semantic_files[idx]
    labels = np.load(label_file)

    unk_label = 11 # for task set 1

    # ==========================================================================================
    # Ani: plot predictions from 4DPLS 
    # ==========================================================================================
    mask = labels != unk_label # Note: opposite to other plot since we plot known classes only
    background_mask = labels == unk_label
    
    pts_velo_cs_objects = pts_velo_cs[mask]
    objectness_objects = objectness[mask]
    pts_indexes_objects = pts_indexes[mask]
    labels_objects = labels[mask]

    # visualize predictions
#     vis_4dpls()
#     break
    # ==========================================================================================


    # ==========================================================================================
    # Ani: plot GT
    # ==========================================================================================
    gt_label_file = label_files[idx]
    gt_label = np.fromfile(gt_label_file, dtype=np.int32)
    gt_label = gt_label & 0xFFFF
    gt_label = learning_map[gt_label]
    
    # to plot clustering segmentation output
    mask = gt_label != unk_label
    background_mask = gt_label == unk_label

    pts_velo_cs_objects = pts_velo_cs[mask]
    objectness_objects = objectness[mask]
    pts_indexes_objects = pts_indexes[mask]
    labels_objects = gt_label[mask]
    
    # visualize GT
#     vis_4dpls()
#     break
    # ==========================================================================================
    mask = labels == unk_label
    background_mask = labels != unk_label
    
    pts_velo_cs_objects = pts_velo_cs[mask]
    objectness_objects = objectness[mask]
    pts_indexes_objects = pts_indexes[mask]
    labels_objects = labels[mask]


    # debug = o3d.geometry.PointCloud()
    # debug.points = o3d.utility.Vector3dVector(pts_velo_cs_objects[:, :3])
    # o3d.visualization.draw_geometries([debug])

    assert (len(pts_velo_cs_objects) == len(objectness_objects))

    if len(pts_velo_cs_objects) < 1:
#         np.savez_compressed(os.path.join(write_dir, seq + '_' + str(idx).zfill(6)),
#                             instances=[], segment_scores=[])
        continue

    # segmentation with point-net
    id_ = 0
    # eps_list = [2.0, 1.0, 0.5, 0.25]
    eps_list_tum = [1.2488, 0.8136, 0.6952, 0.594, 0.4353, 0.3221]
    indices, scores = segment(id_, eps_list_tum, pts_velo_cs_objects[:, :3])

    # flatten list(list(...(indices))) into list(indices)
    flat_indices = flatten_indices(indices)
    # map from object_indexes to pts_indexes
    mapped_indices = []
    for indexes in flat_indices:
        mapped_indices.append(pts_indexes_objects[indexes].tolist())

    # mapped_flat_indices = pts_indexes_objects
    flat_scores = flatten_scores(scores)

    # visualizer
    vis_instance_o3d()
    break

#     # save results
#     np.savez_compressed(os.path.join(write_dir, seq + '_'+str(idx).zfill(6)),
#                         instances=mapped_indices, segment_scores=flat_scores, allow_pickle = True)

  0%|          | 0/4071 [00:00<?, ?it/s]

point cloud has 66 clusters


JVisualizer with 2 geometries

  0%|          | 0/4071 [00:02<?, ?it/s]


### Visualize video

In [None]:
seq = '08'
# seq_prefix = '08_0'

config_file = '/project_data/ramanan/achakrav/4D-PLS/data/SemanticKitti/semantic-kitti.yaml'

task_set = 1
# write_dir = '/project_data/ramanan/achakrav/hu-segmentation/kitti_raw_ts1_segmented/'
write_dir = '/project_data/ramanan/achakrav/hu-segmentation/ts{}_trial/'.format(task_set)
if not os.path.exists(write_dir):
    os.makedirs(write_dir)

# scan_folder = '/media/data/dataset/kitti-odometry/dataset/sequences/' + seq + '/velodyne'
scan_folder = '/project_data/ramanan/achakrav/4D-PLS/data/SemanticKitti/sequences/' + seq + '/velodyne/'
scan_files = load_paths(scan_folder)

# objsem_folder = 'xieyuanlichen/tmp/seq08objsem'
objsem_folder = '/project_data/ramanan/achakrav/4D-PLS/val_preds_TS{}/val_preds/'.format(task_set)
objsem_files = load_paths(objsem_folder)

label_folder = '/project_data/ramanan/achakrav/4D-PLS/data/SemanticKitti/sequences/' + seq + '/labels/'
label_files = load_paths(label_folder)
label_files = [x for x in label_files if '.label' in x]

# for task set 1
if task_set == 1:
    class_strings = ["car", "truck", "person", "road", "sidewalk", "building", "fence", "vegetation", "terrain", "pole", "unknown"]
else:
    assert False

sem_file_mask = []
obj_file_mask = []
for idx, file in enumerate(objsem_files):
    if '_c' in file:
        obj_file_mask.append(idx)
    elif '_c' not in file and '_i' not in file and '_e' not in file and '_pots' not in file:
        sem_file_mask.append(idx)

objectness_files = objsem_files[obj_file_mask]
semantic_files = objsem_files[sem_file_mask]

assert (len(semantic_files) == len(objectness_files))
assert (len(semantic_files) == len(scan_files))

learning_map = load_config(config_file, task_set)

vis = JVisualizer()
# vis.create_window()
geometry = o3d.geometry.PointCloud()
vis.add_geometry(geometry)


for idx in tqdm(range(len(objectness_files))):
    # load scan
    scan_file = scan_files[idx]
    pts_velo_cs = load_vertex(scan_file)
    pts_indexes = np.arange(len(pts_velo_cs))

    # load objectness
    objectness_file = objectness_files[idx]
    objectness = np.load(objectness_file)

    # labels
    label_file = semantic_files[idx]
    labels = np.load(label_file)

    unk_label = 11 # for task set 1

    # ==========================================================================================
    # Ani: plot predictions from 4DPLS 
    # ==========================================================================================
    mask = labels != unk_label # Note: opposite to other plot since we plot known classes only
    background_mask = labels == unk_label
    
    pts_velo_cs_objects = pts_velo_cs[mask]
    objectness_objects = objectness[mask]
    pts_indexes_objects = pts_indexes[mask]
    labels_objects = labels[mask]
    
    # copied function from above
    fg = o3d.geometry.PointCloud()
    colors = np.zeros((len(pts_velo_cs_objects), 4))
    max_cls = unk_label + 1 # [0, 11] for TS1
    colors_cls = plt.get_cmap("tab20")(np.arange(max_cls) / (max_cls if max_cls > 0 else 1))

    colors = colors_cls[labels_objects]

    # visualizer = JVisualizer()
    visualizer = o3d.visualization.Visualizer()
#     visualizer.create_window(visible=False)
#     visualizer.get_render_option().point_color_option = o3d.visualization.PointColorOption.Color
#     visualizer.get_render_option().point_size = 3.0
    fg.points = o3d.utility.Vector3dVector(pts_velo_cs_objects[:, :3])
    fg.colors = o3d.utility.Vector3dVector(colors[:, :3])
    visualizer.add_geometry(fg)

    bg = o3d.geometry.PointCloud()
    bg.points = o3d.utility.Vector3dVector(pts_velo_cs[background_mask, :3])
    bg.paint_uniform_color([0.5, 0.5, 0.5])
    visualizer.add_geometry(bg)
    
#     ctr = vis.get_view_control()
    visualizer.capture_screen_image("file.jpg", do_render=True)
    visualizer.destroy_window()
#     visualizer.show()
    break

    # visualize predictions
#     vis_4dpls()
#     break
    # ==========================================================================================


    # ==========================================================================================
    # Ani: plot GT
    # ==========================================================================================
#     gt_label_file = label_files[idx]
#     gt_label = np.fromfile(gt_label_file, dtype=np.int32)
#     gt_label = gt_label & 0xFFFF
#     gt_label = learning_map[gt_label]
    
#     # to plot clustering segmentation output
#     mask = gt_label != unk_label
#     background_mask = gt_label == unk_label

#     pts_velo_cs_objects = pts_velo_cs[mask]
#     objectness_objects = objectness[mask]
#     pts_indexes_objects = pts_indexes[mask]
#     labels_objects = gt_label[mask]
    
#     # visualize GT
#     vis_4dpls()
#     break
    # ==========================================================================================

    # debug = o3d.geometry.PointCloud()
    # debug.points = o3d.utility.Vector3dVector(pts_velo_cs_objects[:, :3])
    # o3d.visualization.draw_geometries([debug])

#     assert (len(pts_velo_cs_objects) == len(objectness_objects))

#     if len(pts_velo_cs_objects) < 1:
# #         np.savez_compressed(os.path.join(write_dir, seq + '_' + str(idx).zfill(6)),
# #                             instances=[], segment_scores=[])
#         continue

#     # segmentation with point-net
#     id_ = 0
#     # eps_list = [2.0, 1.0, 0.5, 0.25]
#     eps_list_tum = [1.2488, 0.8136, 0.6952, 0.594, 0.4353, 0.3221]
#     indices, scores = segment(id_, eps_list_tum, pts_velo_cs_objects[:, :3])

#     # flatten list(list(...(indices))) into list(indices)
#     flat_indices = flatten_indices(indices)
#     # map from object_indexes to pts_indexes
#     mapped_indices = []
#     for indexes in flat_indices:
#         mapped_indices.append(pts_indexes_objects[indexes].tolist())

#     # mapped_flat_indices = pts_indexes_objects
#     flat_scores = flatten_scores(scores)

#     # visualizer
#     vis_instance_o3d()
#     break

# #     # save results
# #     np.savez_compressed(os.path.join(write_dir, seq + '_'+str(idx).zfill(6)),
# #                         instances=mapped_indices, segment_scores=flat_scores, allow_pickle = True)

  0%|          | 0/4071 [00:00<?, ?it/s]