###### ScanNet

This notebook lets you instantiate the **[ScanNet](http://www.scan-net.org/)** dataset from scratch and visualize **3D+2D room samples**.

Note that you will need **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

The ScanNet dataset is composed of **rooms** of video acquisitions of indoor scenes. Thes video streams were used to produce a point cloud and images.

Each room is small enough to be loaded at once into a **64G RAM** memory. The `ScannetDatasetMM` class from `torch_points3d.datasets.segmentation.multimodal.scannet` deals with loading the room and part of the images of the associated video stream.

In [1]:
# Select you GPU
I_GPU = 0

In [2]:
# Uncomment to use autoreload
# %load_ext autoreload
# %autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES



If `visualize_mm_data` does not throw any error but the visualization does not appear, you may need to change your plotly renderer below.

In [3]:
import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

## Dataset creation

The following will instantiate the dataset. If the data is not found at `DATA_ROOT`, the folder structure will be created there and the raw dataset will be downloaded there. 

**Memory-friendly tip** : if you have already downloaded the dataset once and simply want to instantiate a new dataset with different preprocessing (*e.g* change 3D or 2D resolution, mapping parameterization, etc), I recommend you manually replicate the folder hierarchy of your already-existing dataset and create a symlink to its `raw/` directory to avoid downloading and storing (very) large files twice.

You will find the config file ruling the dataset creation at `conf/data/segmentation/multimodal/scannet-sparse.yaml`. You may edit this file or create new configs inheriting from this one using Hydra and create the associated dataset by modifying `dataset_config` accordingly in the following cell.

In [21]:
# a = torch.load("/project/fsun/dvata/scannet-neucon-smallres/processed/val.pt")
a[1]['rgb']


tensor([       0,   237360,   476621,   693707,   830476,   995832,  1169224,
         1334944,  1512844,  1806655,  2091456,  2375699,  2562556,  2761027,
         2960801,  3172207,  3371241,  3568056,  3645022,  3875694,  4070946,
         4163753,  4254134,  4305515,  4485470,  4639399,  4756244,  4868077,
         4966423,  5090995,  5215205,  5339296,  5506168,  5651818,  5822503,
         5959315,  6174210,  6390464,  6442768,  6493979,  6544307,  6721398,
         6893619,  7034218,  7106225,  7233104,  7349445,  7409883,  7469213,
         7533596,  7695992,  7743701,  7793254,  7887569,  7975642,  8068467,
         8157657,  8397011,  8625085,  8798837,  8943000,  9044783,  9138451,
         9284483,  9490239,  9696449,  9934841, 10239237, 10508511, 10855833,
        11263219, 11389604, 11575432, 11746086, 11980917, 12216694, 12636231,
        13074796, 13427809, 13628871, 13959902, 14186667, 14295184, 14406801,
        14498870, 14679798, 14758982, 14841610, 14923861, 150008

In [4]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/project/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-partial'   
models_config = 'segmentation/multimodal/sparseconv3d'    # model family
model_name = 'Res16UNet34-L4-early'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)


In [5]:
# print(OmegaConf.to_yaml(cfg))

The dataset will now be created based on the parsed configuration. I recommend having **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

As long as you do not change core dataset parameters, preprocessing should only be performed once for your dataset. It may take some time, **mostly depending on the 3D and 2D resolutions** you choose to work with (the larger the slower).

In [6]:
# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)
print(f"Time = {time() - start:0.1f} sec.")

initialize train dataset
initialize val dataset
Time = 0.7 sec.


To visualize the multimodal samples produced by the dataset, we need to remove some of the dataset transforms that affect points, images and mappings. The `sample_real_data` function will be used to get samples without breaking mappings consistency for visualization.

At training and evaluation time, these transforms are used for data augmentation, dynamic size batching (see our [paper](https://arxiv.org/submit/4264152)), etc...

In [7]:
from torch_geometric.transforms import *
from torch_points3d.core.data_transform import *
from torch_points3d.core.data_transform.multimodal.image import *
from torch_points3d.datasets.base_dataset import BaseDataset
from torch_points3d.datasets.base_dataset_multimodal import BaseDatasetMM

# Transforms on 3D points that we need to exclude for visualization purposes
augmentations_3d = [
    ElasticDistortion, Random3AxisRotation, RandomNoise, RandomRotate, 
    RandomScaleAnisotropic, RandomSymmetry, ShiftVoxels]
exclude_3d_viz = augmentations_3d + [AddFeatsByKeys, Center, GridSampling3D]

# Transforms on 2D images and mappings that we need to exclude for visualization
# purposes
augmentations_2d = [JitterMappingFeatures, ColorJitter, RandomHorizontalFlip]
exclude_2d_viz = augmentations_2d + [ToFloatImage, Normalize]


def sample_real_data(tg_dataset, idx=0, exclude_3d=None, exclude_2d=None):
    """
    Temporarily remove the 3D and 2D transforms affecting the point 
    positions and images from the dataset to better visualize points 
    and images relative positions.
    """    
    # Remove some 3D transforms
    transform_3d = tg_dataset.transform
    if exclude_3d:
        tg_dataset.transform = BaseDataset.remove_transform(transform_3d, exclude_3d)

    # Remove some 2D transforms, if any
    is_multimodal = hasattr(tg_dataset, 'transform_image')
    if is_multimodal and exclude_2d:
        transform_2d = tg_dataset.transform_image
        tg_dataset.transform_image = BaseDatasetMM.remove_multimodal_transform(transform_2d, exclude_2d)
    
    # Get a sample from the dataset, with transforms excluded
    out = tg_dataset[idx]
    
    # Restore transforms
    tg_dataset.transform = transform_3d
    if is_multimodal and exclude_2d:
        tg_dataset.transform_image = transform_2d
        
    return out

## Visualize a single multimodal sample

We can now pick samples from the train, val and test datasets.

Please refer to `torch_points3d/visualization/multimodal_data` for more details on visualization options.

In [8]:
# exact splatting
i_room = 0

# Pick a room in the Train set
mm_data = sample_real_data(dataset.train_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Val set
# mm_data = sample_real_data(dataset.val_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Test set
# mm_data = sample_real_data(dataset.test_dataset[0], idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

['/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/364.png'
 '/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/43.png'
 '/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/341.png'
 '/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/4098.png'
 '/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/2143.png'
 '/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/2748.png'
 '/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/3748.png'
 '/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/3566.png'
 '/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/116.png'
 '/project/fsun/dvata/scannet-neucon-smallres-partial/raw/scans/scene0000_00/m2f_masks/1026.png'
 '/project/fsun/dvata/scannet-neuco

In [9]:
# visualize_mm_data(mm_data, class_names=CLASS_NAMES, class_colors=CLASS_COLORS, front='y', figsize=1000, pointsize=3, voxel=0.05, show_2d=False, alpha=0.3)

In [10]:
mm_data.modalities['image'][0].x

tensor([[[[ 13,  64,  63,  ...,  62,  53,   8],
          [ 24, 213, 228,  ..., 194, 173,  18],
          [ 13, 187, 230,  ..., 198, 174,  20],
          ...,
          [  1, 116, 158,  ..., 195, 141,   1],
          [  1, 104, 140,  ..., 178, 135,   4],
          [  4,  13,  14,  ...,  24,  35,   4]],

         [[ 10,  66,  67,  ...,  59,  51,  10],
          [ 26, 222, 239,  ..., 192, 172,  21],
          [ 15, 198, 243,  ..., 187, 169,  21],
          ...,
          [  0, 116, 168,  ..., 165, 122,   0],
          [  2, 110, 151,  ..., 151, 117,   2],
          [  2,  13,  16,  ...,  16,  28,   4]],

         [[ 10,  66,  68,  ...,  53,  49,  14],
          [ 25, 221, 240,  ..., 160, 149,  15],
          [ 16, 199, 245,  ..., 160, 151,  24],
          ...,
          [  1, 122, 177,  ..., 134, 103,   0],
          [  1, 112, 158,  ..., 129, 104,   4],
          [  1,  12,  18,  ...,  14,  25,   7]]],


        [[[  6,  43,  47,  ...,  35,  30,  12],
          [ 23, 136, 155,  ..., 116

### Visualization of only seen points

In [9]:
# idx mapping from each pixel to point
# NOTE: each point is contained multiple times if it has multiple correspondences
im_data = mm_data.modalities['image']

dense_idx_list = [
            torch.arange(im.num_points, device=im_data.device).repeat_interleave(
                im.view_csr_indexing[1:] - im.view_csr_indexing[:-1])
            for im in im_data]
dense_idx_list[0]

# take subset of only seen points without re-indexing the same point
seen_mm_data = mm_data[dense_idx_list[0].unique()]

del im_data, mm_data

In [12]:
print(seen_mm_data)

visualize_mm_data(seen_mm_data, class_names=CLASS_NAMES, class_colors=CLASS_COLORS, front='y', figsize=1000, pointsize=3, voxel=0.05, show_2d=False, alpha=0.3)

MMData(
    data = Data(id_scan=[1], linearity=[17483], mapping_index=[17483], norm=[17483, 3], origin_id=[17483], planarity=[17483], pos=[17483, 3], pos_z=[17483], rgb=[17483, 3], scattering=[17483], y=[17483])
    image = ImageData(num_settings=1, num_views=5, num_points=17483, m2f_pred_mask=torch.Size([5, 1, 240, 320]), device=cpu)
)


# Process MMData to compatible input for MVFusion model

#### Step by step process:
1. ~Take the MMData object from the dataloader~
2. ~Find out how the mappings are stored and accessed. Hint: check 'SelectMappingFromPointId' and 'PickImagesFromMemoryCredit'~
3. Select 9 random views for each 3d point, while maximizing differences between camera positions. I have chosen to not use extrinsics as I hypothesize that there already will be large differences in poses when camera position is different.
4. Get M2F labels for each view
5. Add the missing features to each viewing condition vector
6. Input the prost-processed sample to MVFusion model


Random thoughts:
- Adding noise to viewing conditions 
- 3. has to be vectorized (not for looping through n_points)


In [10]:
# Step 3
most_seen_point = dense_idx_list[0].unique(return_counts=True)[1].argmax()

mm_data = seen_mm_data#[[0, most_seen_point, 1, 2]]
image_data = mm_data.modalities['image']
samesetting_data = image_data[0]   # take first in SameSettingImageData since ScanNet only has 1 setting
print(samesetting_data)

print(samesetting_data.__dict__.keys())

# camera positions
print(samesetting_data.pos)

print(samesetting_data.mappings)

SameSettingImageData(num_views=100, num_points=64888, m2f_pred_mask=torch.Size([100, 1, 240, 320]), device=cpu)
dict_keys(['_ref_size', '_proj_upscale', '_rollings', '_downscale', '_crop_size', '_crop_offsets', '_x', '_mappings', '_mask', 'path', 'pos', 'opk', 'fx', 'fy', 'mx', 'my', 'xi', 'k1', 'k2', 'gamma1', 'gamma2', 'u0', 'v0', 'extrinsic', 'visibility', 'm2f_pred_mask', 'm2f_pred_mask_path'])
tensor([[-2.1312, -0.9857,  1.3165],
        [-1.8866, -0.6650,  1.2376],
        [-1.5644, -1.0577,  1.3521],
        [ 2.1509, -0.6062,  1.2494],
        [-0.5664, -0.5263,  1.2969],
        [-2.2808, -0.0989,  1.4455],
        [ 1.3209, -0.6113,  1.2768],
        [-1.2930, -0.2358,  1.3766],
        [-1.7464, -0.7069,  1.4956],
        [ 0.4178, -0.4069,  1.2682],
        [ 1.4366, -0.6533,  1.2904],
        [-1.2774, -0.7161,  1.3763],
        [ 0.5294, -0.3331,  1.3684],
        [ 1.7830,  2.2212,  1.3499],
        [ 1.5115, -0.2909,  1.2593],
        [ 2.1264, -0.0098,  1.3579],
      

In [25]:
# how to group each point's views?
print(image_data.view_cat_csr_indexing)
print(samesetting_data.view_csr_indexing)

csr_idx = image_data.view_cat_csr_indexing

# Compute dense indices from CSR indices
n_groups = csr_idx.shape[0] - 1
dense_idx = torch.arange(n_groups).to(csr_idx.device).repeat_interleave(
    csr_idx[1:] - csr_idx[:-1])
# if src.dim() > 1:
#     dense_idx = dense_idx.view(-1, 1).repeat(1, src.shape[1])
    
print(dense_idx)

tensor([     0,      4,      8,  ..., 379406, 379407, 379408])
tensor([     0,      4,      8,  ..., 379406, 379407, 379408])
tensor([    0,     0,     0,  ..., 64885, 64886, 64887])


In [35]:
mm_data.num_points

64888

In [39]:
import time
s = time.time()

def extract_viewing_data_per_point(mm_data):
    n_views = 9
    
    
    image_data = mm_data.modalities['image']
    csr_idx = image_data.view_cat_csr_indexing

    
#     # Compute dense indices from CSR indices
#     n_groups = csr_idx.shape[0] - 1
#     dense_idx = torch.arange(n_groups).to(csr_idx.device).repeat_interleave(
#         csr_idx[1:] - csr_idx[:-1])
    
#     print(dense_idx)
    
#     # batch_idx: to which image a view-condition feature belongs
#     # CSRData: n/a
#     # viewing_conditions: feature vector of an image-pixel match
#     for x in image_data[0].mappings.values:
#         print(x)
        
    viewing_conditions = image_data[0].mappings.values[2]
    
    # Add pixel validity as first feature!
    viewing_conditions = torch.cat((torch.ones(viewing_conditions.shape[0], 1).to(viewing_conditions.device),
                                    viewing_conditions), dim=1)
    
    
    # Calculate amount of empty views. There should be n_points * n_views filled view conditions in total.
    n_seen = csr_idx[1:] - csr_idx[:-1]
    
    unfilled_points = n_seen[n_seen < n_views]
    n_views_to_fill = int(len(unfilled_points) * n_views - sum(unfilled_points))
    
    random_invalid_views = viewing_conditions[np.random.choice(range(len(viewing_conditions)), size=n_views_to_fill, replace=True)]
    # set pixel validity to invalid
    random_invalid_views[:, 0] = 0
    
    
    # faster method: conccaat viewing conditions and random invalid views, then index the tensor such that each point
    # either has 9 valid subsampled views, or is filled to 9 views with random views
    combined_tensor = torch.cat((viewing_conditions, random_invalid_views), dim=0)
    
    unused_invalid_view_idx = len(viewing_conditions)
    combined_idx = []
    for i, n in enumerate(n_seen):
        if n < n_views:
            n_empty_views = n_views -  n
            combined_idx += list(range(csr_idx[i], csr_idx[i+1])) + \
                            list(range(unused_invalid_view_idx, unused_invalid_view_idx + n_empty_views))
            unused_invalid_view_idx += n_empty_views
        elif n > n_views:
            sampled_idx = sorted(np.random.choice(range(csr_idx[i], csr_idx[i+1]), size=n_views, replace=False))
            combined_idx += sampled_idx
            
        else:
            combined_idx += list(range(csr_idx[i], csr_idx[i+1]))
    
    combined_tensor = combined_tensor[combined_idx]    
    return combined_tensor.reshape(mm_data.num_points, n_views, -1)
        
extract_viewing_data_per_point(mm_data)


tensor([[[1.0000, 0.6665, 0.1154,  ..., 0.4548, 0.1306, 0.0784],
         [1.0000, 0.4791, 0.1154,  ..., 0.4731, 0.1306, 0.3529],
         [1.0000, 0.3359, 0.1154,  ..., 0.7392, 0.1306, 0.2941],
         ...,
         [0.0000, 0.3266, 0.0614,  ..., 0.7972, 0.0277, 0.8235],
         [0.0000, 0.3633, 0.0870,  ..., 0.2801, 0.0584, 0.9216],
         [0.0000, 0.2830, 0.2495,  ..., 0.3253, 0.0429, 0.9412]],

        [[1.0000, 0.3535, 0.3130,  ..., 0.9383, 0.0848, 0.4902],
         [1.0000, 0.3318, 0.3130,  ..., 0.7669, 0.0848, 0.3725],
         [1.0000, 0.4663, 0.3130,  ..., 0.4011, 0.0848, 0.5098],
         ...,
         [0.0000, 0.5995, 0.3395,  ..., 0.0942, 0.1187, 0.3333],
         [0.0000, 0.5148, 0.1281,  ..., 0.4793, 0.0541, 0.4902],
         [0.0000, 0.3071, 0.1375,  ..., 0.1411, 0.0577, 1.0000]],

        [[1.0000, 0.4762, 0.2326,  ..., 0.4781, 0.0989, 0.5686],
         [1.0000, 0.3325, 0.2326,  ..., 0.7586, 0.0989, 0.3529],
         [1.0000, 0.5181, 0.2326,  ..., 0.3389, 0.0989, 0.

In [23]:
### Redundant code

# s = time.time()
# ### view sampling per point
# view_data_per_point, random_invalid_views = extract_viewing_data_per_point(mm_data)

# cum_sampled_views = 0
# for i in range(len(view_data_per_point)):
#     if len(view_data_per_point[i]) > 9:
#         sampled_idx = sorted(np.random.choice(range(len(view_data_per_point[i])), size=9, replace=False))
#         view_data_per_point[i] = view_data_per_point[i][sampled_idx]
#     else:
#         # Each point should have 9 views, so fill points with random invalid view conditions till it contains 9 views
#         n_views = len(view_data_per_point[i])
#         n_empty = 9 - n_views
#         view_data_per_point[i] = torch.cat((view_data_per_point[i], 
#                                             random_invalid_views[cum_sampled_views:cum_sampled_views+n_empty]), dim=0)
#         cum_sampled_views += n_empty
                
# print(time.time() - s)

# # view_data_per_point


1.5583267211914062


In [12]:

# idx mapping from each pixel to point
im_data = test_data.modalities['image']    # need SameSettingImageData object so no [0] slicing

dense_idx_list = [
            torch.arange(im.num_points, device=im_data.device).repeat_interleave(
                im.view_csr_indexing[1:] - im.view_csr_indexing[:-1])
            for im in im_data]
dense_idx_list[0]

tensor([0, 0, 1, 1])

### Trying to point-image & point-pixel mappings

In [120]:
mapping = mm_data.modalities['image'][0].mappings
mapping

ImageMapping(num_groups=81369, num_items=81821, device=cpu)

In [122]:
mapping.__dict__
# pointers: csr_idx, aka which tensors belong to the same view
# values[0]: indicates the image idx
# values[1]: CSRData holding pixel indices, called with values[1].values[0]

{'pointers': tensor([    0,     0,     0,  ..., 81821, 81821, 81821]),
 'values': [tensor([ 9,  9,  9,  ..., 13,  4, 18]),
  CSRData(num_groups=81821, num_items=81821, device=cpu),
  tensor([[0.3528, 0.2270, 0.6533,  ..., 0.4220, 0.0727, 0.8235],
          [0.3552, 0.1550, 0.5808,  ..., 0.3821, 0.0875, 0.4314],
          [0.3505, 0.0111, 0.6325,  ..., 0.3873, 0.0810, 0.3922],
          ...,
          [0.4889, 0.2754, 0.6022,  ..., 0.1002, 0.0594, 0.4314],
          [0.3923, 0.0640, 0.7785,  ..., 0.0044, 0.0873, 0.0980],
          [0.4105, 0.4782, 0.3603,  ..., 0.0367, 0.0568, 0.2157]])],
 'is_index_value': tensor([ True, False, False])}

In [126]:
# Image idx of each point-pixel match
idx_batch = mapping.values[0].repeat_interleave(
    mapping.values[1].pointers[1:] - mapping.values[1].pointers[:-1])
idx_batch

tensor([ 9,  9,  9,  ..., 13,  4, 18])

In [72]:
# idx mapping from each pixel to point
im_data = mm_data.modalities['image']

dense_idx_list = [
            torch.arange(im.num_points, device=im_data.device).repeat_interleave(
                im.view_csr_indexing[1:] - im.view_csr_indexing[:-1])
            for im in im_data]
dense_idx_list[0]

NameError: name 'mm_data' is not defined

In [135]:
dense_idx_list[0].unique(return_counts=True)[1].min()

tensor(1)

In [125]:
mm_data.modalities['image'].get_mapped_features(interpolate=False)

def get_mapped_features(mod_data, interpolate=False):
    """Return the mapped features, with optional interpolation. If
    `interpolate=False`, the mappings will be adjusted to
    `self.img_size`: the current size of the feature map `self.x`.
    """
    # Compute the feature map's sampling ratio between the input
    # `mapping_size` and the current `img_size`
    # TODO: treat scales independently. Careful with min or max
    #  depending on upscale and downscale
    scale = 1 / mod_data.downscale

    # If not interpolating, set the mapping to the proper scale
    mappings = mod_data.mappings if interpolate \
        else mod_data.mappings.rescale_images(scale)

    # Index the features with/without interpolation
    if interpolate and scale != 1:
        print("BAAAKA")
        resolution = torch.Tensor([mod_data.mapping_size]).to(mod_data.device)
        coords = mappings.pixels / (resolution - 1)
        coords = coords[:, [1, 0]]  # pixel mappings are in (W, H) format
        batch = mappings.feature_map_indexing[0]
        x = sparse_interpolation(mod_data.x, coords, batch)
    else:
        x = mod_data.x[mappings.feature_map_indexing]

    return x
    
get_mapped_features(mm_data.modalities['image'][0], interpolate=True)

#### Returns indices for extracting mapped data from a batch of image feature maps

In [11]:
im = mm_data.modalities['image'][0]
print(im.feature_map_indexing)

im = mm_data.modalities['image']
print(im.feature_map_indexing)


(tensor([22, 22, 22,  ...,  2,  2, 10]), Ellipsis, tensor([72, 75, 74,  ...,  3,  2,  8]), tensor([216, 215, 212,  ...,  12,   9, 300]))
[(tensor([22, 22, 22,  ...,  2,  2, 10]), Ellipsis, tensor([72, 75, 74,  ...,  3,  2,  8]), tensor([216, 215, 212,  ...,  12,   9, 300]))]


In [None]:
#################### I found out how indices are stored. 
#################### Now, check out how the ImageMapping is calculated! in MapImages transform


def feature_map_indexing(SameSettingImageData):  # function within ImageMapping object
    """Return the indices for extracting mapped data from the
    corresponding batch of image feature maps.

    The batch of image feature maps X is expected to have the shape
    `[B, C, H, W]`. The returned indexing object idx is intended to
    be used for recovering the mapped features as: `X[idx]`.
    """
    mappings = SameSettingImageData.mappings
    
    print(mappings.features[:10])
    
    idx_batch = mappings.images.repeat_interleave(
        mappings.values[1].pointers[1:] - mappings.values[1].pointers[:-1])
    idx_height = mappings.pixels[:, 1]
    idx_width = mappings.pixels[:, 0]
    idx = (idx_batch.long(), ..., idx_height.long(), idx_width.long())
    return idx

im = mm_data.modalities['image'][0]   # SameSettingImageData
feature_map_indexing(im)

mm_data.data.__dict__

In [27]:
images = seen_mm_data.modalities['image']
n_seen = sum([
            im.mappings.pointers[1:] - im.mappings.pointers[:-1]
            for im in images])
n_seen.argmax()   # by how many views a 3d point is seen

tensor(26556)

In [21]:
# Compute the unseen points boolean masks and split them in a
# list of masks for easier popping

# if self.use_coverage:
def PickImages(data, images):
    """
    Returns how many points each view sees
    """
    img_unseen_points = torch.zeros(
        images.num_views, data.num_nodes, dtype=torch.bool)
    i_offset = 0
    for im in images:
        mappings = im.mappings
        i_idx = mappings.images + i_offset
        j_idx = mappings.points.repeat_interleave(
            mappings.pointers[1:] - mappings.pointers[:-1])
        img_unseen_points[i_idx, j_idx] = True
        i_offset += im.num_views
    img_unseen_points = [x.numpy() for x in img_unseen_points]
    
    print([img_unseen_points[x].sum() for x in range(len(img_unseen_points))])
    
PickImages(seen_mm_data.data, seen_mm_data.modalities['image'])

[3803, 4053, 4472, 3447, 4429, 3796, 3490, 3764, 4196, 3650, 3520, 2221, 3775, 3414, 3240, 3109, 3219, 3967, 3825, 3156, 5118, 2617, 3521, 4046, 2274]


## Run inference from pretrained weights and visualize predictions
It is possible to visualize the pointwise predictions and errors from a model. 

To do so, we will use the pretrained weights made available with this project. See `README.md` to get the download links and manually place the `.pt` files locally. You will need to provide `checkpoint_dir` where you saved those files in the next cell.

In [None]:
from torch_points3d.models.model_factory import instantiate_model

# Set your parameters
checkpoint_dir = '/home/fsun/DeepViewAgg/model_checkpoints'

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Load the checkpoint and recover the 'latest' model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for inference
model = model.eval().cuda()
print('Model loaded')

Now we have loaded the model, we need to run a forward pass on a sample. Howver, if we want to be able to visualize the predictions, we need to pay special attention to which type of 3D and 2D transforms we apply on the data if we do not want to break the mappings. To do so, we will manually apply some sensitive transforms to be able to both infer on the data and visualize it.

In [None]:
i_room = 3

# Pick a room in the Train set
mm_data = sample_real_data(dataset.train_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Val set
# mm_data = sample_real_data(dataset.val_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Test set
# mm_data = sample_real_data(dataset.test_dataset[0], idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Extract point cloud and images from MMData object
data = mm_data.data.clone()
images = mm_data.modalities['image'].clone()

# For voxel-based 3D backbones such as SparseConv3d and MinkowskiNet, points need to be 
# preprocessed with Center and GridSampling3D. Unfortunately, Center breaks relative 
# positions between points and images. Besides, the combination of Center and GridSampling3D
# may lead to some points being merged into the same voxels, so we must apply it to both the
# inference and visualization data to make sure we have the same voxels. The workaround here 
# is to manually run these while keeping track of the centering offset
center = data.pos.mean(dim=-2, keepdim=True)
data = AddFeatsByKeys(list_add_to_x=[True], feat_names=['pos_z'], delete_feats=[True])(data)          # add z-height to the features
data = Center()(data)                                                                                 # mean-center the data
data = GridSampling3D(cfg.data.resolution_3d, quantize_coords=True, mode='last')(data)                # quantization for volumetric models

# This last voxelization step with GridSampling3D might have removed some points, so we need
# to update the mappings usign SelectMappingFromPointId. To control the size of the batch, we
# use PickImagesFromMemoryCredit. Besides, 2D models expect normalized float images, which is
# why we call ToFloatImage and Normalize
data, images = SelectMappingFromPointId()(data, images)                                               # update mappings after GridSampling3D
data, images = PickImagesFromMemoryCredit(
    img_size=cfg.data.resolution_2d, 
    k_coverage=cfg.data.multimodal.settings.k_coverage, 
    n_img=cfg.data.multimodal.settings.test_pixel_credit)(data, images)                                      # select images to respect memory constraints
data, images_infer = ToFloatImage()(data, images.clone())                                             # convert uint8 images to float
data, images_infer = Normalize()(data, images_infer)                                                  # RGB normalization

# Create a MMData for inference
mm_data_infer = MMData(data, image=images_infer)

# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data_infer])
model.set_input(batch, model.device)
model(batch)

# Create a MMData for visualization
data.pos += center
mm_data = MMData(data, image=images)

# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

In [None]:
visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=True, front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)