###### ScanNet

This notebook lets you instantiate the **[ScanNet](http://www.scan-net.org/)** dataset from scratch and visualize **3D+2D room samples**.

Note that you will need **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

The ScanNet dataset is composed of **rooms** of video acquisitions of indoor scenes. Thes video streams were used to produce a point cloud and images.

Each room is small enough to be loaded at once into a **64G RAM** memory. The `ScannetDatasetMM` class from `torch_points3d.datasets.segmentation.multimodal.scannet` deals with loading the room and part of the images of the associated video stream.

In [1]:
# Select you GPU
I_GPU = 0


In [2]:
# Uncomment to use autoreload
%load_ext autoreload
%autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES, CLASS_LABELS
from torch_points3d.metrics.segmentation_tracker import SegmentationTracker
from torch_points3d.datasets.segmentation import IGNORE_LABEL

from pykeops.torch import LazyTensor

import matplotlib.pyplot as plt 

%matplotlib inline

MMData debug() function changed, please uncomment the 3rd assert line when doing inference without M2F features!


In [3]:
CLASS_COLORS[0] = (174.0, 199.0, 232.0)
CLASS_COLORS[-1] = (0, 0, 0)

If `visualize_mm_data` does not throw any error but the visualization does not appear, you may need to change your plotly renderer below.

In [4]:
import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

In [5]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/scratch-shared/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D_small_6views'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks
cfg.data.m2f_preds_dirname = 'm2f_masks'
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")

6
Load predicted 2D semantic segmentation labels from directory  m2f_masks
initialize train dataset
initialize val dataset
Time = 7.5 sec.


In [6]:
from torch_points3d.models.model_factory import instantiate_model

# ViT_masks 3rd run
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/ViT_masks_3rd_run' # 3rd run
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/MVFusion_3D_6_views_m2f_masks'

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Load the checkpoint and recover the 'best_miou' model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for training
model = model.cuda()
print('Model loaded')

Creating model: MVFusion_3D_small_6views
task:  segmentation.multimodal
tested_model_name:  MVFusion_3D_small_6views
class_name:  MVFusionAPIModel
model_module:  torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d
name, cls of chosen model_cls:  MVFusionAPIModel <class 'torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d.MVFusionAPIModel'>
x feature dim:  {'FEAT': 3}
nc_in:  67
nc_in:  64
nc_in:  32
nc_in:  64
nc_in:  128
nc_in:  256
nc_in:  128
nc_in:  128
nc_in:  96
nc_in:  96
Model loaded


In [7]:
def get_seen_points(mm_data):
    ### Select seen points
    csr_idx = mm_data.modalities['image'][0].view_csr_indexing
    dense_idx_list = torch.arange(mm_data.modalities['image'][0].num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
    # take subset of only seen points without re-indexing the same point
    mm_data = mm_data[dense_idx_list.unique()]
    return mm_data

def get_mode_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds


In [8]:
from torch_points3d.metrics.scannet_segmentation_tracker import ScannetSegmentationTracker

In [9]:
baseline_tracker = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)
mvfusion_tracker = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)


In [2]:
a = {0: '89.67', 1: '98.16', 2: '72.49', 3: '85.95', 4: '90.83', 5: '81.63', 6: '82.44', 7: '70.54', 8: '66.33', 9: '79.68', 10: '44.70', 11: '69.93', 12: '73.52', 13: '81.89', 14: '73.40', 15: '78.48', 16: '94.22', 17: '71.27', 18: '88.69', 19: '67.45'}
" & ".join(list(a.values()))

'89.67 & 98.16 & 72.49 & 85.95 & 90.83 & 81.63 & 82.44 & 70.54 & 66.33 & 79.68 & 44.70 & 69.93 & 73.52 & 81.89 & 73.40 & 78.48 & 94.22 & 71.27 & 88.69 & 67.45'

In [11]:
baseline_tracker.reset(stage='train')
mvfusion_tracker.reset(stage='train')


for mm_data in dataset.val_dataset:
    print(mm_data.id_scan)

    # Create a MMBatch and run inference
    batch = MMBatch.from_mm_data_list([mm_data])

    model.set_input(batch, model.device)
    model(batch)

    # Recover the predicted labels for visualization
    mm_data.data.pred = model.output.detach().cpu().argmax(1)
        
    mm_data = get_seen_points(mm_data)
    baseline_preds = get_mode_pred(mm_data)
    
    mvfusion_tracker.track(pred_labels=mm_data.data.pred, gt_labels=mm_data.data.y, model=None)       
    baseline_tracker.track(pred_labels=baseline_preds, gt_labels=mm_data.data.y, model=None)       

    break
    
print("3D validation IoU of seen points")
print("baseline ", baseline_tracker.get_metrics())
print(" & ".join(list(baseline_tracker._miou_per_class.values())))

print("mvfusion ", mvfusion_tracker.get_metrics())
print(" & ".join(list(mvfusion_tracker._miou_per_class.values())))


tensor([0])
3D validation IoU of seen points
baseline  {'train_acc': 89.08670027260406, 'train_macc': 83.86834171426358, 'train_miou': 58.07665523950559}
80.45 & 90.96 & 67.72 & 0.00 & 74.94 & 0.00 & 76.04 & 79.09 & 60.49 & 0.00 & 0.00 & 73.31 & 0.00 & 0.00 & 63.95 & 0.00 & 0.00 & 64.10 & 0.00 & 82.00
mvfusion  {'train_acc': 93.6874647822975, 'train_macc': 88.87996157274955, 'train_miou': 68.72089955410117}
81.79 & 99.23 & 68.40 & 0.00 & 93.64 & 0.00 & 90.17 & 88.63 & 67.45 & 0.00 & 0.00 & 76.53 & 0.00 & 0.00 & 63.96 & 0.00 & 0.00 & 70.63 & 0.00 & 92.93


In [None]:
print(mode_preds.shape)
print(mm_data.data.pred.shape)

In [None]:
m2f_mm_data = mm_data.clone()
m2f_mm_data.data.x = None
m2f_mm_data.data.pred = mode_preds
# m2f_mm_data.data.pred = m2f_mm_data.data.pred[m2f_mm_data.data.y != -1]
m2f_mm_data = m2f_mm_data[m2f_mm_data.data.y != -1]

visualize_mm_data(m2f_mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_mask_pred', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
mm_data.data.x = None
mm_data.data.pred = mm_data.data.pred[mm_data.data.y != -1]
mm_data = mm_data[mm_data.data.y != -1]


print(mm_data.data.pred.unique())
mm_data.data.y.unique()

In [None]:
visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_pred_mask', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)