In [1]:
import os
import torch
import numpy as np
from glob import glob
from tqdm import tqdm
from os.path import join, exists
import open3d as o3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
original = "/mnt/project/AT3DCV_Data/Preprocessed_OpenScene/data/scannet_3d/example/scene0000_00_vh_clean_2.pth"
# need to have distilled features ready
distilled_features = "/mnt/project/AT3DCV_Data/3D_features/scene0000_00_vh_clean_2_openscene_feat_distill.npy"

In [3]:
original_sample = torch.load(original) 
original_sample_points  = original_sample[0]
original_sample_colors = original_sample[1]

In [4]:
len(original_sample[0])

81369

In [5]:
#to view original scene
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(original_sample_points))
#original colors
pcd.colors = o3d.utility.Vector3dVector(np.asarray(original_sample_colors))
#------
#paint uniform
#sample_paint_uniform = np.asarray([200,200,200])/255.0 #redish
#pcd.paint_uniform_color(sample_paint_uniform)
o3d.visualization.draw_geometries([pcd])

# load distilled features

In [7]:
# just load, no need for masking since we have distilled features for every 3D point
distilled = np.load(distilled_features)

# using clip model

In [8]:
import clip
model, preprocess = clip.load("ViT-L/14@336px")

In [9]:
# type the query here 
query = ["bed"]

In [14]:
with torch.no_grad():
    all_text_embeddings = []
    for category in tqdm(query):
        texts = clip.tokenize(category)  #tokenize
        texts = texts.cuda()
        text_embeddings = model.encode_text(texts)  #embed with text encoder
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
        text_embedding = text_embeddings.mean(dim=0)
        text_embedding /= text_embedding.norm()
        all_text_embeddings.append(text_embedding)

    all_text_embeddings = torch.stack(all_text_embeddings, dim=1)

#cast and normalize embeddings
distilled_t = torch.from_numpy(distilled).half()
distilled_t = distilled_t / distilled_t.norm(p=2, dim=-1, keepdim=True)

# calculating similarity matrix
similarity_matrix = torch.matmul(distilled_t.cuda(), all_text_embeddings) # 

# set higher to increase the certainty (not always correct)
threshold_percentage = 0.4
cap = similarity_matrix.max().item()
found_indices = torch.nonzero(similarity_matrix > cap*threshold_percentage, as_tuple=False).squeeze().T[0]

# creating pc
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(np.asarray(original_sample_points))
pcd.colors = o3d.utility.Vector3dVector(np.asarray(original_sample_colors))

found_region = pcd.select_by_index(found_indices.tolist())
found_region.paint_uniform_color([1.0, 0, 0]) # paint related points to red
rest = pcd.select_by_index(found_indices.tolist(), invert=True)
o3d.visualization.draw_geometries([rest,found_region])

100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 32.92it/s]


In [13]:
found_indices.shape

torch.Size([4947])