###### ScanNet

This notebook lets you instantiate the **[ScanNet](http://www.scan-net.org/)** dataset from scratch and visualize **3D+2D room samples**.

Note that you will need **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

The ScanNet dataset is composed of **rooms** of video acquisitions of indoor scenes. Thes video streams were used to produce a point cloud and images.

Each room is small enough to be loaded at once into a **64G RAM** memory. The `ScannetDatasetMM` class from `torch_points3d.datasets.segmentation.multimodal.scannet` deals with loading the room and part of the images of the associated video stream.

In [1]:
# Select you GPU
I_GPU = 0


In [2]:
# Uncomment to use autoreload
%load_ext autoreload
%autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES, CLASS_LABELS
from torch_points3d.metrics.segmentation_tracker import SegmentationTracker

from pykeops.torch import LazyTensor

MMData debug() function changed, please uncomment the 3rd assert line when doing inference without M2F features!


In [4]:
temp = CLASS_COLORS
temp[-1] = (0.0, 0.0, 0.0)
[list([int(y) for y in x]) for x in temp]
CLASS_COLORS

[(174.0, 199.0, 232.0),
 (152.0, 223.0, 138.0),
 (31.0, 119.0, 180.0),
 (255.0, 187.0, 120.0),
 (188.0, 189.0, 34.0),
 (140.0, 86.0, 75.0),
 (255.0, 152.0, 150.0),
 (214.0, 39.0, 40.0),
 (197.0, 176.0, 213.0),
 (148.0, 103.0, 189.0),
 (196.0, 156.0, 148.0),
 (23.0, 190.0, 207.0),
 (247.0, 182.0, 210.0),
 (219.0, 219.0, 141.0),
 (255.0, 127.0, 14.0),
 (158.0, 218.0, 229.0),
 (44.0, 160.0, 44.0),
 (112.0, 128.0, 144.0),
 (227.0, 119.0, 194.0),
 (82.0, 84.0, 163.0),
 (0.0, 0.0, 0.0)]

In [13]:
d = [
    {"name": 'wall', "id" : 1, "trainId" : 1},
    {"name": 'floor', "id" : 2, "trainId" : 2},
    {"name": 'cabinet', "id" : 3, "trainId" : 3},
    {"name": 'bed', "id" : 4, "trainId" : 4},
    {"name": 'chair', "id" : 5, "trainId" : 5},
    {"name": 'sofa', "id" : 6, "trainId" : 6},
    {"name": 'table', "id" : 7, "trainId" : 7},
    {"name": 'door', "id" : 8, "trainId" : 8},
    {"name": 'window', "id" : 9, "trainId" : 9},
    {"name": 'bookshelf', "id" : 10, "trainId" : 10},
    {"name": 'picture', "id" : 11, "trainId" : 11},
    {"name": 'counter', "id" : 12, "trainId" : 12},
    {"name": 'desk', "id" : 13, "trainId" : 13},
    {"name": 'curtain', "id" : 14, "trainId" : 14},
    {"name": 'refrigerator', "id" : 15, "trainId" : 15},
    {"name": 'shower curtain', "id" : 16, "trainId" : 16},
    {"name": 'toilet', "id" : 17, "trainId" : 17},
    {"name": 'sink', "id" : 18, "trainId" : 18},
    {"name": 'bathtub', "id" : 19, "trainId" : 19},
    {"name": 'otherfurniture', "id" : 20, "trainId" : 20}]


for i, k in enumerate(d):
    k['color'] = temp[i]
    k['id'] -= 1
    k['trainId'] -= 1
    
d

[{'name': 'wall', 'id': 0, 'trainId': 0, 'color': (174.0, 199.0, 232.0)},
 {'name': 'floor', 'id': 1, 'trainId': 1, 'color': (152.0, 223.0, 138.0)},
 {'name': 'cabinet', 'id': 2, 'trainId': 2, 'color': (31.0, 119.0, 180.0)},
 {'name': 'bed', 'id': 3, 'trainId': 3, 'color': (255.0, 187.0, 120.0)},
 {'name': 'chair', 'id': 4, 'trainId': 4, 'color': (188.0, 189.0, 34.0)},
 {'name': 'sofa', 'id': 5, 'trainId': 5, 'color': (140.0, 86.0, 75.0)},
 {'name': 'table', 'id': 6, 'trainId': 6, 'color': (255.0, 152.0, 150.0)},
 {'name': 'door', 'id': 7, 'trainId': 7, 'color': (214.0, 39.0, 40.0)},
 {'name': 'window', 'id': 8, 'trainId': 8, 'color': (197.0, 176.0, 213.0)},
 {'name': 'bookshelf', 'id': 9, 'trainId': 9, 'color': (148.0, 103.0, 189.0)},
 {'name': 'picture', 'id': 10, 'trainId': 10, 'color': (196.0, 156.0, 148.0)},
 {'name': 'counter', 'id': 11, 'trainId': 11, 'color': (23.0, 190.0, 207.0)},
 {'name': 'desk', 'id': 12, 'trainId': 12, 'color': (247.0, 182.0, 210.0)},
 {'name': 'curtain', 

If `visualize_mm_data` does not throw any error but the visualization does not appear, you may need to change your plotly renderer below.

In [3]:
import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

## Dataset creation

The following will instantiate the dataset. If the data is not found at `DATA_ROOT`, the folder structure will be created there and the raw dataset will be downloaded there. 

**Memory-friendly tip** : if you have already downloaded the dataset once and simply want to instantiate a new dataset with different preprocessing (*e.g* change 3D or 2D resolution, mapping parameterization, etc), I recommend you manually replicate the folder hierarchy of your already-existing dataset and create a symlink to its `raw/` directory to avoid downloading and storing (very) large files twice.

You will find the config file ruling the dataset creation at `conf/data/segmentation/multimodal/scannet-sparse.yaml`. You may edit this file or create new configs inheriting from this one using Hydra and create the associated dataset by modifying `dataset_config` accordingly in the following cell.

In [4]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/project/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f-partial-subsampled'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D_small'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks

In [5]:
cfg.data.m2f_preds_dirname = 'm2f_masks'
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

9


The dataset will now be created based on the parsed configuration. I recommend having **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

As long as you do not change core dataset parameters, preprocessing should only be performed once for your dataset. It may take some time, **mostly depending on the 3D and 2D resolutions** you choose to work with (the larger the slower).

In [6]:
# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")

Load predicted 2D semantic segmentation labels from directory  m2f_masks
initialize train dataset
initialize val dataset
initialize test dataset
line 720 scannet.py: split == 'test'
Time = 1.9 sec.


In [7]:
mm_data = dataset.val_dataset[0]

To visualize the multimodal samples produced by the dataset, we need to remove some of the dataset transforms that affect points, images and mappings. The `sample_real_data` function will be used to get samples without breaking mappings consistency for visualization.

At training and evaluation time, these transforms are used for data augmentation, dynamic size batching (see our [paper](https://arxiv.org/submit/4264152)), etc...

In [None]:
from torch_geometric.transforms import *
from torch_points3d.core.data_transform import *
from torch_points3d.core.data_transform.multimodal.image import *
from torch_points3d.datasets.base_dataset import BaseDataset
from torch_points3d.datasets.base_dataset_multimodal import BaseDatasetMM

# Transforms on 3D points that we need to exclude for visualization purposes
augmentations_3d = [
    ElasticDistortion, Random3AxisRotation, RandomNoise, RandomRotate, 
    RandomScaleAnisotropic, RandomSymmetry, ShiftVoxels]
exclude_3d_viz = augmentations_3d + [AddFeatsByKeys, Center, GridSampling3D]

# Transforms on 2D images and mappings that we need to exclude for visualization
# purposes
augmentations_2d = [JitterMappingFeatures, ColorJitter, RandomHorizontalFlip]
exclude_2d_viz = [RandomHorizontalFlip]
exclude_2d_viz = augmentations_2d + [ToFloatImage, Normalize]



def sample_real_data(tg_dataset, idx=0, exclude_3d=None, exclude_2d=None):
    """
    Temporarily remove the 3D and 2D transforms affecting the point 
    positions and images from the dataset to better visualize points 
    and images relative positions.
    """    
    # Remove some 3D transforms
    transform_3d = tg_dataset.transform
    if exclude_3d:
        tg_dataset.transform = BaseDataset.remove_transform(transform_3d, exclude_3d)

    # Remove some 2D transforms, if any
    is_multimodal = hasattr(tg_dataset, 'transform_image')
    if is_multimodal and exclude_2d:
        transform_2d = tg_dataset.transform_image
        tg_dataset.transform_image = BaseDatasetMM.remove_multimodal_transform(transform_2d, exclude_2d)
    
    # Get a sample from the dataset, with transforms excluded
    out = tg_dataset[idx]
    
    # Restore transforms
    tg_dataset.transform = transform_3d
    if is_multimodal and exclude_2d:
        tg_dataset.transform_image = transform_2d
        
    return out

## Visualize a single multimodal sample

We can now pick samples from the train, val and test datasets.

Please refer to `torch_points3d/visualization/multimodal_data` for more details on visualization options.

In [None]:
from PIL import Image
import os

label_unique = set()
for n in os.listdir("/project/fsun/data/scannet/scans/scene0000_00/swin_l_early"):
    a = Image.open(f"/project/fsun/data/scannet/scans/scene0000_00/swin_l_early/{n}")
    label_unique.update(list(np.unique(np.array(a)))) 
print(label_unique)

In [None]:
label_unique = set()
for n in os.listdir("/project/fsun/data/scannet/scans/scene0000_00/m2f_masks"):
    a = Image.open(f"/project/fsun/data/scannet/scans/scene0000_00/m2f_masks/{n}")
    label_unique.update(list(np.unique(np.array(a)))) 
print(label_unique)

In [None]:
mm_data = dataset.train_dataset[0]
mm_data

In [None]:
visualize_mm_data(mm_data, class_names=CLASS_NAMES, class_colors=CLASS_COLORS, error_color=(0, 0, 0), front='y', back='m2f_pred_mask', figsize=1000, pointsize=3, voxel=0.03, show_2d=True, alpha=0.3)

# if it gives NotImplementedError in multimodal_data.py, please retain original features in data.data.x 
# inside the dataset __getitem__.

### Compare mapping feature statistics between train/val/test proccessed_2d 

In [None]:
for split in ['train', 'val', 'test']:
    data_dir = f"/project/fsun/dvata/scannet-neucon-smallres-m2f/processed/processed_2d_{split}"

    mean = 0
    std = 0
    for name in os.listdir(data_dir):
        file = osp.join(data_dir, name)

        data = torch.load(file)
        data = data._mappings.values[2]

        # note that estimation is biased
        mean += data.mean(axis=0)
        std += data.std(axis=0)
        
    mean /= len(os.listdir(data_dir))
    std /= len(os.listdir(data_dir))


    print(f"{split} mapping feature statistics:")
    print(mean)
    print(std)    



In [77]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/project/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f-partial-subsampled'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D_small'                       # specific model

# models_config = 'segmentation/multimodal/Feng/3d_only'    # model family
# model_name = 'Res16UNet34'                       # specific model



# dataset_config = 'segmentation/multimodal/scannet-sparse'
# models_config = 'segmentation/multimodal/sparseconv3d'    # model family
# model_name = 'Res16UNet34-L4-early'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
# cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks

cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks

In [79]:
cfg.data.m2f_preds_dirname = 'm2f_masks'
cfg.data.n_views = 9 #cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

9


In [80]:
# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")

Load predicted 2D semantic segmentation labels from directory  m2f_masks
initialize train dataset
initialize val dataset
initialize test dataset
line 720 scannet.py: split == 'test'
Time = 1.4 sec.


## Run inference from pretrained weights and visualize predictions
It is possible to visualize the pointwise predictions and errors from a model. 

To do so, we will use the pretrained weights made available with this project. See `README.md` to get the download links and manually place the `.pt` files locally. You will need to provide `checkpoint_dir` where you saved those files in the next cell.

In [87]:
from torch_points3d.models.model_factory import instantiate_model

# Set your parameters
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-12-04/15-22-16' # MVFusion_3D_small default m2f_masks
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-12-07/12-07-34' # 3rd run

        
# # checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-12-04/15-48-56' # MVFusion_3D_small default swin_l_early


# checkpoint_dir = '/project/fsun/DeepViewAgg/outputs/2022-11-04/15-51-33' # 3D Backbone, 68.04 miou
# checkpoint_dir = '/home/fsun/DeepViewAgg/model_checkpoints' # DVA best model

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Load the checkpoint and recover the 'best_miou' model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for inference
model = model.eval().cuda()
print('Model loaded')

Creating model: MVFusion_3D_small
task:  segmentation.multimodal
tested_model_name:  MVFusion_3D_small
class_name:  MVFusionAPIModel
model_module:  torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d
name, cls of chosen model_cls:  MVFusionAPIModel <class 'torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d.MVFusionAPIModel'>
x feature dim:  {'FEAT': 3}
nc_in:  67
nc_in:  64
nc_in:  32
nc_in:  64
nc_in:  128
nc_in:  256
nc_in:  128
nc_in:  128
nc_in:  96
nc_in:  96
Model loaded


Now we have loaded the model, we need to run a forward pass on a sample. Howver, if we want to be able to visualize the predictions, we need to pay special attention to which type of 3D and 2D transforms we apply on the data if we do not want to break the mappings. To do so, we will manually apply some sensitive transforms to be able to both infer on the data and visualize it.

In [22]:
i_room = 0

# Pick a room in the Train set
mm_data = sample_real_data(dataset.train_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Val set
# mm_data = sample_real_data(dataset.val_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Test set
# mm_data = sample_real_data(dataset.test_dataset[0], idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Extract point cloud and images from MMData object
data = mm_data.data.clone()
images = mm_data.modalities['image'].clone()

data

NameError: name 'sample_real_data' is not defined

In [23]:
# Run cell for validation sample with original validation transformations
mm_data = dataset.val_dataset[0]

# data = mm_data.data.clone()
# images = mm_data.modalities['image'].clone()
# data

In [None]:
mm_data

In [None]:
# For voxel-based 3D backbones such as SparseConv3d and MinkowskiNet, points need to be 
# preprocessed with Center and GridSampling3D. Unfortunately, Center breaks relative 
# positions between points and images. Besides, the combination of Center and GridSampling3D
# may lead to some points being merged into the same voxels, so we must apply it to both the
# inference and visualization data to make sure we have the same voxels. The workaround here 
# is to manually run these while keeping track of the centering offset
center = data.pos.mean(dim=-2, keepdim=True)
data = AddFeatsByKeys(list_add_to_x=[True, True, True], feat_names=['pos_x', 'pos_y', 'pos_z'], delete_feats=[True, True, True])(data)          # add z-height to the features
data = Center()(data)                                                                                 # mean-center the data
data = GridSampling3D(cfg.data.resolution_3d, quantize_coords=True, mode='last')(data)                # quantization for volumetric models

# This last voxelization step with GridSampling3D might have removed some points, so we need
# to update the mappings usign SelectMappingFromPointId. To control the size of the batch, we
# use PickImagesFromMemoryCredit. Besides, 2D models expect normalized float images, which is
# why we call ToFloatImage and Normalize
data, images = SelectMappingFromPointId()(data, images)                                               # update mappings after GridSampling3D
data, images = PickImagesFromMemoryCredit(
    img_size=cfg.data.resolution_2d, 
    k_coverage=cfg.data.multimodal.settings.k_coverage, 
    n_img=cfg.data.multimodal.settings.test_pixel_credit)(data, images)                                      # select images to respect memory constraints
data, images_infer = ToFloatImage()(data, images.clone())                                             # convert uint8 images to float
data, images_infer = Normalize()(data, images_infer)                                                  # RGB normalization

# Create a MMData for inference
mm_data_infer = MMData(data, image=images_infer)
print(mm_data_infer)

# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data_infer])

with torch.no_grad():
    print("input batch: ", batch)
    model.set_input(batch, model.device)
    model(batch)

# Create a MMData for visualization
data.pos += center
mm_data = MMData(data, image=images)

# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

In [88]:
# Run inference with augmentations
# mm_data = dataset.train_dataset[0]
mm_data = dataset.val_dataset[0]
# mm_data = dataset.test_dataset[0][0]

# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data])

with torch.no_grad():
    print("input batch: ", batch)
    model.set_input(batch, model.device)
    model(batch)
    
# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

input batch:  MMBatch(
    data = Batch(batch=[97387], coords=[97387, 3], grid_size=[1], id_scan=[1], mapping_index=[97387], mvfusion_input=[70411, 9, 10], origin_id=[97387], pos=[97387, 3], ptr=[2], x=[97387, 3], y=[97387])
    image = ImageBatch(num_settings=1, num_views=100, num_points=97387, device=cpu)
)


In [89]:
CLASS_COLORS[0] = (174.0, 199.0, 232.0)
CLASS_COLORS[-1] = (0, 0, 0)
# CLASS_COLORS

In [90]:
print(mm_data.data)
mm_data.modalities['image'] = None
mm_data

Data(coords=[97387, 3], grid_size=[1], id_scan=[1], mapping_index=[97387], mvfusion_input=[70411, 9, 10], origin_id=[97387], pos=[97387, 3], pred=[97387], x=[97387, 3], y=[97387])


MMData(
    data = Data(coords=[97387, 3], grid_size=[1], id_scan=[1], mapping_index=[97387], mvfusion_input=[70411, 9, 10], origin_id=[97387], pos=[97387, 3], pred=[97387], x=[97387, 3], y=[97387])
    image = None
)

In [91]:
fig = visualize_mm_data(mm_data, no_output=True, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='x', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
### Select seen points
csr_idx = mm_data.modalities['image'][0].view_csr_indexing
dense_idx_list = torch.arange(mm_data.modalities['image'][0].num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
# take subset of only seen points without re-indexing the same point
mm_data = mm_data[dense_idx_list.unique()]


In [None]:
def get_mode_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds

mode_preds = get_mode_pred(mm_data)

In [None]:
print(mode_preds.shape)
print(mm_data.data.pred.shape)

In [None]:
m2f_mm_data = mm_data.clone()
m2f_mm_data.data.x = None
m2f_mm_data.data.pred = mode_preds
# m2f_mm_data.data.pred = m2f_mm_data.data.pred[m2f_mm_data.data.y != -1]
m2f_mm_data = m2f_mm_data[m2f_mm_data.data.y != -1]

visualize_mm_data(m2f_mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_mask_pred', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
mm_data.data.x = None
mm_data.data.pred = mm_data.data.pred[mm_data.data.y != -1]
mm_data = mm_data[mm_data.data.y != -1]


print(mm_data.data.pred.unique())
mm_data.data.y.unique()

In [None]:
visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_pred_mask', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

# Swin_l_early masks

In [None]:
m2f_mm_data = mm_data.clone()
m2f_mm_data.data.x = None
m2f_mm_data.data.pred = mode_preds
# m2f_mm_data.data.pred = m2f_mm_data.data.pred[m2f_mm_data.data.y != -1]
m2f_mm_data = m2f_mm_data[m2f_mm_data.data.y != -1]

visualize_mm_data(m2f_mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_mask_pred', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_pred_mask', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
### Calculate number of correct predictions (accuracy)

print(sum(mm_data.y == mm_data.pred) / len(mm_data.y))

print(sum(m2f_mm_data.y == m2f_mm_data.pred) / len(m2f_mm_data.y))


In [None]:
i_room = 0

# Pick a room in the Train set
mm_data = sample_real_data(dataset.train_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Val set
# mm_data = sample_real_data(dataset.val_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Test set
# mm_data = sample_real_data(dataset.test_dataset[0], idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Extract point cloud and images from MMData object
data = mm_data.data.clone()
images = mm_data.modalities['image'].clone()

data

In [None]:
mm_data = dataset.train_dataset[0]

In [None]:
### Select seen points
csr_idx = mm_data.modalities['image'][0].view_csr_indexing
dense_idx_list = torch.arange(mm_data.modalities['image'][0].num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
# take subset of only seen points without re-indexing the same point
mm_data = mm_data[dense_idx_list.unique()]

In [None]:
mode_preds = get_mode_pred(mm_data)
m2f_mm_data = mm_data.clone()
m2f_mm_data.data.x = None
m2f_mm_data.data.pred = mode_preds
# m2f_mm_data.data.pred = m2f_mm_data.data.pred[m2f_mm_data.data.y != -1]
m2f_mm_data = m2f_mm_data[m2f_mm_data.data.y != -1]

visualize_mm_data(m2f_mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_mask_pred', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
mode_preds = get_mode_pred(mm_data)
m2f_mm_data = mm_data.clone()
m2f_mm_data.data.x = None
m2f_mm_data.data.pred = mode_preds
# m2f_mm_data.data.pred = m2f_mm_data.data.pred[m2f_mm_data.data.y != -1]
m2f_mm_data = m2f_mm_data[m2f_mm_data.data.y != -1]

visualize_mm_data(m2f_mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_mask_pred', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)