###### ScanNet

This notebook lets you instantiate the an inference dataset from scratch and visualize **3D+2D room samples**.
Default settings: **5cm voxel resolution** and **960x720 image resolution**. 

The dataset is composed of **rooms** of video acquisitions of indoor scenes. These video streams were used to produce a point cloud and images.


In [1]:
# Select you GPU
I_GPU = 0


In [2]:
# Uncomment to use autoreload
%load_ext autoreload
%autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet_inference import ScannetDatasetMM_Inference
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES, CLASS_LABELS
from torch_points3d.metrics.segmentation_tracker import SegmentationTracker

from pykeops.torch import LazyTensor

import matplotlib.pyplot as plt 

%matplotlib inline

MMData debug() function changed, please uncomment the 3rd assert line when doing inference without M2F features!


In [3]:
CLASS_COLORS[0] = (174.0, 199.0, 232.0)
CLASS_COLORS[-1] = (0, 0, 0)
CLASS_COLORS


[(174.0, 199.0, 232.0),
 (152.0, 223.0, 138.0),
 (31.0, 119.0, 180.0),
 (255.0, 187.0, 120.0),
 (188.0, 189.0, 34.0),
 (140.0, 86.0, 75.0),
 (255.0, 152.0, 150.0),
 (214.0, 39.0, 40.0),
 (197.0, 176.0, 213.0),
 (148.0, 103.0, 189.0),
 (196.0, 156.0, 148.0),
 (23.0, 190.0, 207.0),
 (247.0, 182.0, 210.0),
 (219.0, 219.0, 141.0),
 (255.0, 127.0, 14.0),
 (158.0, 218.0, 229.0),
 (44.0, 160.0, 44.0),
 (112.0, 128.0, 144.0),
 (227.0, 119.0, 194.0),
 (82.0, 84.0, 163.0),
 (0, 0, 0)]

If `visualize_mm_data` does not throw any error but the visualization does not appear, you may need to change your plotly renderer below.

In [4]:
import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

## Dataset creation

The following will instantiate the dataset. If the data is not found at `DATA_ROOT`, the folder structure will be created there and the raw dataset will be downloaded there. 

**Memory-friendly tip** : if you have already downloaded the dataset once and simply want to instantiate a new dataset with different preprocessing (*e.g* change 3D or 2D resolution, mapping parameterization, etc), I recommend you manually replicate the folder hierarchy of your already-existing dataset and create a symlink to its `raw/` directory to avoid downloading and storing (very) large files twice.

You will find the config file ruling the dataset creation at `conf/data/segmentation/multimodal/scannet-sparse.yaml`. You may edit this file or create new configs inheriting from this one using Hydra and create the associated dataset by modifying `dataset_config` accordingly in the following cell.

In [5]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/home/fsun/data/inference_data/dva_processed'

dataset_config = 'segmentation/multimodal/Feng/inference'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D_small_6views'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks

cfg.data.m2f_preds_dirname = 'ViT_masks'
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

6


As long as you do not change core dataset parameters, preprocessing should only be performed once for your dataset. It may take some time, **mostly depending on the 3D and 2D resolutions** you choose to work with (the larger the slower).

In [6]:
# Dataset instantiation
start = time()
dataset = ScannetDatasetMM_Inference(cfg.data)
print(f"Time = {time() - start:0.1f} sec.")

Load predicted 2D semantic segmentation labels from directory  ViT_masks
initialize test dataset
line 720 scannet.py: split == 'test'
Time = 0.0 sec.


In [7]:
mm_data = dataset.test_dataset[0][0]
mm_data

Rotating PCD with get_Rx(270)
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DB78C50>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DB78C50>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CF01410>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mas

self.ref_size:  (960, 720)
self.m2f_pred_mask.shape:  torch.Size([128, 1, 720, 960])


MMData(
    data = Data(coords=[135545, 3], grid_size=[1], id_scan=[1], mapping_index=[135545], mvfusion_input=[84012, 6, 10], origin_id=[135545], pos=[135545, 3], rgb=[135545, 3], x=[135545, 3])
    image = ImageData(num_settings=1, num_views=128, num_points=135545, device=cpu)
)

To visualize the multimodal samples produced by the dataset, we need to remove some of the dataset transforms that affect points, images and mappings. The `sample_real_data` function will be used to get samples without breaking mappings consistency for visualization.

At training and evaluation time, these transforms are used for data augmentation, dynamic size batching (see our [paper](https://arxiv.org/submit/4264152)), etc...

In [8]:
# from torch_geometric.transforms import *
# from torch_points3d.core.data_transform import *
# from torch_points3d.core.data_transform.multimodal.image import *
# from torch_points3d.datasets.base_dataset import BaseDataset
# from torch_points3d.datasets.base_dataset_multimodal import BaseDatasetMM

# # Transforms on 3D points that we need to exclude for visualization purposes
# augmentations_3d = [
#     ElasticDistortion, Random3AxisRotation, RandomNoise, RandomRotate, 
#     RandomScaleAnisotropic, RandomSymmetry, ShiftVoxels]
# exclude_3d_viz = augmentations_3d + [AddFeatsByKeys, Center, GridSampling3D]

# # Transforms on 2D images and mappings that we need to exclude for visualization
# # purposes
# augmentations_2d = [JitterMappingFeatures, ColorJitter, RandomHorizontalFlip]
# exclude_2d_viz = [RandomHorizontalFlip]
# exclude_2d_viz = augmentations_2d + [ToFloatImage, Normalize]



# def sample_real_data(tg_dataset, idx=0, exclude_3d=None, exclude_2d=None):
#     """
#     Temporarily remove the 3D and 2D transforms affecting the point 
#     positions and images from the dataset to better visualize points 
#     and images relative positions.
#     """    
#     # Remove some 3D transforms
#     transform_3d = tg_dataset.transform
#     if exclude_3d:
#         tg_dataset.transform = BaseDataset.remove_transform(transform_3d, exclude_3d)

#     # Remove some 2D transforms, if any
#     is_multimodal = hasattr(tg_dataset, 'transform_image')
#     if is_multimodal and exclude_2d:
#         transform_2d = tg_dataset.transform_image
#         tg_dataset.transform_image = BaseDatasetMM.remove_multimodal_transform(transform_2d, exclude_2d)
    
#     # Get a sample from the dataset, with transforms excluded
#     out = tg_dataset[idx]
    
#     # Restore transforms
#     tg_dataset.transform = transform_3d
#     if is_multimodal and exclude_2d:
#         tg_dataset.transform_image = transform_2d
        
#     return out

## Visualize a single multimodal sample

We can now pick samples from the train, val and test datasets.

Please refer to `torch_points3d/visualization/multimodal_data` for more details on visualization options.

In [9]:
# visualize_mm_data(mm_data, class_names=CLASS_NAMES, class_colors=CLASS_COLORS, error_color=(0, 0, 0), front='rgb', back='x', figsize=1000, pointsize=3, voxel=0.15, show_2d=True, alpha=0.3)

# # if it gives NotImplementedError in multimodal_data.py, please retain original features in data.data.x 
# # inside the dataset __getitem__.

## Run inference from pretrained weights and visualize predictions
It is possible to visualize the pointwise predictions and errors from a model. 

To do so, we will use the pretrained weights made available with this project. See `README.md` to get the download links and manually place the `.pt` files locally. You will need to provide `checkpoint_dir` where you saved those files in the next cell.

In [10]:
from torch_points3d.models.model_factory import instantiate_model

# Set your parameters
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-12-04/15-22-16' # MVFusion_3D_small default m2f_masks
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-12-07/12-07-34' # 3rd run

# ViT_masks 3rd run
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/ViT_masks_3rd_run' # 3rd run


        
# # checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-12-04/15-48-56' # MVFusion_3D_small default swin_l_early


# checkpoint_dir = '/project/fsun/DeepViewAgg/outputs/2022-11-04/15-51-33' # 3D Backbone, 68.04 miou
# checkpoint_dir = '/home/fsun/DeepViewAgg/model_checkpoints' # DVA best model

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Load the checkpoint and recover the 'best_miou' model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for inference
model = model.eval().cuda()
print('Model loaded')

Creating model: MVFusion_3D_small_6views
task:  segmentation.multimodal
tested_model_name:  MVFusion_3D_small_6views
Rotating PCD with get_Rx(270)
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBEAD0>
pred_mask:  <PIL.Image.Image image mode

self.ref_size:  (960, 720)
self.m2f_pred_mask.shape:  torch.Size([128, 1, 720, 960])
Manually set number of classes for model initialization in DeepViewAgg/torch_points3d/datasets/base_dataset.py, line 471
Manually set number of classes for model initialization in DeepViewAgg/torch_points3d/datasets/base_dataset.py, line 471
class_name:  MVFusionAPIModel
model_module:  torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d
name, cls of chosen model_cls:  MVFusionAPIModel <class 'torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d.MVFusionAPIModel'>
Manually set number of classes for model initialization in DeepViewAgg/torch_points3d/datasets/base_dataset.py, line 471
Rotating PCD with get_Rx(270)
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CD5C190>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CC01950>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CC01950>
pred_mask:  <PIL.Image.Image image mode=L size=960x72

self.ref_size:  (960, 720)
self.m2f_pred_mask.shape:  torch.Size([128, 1, 720, 960])
x feature dim:  {'FEAT': 3}
nc_in:  67
nc_in:  64
nc_in:  32
nc_in:  64
nc_in:  128
nc_in:  256
nc_in:  128
nc_in:  128
nc_in:  96
nc_in:  96
Manually set number of classes for model initialization in DeepViewAgg/torch_points3d/datasets/base_dataset.py, line 471
Model loaded


Now we have loaded the model, we need to run a forward pass on a sample. Howver, if we want to be able to visualize the predictions, we need to pay special attention to which type of 3D and 2D transforms we apply on the data if we do not want to break the mappings. To do so, we will manually apply some sensitive transforms to be able to both infer on the data and visualize it.

In [11]:
mm_data = dataset.test_dataset[0][0]

# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data])

with torch.no_grad():
    print("input batch: ", batch)
    model.set_input(batch, model.device)
    model(batch)
    
# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

Rotating PCD with get_Rx(270)
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DB78A90>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBCB90>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DCF9ED0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CDD4A90>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614DC1DBD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CCBCFD0>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CDD4A90>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614D284E10>
pred_mask:  <PIL.Image.Image image mode=L size=960x720 at 0x14614CE41290>
pred_mas

self.ref_size:  (960, 720)
self.m2f_pred_mask.shape:  torch.Size([128, 1, 720, 960])
input batch:  MMBatch(
    data = Batch(batch=[135545], coords=[135545, 3], grid_size=[1], id_scan=[1], mapping_index=[135545], mvfusion_input=[83973, 6, 10], origin_id=[135545], pos=[135545, 3], ptr=[2], rgb=[135545, 3], x=[135545, 3])
    image = ImageBatch(num_settings=1, num_views=128, num_points=135545, device=cpu)
)


In [43]:
### Select seen points
csr_idx = mm_data.modalities['image'][0].view_csr_indexing
dense_idx_list = torch.arange(mm_data.modalities['image'][0].num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
# take subset of only seen points without re-indexing the same point
mm_data = mm_data[dense_idx_list.unique()]
mm_data

MMData(
    data = Data(coords=[83973, 3], grid_size=[1], id_scan=[1], mapping_index=[83973], mvfusion_input=[83973, 6, 10], origin_id=[83973], pos=[83973, 3], pred=[83973], rgb=[83973, 3], x=[83973, 3])
    image = ImageData(num_settings=1, num_views=128, num_points=83973, device=cpu)
)

In [47]:
mm_data
import struct
from plyfile import PlyData, PlyElement

def write_pointcloud(filename,xyz_points,rgb_points=None):

    """ creates a .pkl file of the point clouds generated
    """

    assert xyz_points.shape[1] == 3,'Input XYZ points should be Nx3 float array'
    if rgb_points is None:
        rgb_points = np.ones(xyz_points.shape).astype(np.uint8)*255
    assert xyz_points.shape == rgb_points.shape,'Input RGB colors should be Nx3 float array and have same size as input XYZ points'

    # Write header of .ply file
    fid = open(filename,'wb')
    fid.write(bytes('ply\n', 'utf-8'))
    fid.write(bytes('format binary_little_endian 1.0\n', 'utf-8'))
    fid.write(bytes('element vertex %d\n'%xyz_points.shape[0], 'utf-8'))
    fid.write(bytes('property float x\n', 'utf-8'))
    fid.write(bytes('property float y\n', 'utf-8'))
    fid.write(bytes('property float z\n', 'utf-8'))
    fid.write(bytes('property uchar red\n', 'utf-8'))
    fid.write(bytes('property uchar green\n', 'utf-8'))
    fid.write(bytes('property uchar blue\n', 'utf-8'))
    fid.write(bytes('end_header\n', 'utf-8'))

    # Write 3D points to .ply file
    for i in range(xyz_points.shape[0]):
        fid.write(bytearray(struct.pack("fffccc",xyz_points[i,0],xyz_points[i,1],xyz_points[i,2],
                                        rgb_points[i,0].tostring(),rgb_points[i,1].tostring(),
                                        rgb_points[i,2].tostring())))
    fid.close()

    
def to_ply(pos, label, file):
    assert len(label.shape) == 1
    assert pos.shape[0] == label.shape[0]
    pos = np.asarray(pos)
    colors = np.array(CLASS_COLORS)[np.asarray(label).astype(np.uint8)].astype(np.uint8)
    
    print(pos)
    print(colors)
    ply_array = np.ones(
        pos.shape[0], dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1")]
    )
    
    ply_array["x"] = pos[:, 0]
    ply_array["y"] = pos[:, 1]
    ply_array["z"] = pos[:, 2]
    ply_array["red"] = colors[:, 0]
    ply_array["green"] = colors[:, 1]
    ply_array["blue"] = colors[:, 2]
    el = PlyElement.describe(ply_array, "vertex")
    PlyData([el], byte_order=">").write(file)
    print(PlyData([el], byte_order=">"))
    
# xyz = mm_data.pos.numpy()
# rgb = np.array(CLASS_COLORS)[mm_data.data.pred.numpy().astype(np.uint8)].astype(np.uint8)
to_ply(m2f_mm_data.pos, m2f_mm_data.data.pred, "pcd_2_2d_projected_semantic.ply")

[[-2.6684558 -1.7330347 -2.3226166]
 [-2.907451  -1.773885  -1.9865711]
 [-1.613929  -1.8432987 -1.9432268]
 ...
 [ 1.3498178 -1.529822   1.6153016]
 [ 1.0164851 -1.5042164  1.6134924]
 [ 1.4469929 -2.269065   1.6366584]]
[[ 31 119 180]
 [ 31 119 180]
 [214  39  40]
 ...
 [174 199 232]
 [174 199 232]
 [174 199 232]]
ply
format binary_big_endian 1.0
element vertex 83973
property float x
property float y
property float z
property uchar red
property uchar green
property uchar blue
end_header


In [41]:
mm_data

MMData(
    data = Data(coords=[135545, 3], grid_size=[1], id_scan=[1], mapping_index=[135545], mvfusion_input=[83973, 6, 10], origin_id=[135545], pos=[135545, 3], pred=[135545], rgb=[135545, 3], x=[135545, 3])
    image = ImageData(num_settings=1, num_views=128, num_points=135545, device=cpu)
)

In [13]:
# # Randomly sample views
# mm_data.modalities['image'] = ImageData(mm_data.modalities['image'][0][:25])

In [14]:
# print(mm_data.data)
# mm_data = mm_data[mm_data.pos[:, 1] <= 3.29]
# # mm_data.modalities['image'] = None
# # mm_data

### Point cloud is rotated by 270* in X-axis within inference dataset class, __getitem__ method.

In [12]:
fig = visualize_mm_data(mm_data, no_output=True, figsize=1000, pointsize=3, voxel=0.03, show_2d=False, back='x', front='m2f_pred_mask', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [44]:
def get_mode_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds

mode_preds = get_mode_pred(mm_data)

In [45]:
print(mode_preds.shape)
print(mm_data.data.pred.shape)

torch.Size([83973])
torch.Size([83973])


In [46]:
m2f_mm_data = mm_data.clone()
m2f_mm_data.data.x = None
m2f_mm_data.data.pred = mode_preds

# visualize_mm_data(m2f_mm_data, figsize=1000, pointsize=3, voxel=0.03, show_2d=False, back='m2f_mask_pred', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
mm_data.data.x = None
mm_data.data.pred = mm_data.data.pred[mm_data.data.y != -1]
mm_data = mm_data[mm_data.data.y != -1]


print(mm_data.data.pred.unique())
mm_data.data.y.unique()

In [None]:
visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_pred_mask', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

### Visualize a scan

In [None]:
mm_data = dataset.val_dataset[0]

In [None]:
### Select seen points
csr_idx = mm_data.modalities['image'][0].view_csr_indexing
dense_idx_list = torch.arange(mm_data.modalities['image'][0].num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
# take subset of only seen points without re-indexing the same point
seen_mm_data = mm_data[dense_idx_list.unique()]
seen_mm_data[54945:54946]

In [None]:
print(seen_mm_data)
seen_mm_data = seen_mm_data[seen_mm_data.pos[:, 1] <= 3.29]
seen_mm_data

In [None]:
seen_mm_data.modalities['image'] = None


In [None]:
visualize_mm_data(seen_mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='x', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
from PIL import Image

In [None]:
imgs = np.array(['/project/fsun/dvata/scannet-neucon-smallres-m2f/raw/scans/scene0011_00/sens/color/1138.jpg',
        '/project/fsun/dvata/scannet-neucon-smallres-m2f/raw/scans/scene0011_00/sens/color/1604.jpg',
        '/project/fsun/dvata/scannet-neucon-smallres-m2f/raw/scans/scene0011_00/sens/color/1188.jpg'])

masks = np.array(['/home/fsun/data/scannet/scans/scene0011_00/ViT_masks/1138.jpg',
        '/home/fsun/data/scannet/scans/scene0011_00/ViT_masks/1604.jpg',
        '/home/fsun/data/scannet/scans/scene0011_00/ViT_masks/1188.jpg'])

for i, im in enumerate(masks):
#     im = im.split("/")
#     im[1] = 'home'
#     im[-3] = 'color_resized'
#     im.pop(-2)
#     im = "/".join(im)
    im = im.replace("jpg", "png")
    seg_im = Image.open(im)
    seg_im_np = np.array(seg_im) -1
    
    if i == 2:
        seg_im_np[seg_im_np == 4] = 5
    print(np.unique(seg_im_np))

    seg_im_rgb = np.array(CLASS_COLORS)[seg_im_np.astype(int)]

    seg_im_rgb = Image.fromarray(seg_im_rgb.astype(np.uint8))
    plt.imshow(seg_im_rgb)
    plt.show()
    
      

In [17]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/scratch-shared/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D_small_6views'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks

cfg.data.m2f_preds_dirname = 'ViT_masks'
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

6


In [18]:
# Dataset instantiation
start = time()
scannet_dataset = ScannetDatasetMM(cfg.data)
print(f"Time = {time() - start:0.1f} sec.")

Load predicted 2D semantic segmentation labels from directory  ViT_masks
initialize train dataset
initialize val dataset
Time = 8.2 sec.
