In [1]:
import os
import torch
import numpy as np
from glob import glob
from tqdm import tqdm
from os.path import join, exists
import open3d as o3d
import matplotlib.pyplot as plt

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
# should be the preprocessed file path
sample_path_0 = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/deathwing/scannet_3d/example/scene0000_00_vh_clean_2.pth"
#sample_path_1 = "D:/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_3d/train/scene0000_01_vh_clean_2.pth"
#sample_path_2 = "D:/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_3d/train/scene0000_02_vh_clean_2.pth"

In [3]:
sample_0 = torch.load(sample_path_0) # coords,colors,labels
#sample_1 = torch.load(sample_path_1) # coords,colors,labels
#sample_2 = torch.load(sample_path_2) # coords,colors,labels

In [4]:
len(sample_0[0])

82936

In [5]:
# aggregating all of the partial point clouds of the same scene (they don't overlap perfectly)
#sample_points = np.concatenate((sample_0[0], sample_1[0], sample_2[0]))
#sample_colors = np.concatenate((sample_0[1], sample_1[1], sample_2[1]))

# single partial point cloud
sample_points  = sample_0[0]
sample_colors = sample_0[1]
sample_labels = sample_0[2]

In [6]:
#to view original scene
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(sample_points))
#original colors
pcd.colors = o3d.utility.Vector3dVector(np.asarray(sample_colors))
#------
#paint uniform
#sample_paint_uniform = np.asarray([200,200,200])/255.0 #redish
#pcd.paint_uniform_color(sample_paint_uniform)
o3d.visualization.draw_geometries([pcd])

# load fused features

In [7]:
# should be the fused feature path
feature_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/deathwing/fused/scene0000_00_0.pt"

In [8]:
feature = torch.load(feature_path)

In [9]:
feature["mask_full"].shape

torch.Size([82936])

In [10]:
feature["feat"].shape

torch.Size([79248, 768])

In [11]:
# Get the indices where the mask is True
indices = torch.nonzero(feature["mask_full"]).squeeze()

In [12]:
filtered_point_cloud = sample_points[indices, :]
filtered_point_cloud_colors = sample_colors[indices, :]
filtered_point_cloud_labels = sample_labels[indices]

In [13]:
# Replace every occurrence of 21 with 20 if necessary
gt_ids= np.where(filtered_point_cloud_labels == 21.0, 20.0, filtered_point_cloud_labels)

In [14]:
filtered_point_cloud.shape

(79248, 3)

# using clip model

In [15]:
import clip
model, preprocess = clip.load("ViT-L/14@336px")

In [45]:
# highlight with a threshold
# type the query here 
query = ["dragon"]

with torch.no_grad():
    all_text_embeddings = []
    for category in tqdm(query):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        all_text_embeddings.append(text_embedding)

    all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

# normalizing 
fused_f = (feature["feat"]/(feature["feat"].norm(dim=-1, keepdim=True)+1e-5)).half()
# calculating similarity matrix
# similarity_matrix = torch.matmul(feature["feat"].cuda(), all_text_embeddings) # 
similarity_matrix = fused_f.cuda() @ all_text_embeddings    
    
# set higher to increase the certainty (not always correct)
threshold_percentage = 0.9
cap = similarity_matrix.max().item()
found_indices = torch.nonzero(similarity_matrix > cap*threshold_percentage, as_tuple=False).squeeze().T[0]

# creating pc
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud))
pcd.colors = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud_colors))

found_region = pcd.select_by_index(found_indices.tolist())
found_region.paint_uniform_color([1.0, 0, 0]) # paint related points to red
rest = pcd.select_by_index(found_indices.tolist(), invert=True)
o3d.visualization.draw_geometries([rest,found_region])

100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 65.22it/s]


In [55]:
# highlight with a heatmap
# type the query here 
query = ["deathwing"]

with torch.no_grad():
    all_text_embeddings = []
    for category in tqdm(query):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        all_text_embeddings.append(text_embedding)

    all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

# normalizing 
fused_f = (feature["feat"]/(feature["feat"].norm(dim=-1, keepdim=True)+1e-5)).half()
# calculating similarity matrix
# similarity_matrix = torch.matmul(feature["feat"].cuda(), all_text_embeddings) # 
similarity_matrix = fused_f.cuda() @ all_text_embeddings    

# creating pc
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud))
pcd.colors = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud_colors))

# heatmap
cmap = plt.get_cmap('cividis')

# normalize the tensor to the range [0, 1]
normalized_tensor = (similarity_matrix - torch.min(similarity_matrix)) / (torch.max(similarity_matrix) - torch.min(similarity_matrix))

colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
pcd_heatmap = o3d.geometry.PointCloud()

pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

#transform heatmap to the side
pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])

o3d.visualization.draw_geometries([pcd, pcd_heatmap])

100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 44.88it/s]


# mIoU evaluation

In [56]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
                     'table', 'door', 'window', 'bookshelf', 'picture','counter', 'desk', 'curtain', 'refrigerator', 'shower curtain',
                     'toilet', 'sink', 'bathtub', 'otherfurniture']
UNKNOWN_ID = 255
NO_FEATURE_ID = 256

SCANNET_LABELS_20.append(query[0])
CLASS_LABELS = SCANNET_LABELS_20

In [57]:
CLASS_LABELS

['wall',
 'floor',
 'cabinet',
 'bed',
 'chair',
 'sofa',
 'table',
 'door',
 'window',
 'bookshelf',
 'picture',
 'counter',
 'desk',
 'curtain',
 'refrigerator',
 'shower curtain',
 'toilet',
 'sink',
 'bathtub',
 'otherfurniture',
 'dragon red black wings tail claws']

In [58]:
with torch.no_grad():
    label_embeds = []
    for category in tqdm(SCANNET_LABELS_20):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        label_embeds.append(text_embedding)

    label_embeds = torch.stack(label_embeds, dim=1)


100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:00<00:00, 47.36it/s]


In [59]:
# calculating similarity matrix
similarity_matrix = fused_f.cuda() @ label_embeds    
pred_ids = torch.max(similarity_matrix, 1)[1].detach().cpu()

In [60]:
def confusion_matrix(pred_ids, gt_ids, num_classes):
    '''calculate the confusion matrix.'''

    assert pred_ids.shape == gt_ids.shape, (pred_ids.shape, gt_ids.shape)
    idxs = gt_ids != UNKNOWN_ID
    if NO_FEATURE_ID in pred_ids: # some points have no feature assigned for prediction
        pred_ids[pred_ids==NO_FEATURE_ID] = num_classes
        confusion = np.bincount(
            pred_ids[idxs] * (num_classes+1) + gt_ids[idxs],
            minlength=(num_classes+1)**2).reshape((
            num_classes+1, num_classes+1)).astype(np.ulonglong)
        return confusion[:num_classes, :num_classes]

    return np.bincount(
        pred_ids[idxs] * num_classes + gt_ids[idxs],
        minlength=num_classes**2).reshape((
        num_classes, num_classes)).astype(np.ulonglong)

def get_iou(label_id, confusion):
    '''calculate IoU.'''

    # true positives
    tp = np.longlong(confusion[label_id, label_id])
    # false positives
    fp = np.longlong(confusion[label_id, :].sum()) - tp
    # false negatives
    fn = np.longlong(confusion[:, label_id].sum()) - tp

    denom = (tp + fp + fn)
    if denom == 0:
        return float('nan')
    return float(tp) / denom, tp, denom

In [61]:
N_CLASSES = len(CLASS_LABELS)
confusion = confusion_matrix(pred_ids, gt_ids, N_CLASSES)
class_ious = {}
class_accs = {}
mean_iou = 0
mean_acc = 0

In [62]:
count = 0
for i in range(N_CLASSES):
    label_name = CLASS_LABELS[i]
    if (gt_ids==i).sum() == 0: # at least 1 point needs to be in the evaluation for this class
        continue

    class_ious[label_name] = get_iou(i, confusion)
    class_accs[label_name] = class_ious[label_name][1] / (gt_ids==i).sum()
    count+=1

    mean_iou += class_ious[label_name][0]
    mean_acc += class_accs[label_name]


In [63]:
mean_iou /= N_CLASSES
mean_acc /= N_CLASSES

In [64]:
print('classes          IoU')
print('----------------------------')
for i in range(N_CLASSES):
    label_name = CLASS_LABELS[i]
    try:
        print('{0:<14s}: {1:>5.5f}   ({2:>6d}/{3:<6d})'.format(
                label_name,
                class_ious[label_name][0],
                class_ious[label_name][1],
                class_ious[label_name][2]))
    except:
        print(label_name + ' error!')
        continue

classes          IoU
----------------------------
wall          : 0.59753   (  8078/13519 )
floor         : 0.80171   (  8353/10419 )
cabinet       : 0.50060   (  5863/11712 )
bed           : 0.87822   (  2798/3186  )
chair error!
sofa          : 0.79371   (  4571/5759  )
table         : 0.15513   (   955/6156  )
door          : 0.26407   (  1286/4870  )
window        : 0.65814   (   566/860   )
bookshelf error!
picture error!
counter       : 0.15548   (   301/1936  )
desk          : 0.01666   (    32/1921  )
curtain       : 0.67594   (  4956/7332  )
refrigerator  : 0.66776   (  1017/1523  )
shower curtain error!
toilet        : 0.46035   (   418/908   )
sink          : 0.26351   (   156/592   )
bathtub error!
otherfurniture: 0.01622   (    66/4068  )
dragon red black wings tail claws: 0.76643   (  1201/1567  )
