# ScanNet

This notebook lets you instantiate the **[ScanNet](http://www.scan-net.org/)** dataset from scratch and visualize **3D+2D room samples**.

Note that you will need **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

The ScanNet dataset is composed of **rooms** of video acquisitions of indoor scenes. Thes video streams were used to produce a point cloud and images.

Each room is small enough to be loaded at once into a **64G RAM** memory. The `ScannetDatasetMM` class from `torch_points3d.datasets.segmentation.multimodal.scannet` deals with loading the room and part of the images of the associated video stream.

In [None]:
# Select you GPU
I_GPU = 0

In [None]:
# Uncomment to use autoreload
# %load_ext autoreload
# %autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES

If `visualize_mm_data` does not throw any error but the visualization does not appear, you may need to change your plotly renderer below.

In [None]:
import plotly.io as pio

pio.renderers.default = 'jupyterlab'        # for local notebook
# pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

## Dataset creation

The following will instantiate the dataset. If the data is not found at `DATA_ROOT`, the folder structure will be created there and the raw dataset will be downloaded there. 

**Memory-friendly tip** : if you have already downloaded the dataset once and simply want to instantiate a new dataset with different preprocessing (*e.g* change 3D or 2D resolution, mapping parameterization, etc), I recommend you manually replicate the folder hierarchy of your already-existing dataset and create a symlink to its `raw/` directory to avoid downloading and storing (very) large files twice.

You will find the config file ruling the dataset creation at `conf/data/segmentation/multimodal/scannet-sparse.yaml`. You may edit this file or create new configs inheriting from this one using Hydra and create the associated dataset by modifying `dataset_config` accordingly in the following cell.

In [None]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/path/to/your/dataset/root/directory'

dataset_config = 'segmentation/multimodal/scannet-sparse'   
models_config = 'segmentation/multimodal/sparseconv3d'    # model family
model_name = 'Res16UNet34-L4-early'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
# print(OmegaConf.to_yaml(cfg))

The dataset will now be created based on the parsed configuration. I recommend having **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

As long as you do not change core dataset parameters, preprocessing should only be performed once for your dataset. It may take some time, **mostly depending on the 3D and 2D resolutions** you choose to work with (the larger the slower).

In [None]:
# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)
print(f"Time = {time() - start:0.1f} sec.")

To visualize the multimodal samples produced by the dataset, we need to remove some of the dataset transforms that affect points, images and mappings. The `sample_real_data` function will be used to get samples without breaking mappings consistency for visualization.

At training and evaluation time, these transforms are used for data augmentation, dynamic size batching (see our [paper](https://arxiv.org/submit/4264152)), etc...

In [None]:
from torch_geometric.transforms import *
from torch_points3d.core.data_transform import *
from torch_points3d.core.data_transform.multimodal.image import *
from torch_points3d.datasets.base_dataset import BaseDataset
from torch_points3d.datasets.base_dataset_multimodal import BaseDatasetMM

# Transforms on 3D points that we need to exclude for visualization purposes
augmentations_3d = [
    ElasticDistortion, Random3AxisRotation, RandomNoise, RandomRotate, 
    RandomScaleAnisotropic, RandomSymmetry, ShiftVoxels]
exclude_3d_viz = augmentations_3d + [AddFeatsByKeys, Center, GridSampling3D]

# Transforms on 2D images and mappings that we need to exclude for visualization
# purposes
augmentations_2d = [JitterMappingFeatures, ColorJitter, RandomHorizontalFlip]
exclude_2d_viz = augmentations_2d + [ToFloatImage, Normalize]


def sample_real_data(tg_dataset, idx=0, exclude_3d=None, exclude_2d=None):
    """
    Temporarily remove the 3D and 2D transforms affecting the point 
    positions and images from the dataset to better visualize points 
    and images relative positions.
    """    
    # Remove some 3D transforms
    transform_3d = tg_dataset.transform
    if exclude_3d:
        tg_dataset.transform = BaseDataset.remove_transform(transform_3d, exclude_3d)

    # Remove some 2D transforms, if any
    is_multimodal = hasattr(tg_dataset, 'transform_image')
    if is_multimodal and exclude_2d:
        transform_2d = tg_dataset.transform_image
        tg_dataset.transform_image = BaseDatasetMM.remove_multimodal_transform(transform_2d, exclude_2d)
    
    # Get a sample from the dataset, with transforms excluded
    out = tg_dataset[idx]
    
    # Restore transforms
    tg_dataset.transform = transform_3d
    if is_multimodal and exclude_2d:
        tg_dataset.transform_image = transform_2d
        
    return out

## Visualize a single multimodal sample

We can now pick samples from the train, val and test datasets.

Please refer to `torch_points3d/visualization/multimodal_data` for more details on visualization options.

In [None]:
i_room = 20

# Pick a room in the Train set
mm_data = sample_real_data(dataset.train_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Val set
# mm_data = sample_real_data(dataset.val_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Test set
# mm_data = sample_real_data(dataset.test_dataset[0], idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

visualize_mm_data(mm_data, class_names=CLASS_NAMES, class_colors=CLASS_COLORS, front='y', figsize=1000, pointsize=3, voxel=0.05, show_2d=True, alpha=0.3)

## Run inference from pretrained weights and visualize predictions
It is possible to visualize the pointwise predictions and errors from a model. 

To do so, we will use the pretrained weights made available with this project. See `README.md` to get the download links and manually place the `.pt` files locally. You will need to provide `checkpoint_dir` where you saved those files in the next cell.

In [None]:
from torch_points3d.models.model_factory import instantiate_model

# Set your parameters
checkpoint_dir = '/directory/containing/your/checkpoint/file'

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Load the checkpoint and recover the 'latest' model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for inference
model = model.eval().cuda()
print('Model loaded')

Now we have loaded the model, we need to run a forward pass on a sample. Howver, if we want to be able to visualize the predictions, we need to pay special attention to which type of 3D and 2D transforms we apply on the data if we do not want to break the mappings. To do so, we will manually apply some sensitive transforms to be able to both infer on the data and visualize it.

In [None]:
i_room = 0

# Pick a room in the Train set
mm_data = sample_real_data(dataset.train_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Val set
# mm_data = sample_real_data(dataset.val_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Test set
# mm_data = sample_real_data(dataset.test_dataset[0], idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Extract point cloud and images from MMData object
data = mm_data.data.clone()
images = mm_data.modalities['image'].clone()

# For voxel-based 3D backbones such as SparseConv3d and MinkowskiNet, points need to be 
# preprocessed with Center and GridSampling3D. Unfortunately, Center breaks relative 
# positions between points and images. Besides, the combination of Center and GridSampling3D
# may lead to some points being merged into the same voxels, so we must apply it to both the
# inference and visualization data to make sure we have the same voxels. The workaround here 
# is to manually run these while keeping track of the centering offset
center = data.pos.mean(dim=-2, keepdim=True)
data = AddFeatsByKeys(list_add_to_x=[True], feat_names=['pos_z'], delete_feats=[True])(data)          # add z-height to the features
data = Center()(data)                                                                                 # mean-center the data
data = GridSampling3D(cfg.data.resolution_3d, quantize_coords=True, mode='last')(data)                # quantization for volumetric models

# This last voxelization step with GridSampling3D might have removed some points, so we need
# to update the mappings usign SelectMappingFromPointId. To control the size of the batch, we
# use PickImagesFromMemoryCredit. Besides, 2D models expect normalized float images, which is
# why we call ToFloatImage and Normalize
data, images = SelectMappingFromPointId()(data, images)                                               # update mappings after GridSampling3D
data, images = PickImagesFromMemoryCredit(
    img_size=cfg.data.resolution_2d, 
    k_coverage=cfg.data.multimodal.settings.k_coverage, 
    n_img=cfg.data.multimodal.settings.test_pixel_credit)(data, images)                                      # select images to respect memory constraints
data, images_infer = ToFloatImage()(data, images.clone())                                             # convert uint8 images to float
data, images_infer = Normalize()(data, images_infer)                                                  # RGB normalization

# Create a MMData for inference
mm_data_infer = MMData(data, image=images_infer)

# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data_infer])
model.set_input(batch, model.device)
model(batch)

# Create a MMData for visualization
data.pos += center
mm_data = MMData(data, image=images)

# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

In [None]:
IGNORE_LABELalize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=True, front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)