In [None]:
import os
import torch
import numpy as np
from glob import glob
from tqdm.notebook import tqdm
from os.path import join, exists
import open3d as o3d
import matplotlib.pyplot as plt
from itertools import combinations
import copy
from tabulate import tabulate

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
def load_scene(path, visualize = True):
    
    sample = torch.load(path)
    sample_points  = sample[0]
    sample_colors = sample[1]
    sample_labels = sample[2]
    
    if visualize:
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(np.asarray(sample_points))
        pcd.colors = o3d.utility.Vector3dVector(np.asarray(sample_colors))
        o3d.visualization.draw_geometries([pcd])
        
    return sample_points, sample_colors, sample_labels

def load_fused_features(path, sample_points, sample_colors, sample_labels):
    
    features = torch.load(path)
    indices = torch.nonzero(features["mask_full"]).squeeze()
    filtered_point_cloud = sample_points[indices, :]
    filtered_point_cloud_colors = sample_colors[indices, :]
    filtered_point_cloud_labels = sample_labels[indices]
    fused_features = (features["feat"]/(features["feat"].norm(dim=-1, keepdim=True)+1e-5))
    
    return fused_features, filtered_point_cloud, filtered_point_cloud_colors, filtered_point_cloud_labels, indices

def load_distilled_features(path, indices):
    
    distilled = np.load(path)
    #cast and normalize embeddings for distilled 
    distilled = distilled[indices, :]
    distilled_t = torch.from_numpy(distilled).half()
    distilled_f = (distilled_t/(distilled_t.norm(dim=-1, keepdim=True)+1e-5))
    
    return distilled_f

def draw_improvement(sim_old, sim_new, fpc, scale=None):
    improvement = sim_new - sim_old
    
    if scale is None:
        scale = 2 * torch.max(improvement.max(), -improvement.min())
    
    print(f'Scaled by {scale}')
    improvement = improvement / (2*scale) + 0.5

    # heatmap
    cmap = plt.get_cmap('bwr')
    
    colors = cmap(improvement.detach().cpu().numpy().squeeze())
    pcd_heatmap = o3d.geometry.PointCloud()
    pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(fpc))
    pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])
    
    #o3d.visualization.draw_plotly([pcd_heatmap])
    o3d.visualization.draw_geometries([pcd_heatmap])

def highlight_query(query, feature_type, agg_type, distill, fused, fpc, fpcc, device, draw=True, scale=1, 
                    quantile=0.5, norm=False):
    
    import clip
    model, preprocess = clip.load("ViT-L/14@336px")
    
    with torch.no_grad():
        per_descriptor_embeds = []
        for descriptor in tqdm(query):
            _prompt = descriptor
            print(_prompt)
            texts = clip.tokenize(_prompt)  #tokenize
            texts = texts.to(device)
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            per_descriptor_embeds.append(text_embeddings)

        per_descriptor_embeds = torch.stack(per_descriptor_embeds, dim=1).squeeze()

    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ per_descriptor_embeds.T
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ per_descriptor_embeds.T
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ per_descriptor_embeds.T
        pred_distill = distill.to(device) @ per_descriptor_embeds.T
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble.to(device) @ per_descriptor_embeds.T
        
    if similarity_matrix.ndim == 2:
        if agg_type == "mean":
            agg_sim_mat = torch.mean(similarity_matrix, dim=1)
        elif agg_type == "max":
            agg_sim_mat, _ = torch.max(similarity_matrix, dim=1)
        elif agg_type == "median":
            agg_sim_mat, _ = torch.median(similarity_matrix, dim=1)
        elif agg_type == "quantile":
            agg_sim_mat = torch.quantile(similarity_matrix.float(), quantile, dim=1)
        elif agg_type == "min":
            agg_sim_mat, _ = torch.min(similarity_matrix, dim=1)
        elif agg_type == "bare-overlay-weight-max":
            maximum, _ = similarity_matrix[:,1:].max(dim=0)
            weight = maximum / maximum.sum()
            agg_sim_mat = similarity_matrix[:,0] + weight[None,:] @ similarity_matrix[:,1:].T 
        elif agg_type == "bare-overlay-mean":
            weight = None
            agg_sim_mat = similarity_matrix[:,0] + similarity_matrix[:,1:].mean(axis=1)
        else:
            raise NotImplementedError()
    else: 
        agg_sim_mat = similarity_matrix
        
    if norm:
        agg_sim_mat -= agg_sim_mat.mean()
        
    agg_sim_mat = agg_sim_mat.reshape(-1, 1)
    
    # creating pc
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.asarray(fpc))
    pcd.colors = o3d.utility.Vector3dVector(np.asarray(fpcc))

    # heatmap
    cmap = plt.get_cmap('bwr')

    # normalize the tensor to the range [0, 1]
    print(f'Min: {torch.min(agg_sim_mat)}')
    print(f'Max: {torch.max(agg_sim_mat)}')
    normalized_tensor = scale * agg_sim_mat + 0.5
    #normalized_tensor = (agg_sim_mat - torch.min(agg_sim_mat)) / (torch.max(agg_sim_mat) - torch.min(agg_sim_mat))

    colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
    pcd_heatmap = o3d.geometry.PointCloud()

    pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
    pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

    #transform heatmap to the side
    
    #pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])
    #o3d.visualization.draw_geometries([pcd, pcd_heatmap])
    
    pcd_heatmap.points = pcd.points
    if draw:
        #o3d.visualization.draw_plotly([pcd_heatmap])
        o3d.visualization.draw_geometries([pcd_heatmap])
    
    return agg_sim_mat

def confusion_matrix(pred_ids, gt_ids, num_classes):
    '''calculate the confusion matrix.'''

    assert pred_ids.shape == gt_ids.shape, (pred_ids.shape, gt_ids.shape)
    idxs = gt_ids != UNKNOWN_ID
    if NO_FEATURE_ID in pred_ids: # some points have no feature assigned for prediction
        print("no features")
        pred_ids[pred_ids==NO_FEATURE_ID] = num_classes
        confusion = np.bincount(
            pred_ids[idxs] * (num_classes+1) + gt_ids[idxs],
            minlength=(num_classes+1)**2).reshape((
            num_classes+1, num_classes+1)).astype(np.ulonglong)
        return confusion[:num_classes, :num_classes]

    return np.bincount(
        pred_ids[idxs] * num_classes + gt_ids[idxs],
        minlength=num_classes**2).reshape((
        num_classes, num_classes)).astype(np.ulonglong)

def get_iou(label_id, confusion, gts):
    '''calculate IoU.'''

    # true positives
    tp = np.longlong(confusion[label_id, label_id])
    # false positives
    fp = np.longlong(confusion[label_id, :].sum()) - tp
    # false negatives
    fn = np.longlong(confusion[:, label_id].sum()) - tp

    total = np.sum(gts == label_id)
    denom = (tp + fp + fn)
    if denom == 0:
        return float('nan')
    return float(tp) / denom, tp, denom, total

def evaluate(labelset, descriptors, feature_type, agg_type, distill, fused, gt_ids):
    import clip
    model, preprocess = clip.load("ViT-L/14@336px")
    
    descriptor_lengths = []
    
    with torch.no_grad():
        label_embeds = []
        for category in labelset:
            if not isinstance(category, str): # if not string, process in another loop
                descriptor_lengths.append(len(category)) # get length of descriptors
                for desc in category:
                    texts = clip.tokenize(desc)  #tokenize
                    texts = texts.cuda()
                    text_embeddings = model.encode_text(texts)  #embed with text encoder
                    text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
                    label_embeds.append(text_embeddings)
            else: # if string, just process
                _prompt = f'a {category} in a scene' 
                texts = clip.tokenize(_prompt)  #tokenize
                texts = texts.cuda()
                text_embeddings = model.encode_text(texts)  #embed with text encoder
                text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
                label_embeds.append(text_embeddings)
                
    label_embeds = torch.cat(label_embeds, dim=0) # has the shape of [768, *original labels*+*descriptors*]

    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ label_embeds.T
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ label_embeds.T
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ label_embeds.T
        pred_distill = distill.to(device) @ label_embeds.T
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble.to(device) @ label_embeds.T
    
    # separating the similarity matrix for the original labels and descriptors, 20 only for ScanNet
    sim_labels = similarity_matrix[:, :20]
    sim_descriptors = similarity_matrix[:, 20:]
            
    # aggregate the corresponding descriptor vectors by the length of each
    # keep them in a list to stack it later
    _idx = 0
    agg_desc = []
    for elem in descriptor_lengths:
        sim_descriptor = sim_descriptors[:, _idx : _idx + elem]
        if agg_type == "mean":
            agg_desc_sim_mat = torch.mean(sim_descriptor, dim=1)
        elif agg_type == "max":
            agg_desc_sim_mat, _ = torch.max(sim_descriptor, dim=1)
        elif agg_type == "median":
            agg_desc_sim_mat, _ = torch.median(sim_descriptor, dim=1)
        elif agg_type == "min":
            agg_desc_sim_mat, _ = torch.min(sim_descriptor, dim=1)
        else:
            raise NotImplementedError()
            
        _idx += elem
        agg_desc.append(agg_desc_sim_mat)
        
    # stack aggregated descriptor similarity matrices
    agg_sim_descriptors = torch.stack(agg_desc, dim = 1)
    # combine with the similarity matrix of labels
    agg_sim_mat = torch.cat([sim_labels,agg_sim_descriptors], dim = 1)

    # get the predictions
    pred_ids = torch.max(agg_sim_mat, 1)[1].detach().cpu()    
    
    N_CLASSES = len(labelset)
    
    confusion = confusion_matrix(pred_ids, gt_ids, N_CLASSES)
    class_ious = {}
    class_accs = {}
    mean_iou = 0
    mean_acc = 0
    
    count = 0
    for i in range(N_CLASSES):
        label_name = labelset[i]
        
        if not isinstance(label_name, str): 
            for key, value in descriptors.items():
                if value == label_name:
                    label_name = key
        if (gt_ids==i).sum() == 0: # at least 1 point needs to be in the evaluation for this class
            continue
            
        class_ious[label_name] = get_iou(i, confusion, gt_ids)
        class_accs[label_name] = class_ious[label_name][1] / (gt_ids==i).sum()
        count+=1

        mean_iou += class_ious[label_name][0]
        mean_acc += class_accs[label_name]

    mean_iou /= N_CLASSES
    mean_acc /= N_CLASSES
    
    return class_ious, class_accs, mean_iou, mean_acc, confusion

def print_results(labelset, class_ious, descriptors):
    
    print('classes                 IoU/ total')
    print('----------------------------')
    for i in range(len(labelset)):
        label_name = labelset[i]
        if not isinstance(label_name, str): 
            for key, value in descriptors.items():
                if value == label_name:
                    label_name = key
        try:
            print('{0:<14s}             :          {1:>5.5f}           ({2:>6d}/{3:<6d}   /{4:<6d})'.format(
                    label_name,
                    class_ious[label_name][0],
                    class_ious[label_name][1],
                    class_ious[label_name][2],
                    class_ious[label_name][3]))
        except:
            print(label_name + ' error!')
            continue
            
def print_results_table(labelset, class_ious, descriptors):
    
    results = [["classes","IoU", "tp/(tp + fp + fn)", "total points" ]]
    
    for i in range(len(labelset)):
        label_name = labelset[i]
        if not isinstance(label_name, str): 
            for key, value in descriptors.items():
                if value == label_name:
                    label_name = key
        try:
            results.append([label_name,
                            format(class_ious[label_name][0], '.5f'),
                            f'{class_ious[label_name][1]}/{class_ious[label_name][2]}',
                            class_ious[label_name][3]])
        except:
            results.append([label_name, "---", "---","---"])        
            continue
            
        table = tabulate(results, headers="firstrow", tablefmt="rounded_outline")
        
    print(table)
        
def descriptors_from_prompt(text, verbose = True):
    
    import openai

    openai.api_key = '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'

    response = openai.Completion.create(
      engine="text-davinci-003",
      prompt=text,

      temperature=0.5,
      max_tokens=200
    )
    
    if verbose:
        print(response["choices"][0].text)
    
    lines = [s for s in [line.strip() for line in response["choices"][0].text.splitlines()] if s]
    
    descriptors = {}
    
    for line in lines:
        parts = line.split(":")
        key = parts[0].strip()
        features = [f'a bird which has {f.strip()}.' for f in parts[1].split(",")]
        descriptors[key] = features
        
    return descriptors


def combinations_descriptor(descriptors, subset_len):
    
    combinations_dict= {}
    for key, value in descriptors.items():
        value_combinations = list(combinations(value, subset_len))
        combinations_dict[key] = value_combinations
    
    comb_dict_list = []
    for i in range(len(combinations_dict[next(iter(combinations_dict))])):
        temp = {}
        for key in combinations_dict.keys():
            temp[key] = [str(item) for item in combinations_dict[key][i]]
        comb_dict_list.append(temp)
        
    return comb_dict_list

def try_diff_combs(labelset, comb_dict_list, feature_type, agg_type, distill, fused, gt_ids):

    class_IoU_result_list = []
    class_accs_result_list = []
    mean_iou_result_list = []
    mean_acc_result_list = []
    
    for elem in tqdm(comb_dict_list):
        temp_labelset = copy.deepcopy(labelset)
        
        for key, value in elem.items():
            temp_labelset.append(value)
            
        class_ious, class_accs, mean_iou, mean_acc = evaluate(temp_labelset, elem, feature_type, agg_type , distill, fused, gt_ids)
        class_IoU_result_list.append(class_ious)
        class_accs_result_list.append(class_accs)
        mean_iou_result_list.append(mean_iou)
        mean_acc_result_list.append(mean_acc)
    
    return class_IoU_result_list, class_accs_result_list, mean_iou_result_list, mean_acc_result_list

In [None]:
# load all the required data
#source_path = "/home/aleks/3dcv/openseg_aug/chair_scene_alex/scannet_3d/example/scene0000_00_vh_clean_2.pth"
#fused_path = "/home/aleks/3dcv/openseg_aug/chair_scene_alex/fused/scene0000_00_0.pt"
#distilled_path = "/home/aleks/3dcv/openseg_aug/chair_scene_alex/features_3D/scene0000_00_vh_clean_2_openscene_feat_distill.npy"

source_path = "/home/aleks/3dcv/openseg_aug_new/scannet_3d/example/scene0000_00_vh_clean_2.pth"
fused_path = "/home/aleks/3dcv/openseg_aug_new/fused/scene0000_00_0.pt"
distilled_path = "/home/aleks/3dcv/openseg_aug_new/features_3D/scene0000_00_vh_clean_2_openscene_feat_distill.npy"

source_points, source_colors, source_labels = load_scene(source_path, False)

fused_f, filtered_pc, filtered_pc_c, filtered_pc_labels, indices = load_fused_features(fused_path,
                                                                              source_points, 
                                                                              source_colors,
                                                                              source_labels)
distilled_f = load_distilled_features(distilled_path, indices)

In [None]:
descriptors = {
    "stool": [
        "Stool with backless",
        "Backless stool",
        "Stool with cylindrical legs",
        "Small stool",
        "Stool with flat seat",
        "Simple stool",
        "Stool with no armrests",
        "Wooden stool",
        "Stool with round seat",
        "Metal stool",
    ],
    "armchair": [
        "Armchair with padded arms",
        "Comfy armchair",
        "Armchair with high back",
        "Upholstered armchair",
        "Armchair with wooden frame",
        "Leather armchair",
        "Armchair with cushioned seat",
        "Modern armchair",
        "Armchair with decorative legs",
        "Vintage armchair",
    ],
    "rocking chair": [
        "Rocking chair with curved runners",
        "Wooden rocking chair",
        "Rocking chair with slat back",
        "Outdoor rocking chair",
        "Upholstered rocking chair",
        "Rocking chair with armrests",
        "Classic rocking chair",
        "Rocking chair with wide seat",
        "Modern rocking chair",
        "Nursery rocking chair",
    ],
    "ball chair": [
        "Round ball chair",
        "Modern ball chair",
        "High back ball chair",
        "Hanging ball chair",
        "Swivel base ball chair",
        "Outdoor ball chair",
        "Padded seat ball chair",
        "Enclosed ball chair",
        "Sphere-shaped ball chair",
        "Retro ball chair",
    ],
}

In [None]:
descriptors_bare = {
    "stool": ["backless", "compact", "straight legs", "round seat", "simple design", 
              "no armrests", "low height", "versatile", "footrest", "wooden"],
    
    "armchair": ["upholstered", "armrests", "comfortable", "cushioned", "high back", 
                 "wingback", "padded seat", "elegant design", "curved legs", "fabric"],
    
    "rocking chair": ["curved runners", "sloping back", "rocking motion", "wooden frame", "comfortable seat", 
                      "traditional design", "armrests", "relaxing", "classic", "country-style"],
    
    "ball chair": ["sphere-shaped", "modern", "futuristic", "enclosed space", "swivel base", 
                   "comfortable cushion", "unique design", "bubble chair", "acrylic material", "iconic"]
}

In [None]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 
                     'sofa', 'table', 'door', 'window', 'bookshelf', 
                     'picture','counter', 'desk', 'curtain', 'refrigerator', 
                     'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture']

SCANNET_LABELS_AUG = SCANNET_LABELS_20 + list(descriptors.keys())

UNKNOWN_ID = 255
NO_FEATURE_ID = 256

agg_type = 'mean'
feature_type = 'fused'

In [None]:
feature_type = 'ensembled'
obj_class = 'ball chair'

sim_class = highlight_query(['armchair'], feature_type, 'mean', distilled_f, fused_f, filtered_pc, 
                            filtered_pc_c, device, draw=True, scale=2.)

for d in descriptors['armchair']:
    sim_class = highlight_query([d], feature_type, 'mean', distilled_f, fused_f, filtered_pc, 
                                filtered_pc_c, device, draw=True, scale=2.)

#sim_desc  = highlight_query(descriptors['armchair'], feature_type, 'mean', 
#                            distilled_f, fused_f, filtered_pc, filtered_pc_c, device, draw=True, scale=2.)

sim_desc  = highlight_query(descriptors_bare[obj_class], feature_type, 'quantile', 
                            distilled_f, fused_f, filtered_pc, filtered_pc_c, device, draw=True, scale=5.,
                            quantile=0.4, norm=True)

#sim_desc  = highlight_query([obj_class] + descriptors_bare[obj_class], feature_type, 'bare-overlay-mean', 
#                            distilled_f, fused_f, filtered_pc, filtered_pc_c, device, draw=True, scale=2.)

draw_improvement(sim_class, sim_desc, filtered_pc)

In [None]:
#similarity_results = {}
#difference_results = {}

In [None]:
for class_name in SCANNET_LABELS_AUG:
    if class_name not in similarity_results.keys():
        similarity_results[class_name] = highlight_query([class_name], feature_type, agg_type, 
                                                         distilled_f, fused_f, filtered_pc, 
                                                         filtered_pc_c, device, draw=False)

for class_name, descriptor_list in descriptors.items():
    for descriptor in descriptor_list:
        if descriptor not in similarity_results.keys():
            similarity = highlight_query([descriptor], feature_type, agg_type, 
                                         distilled_f, fused_f, filtered_pc, 
                                         filtered_pc_c, device, draw=False)
            similarity_results[descriptor] = similarity
            difference_results[descriptor] = similarity - similarity_results[class_name]

In [None]:
#similarity_results_bare = {}
#difference_results_bare = {}

In [None]:
for class_name in SCANNET_LABELS_AUG:
    if class_name not in similarity_results_bare.keys():
        similarity_results_bare[class_name] = similarity_results[class_name]

for class_name, descriptor_list in descriptors_bare.items():
    for descriptor in descriptor_list:
        if descriptor not in similarity_results_bare.keys():
            similarity = highlight_query([descriptor], feature_type, agg_type, 
                                         distilled_f, fused_f, filtered_pc, 
                                         filtered_pc_c, device, draw=False)
            similarity_results_bare[descriptor] = similarity
            difference_results_bare[descriptor] = similarity - similarity_results_bare[class_name]

In [None]:
for obj_class, descs_class in descriptors.items():
    gt_label = SCANNET_LABELS_AUG.index(obj_class)
    other_chair_mask = np.all([filtered_pc_labels != gt_label, 
                               20 <= filtered_pc_labels, 
                               filtered_pc_labels < len(SCANNET_LABELS_AUG)], axis=0)
    
    y_labels = []
    x_similarity_other_general = []
    x_difference_other_general = []
    
    x_similarity_other_chair = []
    x_difference_other_chair = []
    
    
    sim = similarity_results[obj_class]
        
    avg_similarity_class = sim[filtered_pc_labels == gt_label].mean().cpu()
    avg_similarity_other_chair = sim[other_chair_mask].mean().cpu()
    avg_similarity_other_general = sim[filtered_pc_labels != gt_label].mean().cpu()

    y_labels.append(obj_class)

    x_similarity_other_general.append(avg_similarity_class-avg_similarity_other_general)
    x_similarity_other_chair.append(avg_similarity_class-avg_similarity_other_chair)
    
    for desc in descs_class:
        sim = similarity_results[desc]
        
        avg_similarity_class = sim[filtered_pc_labels == gt_label].mean().cpu()
        avg_similarity_other_chair = sim[other_chair_mask].mean().cpu()
        avg_similarity_other_general = sim[filtered_pc_labels != gt_label].mean().cpu()
        
        diff = difference_results[desc]
        
        avg_difference_class = diff[filtered_pc_labels == gt_label].mean().cpu()
        avg_difference_other_chair = diff[other_chair_mask].mean().cpu()
        avg_difference_other_general = diff[filtered_pc_labels != gt_label].mean().cpu()
        
        y_labels.append(desc)
        
        x_similarity_other_general.append(avg_similarity_class-avg_similarity_other_general)
        x_difference_other_general.append(avg_difference_class-avg_difference_other_general)
        x_similarity_other_chair.append(avg_similarity_class-avg_similarity_other_chair)
        x_difference_other_chair.append(avg_difference_class-avg_difference_other_chair)

    sorted_indices = sorted(range(len(y_labels)), key=lambda i: x_similarity_other_chair[i])
        
    plt.figure(figsize=(8,4))
    Y_axis = np.arange(len(y_labels))
    plt.barh(Y_axis-0.2, [x_similarity_other_general[i] for i in sorted_indices], 0.4, color='blue',
             label='to other objects')
    plt.barh(Y_axis+0.2, [x_similarity_other_chair[i] for i in sorted_indices], 0.4, color='orange',
             label='to other chairs')
    plt.yticks(Y_axis, [y_labels[i] for i in sorted_indices])
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1))
    plt.title(f'Average similarity difference of \'{obj_class}\'')


print('Done!')

In [None]:
for obj_class, descs_class in descriptors_bare.items():
    gt_label = SCANNET_LABELS_AUG.index(obj_class)
    other_chair_mask = np.all([filtered_pc_labels != gt_label, 
                               20 <= filtered_pc_labels, 
                               filtered_pc_labels < len(SCANNET_LABELS_AUG)], axis=0)
    
    y_labels = []
    x_similarity_other_general = []
    x_difference_other_general = []
    
    x_similarity_other_chair = []
    x_difference_other_chair = []
    
    
    sim = similarity_results_bare[obj_class]
        
    avg_similarity_class = sim[filtered_pc_labels == gt_label].mean().cpu()
    avg_similarity_other_chair = sim[other_chair_mask].mean().cpu()
    avg_similarity_other_general = sim[filtered_pc_labels != gt_label].mean().cpu()

    y_labels.append(obj_class)

    x_similarity_other_general.append(avg_similarity_class-avg_similarity_other_general)
    x_similarity_other_chair.append(avg_similarity_class-avg_similarity_other_chair)
    
    for desc in descs_class:
        sim = similarity_results_bare[desc]
        
        avg_similarity_class = sim[filtered_pc_labels == gt_label].mean().cpu()
        avg_similarity_other_chair = sim[other_chair_mask].mean().cpu()
        avg_similarity_other_general = sim[filtered_pc_labels != gt_label].mean().cpu()
        
        diff = difference_results_bare[desc]
        
        avg_difference_class = diff[filtered_pc_labels == gt_label].mean().cpu()
        avg_difference_other_chair = diff[other_chair_mask].mean().cpu()
        avg_difference_other_general = diff[filtered_pc_labels != gt_label].mean().cpu()
        
        y_labels.append(desc)
        
        x_similarity_other_general.append(avg_similarity_class-avg_similarity_other_general)
        x_difference_other_general.append(avg_difference_class-avg_difference_other_general)
        x_similarity_other_chair.append(avg_similarity_class-avg_similarity_other_chair)
        x_difference_other_chair.append(avg_difference_class-avg_difference_other_chair)

    sorted_indices = sorted(range(len(y_labels)), key=lambda i: x_similarity_other_chair[i])
        
    plt.figure(figsize=(8,4))
    Y_axis = np.arange(len(y_labels))
    plt.barh(Y_axis-0.2, [x_similarity_other_general[i] for i in sorted_indices], 0.4, 
             label='to other objects', color='blue')
    plt.barh(Y_axis+0.2, [x_similarity_other_chair[i] for i in sorted_indices], 0.4, 
             label='to other chairs', color='orange')
    plt.yticks(Y_axis, [y_labels[i] for i in sorted_indices])
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1))
    plt.title(f'Average similarity difference of \'{obj_class}\'')


print('Done!')

In [None]:
for _descriptors in various_descriptors:
    
    SCANNET_LABELS_DESC = SCANNET_LABELS_20.copy()
    for key, value in _descriptors.items():
        SCANNET_LABELS_DESC.append(value)
        
    print(SCANNET_LABELS_DESC)
    print(f'Aggregation: {agg_type}')
    print(f'Features: {feature_type}')
        
    class_ious, class_accs, mean_iou, mean_acc, confusion = \
        evaluate(SCANNET_LABELS_DESC, _descriptors, feature_type, agg_type , distilled_f, fused_f, filtered_pc_labels)
    
    print(f'Mean acc: {mean_acc}')
    
    col_sums = confusion.sum(axis=0)
    col_sums[col_sums==0] = 1
    confusion = confusion / col_sums[np.newaxis, :]
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(confusion)
    fig.colorbar(cax)
    
    #ax.set_xticks(np.arange(len(SCANNET_LABELS_AUG)), labels=SCANNET_LABELS_AUG)
    ax.set_yticks(np.arange(len(SCANNET_LABELS_AUG)), labels=SCANNET_LABELS_AUG)
    #plt.setp([tick.label for tick in ax.xaxis.get_major_ticks()], rotation=45, ha="left",
    #     rotation_mode="anchor")

    plt.show()
    
    print_results_table(SCANNET_LABELS_DESC, class_ious, _descriptors)

In [None]:
print_results_table(SCANNET_LABELS_20, class_ious, descriptors)

# experiment with different descriptor combinations

In [None]:
#parse descriptors from openai api with gpt, set _nr to the number of descriptors you'd want to retrieve
_nr = 10
_prompt = f'Generate {str(_nr)} visual descriptors for each of the following categories, they are bird species: [Blue-faced Honeyeater, Diamond Firetail, Mouse-colored Tyrannulet]. The descriptors will be used for input queries for a CLIP model. The descriptors should be concise and distinct from the descriptors of the other classes. Do not focus on behavior, but purely on attributes which are recognizable by the CLIP model. The output should be in the following form as a string: *class*: *descriptor1*, *descriptor2*, etc.'
descriptors = descriptors_from_prompt(_prompt, verbose = False)

In [None]:
# check if retrieved descriptors are in proper shape
descriptors

In [None]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
                     'table', 'door', 'window', 'bookshelf', 'picture','counter', 'desk', 'curtain', 'refrigerator', 'shower curtain',
                     'toilet', 'sink', 'bathtub', 'otherfurniture']
UNKNOWN_ID = 255
NO_FEATURE_ID = 256

In [None]:
# get combinations of 5 out of *_nr* descriptors for each class
comb_dict_list = combinations_descriptor(descriptors, 5)

In [None]:
# iterate all the combinations and store their results in a list
class_IoU_result_list, class_accs_result_list, mean_iou_result_list, mean_acc_result_list = try_diff_combs(SCANNET_LABELS_20, comb_dict_list, "fused", "mean", distilled_f, fused_f, filtered_pc_labels)

In [None]:
# 0th index of the list corresponds to the descriptors that belongs to the 0th index of the comb_dict_list
class_IoU_result_list[0]

In [None]:
c1, c2, c3 = [],[],[], # Blue-faced Honeyeater, Diamond Firetail, Mouse-colored Tyrannulet

# store tp/ (tp + fp + fn) values in list per augmented class
for idx in range(len(class_IoU_result_list)):
    c1.append(class_IoU_result_list[idx]["Blue-faced Honeyeater"][0])
    c2.append(class_IoU_result_list[idx]["Diamond Firetail"][0])
    c3.append(class_IoU_result_list[idx]["Mouse-colored Tyrannulet"][0])

In [None]:
max(c3)

In [None]:
# these 5 descriptors gives the highest class IoU for Mouse-colored Tyrannulet
comb_dict_list[c3.index(max(c3))]['Mouse-colored Tyrannulet']

# ---------------------------------------------------------------------

In [None]:
# should be the preprocessed file path
sample_path_0 = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds/scannet_3d/example/scene0000_00_vh_clean_2.pth"
#sample_path_1 = "D:/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_3d/train/scene0000_01_vh_clean_2.pth"
#sample_path_2 = "D:/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_3d/train/scene0000_02_vh_clean_2.pth"

In [None]:
sample_0 = torch.load(sample_path_0) # coords,colors,labels
#sample_1 = torch.load(sample_path_1) # coords,colors,labels
#sample_2 = torch.load(sample_path_2) # coords,colors,labels

In [None]:
len(sample_0[0])

In [None]:
# aggregating all of the partial point clouds of the same scene (they don't overlap perfectly)
#sample_points = np.concatenate((sample_0[0], sample_1[0], sample_2[0]))
#sample_colors = np.concatenate((sample_0[1], sample_1[1], sample_2[1]))

# single partial point cloud
sample_points  = sample_0[0]
sample_colors = sample_0[1]
sample_labels = sample_0[2]

In [None]:
#to view original scene
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(sample_points))
#original colors
pcd.colors = o3d.utility.Vector3dVector(np.asarray(sample_colors))
#------
#paint uniform
#sample_paint_uniform = np.asarray([200,200,200])/255.0 #redish
#pcd.paint_uniform_color(sample_paint_uniform)
o3d.visualization.draw_geometries([pcd])

# load fused features

In [None]:
# should be the fused feature path
feature_path = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/augmented/birds/fused/scene0000_00_0.pt"

In [None]:
feature = torch.load(feature_path)

In [None]:
feature["mask_full"].shape

In [None]:
feature["feat"].shape

In [None]:
# Get the indices where the mask is True
indices = torch.nonzero(feature["mask_full"]).squeeze()

In [None]:
filtered_point_cloud = sample_points[indices, :]
filtered_point_cloud_colors = sample_colors[indices, :]
filtered_point_cloud_labels = sample_labels[indices]
gt_ids = filtered_point_cloud_labels

In [None]:
np.unique(filtered_point_cloud_labels)

In [None]:
# Replace every occurrence of 21 with 20 if necessary
gt_ids= np.where(filtered_point_cloud_labels == 21.0, 20.0, filtered_point_cloud_labels)
gt_ids= np.where(gt_ids == 22.0, 20.0, gt_ids)
# gt_ids = filtered_point_cloud_labels

In [None]:
np.unique(gt_ids)

In [None]:
unique_values, counts = np.unique(gt_ids, return_counts=True)

In [None]:
counts

In [None]:
filtered_point_cloud.shape

# using clip model

In [None]:
import clip
model, preprocess = clip.load("ViT-L/14@336px")

In [None]:
# highlight with a threshold
# type the query here 
query = ["dragon"]

with torch.no_grad():
    all_text_embeddings = []
    for category in tqdm(query):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        all_text_embeddings.append(text_embedding)

    all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

# normalizing 
fused_f = (feature["feat"]/(feature["feat"].norm(dim=-1, keepdim=True)+1e-5)).half()
# calculating similarity matrix
# similarity_matrix = torch.matmul(feature["feat"].cuda(), all_text_embeddings) # 
similarity_matrix = fused_f.cuda() @ all_text_embeddings    
    
# set higher to increase the certainty (not always correct)
threshold_percentage = 0.9
cap = similarity_matrix.max().item()
found_indices = torch.nonzero(similarity_matrix > cap*threshold_percentage, as_tuple=False).squeeze().T[0]

# creating pc
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud))
pcd.colors = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud_colors))

found_region = pcd.select_by_index(found_indices.tolist())
found_region.paint_uniform_color([1.0, 0, 0]) # paint related points to red
rest = pcd.select_by_index(found_indices.tolist(), invert=True)
o3d.visualization.draw_geometries([rest,found_region])

In [None]:
# highlight with a heatmap
# type the query here 
# query = ["deathwing"]
# query = [" a blue-faced, yellow-crowned, white-breasted, black-eyed, long-billed, hooked-beak, yellow-beaked, yellow-breasted, yellow-throated and black-tailed bird"]

# mouse-colored tyrannulet 
query = [["grey-bodied","yellow-breasted","black-crowned",
          "white-eyed","black-winged","yellow-throated",
          "white-breasted","yellow-billed","grey-headed",
          "long-tailed","bird"]]

# diamong firetail
query = [["red-breasted","black-crowned","gold-winged",
          "black-winged","white-eyed","yellow-billed",
          "red-headed","black-tailed","long-tailed",
          "white-breasted","bird"]]



#query = ["bird"]
#query = [["Mouse-colored Tyrannulet bird"]]

with torch.no_grad():
    all_text_embeddings = []
    for category in tqdm(query):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        all_text_embeddings.append(text_embedding)

    all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

# normalizing 
fused_f = (feature["feat"]/(feature["feat"].norm(dim=-1, keepdim=True)+1e-5)).half()
# calculating similarity matrix
# similarity_matrix = torch.matmul(feature["feat"].cuda(), all_text_embeddings) # 
similarity_matrix = fused_f.cuda() @ all_text_embeddings    

# creating pc
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud))
pcd.colors = o3d.utility.Vector3dVector(np.asarray(filtered_point_cloud_colors))

# heatmap
cmap = plt.get_cmap('cividis')

# normalize the tensor to the range [0, 1]
normalized_tensor = (similarity_matrix - torch.min(similarity_matrix)) / (torch.max(similarity_matrix) - torch.min(similarity_matrix))

colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
pcd_heatmap = o3d.geometry.PointCloud()

pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

#transform heatmap to the side
pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])

o3d.visualization.draw_geometries([pcd, pcd_heatmap])

# mIoU evaluation

In [None]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 'sofa',
                     'table', 'door', 'window', 'bookshelf', 'picture','counter', 'desk', 'curtain', 'refrigerator', 'shower curtain',
                     'toilet', 'sink', 'bathtub', 'otherfurniture']
UNKNOWN_ID = 255
NO_FEATURE_ID = 256

SCANNET_LABELS_20.append(query[0])
#SCANNET_LABELS_20.append("bird")

CLASS_LABELS = SCANNET_LABELS_20

In [None]:
with torch.no_grad():
    label_embeds = []
    for category in tqdm(SCANNET_LABELS_20):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        label_embeds.append(text_embedding)

    label_embeds = torch.stack(label_embeds, dim=1)


In [None]:
print('classes          IoU')
print('----------------------------')
for i in range(N_CLASSES):
    label_name = CLASS_LABELS[i]
    if not isinstance(label_name, str): label_name = target_label
    try:
        print('{0:<14s}: {1:>5.5f}   ({2:>6d}/{3:<6d})'.format(
                label_name,
                class_ious[label_name][0],
                class_ious[label_name][1],
                class_ious[label_name][2]))
    except:
        print(label_name + ' error!')
        continue

In [None]:
import openai

openai.api_key = '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'

response = openai.Completion.create(
  engine="text-davinci-003",
  prompt="Could you generate 5 visual descriptors for each of the following object classes, they are bird species: [Blue-faced Honeyeate, Diamond Firetail, Mouse-colored Tyrannulet]. The descriptors will be used for input queries for a CLIP model. The descriptors should be concise and distinct from one another. Do not focus on behavior, but purely on attributes which are recognizable by the CLIP model. The output should be in the following form, without any additional text: object class 1, visual descriptor 1.1, visual descriptor 1.2",

  temperature=0.5,
  max_tokens=200
)

In [None]:
# old version of the aggregating text embeddings, it's not properly working
def highlight_query(query, feature_type, model, distill, fused, fpc, fpcc, device):
    
    
    with torch.no_grad():
        all_text_embeddings = []
        for category in tqdm(query):
            texts = clip.tokenize(category)  #tokenize
            texts = texts.to(device)
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            text_embedding = text_embeddings.mean(dim=0)
            text_embedding /= text_embedding.norm()
            all_text_embeddings.append(text_embedding)

        all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

        
    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ all_text_embeddings
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ all_text_embeddings
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ all_text_embeddings
        pred_distill = distill.to(device) @ all_text_embeddings
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble @ all_text_embeddings
        
    print(similarity_matrix.shape)
    # creating pc
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.asarray(fpc))
    pcd.colors = o3d.utility.Vector3dVector(np.asarray(fpcc))

    # heatmap
    cmap = plt.get_cmap('cividis')

    # normalize the tensor to the range [0, 1]
    normalized_tensor = (similarity_matrix - torch.min(similarity_matrix)) / (torch.max(similarity_matrix) - torch.min(similarity_matrix))

    colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
    pcd_heatmap = o3d.geometry.PointCloud()

    pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
    pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

    #transform heatmap to the side
    pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])

    o3d.visualization.draw_geometries([pcd, pcd_heatmap])
    
def evaluate(labelset, descriptors, feature_type, model, distill, fused, gt_ids):
    
    with torch.no_grad():
        label_embeds = []
        for category in tqdm(labelset):
            texts = clip.tokenize(category)  #tokenize
            texts = texts.cuda()
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            text_embedding = text_embeddings.mean(dim=0)
            text_embedding /= text_embedding.norm()
            label_embeds.append(text_embedding)

        label_embeds = torch.stack(label_embeds, dim=1)
        
    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ label_embeds
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ label_embeds
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ label_embeds
        pred_distill = distill.to(device) @ label_embeds
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble.to(device) @ label_embeds
        
    pred_ids = torch.max(similarity_matrix, 1)[1].detach().cpu()    
    
    N_CLASSES = len(labelset)
    confusion = confusion_matrix(pred_ids, gt_ids, N_CLASSES)
    class_ious = {}
    class_accs = {}
    mean_iou = 0
    mean_acc = 0
    
    count = 0
    for i in range(N_CLASSES):
        label_name = labelset[i]

        if not isinstance(label_name, str): 
            for key, value in descriptors.items():
                if value == label_name:
                    label_name = key
                    
        if (gt_ids==i).sum() == 0: # at least 1 point needs to be in the evaluation for this class
            continue


        class_ious[label_name] = get_iou(i, confusion)
        class_accs[label_name] = class_ious[label_name][1] / (gt_ids==i).sum()
        count+=1

        mean_iou += class_ious[label_name][0]
        mean_acc += class_accs[label_name]


    mean_iou /= N_CLASSES
    mean_acc /= N_CLASSES
    
    return class_ious, class_accs, mean_iou, mean_acc
