In [1]:
import os
import numpy as np
import torch

import open3d as o3d

import matplotlib.pyplot as plt
import matplotlib.cm as cm
cmap = cm.get_cmap('jet')

from tqdm import tqdm

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


  cmap = cm.get_cmap('jet')


In [2]:
embedding_space = 'openseg'

In [3]:
scan_name = "2t7WUuJeko7"
fused_feats_path = "/media/rsl_admin/T7/openscene_fused_features/matterport_multiview_openseg_test/2t7WUuJeko7_region0_0.pt"
data_3d_path = "/home/rsl_admin/openscene/openscene/data/matterport_3d/test/2t7WUuJeko7_region0.pth"

## Setup 3D model

In [4]:
'''simple config class to recreate what's done in the openscene code'''
class ModelConfig:
    def __init__(self, feature_2d_extractor, arch_3d):
        self.feature_2d_extractor = feature_2d_extractor
        self.arch_3d = arch_3d

In [5]:
checkpoint_path = "/home/rsl_admin/openscene/checkpoints/matterport_openseg.pth"
# checkpoint_path = "/home/rsl_admin/openscene/checkpoints/scannet_openseg.pth"

checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage.cuda())

In [6]:
from run.distill import get_model

model_cfg = ModelConfig(
    feature_2d_extractor=embedding_space, 
    arch_3d='MinkUNet18A',
)
model = get_model(model_cfg)
model.load_state_dict(checkpoint['state_dict'], strict=True)

model = model.cuda()

In [7]:
from comparison_utils import DisNetRunner

In [8]:
disnet_runner = DisNetRunner(model)

## Process a single region

In [9]:
def process_regions(data_3d_path, fused_feat_path, disnet_runner):

    data_fused_feat = torch.load(fused_feat_path)

    data_3d = torch.load(data_3d_path)

    points, colors, _ = data_3d

    # Check dimensions are the same with mask full and the number of points!!!
    assert points.shape[0] == data_fused_feat["mask_full"].shape[0]

    # make sure fused features are normalized
    data_fused_feat["feat"] /= (np.linalg.norm(data_fused_feat["feat"], axis=1, keepdims=True) + 1e-5)

    distill_feats = disnet_runner.run(points[data_fused_feat["mask_full"]])

    return {
        "points": points[data_fused_feat["mask_full"]],
        "colors": colors[data_fused_feat["mask_full"]],
        
        "fused_feats": data_fused_feat["feat"].to(torch.float16),
        "distill_feats": distill_feats.to(torch.float16),
    }


In [10]:
region_points_feats = process_regions(
    data_3d_path,
    fused_feats_path,
    disnet_runner
)

## Load the scan mesh for visualization

In [11]:
scan_ply_filepath = f"/media/rsl_admin/T7/matterport/data/v1/scans/{scan_name}/{scan_name}/house_segmentations/{scan_name}.ply"

mesh_ply = o3d.io.read_triangle_mesh(scan_ply_filepath)

In [12]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(region_points_feats["points"])
pcd.colors = o3d.utility.Vector3dVector(region_points_feats["colors"])

In [13]:
o3d.visualization.draw_geometries([mesh_ply, pcd])

## Get Segmentions

Segment using the method done in Openscene

Assume we are given a set of classes to segment out, we query all points with all class text embeddings, and take the max similarity to assign a label to the points

### Set up class vocabulary

In [14]:
from dataset.label_constants import MATTERPORT_LABELS_21

In [15]:
labelset = list(MATTERPORT_LABELS_21)

label_to_ind = {label: i for i, label in enumerate(labelset)}

print(label_to_ind)

{'wall': 0, 'floor': 1, 'cabinet': 2, 'bed': 3, 'chair': 4, 'sofa': 5, 'table': 6, 'door': 7, 'window': 8, 'bookshelf': 9, 'picture': 10, 'counter': 11, 'desk': 12, 'curtain': 13, 'refrigerator': 14, 'shower curtain': 15, 'toilet': 16, 'sink': 17, 'bathtub': 18, 'other': 19, 'ceiling': 20}


In [16]:
from util.util import extract_clip_feature


# Modified from original implementation in Openscene
def extract_text_feature(
    labelset, 
    prompt_eng=True,
    feature_2d_extractor='openseg',
    dataset='matterport_3d'
):
    '''extract CLIP text features.'''

    # a bit of prompt engineering
    if prompt_eng:
        print('Use prompt engineering: a XX in a scene')
        labelset = [ "a " + label + " in a scene" for label in labelset]
        
        if dataset == 'scannet_3d':
            labelset[-1] = 'other'
        if dataset == 'matterport_3d':
            labelset[-2] = 'other'
            
    if 'lseg' in feature_2d_extractor:
        text_features = extract_clip_feature(labelset)
    elif 'openseg' in feature_2d_extractor:
        text_features = extract_clip_feature(labelset, model_name="ViT-L/14@336px")
    else:
        raise NotImplementedError

    return text_features.to(torch.float16)

In [17]:
label_feats = extract_text_feature(labelset)

Use prompt engineering: a XX in a scene
Loading CLIP ViT-L/14@336px model...
Finish loading


In [18]:
print(label_feats.shape)

torch.Size([21, 768])


In [19]:
def compute_predictions(
    label_feats,
    fused_feats,
    distill_feats,
    method='fusion',
):
    assert fused_feats.shape == distill_feats.shape
    assert fused_feats.shape[1] == label_feats.shape[1]
    
    # doing the matmul on gpu is way faster...
    label_feats = label_feats.cuda()
    fused_feats = fused_feats.cuda()
    distill_feats = distill_feats.cuda()
    
    if method == 'fusion':
        sim = fused_feats @ label_feats.T
        pred = torch.argmax(sim, dim=1)
        
    elif method == 'distill':
        sim = distill_feats @ label_feats.T
        pred = torch.argmax(sim, dim=1)
        
    
    elif method == 'ensemble':
        sim_fusion = fused_feats @ label_feats.T
        sim_distill = distill_feats @ label_feats.T
        
        max_sim_fusion, argmax_sim_fusion = torch.max(sim_fusion, dim=1)
        max_sim_distill, argmax_sim_distill = torch.max(sim_distill, dim=1)
        
        pred = argmax_sim_distill
        use_fusion = max_sim_fusion > max_sim_distill
        pred[use_fusion] = argmax_sim_fusion[use_fusion]
        
    
    return pred.cpu().numpy()

In [28]:
pred = compute_predictions(
    label_feats,
    region_points_feats["fused_feats"],
    region_points_feats["distill_feats"],
    method="ensemble"
)

In [30]:
label = "desk"

label_inds = np.where(pred == label_to_ind[label])[0]

o3d.visualization.draw_geometries(
    [
        mesh_ply, 
        pcd.select_by_index(label_inds).paint_uniform_color((0,1,0))
    ]
)

In [43]:
label_inds

array([     1,      3,      4, ..., 165838, 165858, 165920])

In [19]:
# def embed_string(string, encoder):
#     with torch.no_grad():
#         text = clip.tokenize([string]).to('cuda')
#         text_embedding = encoder.encode_text(text)
#         text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
#     return text_embedding.cpu().numpy().astype(np.float32)

In [20]:
# def compute_similarity(
#     query_string, 
#     fused_feat, 
#     distill_feat,
#     method="fusion"
# ):
#     query_embedding = embed_string(query_string, clip_pretrained)
    
#     if method == "fusion":
#         similarity = fused_feat @ query_embedding.T
        
#     elif method == "distill":
#         similarity = distill_feat @ query_embedding.T
        
#     elif method == "ensemble":
#         sim_fusion = fused_feat @ query_embedding.T
#         sim_distill = distill_feat @ query_embedding.T
        
#         similarity = sim_distill
#         use_fusion = sim_fusion > sim_distill
        
#         print(use_fusion.shape)
        
#         similarity[use_fusion] = sim_fusion[use_fusion]
#     else:
#         raise Exception(f"unknown method {method}")
    
#     return similarity

In [22]:
# scores = compute_similarity(
#     "a bed in a scene", 
#     scan_points_feats["fused_feats"],
#     scan_points_feats["distill_feats"],
# #     method="fusion"
# #     method="distill"
#     method="ensemble"
# )

# # under_thresh = scores < 0.2
# # scores[under_thresh] = 0.0

# query_pcd = o3d.geometry.PointCloud()
# query_pcd.points = o3d.utility.Vector3dVector(scan_points_feats["points"])
# query_pcd.colors = o3d.utility.Vector3dVector(cmap(scores).reshape(-1,4)[:,:-1])
# o3d.visualization.draw_geometries([query_pcd])

(840678, 1)
