In [None]:
import os
import torch
import numpy as np
from glob import glob
from tqdm.notebook import tqdm
from os.path import join, exists
import open3d as o3d
import matplotlib.pyplot as plt
from itertools import combinations
import copy
from tabulate import tabulate
import ipywidgets as widgets
from IPython.display import display, clear_output

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)

In [None]:
def load_scene(path, visualize = True):
    
    sample = torch.load(path)
    sample_points  = sample[0]
    sample_colors = sample[1]
    sample_labels = sample[2]
    
    if visualize:
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(np.asarray(sample_points))
        pcd.colors = o3d.utility.Vector3dVector(np.asarray(sample_colors))
        o3d.visualization.draw_geometries([pcd])
        
    return sample_points, sample_colors, sample_labels

def load_fused_features(path, sample_points, sample_colors, sample_labels):
    
    features = torch.load(path)
    indices = torch.nonzero(features["mask_full"]).squeeze()
    filtered_point_cloud = sample_points[indices, :]
    filtered_point_cloud_colors = sample_colors[indices, :]
    filtered_point_cloud_labels = sample_labels[indices]
    fused_features = (features["feat"]/(features["feat"].norm(dim=-1, keepdim=True)+1e-5))
    
    return fused_features, filtered_point_cloud, filtered_point_cloud_colors, filtered_point_cloud_labels, indices

def load_distilled_features(path, indices):
    
    distilled = np.load(path)
    #cast and normalize embeddings for distilled 
    distilled = distilled[indices, :]
    distilled_t = torch.from_numpy(distilled).half()
    distilled_f = (distilled_t/(distilled_t.norm(dim=-1, keepdim=True)+1e-5))
    
    return distilled_f

def highlight_query(query, feature_type, agg_type, distill, fused, fpc, fpcc, device, scale=1, 
                    quantile=0.5):
    
    import clip
    model, preprocess = clip.load("ViT-L/14@336px")
    
    with torch.no_grad():
        per_descriptor_embeds = []
        for descriptor in tqdm(query):
            _prompt = descriptor
            #_prompt = f'a {descriptor} in a scene'
            texts = clip.tokenize(_prompt)  #tokenize
            texts = texts.to(device)
            text_embeddings = model.encode_text(texts)  #embed with text encoder
            text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
            per_descriptor_embeds.append(text_embeddings)

        per_descriptor_embeds = torch.stack(per_descriptor_embeds, dim=1).squeeze()

    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ per_descriptor_embeds.T
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ per_descriptor_embeds.T
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ per_descriptor_embeds.T
        pred_distill = distill.to(device) @ per_descriptor_embeds.T
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble.to(device) @ per_descriptor_embeds.T
        
    if similarity_matrix.ndim == 2:
        if agg_type == "mean":
            agg_sim_mat = torch.mean(similarity_matrix, dim=1)
        elif agg_type == "quantile":
            agg_sim_mat = torch.quantile(similarity_matrix.float(), quantile, dim=1)
        elif agg_type == "overlay":
            agg_sim_mat = similarity_matrix[:,0] + similarity_matrix[:,1:].mean(axis=1)
        else:
            raise NotImplementedError()
    else: 
        agg_sim_mat = similarity_matrix
        
    agg_sim_mat = agg_sim_mat.reshape(-1, 1)
    
    # creating pc
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(np.asarray(fpc))
    pcd.colors = o3d.utility.Vector3dVector(np.asarray(fpcc))

    # heatmap
    cmap = plt.get_cmap('bwr')

    # normalize the tensor to the range [0, 1]
    print(f'Min: {torch.min(agg_sim_mat)}')
    print(f'Max: {torch.max(agg_sim_mat)}')
    normalized_tensor = scale * agg_sim_mat + 0.5
    #normalized_tensor = (agg_sim_mat - torch.min(agg_sim_mat)) / (torch.max(agg_sim_mat) - torch.min(agg_sim_mat))

    colors = cmap(normalized_tensor.detach().cpu().numpy().squeeze())
    pcd_heatmap = o3d.geometry.PointCloud()

    pcd_heatmap.points = o3d.utility.Vector3dVector(pcd.points)
    pcd_heatmap.colors = o3d.utility.Vector3dVector(colors[:, :3])

    #transform heatmap to the side
    
    #pcd_heatmap.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) + [0,10,0])
    #o3d.visualization.draw_geometries([pcd, pcd_heatmap])
    
    pcd_heatmap.points = pcd.points
    
    #o3d.visualization.draw_plotly([pcd_heatmap])
    o3d.visualization.draw_geometries([pcd_heatmap])
    
    return agg_sim_mat

def classification(labelset, descriptors, feature_type, agg_type, distill, fused, gt_ids, 
                   ignore_ids=[], quantile=0.5):
    import clip
    model, preprocess = clip.load("ViT-L/14@336px")
    
    descriptor_lengths = []
    
    with torch.no_grad():
        label_embeds = []
        for category in labelset:
            if not isinstance(category, str): # if not string, process in another loop
                descriptor_lengths.append(len(category)) # get length of descriptors
                for desc in category:
                    texts = clip.tokenize(desc)  #tokenize
                    texts = texts.cuda()
                    text_embeddings = model.encode_text(texts)  #embed with text encoder
                    text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
                    label_embeds.append(text_embeddings)
            else: # if string, just process
                _prompt = category
                #_prompt = f'a {category} in a scene' 
                texts = clip.tokenize(_prompt)  #tokenize
                texts = texts.cuda()
                text_embeddings = model.encode_text(texts)  #embed with text encoder
                text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
                label_embeds.append(text_embeddings)
                
    label_embeds = torch.cat(label_embeds, dim=0) # has the shape of [768, *original labels*+*descriptors*]

    if feature_type == "fused":
        similarity_matrix = fused.to(device) @ label_embeds.T
    elif feature_type == "distilled":
        similarity_matrix = distill.to(device) @ label_embeds.T
    elif feature_type == "ensembled":
        pred_fusion = fused.to(device) @ label_embeds.T
        pred_distill = distill.to(device) @ label_embeds.T
        feat_ensemble = distill.clone().half()
        mask_ = pred_distill.max(dim=-1)[0] < pred_fusion.max(dim=-1)[0]
        feat_ensemble[mask_] = fused_f[mask_]
        similarity_matrix = feat_ensemble.to(device) @ label_embeds.T
    
    if len(descriptors) == 0:
        agg_sim_mat = similarity_matrix
    else:
        # separating the similarity matrix for the original labels and descriptors, 20 only for ScanNet
        sim_labels = similarity_matrix[:, :20]
        sim_descriptors = similarity_matrix[:, 20:]

        # aggregate the corresponding descriptor vectors by the length of each
        # keep them in a list to stack it later
        _idx = 0
        agg_desc = []
        for elem in descriptor_lengths:
            sim_descriptor = sim_descriptors[:, _idx : _idx + elem]
            if agg_type == "mean":
                agg_desc_sim_mat = torch.mean(sim_descriptor, dim=1)
            elif agg_type == "quantile":
                agg_desc_sim_mat = torch.quantile(sim_descriptor.float(), quantile, dim=1)
            elif agg_type == "overlay":
                agg_desc_sim_mat = sim_descriptor[:,0] + sim_descriptor[:,1:].mean(axis=1)
            else:
                raise NotImplementedError()

            _idx += elem
            agg_desc.append(agg_desc_sim_mat)

        # stack aggregated descriptor similarity matrices
        agg_sim_descriptors = torch.stack(agg_desc, dim = 1)
        # combine with the similarity matrix of labels
        agg_sim_mat = torch.cat([sim_labels,agg_sim_descriptors], dim = 1)
        
    agg_sim_mat[:,ignore_ids] = -1
    
    # get the predictions
    pred_ids = torch.max(agg_sim_mat, 1)[1].detach().cpu()    
    
    return pred_ids

In [None]:
# load all the required data
#source_path = "/home/aleks/3dcv/openseg_aug/chair_scene_alex/scannet_3d/example/scene0000_00_vh_clean_2.pth"
#fused_path = "/home/aleks/3dcv/openseg_aug/chair_scene_alex/fused/scene0000_00_0.pt"
#distilled_path = "/home/aleks/3dcv/openseg_aug/chair_scene_alex/features_3D/scene0000_00_vh_clean_2_openscene_feat_distill.npy"

source_path = "/home/aleks/3dcv/openseg_aug_new/scannet_3d/example/scene0000_00_vh_clean_2.pth"
fused_path = "/home/aleks/3dcv/openseg_aug_new/fused/scene0000_00_0.pt"
distilled_path = "/home/aleks/3dcv/openseg_aug_new/features_3D/scene0000_00_vh_clean_2_openscene_feat_distill.npy"

source_points, source_colors, source_labels = load_scene(source_path, False)

fused_f, filtered_pc, filtered_pc_c, filtered_pc_labels, indices = load_fused_features(fused_path,
                                                                              source_points, 
                                                                              source_colors,
                                                                              source_labels)
distilled_f = load_distilled_features(distilled_path, indices)

In [None]:
descriptors_bare = {
    "stool": ["backless", "compact", "straight legs", "round seat", "simple design", 
              "no armrests", "low height", "versatile", "footrest", "wooden"],
    
    "armchair": ["upholstered", "armrests", "comfortable", "cushioned", "high back", 
                 "wingback", "padded seat", "elegant design", "curved legs", "fabric"],
    
    "rocking chair": ["curved runners", "sloping back", "rocking motion", "wooden frame", "comfortable seat", 
                      "traditional design", "armrests", "relaxing", "classic", "country-style"],
    
    "ball chair": ["sphere-shaped", "modern", "futuristic", "enclosed space", "swivel base", 
                   "comfortable cushion", "unique design", "bubble chair", "acrylic material", "iconic"]
}

In [None]:
descriptors_combined = {
    "stool": [
        "Stool with backless",
        "Backless stool",
        "Stool with cylindrical legs",
        "Small stool",
        "Stool with flat seat",
        "Simple stool",
        "Stool with no armrests",
        "Wooden stool",
        "Stool with round seat",
        "Metal stool",
    ],
    "armchair": [
        "Armchair with padded arms",
        "Comfy armchair",
        "Armchair with high back",
        "Upholstered armchair",
        "Armchair with wooden frame",
        "Leather armchair",
        "Armchair with cushioned seat",
        "Modern armchair",
        "Armchair with decorative legs",
        "Vintage armchair",
    ],
    "rocking chair": [
        "Rocking chair with curved runners",
        "Wooden rocking chair",
        "Rocking chair with slat back",
        "Outdoor rocking chair",
        "Upholstered rocking chair",
        "Rocking chair with armrests",
        "Classic rocking chair",
        "Rocking chair with wide seat",
        "Modern rocking chair",
        "Nursery rocking chair",
    ],
    "ball chair": [
        "Curved shape ball chair",
        "Modern ball chair",
        "High back ball chair",
        "Hanging ball chair",
        "Swivel base ball chair",
        "Outdoor ball chair",
        "Padded seat ball chair",
        "Wicker ball chair",
        "Metal frame ball chair",
        "Retro ball chair",
    ],
}

In [None]:
SCANNET_LABELS_20 = ['wall', 'floor', 'cabinet', 'bed', 'chair', 
                     'sofa', 'table', 'door', 'window', 'bookshelf', 
                     'picture','counter', 'desk', 'curtain', 'refrigerator', 
                     'shower curtain', 'toilet', 'sink', 'bathtub', 'otherfurniture']

SCANNET_LABELS_AUG = SCANNET_LABELS_20 + list(descriptors_bare.keys())

UNKNOWN_ID = 255
NO_FEATURE_ID = 256

In [None]:
def desc_change(change):
    if change['new'] == 'class-label':
        agg.layout.display = 'none'
        quant.layout.display = 'none'
    else:
        agg.layout.display = 'block'
        if agg.value == 'quantile':
            quant.layout.display = 'flex'
        else:
            quant.layout.display = 'none'
        
def agg_change(change):
    if change['new'] == 'quantile':
        quant.layout.display = 'flex'
    else:
        quant.layout.display = 'none'
        
desc = widgets.ToggleButtons(
    options=['class-label', 'combined', 'bare'],
    value='class-label',
    description='Descriptors',
    disabled=False,
)

class_name = widgets.Text(
    value='',
    placeholder='e.g. rocking chair',
    description='Class name',
    disabled=False   
)

feat = widgets.ToggleButtons(
    options=['ensembled', 'fused', 'distilled'],
    value='ensembled',
    description='Features',
    disabled=False,
)

agg = widgets.ToggleButtons(
    options=['mean', 'quantile'],
    value='mean',
    description='Aggregation',
    disabled=False,
)

quant = widgets.FloatSlider(
    value=0.5,
    min=0,
    max=1.0,
    step=0.05,
    description='Quantile',
    disabled=False,
    #continuous_update=False,
    #orientation='horizontal',
    #readout=True,
    #readout_format='.2f',
)

button_sim = widgets.Button(
    description='Visualize',
    disabled=False,
    button_style='',
    tooltip='Visualize',
    icon='check'
)

button_class = widgets.Button(
    description='Visualize',
    disabled=False,
    button_style='',
    tooltip='Visualize',
    icon='check'
)

In [None]:
display(class_name, desc, feat, agg, quant, button_sim)
quant.layout.display = 'flex'
if desc.value == 'class-label':
    agg.layout.display = 'none'
    quant.layout.display = 'none'
if agg.value != 'quantile':
    quant.layout.display = 'none'

desc.observe(desc_change, 'value')
agg.observe(agg_change, 'value')

def visualize_similarity(b):            
    print('Visualizing Similarity...')
    
    if desc.value == 'class-label':
        sim_class = highlight_query([class_name.value], feat.value, 'mean', distilled_f, fused_f, 
                                    filtered_pc, filtered_pc_c, device, scale=2., quantile=quant.value)
    elif desc.value == 'bare':
        sim_class = highlight_query(descriptors_bare[class_name.value], feat.value, agg.value, distilled_f, fused_f, 
                                    filtered_pc, filtered_pc_c, device, scale=2., quantile=quant.value)
    elif desc.value == 'combined':
        sim_class = highlight_query(descriptors_combined[class_name.value], feat.value, agg.value, distilled_f, fused_f, 
                                    filtered_pc, filtered_pc_c, device, scale=2., quantile=quant.value)

button_sim.on_click(visualize_similarity)

In [None]:
display(desc, feat, agg, quant, button_class)
quant.layout.display = 'flex'
if desc.value == 'class-label':
    agg.layout.display = 'none'
    quant.layout.display = 'none'
if agg.value != 'quantile':
    quant.layout.display = 'none'

desc.observe(desc_change, 'value')
agg.observe(agg_change, 'value')

def visualize_classification(b):             
    print('Visualizing Classification...')
    
    SCANNET_LABELS_DESC = SCANNET_LABELS_20.copy()
    descriptors = {}
    if desc.value == 'class-label':
        for key, value in descriptors_bare.items():
            SCANNET_LABELS_DESC.append(key)
    elif desc.value == 'bare':
        descriptors = descriptors_bare
        for key, value in descriptors_bare.items():
            SCANNET_LABELS_DESC.append(value)
    elif desc.value == 'combined':
        descriptors = descriptors_combined
        for key, value in descriptors_combined.items():
            SCANNET_LABELS_DESC.append(value)
        
    pred = classification(SCANNET_LABELS_DESC, descriptors, feat.value, agg.value, 
                          distilled_f, fused_f, filtered_pc_labels, ignore_ids=[4], quantile=quant.value)
    
    colors = np.zeros((len(SCANNET_LABELS_DESC), 3))
    
    colors[20] = [1,0,0]
    colors[21] = [0,1,0]
    colors[22] = [0,0,1]
    colors[23] = [1,0,1]

    pc_colors = np.zeros(filtered_pc_c.shape)
    for cl in range(0, len(colors)):
        pc_colors[pred == cl] = colors[cl]
    
    pcd_class = o3d.geometry.PointCloud()
    pcd_class.points = o3d.utility.Vector3dVector(filtered_pc)
    pcd_class.colors = o3d.utility.Vector3dVector(pc_colors)
    
    o3d.visualization.draw_geometries([pcd_class])

    
button_class.on_click(visualize_classification)