In [1]:
import os
import numpy as np
import torch

import open3d as o3d

import matplotlib.pyplot as plt
import matplotlib.cm as cm
cmap = cm.get_cmap('jet')

from tqdm import tqdm

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


  cmap = cm.get_cmap('jet')


In [2]:
MATTERPORT_DIR = "/media/rsl_admin/T7/matterport/data/v1/scans"

FUSED_FEATURES_DIR = "/media/rsl_admin/T7/openscene_fused_features/matterport_multiview_openseg_test"

In [3]:
if 'openseg' in FUSED_FEATURES_DIR:
    embedding_space = 'openseg'
elif 'lseg' in FUSED_FEATURES_DIR:
    embedding_space = 'lseg'
else:
    raise NotImplementedError

In [4]:
scan_name = "2t7WUuJeko7"

## Setup 3D model

In [5]:
'''simple config class to recreate what's done in the openscene code'''
class ModelConfig:
    def __init__(self, feature_2d_extractor, arch_3d):
        self.feature_2d_extractor = feature_2d_extractor
        self.arch_3d = arch_3d

In [6]:
checkpoint_path = "/home/rsl_admin/openscene/checkpoints/matterport_openseg.pth"
# checkpoint_path = "/home/rsl_admin/openscene/checkpoints/scannet_openseg.pth"

checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage.cuda())

In [7]:
from run.distill import get_model

model_cfg = ModelConfig(
    feature_2d_extractor=embedding_space, 
    arch_3d='MinkUNet18A',
)
model = get_model(model_cfg)
model.load_state_dict(checkpoint['state_dict'], strict=True)

model = model.cuda()

In [8]:
from comparison_utils import DisNetRunner

In [9]:
disnet_runner = DisNetRunner(model)

## Process a scan from its regions

In [10]:
import re

In [11]:
def process_scan_regions(scan_name, data_3d_dir, fused_feat_dir, disnet_runner):
    scan_points = []
    scan_colors = []
    scan_fused_feats = []
    scan_distill_feats = []
    
    
    data_3d_filenames = [fname for fname in os.listdir(data_3d_dir) if scan_name in fname]
    fused_feat_filenames = []
    for fname in data_3d_filenames:
        match = re.search(
            r'_region(\d+)\.pth', 
            fname
        )
        if match:
            region_number = match.group(1)
        else:
            raise RuntimeError(f"No region number for pre-processed 3D data file {fname}")
        
        fused_feat_fname = f"{scan_name}_region{region_number}_0.pt"
        if not os.path.isfile(
            os.path.join(fused_feat_dir, fused_feat_fname)
        ):
            raise RuntimeError(f"{fused_feat_fname} doesn't exist")
            
        fused_feat_filenames.append(fused_feat_fname)
    
    
    for data_3d_fname, fused_feat_fname in zip(data_3d_filenames, fused_feat_filenames):

        data_fused_feat = torch.load(os.path.join(fused_feat_dir, fused_feat_fname))
        
        data_3d = torch.load(os.path.join(data_3d_dir, data_3d_fname))
        
        points, colors, _ = data_3d
        
        # Check dimensions are the same with mask full and the number of points!!!
        assert points.shape[0] == data_fused_feat["mask_full"].shape[0]
        
        scan_points.append(
            points[data_fused_feat["mask_full"]]
        )
        scan_colors.append(
            colors[data_fused_feat["mask_full"]]
        )
        scan_fused_feats.append(
            # make sure fused features are normalized
            data_fused_feat["feat"] / (
                np.linalg.norm(data_fused_feat["feat"], axis=1, keepdims=True) + 1e-5
            )
        )
        
        scan_distill_feats.append(
            disnet_runner.run(points[data_fused_feat["mask_full"]])
        )

    
    return {
        "points": np.concatenate(scan_points, axis=0),
        "colors": np.concatenate(scan_colors, axis=0),
        "fused_feats": np.concatenate(scan_fused_feats, axis=0).astype(np.float16),
        "distill_feats": np.concatenate(scan_distill_feats, axis=0).astype(np.float16),
    }


In [12]:
scan_points_feats = process_scan_regions(
    scan_name,
    "/home/rsl_admin/openscene/openscene/data/matterport_3d/test",
    "/media/rsl_admin/T7/openscene_fused_features/matterport_multiview_openseg_test",
    disnet_runner
)

## Load the scan mesh for visualization

In [13]:
scan_ply_filepath = f"/media/rsl_admin/T7/matterport/data/v1/scans/{scan_name}/{scan_name}/house_segmentations/{scan_name}.ply"

mesh_ply = o3d.io.read_triangle_mesh(scan_ply_filepath)

In [14]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(scan_points_feats["points"])
pcd.colors = o3d.utility.Vector3dVector(scan_points_feats["colors"])

In [15]:
o3d.visualization.draw_geometries([mesh_ply, pcd])

## Get Segmentions

Segment using the method done in Openscene

Assume we are given a set of classes to segment out, we query all points with all class text embeddings, and take the max similarity to assign a label to the points

### Set up class vocabulary

In [16]:
from dataset.label_constants import MATTERPORT_LABELS_21

In [17]:
labelset = list(MATTERPORT_LABELS_21)

In [19]:
from util.util import extract_clip_feature


# Modified from original implementation in Openscene
def extract_text_feature(
    labelset, 
    prompt_eng=True,
    feature_2d_extractor='openseg',
    dataset='matterport_3d'
):
    '''extract CLIP text features.'''

    # a bit of prompt engineering
    if prompt_eng:
        print('Use prompt engineering: a XX in a scene')
        labelset = [ "a " + label + " in a scene" for label in labelset]
        
        if dataset == 'scannet_3d':
            labelset[-1] = 'other'
        if dataset == 'matterport_3d':
            labelset[-2] = 'other'
            
    if 'lseg' in feature_2d_extractor:
        text_features = extract_clip_feature(labelset)
    elif 'openseg' in feature_2d_extractor:
        text_features = extract_clip_feature(labelset, model_name="ViT-L/14@336px")
    else:
        raise NotImplementedError

    return text_features.cpu().numpy().astype(np.float16)

In [20]:
label_feats = extract_text_feature(labelset)

Use prompt engineering: a XX in a scene
Loading CLIP ViT-L/14@336px model...
Finish loading


In [21]:
print(label_feats.shape)

(21, 768)


In [27]:
def compute_predictions(
    label_feats,
    fused_feats,
    distill_feats,
    method='fusion',
):
    label_feats = torch.Tensor(label_feats).cuda()
    fused_feats = torch.Tensor(fused_feats).cuda()
    distill_feats = torch.Tensor(distill_feats).cuda()
    
    if method == 'fusion':
        sim = fused_feats @ label_feats.T
        
        print(sim.shape)
        
        
        
    elif method == 'distill':
        raise NotImplementedError
        
    elif method == 'ensemble':
        raise NotImplementedError
        
    
    

In [28]:
compute_predictions(
    label_feats,
    scan_points_feats["fused_feats"],
    scan_points_feats["distill_feats"],
    method="fusion"
)

torch.Size([840678, 21])


In [19]:
# def embed_string(string, encoder):
#     with torch.no_grad():
#         text = clip.tokenize([string]).to('cuda')
#         text_embedding = encoder.encode_text(text)
#         text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
#     return text_embedding.cpu().numpy().astype(np.float32)

In [20]:
# def compute_similarity(
#     query_string, 
#     fused_feat, 
#     distill_feat,
#     method="fusion"
# ):
#     query_embedding = embed_string(query_string, clip_pretrained)
    
#     if method == "fusion":
#         similarity = fused_feat @ query_embedding.T
        
#     elif method == "distill":
#         similarity = distill_feat @ query_embedding.T
        
#     elif method == "ensemble":
#         sim_fusion = fused_feat @ query_embedding.T
#         sim_distill = distill_feat @ query_embedding.T
        
#         similarity = sim_distill
#         use_fusion = sim_fusion > sim_distill
        
#         print(use_fusion.shape)
        
#         similarity[use_fusion] = sim_fusion[use_fusion]
#     else:
#         raise Exception(f"unknown method {method}")
    
#     return similarity

In [22]:
# scores = compute_similarity(
#     "a bed in a scene", 
#     scan_points_feats["fused_feats"],
#     scan_points_feats["distill_feats"],
# #     method="fusion"
# #     method="distill"
#     method="ensemble"
# )

# # under_thresh = scores < 0.2
# # scores[under_thresh] = 0.0

# query_pcd = o3d.geometry.PointCloud()
# query_pcd.points = o3d.utility.Vector3dVector(scan_points_feats["points"])
# query_pcd.colors = o3d.utility.Vector3dVector(cmap(scores).reshape(-1,4)[:,:-1])
# o3d.visualization.draw_geometries([query_pcd])

(840678, 1)
