In [1]:
import os
import torch
import numpy as np
from glob import glob
from tqdm import tqdm
from os.path import join, exists
import open3d as o3d
import matplotlib.pyplot as plt

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [18]:
def load_scene(path, visualize = True):
    
    sample = torch.load(path)
    sample_points  = sample[0]
    sample_colors = sample[1]
    sample_labels = sample[2]
    
    if visualize:
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(np.asarray(sample_points))
        pcd.colors = o3d.utility.Vector3dVector(np.asarray(sample_colors))
        o3d.visualization.draw_geometries([pcd])
        
    return sample_points, sample_colors, sample_labels

def load_fused_features(path, sample_points, sample_colors, sample_labels):
    
    features = torch.load(path)
    indices = torch.nonzero(features["mask_full"]).squeeze()
    filtered_point_cloud = sample_points[indices, :]
    filtered_point_cloud_colors = sample_colors[indices, :]
    filtered_point_cloud_labels = sample_labels[indices]
    fused_features = (features["feat"]/(features["feat"].norm(dim=-1, keepdim=True)+1e-5))
    
    return fused_features, filtered_point_cloud, filtered_point_cloud_colors, filtered_point_cloud_labels, indices

def load_distilled_features(path, indices):
    
    distilled = np.load(path)
    #cast and normalize embeddings for distilled 
    distilled = distilled[indices, :]
    distilled_t = torch.from_numpy(distilled).half()
    distilled_f = (distilled_t/(distilled_t.norm(dim=-1, keepdim=True)+1e-5))
    
    return distilled_f

def highlight_query(query, feature_type, agg_type, distill, fused, fpc, fpcc, device):
    
    import clip
    model, preprocess = clip.load("ViT-L/14@336px")
    
    with torch.no_grad():
        per_descriptor_embeds = []
        for descriptor in tqdm(query):
            texts = clip.tokenize(descriptor)  #tokenize
            texts = texts.to(device)
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            per_descriptor_embeds.append(text_embeddings)

        per_descriptor_embeds = torch.stack(per_descriptor_embeds, dim=1).squeeze()

    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ per_descriptor_embeds.T
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ per_descriptor_embeds.T
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ per_descriptor_embeds
        pred_distill = distill.to(device) @ per_descriptor_embeds.T
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble @ per_descriptor_embeds.T
        
    if agg_type == "mean":
        agg_sim_mat = torch.mean(similarity_matrix, dim=1)
    elif agg_type == "max":
        agg_sim_mat, _ = torch.max(similarity_matrix, dim=1)
        
    agg_sim_mat = agg_sim_mat.reshape(-1, 1)
    
    # creating pc
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.asarray(fpc))
    pcd.colors = o3d.utility.Vector3dVector(np.asarray(fpcc))

    # heatmap
    cmap = plt.get_cmap('cividis')

    # normalize the tensor to the range [0, 1]
    normalized_tensor = (agg_sim_mat - torch.min(agg_sim_mat)) / (torch.max(agg_sim_mat) - torch.min(agg_sim_mat))

    colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
    pcd_heatmap = o3d.geometry.PointCloud()

    pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
    pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

    #transform heatmap to the side
    pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])

    o3d.visualization.draw_geometries([pcd, pcd_heatmap])
    
    return agg_sim_mat

def confusion_matrix(pred_ids, gt_ids, num_classes):
    '''calculate the confusion matrix.'''

    assert pred_ids.shape == gt_ids.shape, (pred_ids.shape, gt_ids.shape)
    idxs = gt_ids != UNKNOWN_ID
    if NO_FEATURE_ID in pred_ids: # some points have no feature assigned for prediction
        pred_ids[pred_ids==NO_FEATURE_ID] = num_classes
        confusion = np.bincount(
            pred_ids[idxs] * (num_classes+1) + gt_ids[idxs],
            minlength=(num_classes+1)**2).reshape((
            num_classes+1, num_classes+1)).astype(np.ulonglong)
        return confusion[:num_classes, :num_classes]

    return np.bincount(
        pred_ids[idxs] * num_classes + gt_ids[idxs],
        minlength=num_classes**2).reshape((
        num_classes, num_classes)).astype(np.ulonglong)

def get_iou(label_id, confusion):
    '''calculate IoU.'''

    # true positives
    tp = np.longlong(confusion[label_id, label_id])
    # false positives
    fp = np.longlong(confusion[label_id, :].sum()) - tp
    # false negatives
    fn = np.longlong(confusion[:, label_id].sum()) - tp

    denom = (tp + fp + fn)
    if denom == 0:
        return float('nan')
    return float(tp) / denom, tp, denom

def evaluate(labelset, descriptors, feature_type, agg_type, distill, fused, gt_ids):
    
    import clip
    model, preprocess = clip.load("ViT-L/14@336px")
    
    with torch.no_grad():
        label_embeds = []
        for category in tqdm(labelset):
            texts = clip.tokenize(category)  #tokenize
            texts = texts.cuda()
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            label_embeds.append(text_embeddings)

    label_embeds = torch.cat(label_embeds, dim=0) # has the shape of [768, *original labels*+*descriptors*]
    
    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ label_embeds.T
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ label_embeds.T
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ label_embeds.T
        pred_distill = distill.to(device) @ label_embeds.T
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble.to(device) @ label_embeds.T
    
    # separating the similarity matrix for the original labels and descriptors
    sim_labels = similarity_matrix[:, :20]
    sim_descriptors = similarity_matrix[:, 20:]
    
    # get lengths of the descriptors
    descriptor_lengths = []
    for label in labelset:
        if not isinstance(label, str):
            descriptor_lengths.append(len(label))
            
    # aggregate the corresponding descriptor vectors by the length of each
    # keep them in a list to stack it later
    _idx = 0
    agg_desc= []
    for elem in descriptor_lengths:
        if agg_type == "mean":
            agg_desc_sim_mat = torch.mean(sim_descriptors[:, _idx : _idx + elem], dim=1)
        elif agg_type == "max":
            agg_desc_sim_mat, _ = torch.max(sim_descriptors[:, _idx : _idx + elem], dim=1)
        
        _idx += elem
        agg_desc.append(agg_desc_sim_mat)
        
    # stack aggregated descriptor similarity matrices
    agg_sim_descriptors = torch.stack(agg_desc, dim = 1)
    # combine with the similarity matrix of labels
    agg_sim_mat = torch.cat([sim_labels,agg_sim_descriptors], dim = 1)

    # get the predictions
    pred_ids = torch.max(agg_sim_mat, 1)[1].detach().cpu()    
    
    N_CLASSES = len(labelset)
    confusion = confusion_matrix(pred_ids, gt_ids, N_CLASSES)
    class_ious = {}
    class_accs = {}
    mean_iou = 0
    mean_acc = 0
    
    count = 0
    for i in range(N_CLASSES):
        label_name = labelset[i]

        if not isinstance(label_name, str): 
            for key, value in descriptors.items():
                if value == label_name:
                    label_name = key
                    
        if (gt_ids==i).sum() == 0: # at least 1 point needs to be in the evaluation for this class
            continue

        class_ious[label_name] = get_iou(i, confusion)
        class_accs[label_name] = class_ious[label_name][1] / (gt_ids==i).sum()
        count+=1

        mean_iou += class_ious[label_name][0]
        mean_acc += class_accs[label_name]

    mean_iou /= N_CLASSES
    mean_acc /= N_CLASSES
    
    return class_ious, class_accs, mean_iou, mean_acc

def print_results(labelset, class_ious):
    
    print('classes                 IoU')
    print('----------------------------')
    for i in range(len(labelset)):
        label_name = labelset[i]
        if not isinstance(label_name, str): 
            for key, value in descriptors.items():
                if value == label_name:
                    label_name = key
        try:
            print('{0:<14s}             :          {1:>5.5f}           ({2:>6d}/{3:<6d})'.format(
                    label_name,
                    class_ious[label_name][0],
                    class_ious[label_name][1],
                    class_ious[label_name][2]))
        except:
            print(label_name + ' error!')
            continue
            
def descriptors_from_prompt(text, verbose = True):
    
    import openai

    openai.api_key = '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'

    response = openai.Completion.create(
      engine="text-davinci-003",
      prompt=text,

      temperature=0.5,
      max_tokens=200
    )
    
    print(response["choices"][0].text)
    
    lines = [s for s in [line.strip() for line in response["choices"][0].text.splitlines()] if s]
    
    descriptors = {}
    
    for line in lines:
        parts = line.split(":")
        key = parts[0].strip()
        features = [f'a bird which is {f.strip()}.' for f in parts[1].split(",")]
        descriptors[key] = features
        
    return descriptors

In [3]:
source_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds/scannet_3d/example/scene0000_00_vh_clean_2.pth"
fused_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds/fused/scene0000_00_0.pt"
distilled_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds/features_3D/scene0000_00_vh_clean_2_openscene_feat_distill.npy"

source_points, source_colors, source_labels = load_scene(source_path, False)

fused_f, filtered_pc, filtered_pc_c, filtered_pc_labels, indices = load_fused_features(fused_path,
                                                                              source_points, 
                                                                              source_colors,
                                                                              source_labels)
distilled_f = load_distilled_features(distilled_path, indices)

In [None]:
# info about scene with birds
# diamond firetail         class label = 20, 1798 points, on the bed
# Blue-faced Honeyeater    class label = 21, 1771 points, on the kitchen counter
# Mouse-colored Tyrannulet class label = 22, 1476 points, on the sofa corner

In [21]:
# query
query = ["a bird which is grey-breasted", "a bird which is brown-crowned", "a bird which is yellow-eyed"]

similarity = highlight_query(query, "fused", "mean", distilled_f, fused_f, filtered_pc, filtered_pc_c, device)

100%|█████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 98.00it/s]


torch.Size([3, 768])
torch.Size([82723, 768])
torch.Size([82723, 3])


In [11]:
#parse descriptors from openai api with gpt
_prompt = "Generate 10 visual descriptors for each of the following categories, they are bird species: [Blue-faced Honeyeater, Diamond Firetail, Mouse-colored Tyrannulet]. The descriptors will be used for input queries for a CLIP model. The descriptors should be concise and distinct from the descriptors of the other classes. Do not focus on behavior, but purely on attributes which are recognizable by the CLIP model. The output should be in the following form as a string: *class*: *descriptor1*, *descriptor2*, etc."

descriptors = descriptors_from_prompt(_prompt, verbose = False)



Blue-faced Honeyeater: black-masked, yellow-bellied, white-throated, white-eyed, yellow-breasted, black-billed, long-tailed, blue-faced, black-crowned, white-tipped.

Diamond Firetail: red-headed, black-shouldered, white-rumped, grey-backed, white-bellied, yellow-breasted, orange-breasted, black-billed, red-tailed, black-winged.

Mouse-colored Tyrannulet: olive-backed, yellow-breasted, white-bellied, grey-crowned, grey-tailed, white-eyed, yellow-throated, black-billed, white-winged, grey-breasted.


In [12]:
descriptors

{'Blue-faced Honeyeater': ['a bird which is black-masked.',
  'a bird which is yellow-bellied.',
  'a bird which is white-throated.',
  'a bird which is white-eyed.',
  'a bird which is yellow-breasted.',
  'a bird which is black-billed.',
  'a bird which is long-tailed.',
  'a bird which is blue-faced.',
  'a bird which is black-crowned.',
  'a bird which is white-tipped..'],
 'Diamond Firetail': ['a bird which is red-headed.',
  'a bird which is black-shouldered.',
  'a bird which is white-rumped.',
  'a bird which is grey-backed.',
  'a bird which is white-bellied.',
  'a bird which is yellow-breasted.',
  'a bird which is orange-breasted.',
  'a bird which is black-billed.',
  'a bird which is red-tailed.',
  'a bird which is black-winged..'],
 'Mouse-colored Tyrannulet': ['a bird which is olive-backed.',
  'a bird which is yellow-breasted.',
  'a bird which is white-bellied.',
  'a bird which is grey-crowned.',
  'a bird which is grey-tailed.',
  'a bird which is white-eyed.',
  '

In [136]:
# manually set descriptors if necessary
descriptors = {"mouse-colored tyrannulet": ["a bird which is black-winged", "a bird which is white-breasted", "a bird which is long-tailed"],
               "diamond firetail" :         ["a bird which is red-breasted", "a bird which is white-eyed", "a bird which is black-tailed"],
               "blue-faced honeyeater":     ["a bird which is blue-faced", "a bird which is hooked-beak", "a bird which is yellow-throated"]
              }


In [13]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
                     'table', 'door', 'window', 'bookshelf', 'picture','counter', 'desk', 'curtain', 'refrigerator', 'shower curtain',
                     'toilet', 'sink', 'bathtub', 'otherfurniture']
UNKNOWN_ID = 255
NO_FEATURE_ID = 256

for key, value in descriptors.items():
    SCANNET_LABELS_20.append(value)

In [19]:
class_ious, class_accs, mean_iou, mean_acc = evaluate(SCANNET_LABELS_20, descriptors, "fused", "max" , distilled_f, fused_f, filtered_pc_labels)

100%|██████████████████████████████████████████████████████████████████████████████████| 23/23 [00:00<00:00, 111.53it/s]


In [23]:
# combine all the added object labels into 20 
# gt_ids = np.where(np.logical_or(filtered_pc_labels == 21, filtered_pc_labels == 22), 20, filtered_pc_labels)

In [20]:
print_results(SCANNET_LABELS_20, class_ious)

classes                 IoU
----------------------------
wall                       :          0.60410           (  7991/13228 )
floor                      :          0.80123           (  8344/10414 )
cabinet                    :          0.48422           (  5722/11817 )
bed                        :          0.56505           (  2810/4973  )
chair error!
sofa                       :          0.77087           (  4562/5918  )
table                      :          0.18677           (  1347/7212  )
door                       :          0.25350           (  1287/5077  )
window                     :          0.65585           (   566/863   )
bookshelf error!
picture error!
counter                    :          0.17139           (   272/1587  )
desk                       :          0.01610           (    32/1987  )
curtain                    :          0.64372           (  4712/7320  )
refrigerator               :          0.66776           (  1015/1520  )
shower curtain error!
toilet      

# ---------------------------------------------------------------------

In [3]:
# should be the preprocessed file path
sample_path_0 = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds/scannet_3d/example/scene0000_00_vh_clean_2.pth"
#sample_path_1 = "D:/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_3d/train/scene0000_01_vh_clean_2.pth"
#sample_path_2 = "D:/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_3d/train/scene0000_02_vh_clean_2.pth"

In [4]:
sample_0 = torch.load(sample_path_0) # coords,colors,labels
#sample_1 = torch.load(sample_path_1) # coords,colors,labels
#sample_2 = torch.load(sample_path_2) # coords,colors,labels

In [5]:
len(sample_0[0])

86414

In [6]:
# aggregating all of the partial point clouds of the same scene (they don't overlap perfectly)
#sample_points = np.concatenate((sample_0[0], sample_1[0], sample_2[0]))
#sample_colors = np.concatenate((sample_0[1], sample_1[1], sample_2[1]))

# single partial point cloud
sample_points  = sample_0[0]
sample_colors = sample_0[1]
sample_labels = sample_0[2]

In [7]:
#to view original scene
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(sample_points))
#original colors
pcd.colors = o3d.utility.Vector3dVector(np.asarray(sample_colors))
#------
#paint uniform
#sample_paint_uniform = np.asarray([200,200,200])/255.0 #redish
#pcd.paint_uniform_color(sample_paint_uniform)
o3d.visualization.draw_geometries([pcd])

# load fused features

In [8]:
# should be the fused feature path
feature_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds/fused/scene0000_00_0.pt"

In [9]:
feature = torch.load(feature_path)

In [10]:
feature["mask_full"].shape

torch.Size([86414])

In [11]:
feature["feat"].shape

torch.Size([82723, 768])

In [12]:
# Get the indices where the mask is True
indices = torch.nonzero(feature["mask_full"]).squeeze()

In [13]:
filtered_point_cloud = sample_points[indices, :]
filtered_point_cloud_colors = sample_colors[indices, :]
filtered_point_cloud_labels = sample_labels[indices]
gt_ids = filtered_point_cloud_labels

In [14]:
np.unique(filtered_point_cloud_labels)

array([  0.,   1.,   2.,   3.,   5.,   6.,   7.,   8.,  11.,  12.,  13.,
        14.,  16.,  17.,  19.,  20.,  21.,  22., 255.])

In [15]:
# Replace every occurrence of 21 with 20 if necessary
gt_ids= np.where(filtered_point_cloud_labels == 21.0, 20.0, filtered_point_cloud_labels)
gt_ids= np.where(gt_ids == 22.0, 20.0, gt_ids)
# gt_ids = filtered_point_cloud_labels

In [15]:
np.unique(gt_ids)

array([  0.,   1.,   2.,   3.,   5.,   6.,   7.,   8.,  11.,  12.,  13.,
        14.,  16.,  17.,  19.,  20.,  21.,  22., 255.])

In [29]:
unique_values, counts = np.unique(gt_ids, return_counts=True)

In [30]:
counts

array([12425, 10081,  9583,  2948,  5098,  4428,  1414,   593,   462,
        1891,  5271,  1404,   454,   362,   642,  1798,  1771,  1476,
       20622])

In [16]:
filtered_point_cloud.shape

(82723, 3)

# using clip model

In [14]:
import clip
model, preprocess = clip.load("ViT-L/14@336px")

In [15]:
# highlight with a threshold
# type the query here 
query = ["dragon"]

with torch.no_grad():
    all_text_embeddings = []
    for category in tqdm(query):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        all_text_embeddings.append(text_embedding)

    all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

# normalizing 
fused_f = (feature["feat"]/(feature["feat"].norm(dim=-1, keepdim=True)+1e-5)).half()
# calculating similarity matrix
# similarity_matrix = torch.matmul(feature["feat"].cuda(), all_text_embeddings) # 
similarity_matrix = fused_f.cuda() @ all_text_embeddings    
    
# set higher to increase the certainty (not always correct)
threshold_percentage = 0.9
cap = similarity_matrix.max().item()
found_indices = torch.nonzero(similarity_matrix > cap*threshold_percentage, as_tuple=False).squeeze().T[0]

# creating pc
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud))
pcd.colors = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud_colors))

found_region = pcd.select_by_index(found_indices.tolist())
found_region.paint_uniform_color([1.0, 0, 0]) # paint related points to red
rest = pcd.select_by_index(found_indices.tolist(), invert=True)
o3d.visualization.draw_geometries([rest,found_region])

100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.34s/it]


In [16]:
# highlight with a heatmap
# type the query here 
# query = ["deathwing"]
# query = [" a blue-faced, yellow-crowned, white-breasted, black-eyed, long-billed, hooked-beak, yellow-beaked, yellow-breasted, yellow-throated and black-tailed bird"]

# mouse-colored tyrannulet 
query = [["grey-bodied","yellow-breasted","black-crowned",
          "white-eyed","black-winged","yellow-throated",
          "white-breasted","yellow-billed","grey-headed",
          "long-tailed","bird"]]

# diamong firetail
query = [["red-breasted","black-crowned","gold-winged",
          "black-winged","white-eyed","yellow-billed",
          "red-headed","black-tailed","long-tailed",
          "white-breasted","bird"]]



#query = ["bird"]
#query = [["Mouse-colored Tyrannulet bird"]]

with torch.no_grad():
    all_text_embeddings = []
    for category in tqdm(query):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        all_text_embeddings.append(text_embedding)

    all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

# normalizing 
fused_f = (feature["feat"]/(feature["feat"].norm(dim=-1, keepdim=True)+1e-5)).half()
# calculating similarity matrix
# similarity_matrix = torch.matmul(feature["feat"].cuda(), all_text_embeddings) # 
similarity_matrix = fused_f.cuda() @ all_text_embeddings    

# creating pc
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud))
pcd.colors = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud_colors))

# heatmap
cmap = plt.get_cmap('cividis')

# normalize the tensor to the range [0, 1]
normalized_tensor = (similarity_matrix - torch.min(similarity_matrix)) / (torch.max(similarity_matrix) - torch.min(similarity_matrix))

colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
pcd_heatmap = o3d.geometry.PointCloud()

pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

#transform heatmap to the side
pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])

o3d.visualization.draw_geometries([pcd, pcd_heatmap])

100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.92it/s]


# mIoU evaluation

In [48]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
                     'table', 'door', 'window', 'bookshelf', 'picture','counter', 'desk', 'curtain', 'refrigerator', 'shower curtain',
                     'toilet', 'sink', 'bathtub', 'otherfurniture']
UNKNOWN_ID = 255
NO_FEATURE_ID = 256

SCANNET_LABELS_20.append(query[0])
#SCANNET_LABELS_20.append("bird")

CLASS_LABELS = SCANNET_LABELS_20

In [51]:
with torch.no_grad():
    label_embeds = []
    for category in tqdm(SCANNET_LABELS_20):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        label_embeds.append(text_embedding)

    label_embeds = torch.stack(label_embeds, dim=1)


100%|███████████████████████████████████████████████████████████████████████████████████| 21/21 [00:00<00:00, 52.34it/s]


In [57]:
print('classes          IoU')
print('----------------------------')
for i in range(N_CLASSES):
    label_name = CLASS_LABELS[i]
    if not isinstance(label_name, str): label_name = target_label
    try:
        print('{0:<14s}: {1:>5.5f}   ({2:>6d}/{3:<6d})'.format(
                label_name,
                class_ious[label_name][0],
                class_ious[label_name][1],
                class_ious[label_name][2]))
    except:
        print(label_name + ' error!')
        continue

classes          IoU
----------------------------
wall          : 0.60079   (  8020/13349 )
floor         : 0.80275   (  8363/10418 )
cabinet       : 0.49356   (  5866/11885 )
bed           : 0.81338   (  2820/3467  )
chair error!
sofa          : 0.79770   (  4586/5749  )
table         : 0.16482   (  1059/6425  )
door          : 0.26445   (  1286/4863  )
window        : 0.65738   (   566/861   )
bookshelf error!
picture error!
counter       : 0.15129   (   300/1983  )
desk          : 0.01605   (    32/1994  )
curtain       : 0.67212   (  4928/7332  )
refrigerator  : 0.66579   (  1014/1523  )
shower curtain error!
toilet        : 0.45985   (   418/909   )
sink          : 0.25487   (   157/616   )
bathtub error!
otherfurniture: 0.01146   (    49/4274  )
bird          : 0.00000   (     0/1301  )


In [16]:
import openai

openai.api_key = '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'

response = openai.Completion.create(
  engine="text-davinci-003",
  prompt="Could you generate 5 visual descriptors for each of the following object classes, they are bird species: [Blue-faced Honeyeate, Diamond Firetail, Mouse-colored Tyrannulet]. The descriptors will be used for input queries for a CLIP model. The descriptors should be concise and distinct from one another. Do not focus on behavior, but purely on attributes which are recognizable by the CLIP model. The output should be in the following form, without any additional text: object class 1, visual descriptor 1.1, visual descriptor 1.2",

  temperature=0.5,
  max_tokens=200
)

In [None]:
# old version of the aggregating text embeddings, it's not properly working
def highlight_query(query, feature_type, model, distill, fused, fpc, fpcc, device):
    
    
    with torch.no_grad():
        all_text_embeddings = []
        for category in tqdm(query):
            texts = clip.tokenize(category)  #tokenize
            texts = texts.to(device)
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            text_embedding = text_embeddings.mean(dim=0)
            text_embedding /= text_embedding.norm()
            all_text_embeddings.append(text_embedding)

        all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

        
    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ all_text_embeddings
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ all_text_embeddings
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ all_text_embeddings
        pred_distill = distill.to(device) @ all_text_embeddings
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble @ all_text_embeddings
        
    print(similarity_matrix.shape)
    # creating pc
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.asarray(fpc))
    pcd.colors = o3d.utility.Vector3dVector(np.asarray(fpcc))

    # heatmap
    cmap = plt.get_cmap('cividis')

    # normalize the tensor to the range [0, 1]
    normalized_tensor = (similarity_matrix - torch.min(similarity_matrix)) / (torch.max(similarity_matrix) - torch.min(similarity_matrix))

    colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
    pcd_heatmap = o3d.geometry.PointCloud()

    pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
    pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

    #transform heatmap to the side
    pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])

    o3d.visualization.draw_geometries([pcd, pcd_heatmap])
    
def evaluate(labelset, descriptors, feature_type, model, distill, fused, gt_ids):
    
    with torch.no_grad():
        label_embeds = []
        for category in tqdm(labelset):
            texts = clip.tokenize(category)  #tokenize
            texts = texts.cuda()
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            text_embedding = text_embeddings.mean(dim=0)
            text_embedding /= text_embedding.norm()
            label_embeds.append(text_embedding)

        label_embeds = torch.stack(label_embeds, dim=1)
        
    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ label_embeds
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ label_embeds
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ label_embeds
        pred_distill = distill.to(device) @ label_embeds
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble.to(device) @ label_embeds
        
    pred_ids = torch.max(similarity_matrix, 1)[1].detach().cpu()    
    
    N_CLASSES = len(labelset)
    confusion = confusion_matrix(pred_ids, gt_ids, N_CLASSES)
    class_ious = {}
    class_accs = {}
    mean_iou = 0
    mean_acc = 0
    
    count = 0
    for i in range(N_CLASSES):
        label_name = labelset[i]

        if not isinstance(label_name, str): 
            for key, value in descriptors.items():
                if value == label_name:
                    label_name = key
                    
        if (gt_ids==i).sum() == 0: # at least 1 point needs to be in the evaluation for this class
            continue


        class_ious[label_name] = get_iou(i, confusion)
        class_accs[label_name] = class_ious[label_name][1] / (gt_ids==i).sum()
        count+=1

        mean_iou += class_ious[label_name][0]
        mean_acc += class_accs[label_name]


    mean_iou /= N_CLASSES
    mean_acc /= N_CLASSES
    
    return class_ious, class_accs, mean_iou, mean_acc
