In [None]:
from pathlib import Path
import json

import torch
import numpy as np
import open3d as o3d
from scipy.spatial import cKDTree

from pointcept.models import build_model
from pointcept.utils.visualization import save_point_cloud

# Define model

In [None]:
cfg_model = dict(
    type="DefaultSegmentor",
    backbone=dict(
        type="PT-v2m2",
        in_channels=4,
        num_classes=8,
        patch_embed_depth=1,
        patch_embed_channels=48,
        patch_embed_groups=6,
        patch_embed_neighbours=8,
        enc_depths=(2, 2, 6, 2),
        enc_channels=(96, 192, 384, 512),
        enc_groups=(12, 24, 48, 64),
        enc_neighbours=(16, 16, 16, 16),
        dec_depths=(1, 1, 1, 1),
        dec_channels=(48, 96, 192, 384),
        dec_groups=(6, 12, 24, 48),
        dec_neighbours=(16, 16, 16, 16),
        grid_sizes=(0.15, 0.375, 0.9375, 2.34375),  # x3, x2.5, x2.5, x2.5
        attn_qkv_bias=True,
        pe_multiplier=False,
        pe_bias=True,
        attn_drop_rate=0.0,
        drop_path_rate=0.3,
        enable_checkpoint=False,
        unpool_backend="map",  # map / interp
    ),
    criteria=[
        dict(type="CrossEntropyLoss", loss_weight=1.0, ignore_index=-1),
        dict(type="LovaszLoss", mode="multiclass", loss_weight=1.0, ignore_index=-1),
    ],
)

def get_learning_map(ignore_index):
    learning_map = {
        ignore_index: ignore_index,
        0: 2,  # "Two-wheel Vehicle",
        1: 3,  # "Pedestrian"
        3: 0,  # "Car"
        8: 1,  # "Truck/Bus"
        10: 7,  # "Traffic Light"
        12: 6,  # "Traffic Sign"
        40: 4,  # "Road"
        48: 5,  # "Sidewalk"
    }
    return learning_map
ignore_index = -1
learning_map = get_learning_map(ignore_index)

# Build a model

In [None]:
model = build_model(cfg_model).cuda().eval()

# Load all weights from a checkpoint

In [None]:
checkpoint_dir = "./exp/nia_cycle1/semseg-pt-v2m2-1-ft-split-val-yh/model/model_last.pth"

In [None]:
checkpoint = torch.load(checkpoint_dir)
state_dict = checkpoint["state_dict"]

## state_dict의 key에 "module."이 추가되어 있는 경우

In [None]:
new_state_dict = {}
for name, value in state_dict.items():
    if name.startswith("module."):
        name = name[7:]
    new_state_dict[name] = value
state_dict = new_state_dict

In [None]:
model.load_state_dict(state_dict, strict=True)

# Load point cloud data

In [None]:
data_root = Path("/datasets/nia")
lidar_dataset_path_to_predict = data_root / "collections/230822/230822_162106_K/"

In [None]:
def read_json(json_path):
    with open(json_path) as f:
        label = json.load(f)
    return label


def get_label_path_from_lidar_path(lidar_path):
    collections_path = Path(data_root) / 'collections'
    annotations_path = Path(data_root) / 'annotations'
    return annotations_path / lidar_path.relative_to(collections_path).with_suffix('.json')


def get_data_list(lidar_dataset_path):
    collection_path = Path(lidar_dataset_path)
    lidar_paths = sorted(collection_path.rglob('*.pcd'))
    label_paths = list(map(get_label_path_from_lidar_path, lidar_paths))

    return list(zip(
        map(str, lidar_paths),
        map(str, label_paths),
    ))

data_list = get_data_list(lidar_dataset_path_to_predict)

In [None]:
def find_common_points(cloud1, cloud2):
    # Create a cKDTree object for each point cloud
    tree1 = cKDTree(cloud1)
    tree2 = cKDTree(cloud2)
    
    # Fast search indices of common points of two point clouds
    # ex) indices: [[tree2_index_i], [tree2_index_j], [], ...]
    # len(indices) == len(tree1)
    common_point_indices = tree1.query_ball_tree(tree2, r=0)
    
    # overlapping_indices = []
    # for i, indices in enumerate(common_point_indices):
    #     for j in indices:
    #         overlapping_indices.append((i, j))
    
    return common_point_indices


def create_segment_array(label_path, cloud_points):
    label = read_json(label_path)

    segment = np.ones((cloud_points.shape[0],), dtype=int) * ignore_index

    annotations = label['annotations']
    for annotation in annotations:
        instance_points = annotation['3D_points']
        instance_class_id = annotation['class_id']

        indices = find_common_points(instance_points, cloud_points)
        indices = np.array(indices, dtype=int).reshape(-1)
        segment[indices] = instance_class_id
    
    return segment

In [None]:
def get_data(idx):
    lidar_path, label_path = data_list[idx % len(data_list)]

    # Read lidar
    cloud = o3d.t.io.read_point_cloud(lidar_path)
    coord = cloud.point['positions'].numpy()
    strength = cloud.point['reflectivity'].numpy() / 255

    # Read label
    segment = create_segment_array(label_path, coord)
    segment = np.vectorize(learning_map.__getitem__)(segment).astype(
        np.int64
    )

    data_dict = dict(coord=coord, strength=strength, segment=segment)
    return data_dict

# Inference

In [None]:
sample = get_data(0)

## Transform data before feeding

In [None]:
def transform(data_dict):
    data = dict()

    # Transform: ToTensor, Collect
    # ToTensor
    data["coord"] = torch.from_numpy(data_dict["coord"]).float().cuda()
    data["strength"] = torch.from_numpy(data_dict["strength"]).float().cuda()
    # data["segment"] = torch.from_numpy(data_dict["segment"]).long().cuda()

    # Collect
    data["coord"] = data["coord"]
    # data["segment"] = data["segment"]
    data["offset"] = torch.tensor([data["coord"].shape[0]]).cuda()
    data["feat"] = torch.cat([data[key].float() for key in ["coord", "strength"]], dim=1).cuda()

    return data

In [None]:
data_dict = transform(sample)

## Forward model

In [None]:
with torch.no_grad():
    output_dict = model(data_dict)
    logits = output_dict['seg_logits']

    pred_ids = torch.argmax(logits, dim=1).cpu().numpy()
    pred_probs = torch.softmax(logits, dim=1).cpu().numpy()

# Define color map to visualize

In [None]:
color_map = {
    ignore_index: [255, 255, 255],  # Silver
    2: [0, 255, 0],  # "Two-wheel Vehicle"  # Lime
    3: [255, 255, 0],  # "Pedestrian"  # Yellow
    0: [0, 255, 255],  # "Car"  # Cyan
    1: [255, 0, 0],  # "Truck/Bus"  # Red
    7: [0, 0, 255],  # "Traffic Light"  # Blue
    6: [0, 128, 128],  # "Traffic Sign"  # Teal
    4: [128, 128, 128],  # "Road"  # Gray
    5: [255, 0, 255],  # "Sidewalk"  # Magenta
}
# ignore_index = -1 일때만 작동
color_map_lut_gt = np.array([color_map[i - 1] for i in range(len(color_map))]) / 255
color_map_lut_pd = np.array([color_map[i] for i in range(len(color_map) - 1)]) / 255

In [None]:
coord = sample['coord']
gt_color = color_map_lut_gt[sample['segment'] + 1]
pd_color = color_map_lut_pd[pred_ids]

In [None]:
gt_pcd = o3d.geometry.PointCloud()
gt_pcd.points = o3d.utility.Vector3dVector(coord)
gt_pcd.colors = o3d.utility.Vector3dVector(gt_color)

In [None]:
pd_pcd = o3d.geometry.PointCloud()
pd_pcd.points = o3d.utility.Vector3dVector(coord)
pd_pcd.colors = o3d.utility.Vector3dVector(pd_color)

# Visualize

### 추론 데이터 시각화는 Jupyter 보다는 Visual Studio Code의 extension이 좋다.

### Groundtruth

In [None]:
o3d.visualization.draw_plotly([gt_pcd])

### Prediction

In [None]:
o3d.visualization.draw_plotly([pd_pcd])

# Infer all samples in the dataset

In [None]:
len(data_list)

In [None]:
all_pred_ids = []
all_pred_probs = []
for i in range(len(data_list)):
    data_dict = get_data(i)
    data_dict = transform(data_dict)

    with torch.no_grad():
        output_dict = model(data_dict)
        logits = output_dict['seg_logits']

        pred_ids = torch.argmax(logits, dim=1).cpu().numpy()
        pred_probs = torch.softmax(logits, dim=1).cpu().numpy()
    
    all_pred_ids.append(pred_ids)
    all_pred_probs.append(pred_probs)

# Save groundtruths and predictions in pcd format

In [None]:
for i in range(len(data_list)):
    sample = get_data(i)
    coord = sample['coord']
    gt_color = color_map_lut_gt[sample['segment'] + 1]
    pd_color = color_map_lut_pd[all_pred_ids[i]]
    save_point_cloud(coord, gt_color, file_path=f'./exp/demo/groundtruth/gt{i:03d}.pcd')
    save_point_cloud(coord, pd_color, file_path=f'./exp/demo/prediction/pd{i:03d}.pcd')

## Filter unlabeled groundtruths/low confidence predictions

### Save only labeled groundtruths

In [None]:
for i in range(len(data_list)):
    sample = get_data(i)
    coord = sample['coord']
    gt_color = color_map_lut_gt[sample['segment'] + 1]
    
    labeled_mask = sample['segment'] != -1
    labeled_coord = coord[labeled_mask]
    labeled_gt_color = gt_color[labeled_mask]
    
    save_point_cloud(labeled_coord, labeled_gt_color, file_path=f'./exp/demo/labeled_groundtruth/lgt{i:03d}.pcd')

### Save only confident predictions

In [None]:
for i in range(len(data_list)):
    sample = get_data(i)
    coord = sample['coord']
    pd_color = color_map_lut_pd[all_pred_ids[i]]

    confidence = np.max(all_pred_probs[i], axis=1)
    confident_mask = confidence > 0.9
    confident_coord = coord[confident_mask]
    confident_pd_color = pd_color[confident_mask]
    
    save_point_cloud(confident_coord, confident_pd_color, file_path=f'./exp/demo/confident_prediction_90/cpd{i:03d}.pcd')

### Save predictions for labeled groundtruths only

In [None]:
for i in range(len(data_list)):
    sample = get_data(i)
    coord = sample['coord']
    pd_color = color_map_lut_pd[all_pred_ids[i]]
    
    labeled_mask = sample['segment'] != -1
    labeled_coord = coord[labeled_mask]
    labeled_pd_color = pd_color[labeled_mask]
    
    save_point_cloud(labeled_coord, labeled_pd_color, file_path=f'./exp/demo/labeled_prediction/lpd{i:03d}.pcd')