## Imports

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import numpy as np
from matplotlib import pyplot as plt
from argparse import ArgumentParser, Namespace

from arguments import ModelParams, PipelineParams
from scene import Scene, GaussianModel, FeatureGaussianModel
from gaussian_renderer import render, render_contrastive_feature

def get_combined_args(parser : ArgumentParser, model_path, target_cfg_file = None):
    cmdlne_string = ['--model_path', model_path]
    cfgfile_string = "Namespace()"
    args_cmdline = parser.parse_args(cmdlne_string)
    
    if target_cfg_file is None:
        if args_cmdline.target == 'seg':
            target_cfg_file = "seg_cfg_args"
        elif args_cmdline.target == 'scene' or args_cmdline.target == 'xyz':
            target_cfg_file = "cfg_args"
        elif args_cmdline.target == 'feature' or args_cmdline.target == 'coarse_seg_everything' or args_cmdline.target == 'contrastive_feature' :
            target_cfg_file = "feature_cfg_args"

    try:
        cfgfilepath = os.path.join(model_path, target_cfg_file)
        print("Looking for config file in", cfgfilepath)
        with open(cfgfilepath) as cfg_file:
            print("Config file found: {}".format(cfgfilepath))
            cfgfile_string = cfg_file.read()
    except TypeError:
        print("Config file found: {}".format(cfgfilepath))
        pass
    args_cfgfile = eval(cfgfile_string)

    merged_dict = vars(args_cfgfile).copy()
    for k,v in vars(args_cmdline).items():
        if v != None:
            merged_dict[k] = v

    return Namespace(**merged_dict)

## Hyper-parameters

In [None]:
import os
FEATURE_DIM = 32 # fixed

# MODEL_PATH = './output/lerf-fruit_aisle/'
MODEL_PATH = '.../results/splatting_models/room1_final' # 30000

FEATURE_GAUSSIAN_ITERATION = 10000

SCALE_GATE_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/scale_gate.pt')

FEATURE_PCD_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/contrastive_feature_point_cloud.ply')
SCENE_PCD_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/scene_point_cloud.ply')

## Data and Model Preparation


In [None]:
scale_gate = torch.nn.Sequential(
    torch.nn.Linear(1, 32, bias=True),
    torch.nn.Sigmoid()
)

scale_gate.load_state_dict(torch.load(SCALE_GATE_PATH))
scale_gate = scale_gate.cuda()

parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument('--target', default='scene', type=str)

args = get_combined_args(parser, MODEL_PATH)

dataset = model.extract(args)

# If use language-driven segmentation, load clip feature and original masks
dataset.need_features = True

# To obtain mask scales
dataset.need_masks = True

scene_gaussians = GaussianModel(dataset.sh_degree)

feature_gaussians = FeatureGaussianModel(FEATURE_DIM)
scene = Scene(dataset, scene_gaussians, feature_gaussians, load_iteration=-1, feature_load_iteration=FEATURE_GAUSSIAN_ITERATION, shuffle=False, mode='eval', target='contrastive_feature')


In [None]:
all_scales = []
for cam in scene.getTrainCameras():
    all_scales.append(cam.mask_scales)
all_scales = torch.cat(all_scales)

upper_bound_scale = all_scales.max().item()

In [None]:
from copy import deepcopy
cameras = scene.getTrainCameras()
print("There are",len(cameras),"views in the dataset.")
print(upper_bound_scale)

In [None]:
ref_img_camera_id = 0
mask_img_camera_id = 0

view = deepcopy(cameras[ref_img_camera_id])

view.feature_height, view.feature_width = view.image_height, view.image_width
img = view.original_image * 255
img = img.permute([1,2,0]).detach().cpu().numpy().astype(np.uint8)

bg_color = [0 for i in range(FEATURE_DIM)]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
rendered_feature = render_contrastive_feature(view, feature_gaussians, pipeline.extract(args), background, norm_point_features=True, smooth_type = None)['render']
feature_h, feature_w = rendered_feature.shape[-2:]


plt.imshow(img)

In [None]:
with torch.no_grad():
    scale = torch.tensor([1.]).cuda()
    gates = scale_gate(scale)
    feature_with_scale = rendered_feature
    feature_with_scale = feature_with_scale * gates.unsqueeze(-1).unsqueeze(-1)
    scale_conditioned_feature = feature_with_scale.permute([1,2,0])

In [None]:
def preprocess_point_features(feature_gaussians):
    point_features = feature_gaussians.get_point_features
    scale_conditioned_point_features = torch.nn.functional.normalize(point_features, dim = -1, p = 2) * gates.unsqueeze(0)
    normed_point_features = torch.nn.functional.normalize(scale_conditioned_point_features, dim = -1, p = 2)
    sampled_point_features = scale_conditioned_point_features[torch.rand(scale_conditioned_point_features.shape[0]) > 0.98]
    normed_sampled_point_features = sampled_point_features / torch.norm(sampled_point_features, dim = -1, keepdim = True)

    return normed_point_features, normed_sampled_point_features

normed_point_features, normed_sampled_point_features = preprocess_point_features(feature_gaussians)

-------------------------------------------

## "Our" Code starts here

## Perform 3D Clustering

For each of the 200 views we have, get the image with resulting clusters from 3D clusters

In [None]:
import torch
import numpy as np
import hdbscan

# clustering
clusterer = hdbscan.HDBSCAN(min_cluster_size=50, cluster_selection_epsilon=0.1)
cluster_labels = clusterer.fit_predict(normed_sampled_point_features.detach().cpu().numpy())
cluster_centers = torch.zeros(len(np.unique(cluster_labels)) - 1, normed_sampled_point_features.shape[-1])

for i in range(1, len(np.unique(cluster_labels))):
    cluster_centers[i - 1] = torch.nn.functional.normalize(normed_sampled_point_features[cluster_labels == i - 1].mean(dim=0), dim=-1)

# segmenting with all labels
seg_score = torch.einsum('nc,bc->bn', cluster_centers.cpu(), normed_point_features.cpu())

#TODO: Cann we use the same color for the same cluster every time? OpenNerf has the SCANNET Colorlist, maybe this is a start? Lets get rid of the randomness
np.random.seed(12)
label_to_color = np.random.rand(100, 3)

point_colors = label_to_color[seg_score.argmax(dim=-1).cpu().numpy()] #
point_colors[seg_score.max(dim=-1)[0].detach().cpu().numpy() < 0.5] = (0, 0, 0)

try:
    scene_gaussians.roll_back()
except:
    pass

print(len(np.unique(cluster_labels)))
print(np.unique(cluster_labels))

------------------------------------------------------------------------------------------------

### Extract single cluster

In [None]:
def extract_cluster(cluster_label, point_colors, seg_score):
    # Create a copy of point_colors
    cluster_colors = np.copy(point_colors)
    
    # Identify points belonging to the specified cluster
    cluster_indices = seg_score.argmax(dim=-1).cpu().numpy() == cluster_label
    
    # Replace points not in the specified cluster with black
    cluster_colors[~cluster_indices] = (0, 0, 0)
    
    return cluster_colors

Visualize a single cluster in this image

In [None]:

for i in range(len(np.unique(cluster_labels))):    
    new_points = extract_cluster(i, point_colors, seg_score)
    #print(new_points.shape)

    #TODO: What exactly is this background color?
    bg_color = [0 for i in range(FEATURE_DIM)]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

    rendered_seg_map = render(cameras[0], scene_gaussians, pipeline.extract(args), background, override_color=torch.from_numpy(new_points).cuda().float())['render']
    #print(rendered_seg_map.shape)
    img_seg = rendered_seg_map.permute([1,2,0]).detach().cpu().numpy()
    #print(img_seg.shape)

    #TODO: Check if the criterion that the img_seg > 0.3 is a valid measure to see if a cluster is in this picture
    mask = img_seg > 0.1
    if mask.ndim > 2:
        mask = mask.any(axis=-1)
    segmented_image = np.zeros_like(img)
    segmented_image[mask] = img[mask]

    images_sum = img_seg.sum(axis=2)
    #TODO: This 1 is a magic number atm (we check if the sum over all three channels is over 1), we should find a better way to determine the threshold
    masks = torch.where(torch.tensor(images_sum > 0.5), 1.0, 0.)

    # Plot images
    plt.subplot(1,3,1)
    plt.imshow(mask)

    plt.subplot(1,3,2)
    plt.imshow(masks)

    plt.subplot(1,3,3)
    plt.imshow(segmented_image)
    plt.show()

## Move individual masks for every cluster into one array  
Desired Size: [37, 679, 1199]  
This can then be matched to the output of OV-SEG, because for the reduced replica classes this output is [51, 679, 1199]

Everything over 0.5 is left in the mask, everything else is thrown out

In [None]:
cluster_ids = np.unique(cluster_labels)
cluster_masks = torch.empty([len(cluster_ids), img.shape[0], img.shape[1]])

for id in cluster_ids:
    new_points = extract_cluster(id, point_colors, seg_score)
    bg_color = [0 for i in range(FEATURE_DIM)]
    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
    rendered_seg_map = render(cameras[0], scene_gaussians, pipeline.extract(args), background, override_color=torch.from_numpy(new_points).cuda().float())['render']
    img_seg = rendered_seg_map.permute([1,2,0]).detach().cpu()
    # mask = img_seg > 0.1
    # if mask.ndim > 2:
    #     mask = mask.any(axis=-1)
    #TODO: This 1 is a magic number atm (we check if the sum over all three channels is over 1), we should find a better way to determine the threshold
    images_sum = img_seg.sum(axis=2)
    mask = torch.where(images_sum > 0.5, 1.0, 0.)
    cluster_masks[id] = mask

plt.imshow(cluster_masks[20].cpu().numpy())
cluster_masks.shape

This is now the constructed tensor for a single image. It contains the masks for every found cluster. This is exported to pickle for further use in the OV-SEG Notebook

In [None]:
import pickle

with open('.../segment_3d_gaussians/segmentation_res/cluster_masks.pickle', 'wb') as handle:
    pickle.dump(cluster_masks, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Compare and match clusters to OVSEG

Match the SAGA Clusters to the OVSEG Predictions
=> See other notebook in ov-seg/test.ipynb


Run all the code in the other notebook and copy the resulting cluster_labels (named different here because this variable already exists)

These semantics are copied AFTER running the test.ipynb of OVSeg, and were pasted here to see if the matching makes sense, if this notebook is restarted it might not make sense anymore due to reshuffeling of the clusters (We have to understand the randomness in SAGA and ideally get it to be reproducible for our own iterations)

In [None]:
cluster_semantics = ['bin',
 'bed',
 'floor',
 'chair',
 'blinds',
 'vase',
 'sofa',
 'table',
 'picture',
 'rug',
 'window',
 'lamp',
 'wall-plug',
 'tv-stand',
 'pillow',
 'bench',
 'pot',
 'tv-screen',
 'cabinet',
 'pillar',
 'blanket',
 'door',
 'cushion',
 'basket',
 'indoor-plant',
 'vent',
 'shelf',
 'stool',
 'desk',
 'comforter',
 'nightstand',
 'plant-stand',
 'ceiling',
 'plate',
 'monitor',
 'pipe',
 'wall',
 'panel']

In [None]:
import pickle
import matplotlib.pyplot as plt
with open('.../results/saga_masks/room1_new/frame_00001_pred.pkl', 'rb') as handle:
    cluster_masks_unpadded = pickle.load(handle)

In [None]:
for i in range(cluster_masks_unpadded.shape[0]):
    print(cluster_semantics[i])
    plt.imshow(cluster_masks_unpadded[i])
    plt.show()

In [None]:
for id in cluster_ids:
    print(cluster_semantics[id])
    new_points = extract_cluster(id, point_colors, seg_score)
    bg_color = [0 for i in range(FEATURE_DIM)]

    background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
    rendered_seg_map = render(cameras[0], scene_gaussians, pipeline.extract(args), background, override_color=torch.from_numpy(new_points).cuda().float())['render']
    
    img_seg = rendered_seg_map.permute([1,2,0]).detach().cpu().numpy()

    #NOTE: Hyperparamter zum masken
    mask = img_seg > 0.3
    if mask.ndim > 2:
        mask = mask.any(axis=-1)
    segmented_image = np.zeros_like(img)
    segmented_image[mask] = img[mask]

    plt.imshow(segmented_image)
    plt.show()

## 3: Compare to ground-truth labels

This works only in the opennerf env

In [None]:
import open3d as o3d
import numpy as np
import pickle
import torch
import os

scene = "room1"

mesh_path = f".../opennerf/data/nerfstudio/replica_{scene}/{scene}_mesh.ply"
scene_point_cloud = o3d.io.read_point_cloud(mesh_path)
points = np.array(scene_point_cloud.points)
points.shape

In [None]:
saga_masks_path = f".../results/saga_masks/{scene}_new/frame_00001_pred.pkl"
with open(saga_masks_path, 'rb') as handle:
    cluster_masks = pickle.load(handle)
merged_masks = torch.argmax(cluster_masks, dim=0)
print(cluster_masks.shape)#, cluster_masks)
print(merged_masks.shape, merged_masks)

In [None]:
dir = f".../results/saga_masks/{scene}_new/"
path_list = sorted(os.listdir(dir))
semantic_masks_saga = torch.empty([len(path_list), merged_masks.shape[0], merged_masks.shape[1]])
for idx, path in enumerate(path_list):
    with open(os.path.join(dir, path), 'rb') as handle:
        cluster_masks = pickle.load(handle)
    merged_masks = torch.argmax(cluster_masks, dim=0)
    semantic_masks_saga[idx] = merged_masks

In [None]:
semantic_masks_saga[0]

In [None]:
cluster_semantics_dict = {
    "0":'bin',
    "1":'bed',
    "2":'floor',
    "3":'chair',
    "4":'blinds',
    "5":'vase',
    "6":'sofa',
    "7":'table',
    "8":'picture',
    "9":'rug',
    "10":'window',
    "11": 'lamp',
    "12": 'wall-plug',
    "13": 'tv-stand',
    "14": 'pillow',
    "15": 'bench',
    "16": 'pot',
    "17": 'tv-screen',
    "18": 'cabinet',
    "19": 'pillar',
    "20": 'blanket',
    "21": 'door',
    "22": 'cushion',
    "23": 'basket',
    "24": 'indoor-plant',
    "25": 'vent',
    "26": 'shelf',
    "27": 'stool',
    "28": 'desk',
    "29": 'comforter',
    "30": 'nightstand',
    "31": 'plant-stand',
    "32": 'ceiling',
    "33": 'plate',
    "34": 'monitor',
    "35": 'pipe',
    "36": 'wall',
    "37": 'panel'}

In [None]:
import json

# Load the JSON mapping
# mapping_path = f'.../results/class_assignments/{scene}_openseg.json'
# with open(mapping_path, 'r') as f:
#     mapping = json.load(f)
mapping = cluster_semantics_dict

print(type(mapping))

# Define the reduced labels
class_names_reduced = [
    'wall',
    'ceiling',
    'floor',
    'chair',
    'blinds',
    'sofa',
    'table',
    'rug',
    'window',
    'lamp',
    'door',
    'pillow',
    'bench',
    'tv-screen',
    'cabinet',
    'pillar',
    'blanket',
    'tv-stand',
    'cushion',
    'bin',
    'vent',
    'bed',
    'stool',
    'picture',
    'indoor-plant',
    'desk',
    'comforter',
    'nightstand',
    'shelf',
    'vase',
    'plant-stand',
    'basket',
    'plate',
    'monitor',
    'pipe',
    'panel',
    'desk-organizer',
    'wall-plug',
    'book',
    'box',
    'clock',
    'sculpture',
    'tissue-paper',
    'camera',
    'tablet',
    'pot',
    'bottle',
    'candle',
    'bowl',
    'cloth',
    'switch',
    ]

semantic_class_labels = torch.empty_like(semantic_masks_saga)

# This loop is HORRIBLY slow, but it works, but we NEED to find a better way to do this

# # Iterate over each cluster index in semantic_masks_saga
# for k in [0]:#range(semantic_masks_saga.shape[0]):
#     for i in range(semantic_masks_saga.shape[1]):
#         for j in range(semantic_masks_saga.shape[2]):
#             # Get the cluster index
#             cluster_index = semantic_masks_saga[k, i, j]

#             # Get the corresponding cluster label from the JSON mapping
#             cluster_label = mapping[str(int(cluster_index))]

#             # Convert the cluster label to the corresponding label index
#             label_index = class_names_reduced.index(cluster_label)

#             # Update the value in semantic_masks_saga with the label index
#             semantic_class_labels[k, i, j] = label_index
#     print(k)


# This should fix the horribly slow loop
# Create a lookup table for the labels to indices mapping
label_to_index = {label: idx for idx, label in enumerate(class_names_reduced)}

# Convert the mapping to use indices instead of labels
mapping_indices = {int(key): label_to_index[label] for key, label in mapping.items()}

# Convert semantic_masks_saga to a NumPy array if it's not already
semantic_masks_saga_np = semantic_masks_saga.cpu().numpy() if isinstance(semantic_masks_saga, torch.Tensor) else semantic_masks_saga

# Create an array for the mapped indices
mapped_indices = np.vectorize(mapping_indices.get)(semantic_masks_saga_np)

# Convert back to a torch.Tensor if needed
semantic_class_labels = torch.tensor(mapped_indices, device=semantic_masks_saga.device)

In [None]:
semantic_class_labels[0]

In [None]:
with open(f'.../results/semantic_labels/semantic_class_labels_sm_{scene}_openseg.pickle', 'wb') as handle:
    pickle.dump(semantic_class_labels, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import torch.nn.functional as F

# Assuming semantic_masks_saga is your input tensor of shape [45, 679, 1199]
input_tensor = semantic_class_labels

# Define the padding
# (padding_left, padding_right, padding_top, padding_bottom)
# (0, 1) adds one column to the right, and (0, 1) adds one row to the bottom
padding = (0, 1, 0, 1)

# Apply padding
padded_tensor = F.pad(input_tensor, padding, "constant", 0)

# Verify the new shape
padded_tensor.shape

In [None]:
padded_tensor[0]

In [None]:
with open(f'.../results/semantic_labels/semantic_class_labels_padded_sm_{scene}_openseg.pickle', 'wb') as handle:
    pickle.dump(padded_tensor, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
saga_semantics_path = f".../results/semantic_labels/semantic_class_labels_padded_sm_{scene}_openseg.pickle"
with open(saga_semantics_path, 'rb') as handle:
    saga_semantics = pickle.load(handle)

In [None]:
saga_semantics_img1 = padded_tensor[0]
saga_semantics_img1

In [None]:
SCANNET_COLOR_MAP_200 = {
0: (196., 51., 182.),   # 0
1: (174., 199., 232.),  # 1
2: (188., 189., 34.),   # 2
3: (152., 223., 138.),  # 3
4: (255., 152., 150.),  # 4
5: (214., 39., 40.),    # 5
6: (91., 135., 229.),   # 6
7: (31., 119., 180.),   # 7
8: (229., 91., 104.),   # 8
9: (247., 182., 210.),  # 9
10: (91., 229., 110.),  # 10
11: (255., 187., 120.), # 11
12: (141., 91., 229.),  # 12
13: (112., 128., 144.), # 13
14: (196., 156., 148.), # 14
15: (197., 176., 213.), # 15
16: (44., 160., 44.),   # 16
17: (148., 103., 189.), # 17
18: (229., 91., 223.),  # 18
19: (219., 219., 141.), # 19
20: (192., 229., 91.),  # 20
21: (88., 218., 137.),  # 21
22: (58., 98., 137.),   # 22
23: (177., 82., 239.),  # 23
24: (255., 127., 14.),  # 24
25: (237., 204., 37.),  # 25
26: (41., 206., 32.),   # 26
27: (62., 143., 148.),  # 27
28: (34., 14., 130.),   # 28
29: (143., 45., 115.),  # 29
30: (137., 63., 14.),   # 30
31: (23., 190., 207.),  # 31
32: (16., 212., 139.),  # 32
33: (90., 119., 201.),  # 33
34: (125., 30., 141.),  # 34
35: (150., 53., 56.),   # 35
36: (186., 197., 62.),  # 36
37: (227., 119., 194.), # 37
38: (38., 100., 128.),  # 38
39: (120., 31., 243.),  # 39
40: (154., 59., 103.),  # 40
41: (169., 137., 78.),  # 41
42: (143., 245., 111.), # 42
43: (37., 230., 205.),  # 43
44: (14., 16., 155.),   # 44
45: (208., 49., 84.),   # 45
46: (237., 80., 38.),   # 46
47: (138., 175., 62.),  # 47
48: (158., 218., 229.), # 48
49: (38., 96., 167.),   # 49
50: (190., 77., 246.),  # 50
51: (0., 0., 0.),}     # 51

In [None]:
import matplotlib.pyplot as plt
saga_semantics_img1_colors = np.array([list(SCANNET_COLOR_MAP_200.values())[i] for i in saga_semantics_img1.flatten()], dtype=np.uint8)
saga_semantics_img1_colors = saga_semantics_img1_colors.reshape(saga_semantics_img1.shape + (3,))
plt.imshow(saga_semantics_img1_colors)
plt.show()

In [None]:
saga_mask_path_new = f".../results/saga_masks/room1_new/frame_00001_pred.pkl"
saga_mask_path_old = f".../results/saga_masks/room1_final/frame_00001_pred.pkl"
with open(saga_mask_path_new, 'rb') as handle:
    saga_masks_new = pickle.load(handle)
with open(saga_mask_path_old, 'rb') as handle:
    saga_masks_old = pickle.load(handle)
saga_masks_new_merged = torch.argmax(saga_masks_new, dim=0)
plt.imshow(saga_masks_new_merged)
plt.show()
saga_masks_old_merged = torch.argmax(saga_masks_old, dim=0)
plt.imshow(saga_masks_old_merged)

After the evaluation is done you can check your metrics here:

In [None]:
scenes = ['room1']#['office0', 'office1', 'office2', 'office3', 'office4', 'room0', 'room1', 'room2'] 

valid_class_ids = [0, 3, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 22, 23,
                   26, 29, 31, 34, 35, 37, 40, 44, 47, 52, 54, 56, 59, 60, 61,
                   62, 63, 64, 65, 70, 71, 76, 78, 79, 80, 82, 83, 87, 88, 91,
                   92, 93, 95, 97, 98]

class_popularity = [1, 47, 18, 22, 13,  5, 37, 40, 50, 49, 24, 21, 30,  2, 43, 11,
                    29,  4, 44, 17,  3, 46,  1, 38, 28, 23, 19, 16, 26, 36, 45,
                    32,  0, 33, 25, 31, 27, 20, 48,  6,  8, 14, 35, 42, 10, 41,
                    34,  7, 12,  9, 15, 39]
num_classes = len(valid_class_ids)  # 51

map_to_reduced = {
93:	0,
31:	1,
40:	2,
20:	3,
12:	4,
76:	5,
80:	6,
98:	7,
97:	8,
47:	9,
37:	10,
61:	11,
8:	12,
87:	13,
18:	14,
60:	15,
11:	16,
88:	17,
29:	18,
10:	19,
92:	20,
7:	21,
78:	22,
59:	23,
44:	24,
34:	25,
26:	26,
54:	27,
71:	28,
91:	29,
63:	30,
3:	31,
64:	32,
52:	33,
62:	34,
56:	35,
35:	36,
95:	37,
13:	38,
15:	39,
22:	40,
70:	41,
83:	42,
17:	43,
82:	44,
65:	45,
14:	46,
19:	47,
16:	48,
23:	49,
79:	50,
-1: 51,
-2: 51,
256:51,
}

class_names_reduced = [
    'wall',
    'ceiling',
    'floor',
    'chair',
    'blinds',
    'sofa',
    'table',
    'rug',
    'window',
    'lamp',
    'door',
    'pillow',
    'bench',
    'tv-screen',
    'cabinet',
    'pillar',
    'blanket',
    'tv-stand',
    'cushion',
    'bin',
    'vent',
    'bed',
    'stool',
    'picture',
    'indoor-plant',
    'desk',
    'comforter',
    'nightstand',
    'shelf',
    'vase',
    'plant-stand',
    'basket',
    'plate',
    'monitor',
    'pipe',
    'panel',
    'desk-organizer',
    'wall-plug',
    'book',
    'box',
    'clock',
    'sculpture',
    'tissue-paper',
    'camera',
    'tablet',
    'pot',
    'bottle',
    'candle',
    'bowl',
    'cloth',
    'switch',
    '0',
    ]

class_names = [
    '0',                # 0
    'backpack',         # 1
    'base-cabinet',     # 2
    'basket',           # 3
    'bathtub',          # 4
    'beam',
    'beanbag',
    'bed',
    'bench',
    'bike',
    'bin',
    'blanket',
    'blinds',
    'book',
    'bottle',
    'box',
    'bowl',
    'camera',
    'cabinet',
    'candle',
    'chair',
    'chopping-board',
    'clock',
    'cloth',
    'clothing',
    'coaster',
    'comforter',
    'computer-keyboard',
    'cup',
    'cushion',
    'curtain',
    'ceiling',
    'cooktop',
    'countertop',
    'desk',
    'desk-organizer',
    'desktop-computer',
    'door',
    'exercise-ball',
    'faucet',
    'floor',
    'handbag',
    'hair-dryer',
    'handrail',
    'indoor-plant',
    'knife-block',
    'kitchen-utensil',
    'lamp',
    'laptop',
    'major-appliance',
    'mat',
    'microwave',
    'monitor',
    'mouse',
    'nightstand',
    'pan',
    'panel',
    'paper-towel',
    'phone',
    'picture',
    'pillar',
    'pillow',
    'pipe',
    'plant-stand',
    'plate',
    'pot',
    'rack',
    'refrigerator',
    'remote-control',
    'scarf',
    'sculpture',
    'shelf',
    'shoe',
    'shower-stall',
    'sink',
    'small-appliance',
    'sofa',
    'stair',
    'stool',
    'switch',
    'table',
    'table-runner',
    'tablet',
    'tissue-paper',
    'toilet',
    'toothbrush',
    'towel',
    'tv-screen',
    'tv-stand',
    'umbrella',
    'utensil-holder',
    'vase',
    'vent',
    'wall',
    'wall-cabinet',
    'wall-plug',
    'wardrobe',
    'window',
    'rug',
    'logo',
    'bag',
    'set-of-clothing',
]

SCANNET_COLOR_MAP_200 = {
0: (196., 51., 182.),   # 0
1: (174., 199., 232.),  # 1
2: (188., 189., 34.),   # 2
3: (152., 223., 138.),  # 3
4: (255., 152., 150.),  # 4
5: (214., 39., 40.),    # 5
6: (91., 135., 229.),   # 6
7: (31., 119., 180.),   # 7
8: (229., 91., 104.),   # 8
9: (247., 182., 210.),  # 9
10: (91., 229., 110.),  # 10
11: (255., 187., 120.), # 11
13: (141., 91., 229.),  # 12
14: (112., 128., 144.), # 13
15: (196., 156., 148.), # 14
16: (197., 176., 213.), # 15
17: (44., 160., 44.),   # 16
18: (148., 103., 189.), # 17
19: (229., 91., 223.),  # 18
21: (219., 219., 141.), # 19
22: (192., 229., 91.),  # 20
23: (88., 218., 137.),  # 21
24: (58., 98., 137.),   # 22
26: (177., 82., 239.),  # 23
27: (255., 127., 14.),  # 24
28: (237., 204., 37.),  # 25
29: (41., 206., 32.),   # 26
31: (62., 143., 148.),  # 27
32: (34., 14., 130.),   # 28
33: (143., 45., 115.),  # 29
34: (137., 63., 14.),   # 30
35: (23., 190., 207.),  # 31
36: (16., 212., 139.),  # 32
38: (90., 119., 201.),  # 33
39: (125., 30., 141.),  # 34
40: (150., 53., 56.),   # 35
41: (186., 197., 62.),  # 36
42: (227., 119., 194.), # 37
44: (38., 100., 128.),  # 38
45: (120., 31., 243.),  # 39
46: (154., 59., 103.),  # 40
47: (169., 137., 78.),  # 41
48: (143., 245., 111.), # 42
49: (37., 230., 205.),  # 43
50: (14., 16., 155.),   # 44
51: (208., 49., 84.),   # 45
52: (237., 80., 38.),   # 46
54: (138., 175., 62.),  # 47
55: (158., 218., 229.), # 48
56: (38., 96., 167.),   # 49
57: (190., 77., 246.),  # 50
58: (0., 0., 0.),       # 51
59: (208., 193., 72.),
62: (55., 220., 57.),
63: (10., 125., 140.),
64: (76., 38., 202.),
65: (191., 28., 135.),
66: (211., 120., 42.),
67: (118., 174., 76.),
68: (17., 242., 171.),
69: (20., 65., 247.),
70: (208., 61., 222.),
71: (162., 62., 60.),
72: (210., 235., 62.),
73: (45., 152., 72.),
74: (35., 107., 149.),
75: (160., 89., 237.),
76: (227., 56., 125.),
77: (169., 143., 81.),
78: (42., 143., 20.),
79: (25., 160., 151.),
80: (82., 75., 227.),
82: (253., 59., 222.),
84: (240., 130., 89.),
86: (123., 172., 47.),
87: (71., 194., 133.),
88: (24., 94., 205.),
89: (134., 16., 179.),
90: (159., 32., 52.),
93: (213., 208., 88.),
95: (64., 158., 70.),
96: (18., 163., 194.),
97: (65., 29., 153.),
98: (177., 10., 109.),
99: (152., 83., 7.),
100: (83., 175., 30.),
101: (18., 199., 153.),
102: (61., 81., 208.),
103: (213., 85., 216.),
104: (170., 53., 42.),
105: (161., 192., 38.),
106: (23., 241., 91.),
107: (12., 103., 170.),
110: (151., 41., 245.),
112: (133., 51., 80.),
115: (184., 162., 91.),
116: (50., 138., 38.),
118: (31., 237., 236.),
120: (39., 19., 208.),
121: (223., 27., 180.),
122: (254., 141., 85.),
125: (97., 144., 39.),
128: (106., 231., 176.),
130: (12., 61., 162.),
131: (124., 66., 140.),
132: (137., 66., 73.),
134: (250., 253., 26.),
136: (55., 191., 73.),
138: (60., 126., 146.),
139: (153., 108., 234.),
140: (184., 58., 125.),
141: (135., 84., 14.),
145: (139., 248., 91.),
148: (53., 200., 172.),
154: (63., 69., 134.),
155: (190., 75., 186.),
156: (127., 63., 52.),
157: (141., 182., 25.),
159: (56., 144., 89.),
161: (64., 160., 250.),
163: (182., 86., 245.),
165: (139., 18., 53.),
166: (134., 120., 54.),
168: (49., 165., 42.),
169: (51., 128., 133.),
170: (44., 21., 163.),
177: (232., 93., 193.),
180: (176., 102., 54.),
185: (116., 217., 17.),
188: (54., 209., 150.),
191: (60., 99., 204.),
193: (129., 43., 144.),
195: (252., 100., 106.),
202: (187., 196., 73.),
208: (13., 158., 40.),
213: (52., 122., 152.),
214: (128., 76., 202.),
221: (187., 50., 115.),
229: (180., 141., 71.),
230: (77., 208., 35.),
232: (72., 183., 168.),
233: (97., 99., 203.),
242: (172., 22., 158.),
250: (155., 64., 40.),
261: (118., 159., 30.),
264: (69., 252., 148.),
276: (45., 103., 173.),
283: (111., 38., 149.),
286: (184., 9., 49.),
300: (188., 174., 67.),
304: (53., 206., 53.),
312: (97., 235., 252.),
323: (66., 32., 182.),
325: (236., 114., 195.),
331: (241., 154., 83.),
342: (133., 240., 52.),
356: (16., 205., 144.),
370: (75., 101., 198.),
392: (237., 95., 251.),
395: (191., 52., 49.),
399: (227., 254., 54.),
408: (49., 206., 87.),
417: (48., 113., 150.),
488: (125., 73., 182.),
540: (229., 32., 114.),
562: (158., 119., 28.),
570: (60., 205., 27.),
572: (18., 215., 201.),
581: (79., 76., 153.),
609: (134., 13., 116.),
748: (192., 97., 63.),
776: (108., 163., 18.),
1156: (95., 220., 156.),
1163: (98., 141., 208.),
1164: (144., 19., 193.),
1165: (166., 36., 57.),
1166: (212., 202., 34.),
1167: (23., 206., 34.),
1168: (91., 211., 236.),
1169: (79., 55., 137.),
1170: (182., 19., 117.),
1171: (134., 76., 14.),
1172: (87., 185., 28.),
1173: (82., 224., 187.),
1174: (92., 110., 214.),
1175: (168., 80., 171.),
1176: (197., 63., 51.),
1178: (175., 199., 77.),
1179: (62., 180., 98.),
1180: (8., 91., 150.),
1181: (77., 15., 130.),
1182: (154., 65., 96.),
1183: (197., 152., 11.),
1184: (59., 155., 45.),
1185: (12., 147., 145.),
1186: (54., 35., 219.),
1187: (210., 73., 181.),
1188: (221., 124., 77.),
1189: (149., 214., 66.),
1190: (72., 185., 134.),
1191: (42., 94., 198.),
}


In [None]:
import sys
import math
import numpy as np

PREFIX = ".../opennerf"
experiment_name = "benchmark"
scene = "room1"

def process_txt(filename):
    with open(filename) as file:
        lines = file.readlines()
        lines = [line.rstrip() for line in lines]
    return lines

def eval_semantics():

    pr_files = []  # predicted files
    gt_files = []  # ground truth files
    # for scene in scenes:
    #     pr_files.append(f'{PREFIX}/outputs/replica_{scene}/opennerf/{experiment_name}/semantics_{scene}.txt')
    #     gt_files.append(f'{PREFIX}/datasets/replica_gt_semantics/semantic_labels_{scene}.txt')
    pr_files.append(f'{PREFIX}/outputs/replica_{scene}/opennerf/{experiment_name}/semantics_{scene}.txt')
    gt_files.append(f'{PREFIX}/datasets/replica_gt_semantics/semantic_labels_{scene}.txt')

    confusion = np.zeros([num_classes, num_classes], dtype=np.ulonglong)

    print('evaluating', len(pr_files), 'scans...')
    for i in range(len(pr_files)):
        evaluate_scan(pr_files[i], gt_files[i], confusion)
        sys.stdout.write("\rscans processed: {}".format(i+1))
        sys.stdout.flush()

    class_ious = {}
    for i in range(num_classes):
        label_name = class_names_reduced[i]
        label_id = i
        class_ious[label_name] = get_iou(label_id, confusion)

    print('classes \t IoU \t Acc')
    print('----------------------------')
    for i in range(num_classes):
        label_name = class_names_reduced[i]
        print('{0:<14s}: {1:>5.2%}   {2:>6.2%}'.format(label_name, class_ious[label_name][0], class_ious[label_name][1]))

    iou_values = np.array([i[0] for i in class_ious.values()])
    acc_values = np.array([i[1] for i in class_ious.values()])
    print()
    print(f'mIoU: \t {np.mean(iou_values):.2%}')
    print(f'mAcc: \t {np.mean(acc_values):.2%}')
    print()
    for i, split in enumerate(['head', 'comm', 'tail']):
        print(f'{split}: \t {np.mean(iou_values[17 * i:17 * (i + 1)]):.2%}')
        print(f'{split}: \t {np.mean(acc_values[17 * i:17 * (i + 1)]):.2%}')
        print('---')


def evaluate_scan(pr_file, gt_file, confusion):

    pr_ids = np.array(process_txt(pr_file), dtype=np.int64)
    gt_file_contents = np.array(process_txt(gt_file)).astype(np.int64)
    gt_ids = np.vectorize(map_to_reduced.get)(gt_file_contents)

    # sanity checks
    if not pr_ids.shape == gt_ids.shape:
        print(f'number of predicted values does not match number of vertices: {pr_file}')
    for (gt_val, pr_val) in zip(gt_ids, pr_ids):
        if gt_val == num_classes:
            continue
        confusion[gt_val][pr_val] += 1


def get_iou(label_id, confusion):
    tp = np.longlong(confusion[label_id, label_id])
    fn = np.longlong(confusion[label_id, :].sum()) - tp
    fp = np.longlong(confusion[:, label_id].sum()) - tp
    denom = float(tp + fp + fn)
    if denom == 0:
        return (0, 0) #float('nan')
    iou = tp / denom

    if tp==0 and fn==0:
        return (iou, 0)
    acc = tp / float(tp + fn)
    
    return (iou, acc)

In [None]:
eval_semantics()