In [83]:
import os
import numpy as np
import torch

import open3d as o3d

import matplotlib.pyplot as plt
import matplotlib.cm as cm
cmap = cm.get_cmap('jet')

from tqdm import tqdm

  cmap = cm.get_cmap('jet')


In [2]:
MATTERPORT_DIR = "/media/rsl_admin/T7/matterport/data/v1/scans"

FUSED_FEATURES_DIR = "/media/rsl_admin/T7/openscene_fused_features/matterport_multiview_openseg_test"

In [3]:
if 'openseg' in FUSED_FEATURES_DIR:
    embedding_space = 'openseg'
elif 'lseg' in FUSED_FEATURES_DIR:
    embedding_space = 'lseg'
else:
    raise NotImplementedError

##
Define a function which reads the preprocessed 3D data and fused features stored as separate regions and combines them into a single arrays for the scan

In [90]:
import re

In [167]:
def combine_scan_regions(scan_name, data_3d_dir, fused_feat_dir):
    scan_points = []
    scan_colors = []
    scan_fused_feats = []
    
    data_3d_filenames = [fname for fname in os.listdir(data_3d_dir) if scan_name in fname]
    
    for fname in data_3d_filenames:

        match = re.search(
            r'_region(\d+)\.pth', 
            fname
        )
        if match:
            region_number = match.group(1)
        else:
            raise RuntimeError(f"No region number for file {fname}")
            
        fused_feat_fname = f"{scan_name}_region{region_number}_0.pt"
        
        data_fused_feat = torch.load(
            os.path.join(fused_feat_dir, fused_feat_fname)
        )
        
        data_3d = torch.load(
            os.path.join(data_3d_dir, fname)
        )
        
        points = data_3d[0]
        colors = data_3d[1]
        
        # Check dimensions are the same with mask full and the number of points!!!
        assert points.shape[0] == data_fused_feat["mask_full"].shape[0]
        
        scan_points.append(
            points[data_fused_feat["mask_full"]]
        )
        scan_colors.append(
            colors[data_fused_feat["mask_full"]]
        )
        scan_fused_feats.append(
            data_fused_feat["feat"]
        )
        
    scan_points = np.concatenate(scan_points, axis=0)
    scan_colors = np.concatenate(scan_colors, axis=0)
    scan_fused_feats = np.concatenate(scan_fused_feats, axis=0)
    
    return scan_points, scan_colors, scan_fused_feats


In [168]:
scan_points, scan_colors, scan_fused_feats = combine_scan_regions(
    "2t7WUuJeko7",
    "/home/rsl_admin/openscene/openscene/data/matterport_3d/test",
    "/media/rsl_admin/T7/openscene_fused_features/matterport_multiview_openseg_test"
)

## Load the scan mesh for visualization

In [169]:
scan_ply_filepath = f"/media/rsl_admin/T7/matterport/data/v1/scans/{scan_name}/{scan_name}/house_segmentations/{scan_name}.ply"

mesh_ply = o3d.io.read_triangle_mesh(scan_ply_filepath)

In [170]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(scan_points)
pcd.colors = o3d.utility.Vector3dVector(scan_colors)

In [171]:
o3d.visualization.draw_geometries([mesh_ply, pcd])

## Load CLIP text encoder model

In [114]:
import clip

if embedding_space == 'openseg':
    clip_model = 'ViT-L/14@336px'
elif embedding_space == 'lseg':
    clip_model = 'ViT-B/32'
    
clip_pretrained, _ = clip.load(clip_model, device='cuda', jit=False)

In [115]:
def compute_text_embedding(query_string, encoder):
    with torch.no_grad():
        text = clip.tokenize([query_string]).to('cuda')
        text_embedding = encoder.encode_text(text)
        text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
    return text_embedding.cpu().numpy().astype(np.float32)

In [116]:
def compute_point_scores(query_string, voxel_embeddings):
    query_embedding = compute_text_embedding(query_string, clip_pretrained)
    
    # compute the similarity first for each voxel
    similarity = voxel_embeddings @ query_embedding.T
    
    return (similarity - similarity.min()) / (similarity.max() - similarity.min())

In [124]:
scores = compute_point_scores(
    "a fireplace", 
    scan_fused_feats)

query_pcd = o3d.geometry.PointCloud()
query_pcd.points = o3d.utility.Vector3dVector(scan_points)
query_pcd.colors = o3d.utility.Vector3dVector(cmap(scores).reshape(-1,4)[:,:-1])
o3d.visualization.draw_geometries([query_pcd])