In [1]:
import os

import pickle

import numpy as np
import torch

import open3d as o3d

import matplotlib.pyplot as plt
import matplotlib.cm as cm
cmap = cm.get_cmap('jet')

from tqdm.notebook import tqdm

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
import ipywidgets as widgets
from IPython.display import display

In [3]:
PREPROCESSED_3D_DIR = "/home/rsl_admin/openscene/openscene/data/matterport_3d/test"
INSTANCE_MASKS_FEATS_DIR = "/home/rsl_admin/openmask3d/cluster_output/test"


OUTPUT_DIR = "/home/rsl_admin/openmask3d/labeled_instance_outputs"

In [4]:
split = os.path.basename(PREPROCESSED_3D_DIR)

print(f"split: {split}")

split: test


##
First check that every 3D region has fused features for that region, pack this data for use later

In [5]:
from collections import namedtuple

In [6]:
RegionDataPaths = namedtuple("RegionDataPaths", ["data_3d_path", "masks_path", "mask_feats_path"])

In [7]:
import re

In [8]:
data_paths = []

for fname in os.listdir(PREPROCESSED_3D_DIR):
    region_name = fname.split('.')[0]
    
    masks_fname = f"{region_name}_masks.pt"
    
    
    masks_path = os.path.join(
        INSTANCE_MASKS_FEATS_DIR, 
        region_name, 
        f"{region_name}_masks.pt"
    )
    if not os.path.isfile(masks_path):
        print(f"no mask file at {masks_path}, skipping")
        continue
    
    
    
    
    mask_feats_path = os.path.join(
        INSTANCE_MASKS_FEATS_DIR, 
        region_name, 
        f"{region_name}_openmask3d_features.npy"
    )
    if not os.path.isfile(mask_feats_path):
        print(f"no mask features file at {mask_feats_path}, skipping")
        continue
    
              
    data_paths.append(
        RegionDataPaths(
            data_3d_path=os.path.join(PREPROCESSED_3D_DIR, fname),
            masks_path=masks_path,
            mask_feats_path=mask_feats_path,
        )
    )

In [9]:
for dp in data_paths:
    print(dp.data_3d_path)
    print(dp.masks_path)
    print(dp.mask_feats_path)
    print()

/home/rsl_admin/openscene/openscene/data/matterport_3d/test/yqstnuAEVhm_region23.pth
/home/rsl_admin/openmask3d/cluster_output/test/yqstnuAEVhm_region23/yqstnuAEVhm_region23_masks.pt
/home/rsl_admin/openmask3d/cluster_output/test/yqstnuAEVhm_region23/yqstnuAEVhm_region23_openmask3d_features.npy

/home/rsl_admin/openscene/openscene/data/matterport_3d/test/q9vSo1VnCiC_region13.pth
/home/rsl_admin/openmask3d/cluster_output/test/q9vSo1VnCiC_region13/q9vSo1VnCiC_region13_masks.pt
/home/rsl_admin/openmask3d/cluster_output/test/q9vSo1VnCiC_region13/q9vSo1VnCiC_region13_openmask3d_features.npy

/home/rsl_admin/openscene/openscene/data/matterport_3d/test/wc2JMjhGNzB_region27.pth
/home/rsl_admin/openmask3d/cluster_output/test/wc2JMjhGNzB_region27/wc2JMjhGNzB_region27_masks.pt
/home/rsl_admin/openmask3d/cluster_output/test/wc2JMjhGNzB_region27/wc2JMjhGNzB_region27_openmask3d_features.npy

/home/rsl_admin/openscene/openscene/data/matterport_3d/test/5ZKStnWn8Zo_region0.pth
/home/rsl_admin/openmask3

### Set up class vocabulary
This step requires the user to select a label set to query the segmentations on

In [10]:
labelset_selection_dropdown  = widgets.Dropdown(
    options=[
        "HABITAT_OGN_LABELS",
        "MATTERPORT_REGION_LABELS",
    ],
    disabled=False,
)

display(labelset_selection_dropdown)

Dropdown(options=('HABITAT_OGN_LABELS', 'MATTERPORT_REGION_LABELS'), value='HABITAT_OGN_LABELS')

In [11]:
if labelset_selection_dropdown.value == "HABITAT_OGN_LABELS":
    from comparison_label_mappings import HABITAT_OGN_LABELS, HABITAT_OGN_LABELS_TO_TEXT_PROMPTS
    
    labelset = list(HABITAT_OGN_LABELS)
    label_text_prompts = [
        HABITAT_OGN_LABELS_TO_TEXT_PROMPTS[label] for label in HABITAT_OGN_LABELS
    ]
    
    labelset_name = "object"
    
    
elif labelset_selection_dropdown.value == "MATTERPORT_REGION_LABELS":
    from comparison_label_mappings import MATTERPORT_REGION_LABELS, MATTERPORT_REGION_LABELS_TO_TEXT_PROMPTS
    
    labelset = list(MATTERPORT_REGION_LABELS)
    label_text_prompts = [
        MATTERPORT_REGION_LABELS_TO_TEXT_PROMPTS[label] for label in MATTERPORT_REGION_LABELS
    ]
    
    labelset_name = "region"
    
else: 
    raise Exception


In [12]:
label_to_ind = {label: i for i, label in enumerate(labelset)}

print(label_to_ind)

{'bathroom': 0, 'bedroom': 1, 'closet': 2, 'dining room': 3, 'garage': 4, 'hallway': 5, 'library': 6, 'laundryroom/mudroom': 7, 'kitchen': 8, 'living room': 9, 'meetingroom/conferenceroom': 10, 'office': 11, 'porch/terrace/deck/driveway': 12, 'rec/game': 13, 'stairs': 14, 'utilityroom/toolroom': 15, 'tv': 16, 'workout/gym/exercise': 17, 'balcony': 18, 'bar': 19, 'classroom': 20, 'spa/sauna': 21, 'entryway/foyer/lobby': 22, 'outdoor': 23, 'dining booth': 24, 'other room': 25}


In [13]:
import clip

In [14]:
def extract_text_feature(labelset, model_name="ViT-L/14@336px"):
    """
    Modified from Openscene
    """
    print("Loading CLIP {} model...".format(model_name))
    clip_pretrained, _ = clip.load(model_name, device='cuda', jit=False)
    clip_pretrained.eval()
    print("Finish loading")

    if isinstance(labelset, str):
        lines = labelset.split(',')
    elif isinstance(labelset, list):
        lines = labelset
    else:
        raise NotImplementedError
    
    labels = []
    for line in lines:
        label = line
        labels.append(label)
    text = clip.tokenize(labels)
    text = text.cuda()
    with torch.no_grad():
        text_features = clip_pretrained.encode_text(text)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    return text_features.cpu().numpy()

In [15]:
label_feats = extract_text_feature(
    label_text_prompts,
)

print(label_feats.shape)

Loading CLIP ViT-L/14@336px model...
Finish loading
(26, 768)


## We may want to remove redundant/overlaping mask instances

In [69]:
REMOVE_REDUNDANT_INSTANCES = True

IOU_THRESHOLD = 0.6

In [70]:
def aabb_intersect(
    min_1, max_1, 
    min_2, max_2
):
    if np.any(max_1 < min_2) or np.any(max_2 < min_1):
        return False
    
    return True

In [71]:
def aabb_iou(    
    min_1, max_1, 
    min_2, max_2
):
    EPS = 1e-5  # could have zero volume (e.g. 2D) boxes
    
    intersect_min = np.maximum(min_1, min_2)
    intersect_max = np.minimum(max_1, max_2)
    
    intersect_lengths = np.maximum(0, intersect_max - intersect_min)
    
    intersect_volume = np.prod(intersect_lengths) + EPS
    
    volume_1 = np.prod(max_1 - min_1) + EPS
    volume_2 = np.prod(max_2 - min_2) + EPS
    
    return intersect_volume / (volume_1 + volume_2 - intersect_volume)

## Generate and save the segmentations

In [72]:
save_dirname = f"{labelset_name}-{split}"

if REMOVE_REDUNDANT_INSTANCES:
    save_dirname = f"filt_iou_{IOU_THRESHOLD}/{save_dirname}"

os.makedirs(
    os.path.join(OUTPUT_DIR, save_dirname),
    exist_ok=True
)

print(f"saving to: {os.path.join(OUTPUT_DIR, save_dirname)}")

saving to: /home/rsl_admin/openmask3d/labeled_instance_outputs/filt_iou_0.6/region-test


In [None]:
for dp in tqdm(data_paths):
    
    data_3d = torch.load(dp.data_3d_path)
    points, colors, _ = data_3d
    
    masks = torch.load(dp.masks_path).astype(bool)
    
    mask_feats = np.load(dp.mask_feats_path)
    
#     print(type(points))
#     print(points.shape)
    
#     print(type(masks))
#     print(masks.shape)
    
#     print(type(mask_feats))
#     print(mask_feats.shape)
    
    assert points.shape[0] == masks.shape[0]
    assert masks.shape[1] == mask_feats.shape[0]
    

    instance_boxes = []
    for i in range(masks.shape[1]):
        instance_points = points[masks[:,i]].astype(np.float16)
        
        min_bound = np.min(instance_points, axis=0)
        max_bound = np.max(instance_points, axis=0)
        
        instance_boxes.append(
            (min_bound, max_bound)
        )
        

    if REMOVE_REDUNDANT_INSTANCES:

        nr_box_inds = []
        
        for box_1_ind, box_1 in enumerate(instance_boxes):
            redundant = False
            
            for box_2_ind in nr_box_inds:
                box_2 = instance_boxes[box_2_ind]
                             
                if not aabb_intersect(box_1[0], box_1[1], box_2[0], box_2[1]):
                    continue
                    
                iou = aabb_iou(box_1[0], box_1[1], box_2[0], box_2[1])
                if iou > IOU_THRESHOLD:
                    redundant = True
                    break
            
            if not redundant:
                nr_box_inds.append(box_1_ind)
        
        
#         print(len(instance_boxes))
#         print(len(nr_box_inds))
        
    
        masks = masks[:, nr_box_inds]
        mask_feats = mask_feats[nr_box_inds, :]
        
        instance_boxes = [
            box for i, box in enumerate(instance_boxes) if i in nr_box_inds
        ]
        
        
    instance_preds = np.argmax(mask_feats @ label_feats.T, axis=1).astype(int)

    
    out = {
        "points": points,
        "masks": masks,
        "instance_preds": instance_preds,
        "instance_min_max_bounds": tuple(instance_boxes),
        "label_to_ind": label_to_ind,
    }
    
    save_filename = os.path.basename(dp.data_3d_path).replace('.pth', '.pkl')
    save_path = os.path.join(
        os.path.join(OUTPUT_DIR, save_dirname, save_filename),
    )
    
    
    with open(save_path, 'wb') as f:
        pickle.dump(out, f)

  0%|          | 0/406 [00:00<?, ?it/s]

## Visualize an output segmentation

In [51]:
MATTERPORT_DIR = "/media/rsl_admin/T7/matterport/data/v1/scans"

In [52]:
np.random.seed(0)

In [53]:
# output_fname = np.random.choice(
#     os.listdir(os.path.join(OUTPUT_DIR, save_dirname))
# )

output_fname = "2t7WUuJeko7_region1.pkl"

output_path = os.path.join(OUTPUT_DIR, save_dirname, output_fname)

with open(output_path, 'rb') as f:
    out = pickle.load(f)
    
print(f"loaded {output_fname}")

loaded 2t7WUuJeko7_region1.pkl


In [54]:
scan_name = output_fname.split('_')[0]

scan_mesh = o3d.io.read_triangle_mesh(
    os.path.join(MATTERPORT_DIR, f"{scan_name}/{scan_name}/house_segmentations/{scan_name}.ply")
)

In [55]:
label_to_ind = out["label_to_ind"]
labels = label_to_ind.keys()

In [56]:
print(labels)

dict_keys(['bathroom', 'bedroom', 'closet', 'dining room', 'garage', 'hallway', 'library', 'laundryroom/mudroom', 'kitchen', 'living room', 'meetingroom/conferenceroom', 'office', 'porch/terrace/deck/driveway', 'rec/game', 'stairs', 'utilityroom/toolroom', 'tv', 'workout/gym/exercise', 'balcony', 'bar', 'classroom', 'spa/sauna', 'entryway/foyer/lobby', 'outdoor', 'dining booth', 'other room'])


In [57]:
label_dropdown = widgets.Dropdown(
    options=['all'] + sorted(labels),
    disabled=False,
)

display(label_dropdown)

Dropdown(options=('all', 'balcony', 'bar', 'bathroom', 'bedroom', 'classroom', 'closet', 'dining booth', 'dini…

In [67]:
instance_inds = np.where(out["instance_preds"] == label_to_ind[label_dropdown.value])[0]

instance_pcds = []
instance_boxes = []
for ind in instance_inds:
    instance_color = np.random.uniform(size=3)
    
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(
        out["points"][out["masks"][:,ind]]
    )
    pcd.paint_uniform_color(instance_color)
    
    instance_pcds.append(pcd)
    
    box_bounds = out["instance_min_max_bounds"][ind]
    box = o3d.geometry.AxisAlignedBoundingBox(
        box_bounds[0], box_bounds[1]
    )
    box.color = instance_color
    
    instance_boxes.append(box)
    
    
    
#     o3d.visualization.draw_geometries(
#         [scan_mesh, pcd, box]
#     )
    

In [68]:
o3d.visualization.draw_geometries(
    [scan_mesh] + instance_pcds + instance_boxes
)