###### ScanNet

This notebook lets you instantiate the **[ScanNet](http://www.scan-net.org/)** dataset from scratch and visualize **3D+2D room samples**.

Note that you will need **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

The ScanNet dataset is composed of **rooms** of video acquisitions of indoor scenes. Thes video streams were used to produce a point cloud and images.

Each room is small enough to be loaded at once into a **64G RAM** memory. The `ScannetDatasetMM` class from `torch_points3d.datasets.segmentation.multimodal.scannet` deals with loading the room and part of the images of the associated video stream.

In [1]:
# Select you GPU
I_GPU = 0


In [2]:
# Uncomment to use autoreload
%load_ext autoreload
%autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES, CLASS_LABELS
from torch_points3d.metrics.segmentation_tracker import SegmentationTracker

from pykeops.torch import LazyTensor

MMData debug() function changed, please uncomment the 3rd assert line when doing inference without M2F features!


If `visualize_mm_data` does not throw any error but the visualization does not appear, you may need to change your plotly renderer below.

In [3]:
import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

## Dataset creation

The following will instantiate the dataset. If the data is not found at `DATA_ROOT`, the folder structure will be created there and the raw dataset will be downloaded there. 

**Memory-friendly tip** : if you have already downloaded the dataset once and simply want to instantiate a new dataset with different preprocessing (*e.g* change 3D or 2D resolution, mapping parameterization, etc), I recommend you manually replicate the folder hierarchy of your already-existing dataset and create a symlink to its `raw/` directory to avoid downloading and storing (very) large files twice.

You will find the config file ruling the dataset creation at `conf/data/segmentation/multimodal/scannet-sparse.yaml`. You may edit this file or create new configs inheriting from this one using Hydra and create the associated dataset by modifying `dataset_config` accordingly in the following cell.

In [None]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/project/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks

The dataset will now be created based on the parsed configuration. I recommend having **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

As long as you do not change core dataset parameters, preprocessing should only be performed once for your dataset. It may take some time, **mostly depending on the 3D and 2D resolutions** you choose to work with (the larger the slower).

In [8]:
# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")

load_m2f_masks:  True
initialize test dataset
line 720 scannet.py: split == 'test'
Time = 1.3 sec.


In [12]:
dataset.test_dataset[0].num_classes

20

To visualize the multimodal samples produced by the dataset, we need to remove some of the dataset transforms that affect points, images and mappings. The `sample_real_data` function will be used to get samples without breaking mappings consistency for visualization.

At training and evaluation time, these transforms are used for data augmentation, dynamic size batching (see our [paper](https://arxiv.org/submit/4264152)), etc...

In [6]:
from torch_geometric.transforms import *
from torch_points3d.core.data_transform import *
from torch_points3d.core.data_transform.multimodal.image import *
from torch_points3d.datasets.base_dataset import BaseDataset
from torch_points3d.datasets.base_dataset_multimodal import BaseDatasetMM

# Transforms on 3D points that we need to exclude for visualization purposes
augmentations_3d = [
    ElasticDistortion, Random3AxisRotation, RandomNoise, RandomRotate, 
    RandomScaleAnisotropic, RandomSymmetry, ShiftVoxels]
exclude_3d_viz = augmentations_3d + [AddFeatsByKeys, Center, GridSampling3D]

# Transforms on 2D images and mappings that we need to exclude for visualization
# purposes
augmentations_2d = [JitterMappingFeatures, ColorJitter, RandomHorizontalFlip]
exclude_2d_viz = [RandomHorizontalFlip]
exclude_2d_viz = augmentations_2d + [ToFloatImage, Normalize]



def sample_real_data(tg_dataset, idx=0, exclude_3d=None, exclude_2d=None):
    """
    Temporarily remove the 3D and 2D transforms affecting the point 
    positions and images from the dataset to better visualize points 
    and images relative positions.
    """    
    # Remove some 3D transforms
    transform_3d = tg_dataset.transform
    if exclude_3d:
        tg_dataset.transform = BaseDataset.remove_transform(transform_3d, exclude_3d)

    # Remove some 2D transforms, if any
    is_multimodal = hasattr(tg_dataset, 'transform_image')
    if is_multimodal and exclude_2d:
        transform_2d = tg_dataset.transform_image
        tg_dataset.transform_image = BaseDatasetMM.remove_multimodal_transform(transform_2d, exclude_2d)
    
    # Get a sample from the dataset, with transforms excluded
    out = tg_dataset[idx]
    
    # Restore transforms
    tg_dataset.transform = transform_3d
    if is_multimodal and exclude_2d:
        tg_dataset.transform_image = transform_2d
        
    return out

## Visualize a single multimodal sample

We can now pick samples from the train, val and test datasets.

Please refer to `torch_points3d/visualization/multimodal_data` for more details on visualization options.

In [7]:
# mm_data.modalities['image'][0].x.shape
# mm_data.modalities['image'][0].m2f_pred_mask.shape

In [8]:
mm_data = dataset.train_dataset[0]
mm_data

MMData(
    data = Data(coords=[52285, 3], grid_size=[1], id_scan=[1], mapping_index=[52285], mvfusion_input=[42774, 9, 10], origin_id=[52285], pos=[52285, 3], x=[52285, 3], y=[52285])
    image = ImageData(num_settings=1, num_views=100, num_points=52285, device=cpu)
)

In [24]:
temp = mm_data[torch.round(torch.linspace(0, 52285-1, 5000)).long()]

In [21]:
# mm_data = sample_real_data(dataset.train_dataset, idx=2, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)
# print(mm_data)

In [22]:
visualize_mm_data(temp, class_names=CLASS_NAMES, class_colors=CLASS_COLORS, error_color=(0, 0, 0), front='y', back='m2f_pred_mask', figsize=1000, pointsize=3, voxel=0.03, show_2d=True, alpha=0.3)

# if it gives NotImplementedError in multimodal_data.py, please retain original features in data.data.x 
# inside the dataset __getitem__.

DFD
bacl:  m2f_pred_mask
LOLKEK
back is m2f_pred_mask
torch.Size([42, 1, 240, 320])
class_colors:  [(174.0, 199.0, 232.0), (152.0, 223.0, 138.0), (31.0, 119.0, 180.0), (255.0, 187.0, 120.0), (188.0, 189.0, 34.0), (140.0, 86.0, 75.0), (255.0, 152.0, 150.0), (214.0, 39.0, 40.0), (197.0, 176.0, 213.0), (148.0, 103.0, 189.0), (196.0, 156.0, 148.0), (23.0, 190.0, 207.0), (247.0, 182.0, 210.0), (219.0, 219.0, 141.0), (255.0, 127.0, 14.0), (158.0, 218.0, 229.0), (44.0, 160.0, 44.0), (112.0, 128.0, 144.0), (227.0, 119.0, 194.0), (82.0, 84.0, 163.0), (225, 225, 255)]
torch.ByteTensor(class_colors) torch.Size([21, 3])
im.pred torch.Size([42, 240, 320])
 im.background  torch.Size([42, 3, 240, 320])
images[-1].background torch.Size([42, 3, 240, 320])
im pred:  torch.Size([42, 240, 320])
pred
data.y tensor([ 1,  1,  1,  ..., -1, -1,  1])
class_colors [(174.0, 199.0, 232.0), (152.0, 223.0, 138.0), (31.0, 119.0, 180.0), (255.0, 187.0, 120.0), (188.0, 189.0, 34.0), (140.0, 86.0, 75.0), (255.0, 152.0, 

In [None]:
# # exact splatting
# i_room = 1

# # Pick a room in the Train set
# mm_data = sample_real_data(dataset.train_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# # Pick a room in the Val set
# # mm_data = sample_real_data(dataset.val_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# mm_data

# # 2.6422173976898193 previous for_1 loop
# # 3.252363681793213 for1 with random sample
# # 2.0284531116485596 for1 with numpy choice

In [None]:
def computeScores(data, preds, dataset, stage='train', tracker=None, reset=False):
    if len(preds.shape) < 2:
        preds = torch.nn.functional.one_hot(preds, num_classes=20)

    if tracker is None:
        tracker = SegmentationTracker(dataset, stage=stage, wandb_log=False, use_tensorboard=False, ignore_label=-1)
    
    if reset is True:
        tracker.reset()
    
    tracker._compute_metrics(outputs=preds, labels=data.data.y)
    
    return tracker.get_metrics()

In [None]:
# def get_m2f_features(data, n_views):
# #     n_views = 9

# #     ### Make a 3D point cloud populated by mode M2F labels!
# #     m2f_mapped_feats = data.modalities['image'].get_mapped_m2f_features(interpolate=True)[0]
# #     csr_idx = data.modalities['image'].view_cat_csr_indexing

# #     # Calculate amount of empty views. There should be n_points * n_views filled view conditions in total.
# #     n_seen = csr_idx[1:] - csr_idx[:-1]
# #     unfilled_points = n_seen[n_seen < n_views]
# #     n_views_to_fill = int(len(unfilled_points) * n_views - sum(unfilled_points))

# #     random_m2f_preds = m2f_mapped_feats[np.random.choice(range(len(m2f_mapped_feats)), size=n_views_to_fill, replace=True)]


# #     combined_m2f_tensor = torch.cat((m2f_mapped_feats, random_m2f_preds), dim=0)

# #     unused_invalid_view_idx = len(m2f_mapped_feats)
# #     combined_idx = []
# #     for i, n in enumerate(n_seen):
# #         if n < n_views:
# #             n_empty_views = n_views -  n
# #             combined_idx += list(range(csr_idx[i], csr_idx[i+1])) + \
# #                             list(range(unused_invalid_view_idx, unused_invalid_view_idx + n_empty_views))
# #             unused_invalid_view_idx += n_empty_views
# #         elif n > n_views:
# #             sampled_idx = sorted(np.random.choice(range(csr_idx[i], csr_idx[i+1]), size=n_views, replace=False))
# #             combined_idx += sampled_idx
# #         else:
# #             combined_idx += list(range(csr_idx[i], csr_idx[i+1]))

# #     # re-index tensor for MVFusion format
# #     combined_m2f_tensor = combined_m2f_tensor[combined_idx].reshape(data.num_points, n_views)
    
# #     return combined_m2f_tensor


#     pixel_validity = data.data.x[:, 0, 0]
#     print(pixel_validity, pixel_validity.shape)
    
    
#     # same for m2f feats
#     out = data.data.x
#     print(out, out.shape)

In [None]:
def get_mode_pred(data):
    pixel_validity = data.data.x[:, :, 0].bool()
    mv_preds = data.data.x[:, :, -1].long()
        
    n_views = 9
    
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds

def eval_small_dataset(stage='train', eval_unseen=False, exclude_2d=None, exclude_3d=None):
    # train, val

    tracker = SegmentationTracker(dataset, stage=stage, wandb_log=False, use_tensorboard=False, ignore_label=-1)

    for i_room in range(3):
        if stage == 'train':
#             data = dataset.train_dataset[i_room]
            data = sample_real_data(dataset.train_dataset, idx=i_room, exclude_3d=exclude_3d, exclude_2d=exclude_2d)
        elif stage == 'val':
            data = dataset.val_dataset[i_room]
        
#         if not eval_unseen:
#             # Take subset of seen points
#             images = data.modalities['image']
#             dense_idx_list = [
#                         torch.arange(im.num_points, device=images.device).repeat_interleave(
#                             im.view_csr_indexing[1:] - im.view_csr_indexing[:-1])
#                         for im in images]
#             data = data[dense_idx_list[0].unique()]

#         # Remove points with ignore label for visual clearance
#         data = data[data.data.y != -1]

        # Get M2F features 
        # Semantic predictions as mode of multi-view predictions
        mode_preds = get_mode_pred(data)

#         print("M2F preds: ", mode_preds.unique())
#         print("GT labels: ", data.data.y.unique()[1:])
        
        if eval_unseen:
            seen_mask = csr_idx[1:] > csr_idx[:-1]
                        
            xyz_query_keops = LazyTensor(data.pos[~seen_mask][:, None, :])
            xyz_search_keops = LazyTensor(data.pos[seen_mask][None, :, :])
            d_keops = ((xyz_query_keops - xyz_search_keops) ** 2).sum(dim=2)
            nn_idx = d_keops.argmin(dim=1)
            del xyz_query_keops, xyz_search_keops, d_keops

            temp = torch.zeros(seen_mask.shape)
            temp[seen_mask] = mode_preds
            
            temp[~seen_mask] = temp[seen_mask][nn_idx].squeeze()
            mode_preds = temp
                

        # Calculate scores
        scores = computeScores(data, mode_preds, dataset, stage=stage, tracker=tracker, reset=False)
        return scores[f'{stage}_miou']
        
#         if i_room == 1:
#             data.data.pred = mode_preds
#             data.data.x = None
#             visualize_mm_data(data, class_names=CLASS_NAMES, class_colors=CLASS_COLORS, error_color=(0, 0, 0), front='y', back='m2f_pred_mask', figsize=1000, pointsize=3, voxel=0.05, show_2d=True, alpha=0.3)
     
# s = []
# for i in range(5):
#     s.append(eval_small_dataset(stage='train', eval_unseen=False))
    
# print('avg miou over 5 loops: ', sum(s) / len(s))

# train set when invalid points are removed
#{'train_acc': 91.86458520553573, 'train_macc': 85.37448181167012, 'train_miou': 67.662546940765}
#{'train_acc': 91.77955465848957, 'train_macc': 84.59272635063425, 'train_miou': 63.49855216414911}
#{'train_acc': 91.74896300196484, 'train_macc': 84.32568566228107, 'train_miou': 66.4833291671551}
# avg miou over 10 loops:  62.50393662841291



# train set with invalid points
#{'train_acc': 92.08282038361197, 'train_macc': 85.37962358651747, 'train_miou': 68.2135868206652}
#{'train_acc': 91.5983554227636, 'train_macc': 85.08323064550525, 'train_miou': 67.1099501308721}
#{'train_acc': 92.09589199021353, 'train_macc': 86.20414449812343, 'train_miou': 68.7643864009719}
#{'train_acc': 91.98125655730847, 'train_macc': 85.79685840582141, 'train_miou': 68.38436115692215}
# avg miou over 10 loops:  63.550427023383996


# val set
# avg miou over 10 loops:  56.41112108349539

# avg miou over 10 loops:  54.78561234690687


# train: keeping invalid points
# 3 seeds -> 

# val
# avg miou over 3 loops:  57.956132626631245


In [None]:
augmentations_3d = [
    ElasticDistortion, Random3AxisRotation, RandomNoise, RandomRotate, 
    RandomScaleAnisotropic, RandomSymmetry, ShiftVoxels]

exclude_3d_viz = augmentations_3d + [AddFeatsByKeys, Center, GridSampling3D]
# exclude_3d_viz = None



augmentations_2d = [JitterMappingFeatures, ColorJitter, RandomHorizontalFlip]
exclude_2d_viz = augmentations_2d + [ToFloatImage, Normalize]
# exclude_2d_viz = None #[RandomHorizontalFlip]


# s = []
# for i in range(5):
#     s.append(eval_small_dataset(stage='train', eval_unseen=False, 
#                                 exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz))

# print()
# print('avg miou over 5 loops: ', sum(s) / len(s))
# print()

s = []
for i in range(3):
    s.append(eval_small_dataset(stage='val', eval_unseen=False, 
                                exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz))
print()
print('avg miou over 3 loops: ', sum(s) / len(s))
print()    

# boosdoeners zijn niet: ToFloatImage, Normalize, JitterMappingFeatures, ColorJitter
# miOu ~ 63
# RandomHorizontalFlip -> 47

In [None]:
# Transforms on 3D points that we need to exclude for visualization purposes
augmentations_3d = [
    ElasticDistortion, Random3AxisRotation, RandomNoise, RandomRotate, 
    RandomScaleAnisotropic, RandomSymmetry, ShiftVoxels]
exclude_3d_viz = augmentations_3d + [AddFeatsByKeys, Center, GridSampling3D]

# Transforms on 2D images and mappings that we need to exclude for visualization
# purposes
augmentations_2d = [JitterMappingFeatures, ColorJitter, RandomHorizontalFlip]
exclude_2d_viz = augmentations_2d + [ToFloatImage, Normalize]

for i in range(len(exclude_3d_viz)):
    temp_3d = exclude_3d_viz.copy()
    included_3d_transformation = temp_3d.pop(i)
    
    s = []
    for i in range(5):
        s.append(eval_small_dataset(stage='train', eval_unseen=False, 
                                    exclude_3d=temp_3d, exclude_2d=exclude_2d_viz))
        
    print()
    print(included_3d_transformation)
    print('avg miou over 5 loops: ', sum(s) / len(s))
    print()
        
for i in range(len(exclude_2d_viz)):
    temp_2d = exclude_2d_viz.copy()
    included_2d_transformation = temp_2d.pop(i)
    
    s = []
    for i in range(5):
        s.append(eval_small_dataset(stage='train', eval_unseen=False, 
                                    exclude_3d=exclude_3d_viz, exclude_2d=temp_2d))
    print()
    print(included_2d_transformation)
    print('avg miou over 5 loops: ', sum(s) / len(s))
    print()    
    


In [None]:
images = mm_data.modalities['image']
# Take subset of only seen points
# NOTE: each point is contained multiple times if it has multiple correspondences
dense_idx_list = [
            torch.arange(im.num_points, device=images.device).repeat_interleave(
                im.view_csr_indexing[1:] - im.view_csr_indexing[:-1])
            for im in images]
# take subset of only seen points without re-indexing the same point
data = mm_data[dense_idx_list[0].unique()]

data

In [None]:
# visualize_mm_data(data, class_names=CLASS_NAMES, class_colors=CLASS_COLORS, front='y', back='m2f_pred_mask', figsize=1000, pointsize=3, voxel=0.05, show_2d=True, alpha=1.0)

### Per scene evaluation

In [None]:
### M2F mode pred scores
mode_preds = torch.mode(combined_m2f_tensor, dim=-1)[0].long()


# mode_preds[mode_preds == 4] = 0
# data.data.y[0] = 4


print(mode_preds.unique())
print(data.data.y.unique()[1:])
print(CLASS_LABELS)

In [None]:
print(mode_preds.unique(return_counts=True))


### Check if bookshelf preds are still in preds after filtering out ignore_label from pred and labels

In [None]:
# combined tracker
computeScores(data, mode_preds, dataset, stage='train', tracker=tracker, reset=True)

# """
# {'train_acc': 84.92801236834477,
#  'train_macc': 72.13466411119538,
#  'train_miou': 57.735325170792414}
#  """


# Score with chair prediction, scene0000_00
# {'train_acc': 78.26635359866494,
#  'train_macc': 68.09951978180109,
#  'train_miou': 53.20302327275945}

# without chair pred
# {'train_acc': 78.26635359866494,
#  'train_macc': 68.09951978180109,
#  'train_miou': 55.494691728583014}

In [None]:
print(100 * tracker._confusion_matrix.get_average_intersection_union(missing_as_one=False))
print(tracker._confusion_matrix.get_intersection_union_per_class())

In [None]:
np.asarray([6.37603604e-01, 7.16119607e-01, 6.55039517e-01, 7.14544805e-01,
       1.00000000e-08, 8.21926115e-01, 3.66729689e-01, 6.36812422e-01,
       2.43762006e-01,   2.46813451e-01,
       4.66349027e-01, 7.12784598e-01, 7.48108936e-01, 
       4.62686577e-01, 6.39296198e-01, 4.37693110e-01]).mean()

In [None]:
def save_confusion_matrix(cm, path2save=None, ordered_names=None):
    import seaborn as sns
    import matplotlib.pyplot as plt

    sns.set(font_scale=5)
    
#     template_path = os.path.join(path2save, "{}.svg")
    # PRECISION
    cmn = cm.astype("float") / cm.sum(axis=-1)[:, np.newaxis]
    cmn[np.isnan(cmn) | np.isinf(cmn)] = 0
    fig, ax = plt.subplots(figsize=(31, 31))
    sns.heatmap(
        cmn, annot=True, fmt=".2f", xticklabels=ordered_names, yticklabels=ordered_names, annot_kws={"size": 20}
    )
    # g.set_xticklabels(g.get_xticklabels(), rotation = 35, fontsize = 20)
    plt.ylabel("Actual")
    plt.xlabel("Predicted")
    plt.show()
    
#     path_precision = template_path.format("precision")
#     plt.savefig(path_precision, format="svg")

    # RECALL
    cmn = cm.astype("float") / cm.sum(axis=0)[np.newaxis, :]
    cmn[np.isnan(cmn) | np.isinf(cmn)] = 0
    fig, ax = plt.subplots(figsize=(31, 31))
    sns.heatmap(
        cmn, annot=True, fmt=".2f", xticklabels=ordered_names, yticklabels=ordered_names, annot_kws={"size": 20}
    )
    # g.set_xticklabels(g.get_xticklabels(), rotation = 35, fontsize = 20)
    plt.ylabel("Actual")
    plt.xlabel("Predicted")
    plt.show()
    
#     path_recall = template_path.format("recall")
    #plt.savefig(path_recall, format="svg")


save_confusion_matrix(tracker._confusion_matrix.confusion_matrix, ordered_names=CLASS_LABELS)

In [None]:
# scene 0000_00
computeScores(data, mode_preds, dataset, stage='train', tracker=None)

{'train_acc': 78.5082992024143,
 'train_macc': 68.44232036545344,
 'train_miou': 53.298953827959814}

In [None]:
# scene0190_0
computeScores(data, mode_preds, dataset, stage='train', tracker=None)\

{'train_acc': 88.12224636669923,
 'train_macc': 66.77507086551887,
 'train_miou': 41.859285338265536}

In [None]:
# scene0190_1
computeScores(data, mode_preds, dataset, stage='train', tracker=None)

{'train_acc': 88.02479415302064,
 'train_macc': 75.62845810740538,
 'train_miou': 46.844401523580046}

# Process MMData to compatible input for MVFusion model

#### Step by step process:
1. ~Take the MMData object from the dataloader~
2. ~Find out how the mappings are stored and accessed. Hint: check 'SelectMappingFromPointId' and 'PickImagesFromMemoryCredit'~
3. Select 9 random views for each 3d point, while maximizing differences between camera positions. I have chosen to not use extrinsics as I hypothesize that there already will be large differences in poses when camera position is different.
4. Get M2F labels for each view
5. Add the missing features to each viewing condition vector
6. Input the prost-processed sample to MVFusion model


Random thoughts:
- Adding noise to viewing conditions 
- 3. has to be vectorized (not for looping through n_points)


In [None]:
# # Step 3
# most_seen_point = dense_idx_list[0].unique(return_counts=True)[1].argmax()

# mm_data = seen_mm_data#[[0, most_seen_point, 1, 2]]
# image_data = mm_data.modalities['image']
# samesetting_data = image_data[0]   # take first in SameSettingImageData since ScanNet only has 1 setting
# print(samesetting_data)

# print(samesetting_data.__dict__.keys())

# # camera positions
# print(samesetting_data.pos)

# print(samesetting_data.mappings)

In [None]:
# how to group each point's views?
print(image_data.view_cat_csr_indexing)
print(samesetting_data.view_csr_indexing)

csr_idx = image_data.view_cat_csr_indexing

# Compute dense indices from CSR indices
n_groups = csr_idx.shape[0] - 1
dense_idx = torch.arange(n_groups).to(csr_idx.device).repeat_interleave(
    csr_idx[1:] - csr_idx[:-1])
# if src.dim() > 1:
#     dense_idx = dense_idx.view(-1, 1).repeat(1, src.shape[1])
    
print(dense_idx)

In [None]:
mm_data.num_points

In [None]:
# import time
# s = time.time()

# def extract_viewing_data_per_point(mm_data):
#     n_views = 9
    
    
#     image_data = mm_data.modalities['image']
#     csr_idx = image_data.view_cat_csr_indexing

    
# #     # Compute dense indices from CSR indices
# #     n_groups = csr_idx.shape[0] - 1
# #     dense_idx = torch.arange(n_groups).to(csr_idx.device).repeat_interleave(
# #         csr_idx[1:] - csr_idx[:-1])
    
# #     print(dense_idx)
    
# #     # batch_idx: to which image a view-condition feature belongs
# #     # CSRData: n/a
# #     # viewing_conditions: feature vector of an image-pixel match
# #     for x in image_data[0].mappings.values:
# #         print(x)
        
#     viewing_conditions = image_data[0].mappings.values[2]
    
#     # Add pixel validity as first feature!
#     viewing_conditions = torch.cat((torch.ones(viewing_conditions.shape[0], 1).to(viewing_conditions.device),
#                                     viewing_conditions), dim=1)
    
    
#     # Calculate amount of empty views. There should be n_points * n_views filled view conditions in total.
#     n_seen = csr_idx[1:] - csr_idx[:-1]
    
#     unfilled_points = n_seen[n_seen < n_views]
#     n_views_to_fill = int(len(unfilled_points) * n_views - sum(unfilled_points))
    
#     random_invalid_views = viewing_conditions[np.random.choice(range(len(viewing_conditions)), size=n_views_to_fill, replace=True)]
#     # set pixel validity to invalid
#     random_invalid_views[:, 0] = 0
    
    
#     # faster method: conccaat viewing conditions and random invalid views, then index the tensor such that each point
#     # either has 9 valid subsampled views, or is filled to 9 views with random views
#     combined_tensor = torch.cat((viewing_conditions, random_invalid_views), dim=0)
    
#     unused_invalid_view_idx = len(viewing_conditions)
#     combined_idx = []
#     for i, n in enumerate(n_seen):
#         if n < n_views:
#             n_empty_views = n_views -  n
#             combined_idx += list(range(csr_idx[i], csr_idx[i+1])) + \
#                             list(range(unused_invalid_view_idx, unused_invalid_view_idx + n_empty_views))
#             unused_invalid_view_idx += n_empty_views
#         elif n > n_views:
#             sampled_idx = sorted(np.random.choice(range(csr_idx[i], csr_idx[i+1]), size=n_views, replace=False))
#             combined_idx += sampled_idx
            
#         else:
#             combined_idx += list(range(csr_idx[i], csr_idx[i+1]))
    
#     combined_tensor = combined_tensor[combined_idx]    
#     return combined_tensor.reshape(mm_data.num_points, n_views, -1)
        
# extract_viewing_data_per_point(mm_data)


In [None]:
### Redundant code

# s = time.time()
# ### view sampling per point
# view_data_per_point, random_invalid_views = extract_viewing_data_per_point(mm_data)

# cum_sampled_views = 0
# for i in range(len(view_data_per_point)):
#     if len(view_data_per_point[i]) > 9:
#         sampled_idx = sorted(np.random.choice(range(len(view_data_per_point[i])), size=9, replace=False))
#         view_data_per_point[i] = view_data_per_point[i][sampled_idx]
#     else:
#         # Each point should have 9 views, so fill points with random invalid view conditions till it contains 9 views
#         n_views = len(view_data_per_point[i])
#         n_empty = 9 - n_views
#         view_data_per_point[i] = torch.cat((view_data_per_point[i], 
#                                             random_invalid_views[cum_sampled_views:cum_sampled_views+n_empty]), dim=0)
#         cum_sampled_views += n_empty
                
# print(time.time() - s)

# # view_data_per_point


In [None]:

# idx mapping from each pixel to point
im_data = test_data.modalities['image']    # need SameSettingImageData object so no [0] slicing

dense_idx_list = [
            torch.arange(im.num_points, device=im_data.device).repeat_interleave(
                im.view_csr_indexing[1:] - im.view_csr_indexing[:-1])
            for im in im_data]
dense_idx_list[0]

### Trying to point-image & point-pixel mappings

In [None]:
mapping = mm_data.modalities['image'][0].mappings
mapping

In [None]:
mapping.__dict__
# pointers: csr_idx, aka which tensors belong to the same view
# values[0]: indicates the image idx
# values[1]: CSRData holding pixel indices, called with values[1].values[0]

In [None]:
# Image idx of each point-pixel match
idx_batch = mapping.values[0].repeat_interleave(
    mapping.values[1].pointers[1:] - mapping.values[1].pointers[:-1])
idx_batch

In [None]:
# idx mapping from each pixel to point
im_data = mm_data.modalities['image']

dense_idx_list = [
            torch.arange(im.num_points, device=im_data.device).repeat_interleave(
                im.view_csr_indexing[1:] - im.view_csr_indexing[:-1])
            for im in im_data]
dense_idx_list[0]

In [None]:
dense_idx_list[0].unique(return_counts=True)[1].min()

In [None]:
mm_data.modalities['image'].get_mapped_features(interpolate=False)

def get_mapped_features(mod_data, interpolate=False):
    """Return the mapped features, with optional interpolation. If
    `interpolate=False`, the mappings will be adjusted to
    `self.img_size`: the current size of the feature map `self.x`.
    """
    # Compute the feature map's sampling ratio between the input
    # `mapping_size` and the current `img_size`
    # TODO: treat scales independently. Careful with min or max
    #  depending on upscale and downscale
    scale = 1 / mod_data.downscale

    # If not interpolating, set the mapping to the proper scale
    mappings = mod_data.mappings if interpolate \
        else mod_data.mappings.rescale_images(scale)

    # Index the features with/without interpolation
    if interpolate and scale != 1:
        print("BAAAKA")
        resolution = torch.Tensor([mod_data.mapping_size]).to(mod_data.device)
        coords = mappings.pixels / (resolution - 1)
        coords = coords[:, [1, 0]]  # pixel mappings are in (W, H) format
        batch = mappings.feature_map_indexing[0]
        x = sparse_interpolation(mod_data.x, coords, batch)
    else:
        x = mod_data.x[mappings.feature_map_indexing]

    return x
    
get_mapped_features(mm_data.modalities['image'][0], interpolate=True)

#### Returns indices for extracting mapped data from a batch of image feature maps

In [None]:
im = mm_data.modalities['image'][0]
print(im.feature_map_indexing)

im = mm_data.modalities['image']
print(im.feature_map_indexing)


In [None]:
#################### I found out how indices are stored. 
#################### Now, check out how the ImageMapping is calculated! in MapImages transform


def feature_map_indexing(SameSettingImageData):  # function within ImageMapping object
    """Return the indices for extracting mapped data from the
    corresponding batch of image feature maps.

    The batch of image feature maps X is expected to have the shape
    `[B, C, H, W]`. The returned indexing object idx is intended to
    be used for recovering the mapped features as: `X[idx]`.
    """
    mappings = SameSettingImageData.mappings
    
    print(mappings.features[:10])
    
    idx_batch = mappings.images.repeat_interleave(
        mappings.values[1].pointers[1:] - mappings.values[1].pointers[:-1])
    idx_height = mappings.pixels[:, 1]
    idx_width = mappings.pixels[:, 0]
    idx = (idx_batch.long(), ..., idx_height.long(), idx_width.long())
    return idx

im = mm_data.modalities['image'][0]   # SameSettingImageData
feature_map_indexing(im)

mm_data.data.__dict__

In [None]:
images = seen_mm_data.modalities['image']
n_seen = sum([
            im.mappings.pointers[1:] - im.mappings.pointers[:-1]
            for im in images])
n_seen.argmax()   # by how many views a 3d point is seen

In [None]:
# Compute the unseen points boolean masks and split them in a
# list of masks for easier popping

# if self.use_coverage:
def PickImages(data, images):
    """
    Returns how many points each view sees
    """
    img_unseen_points = torch.zeros(
        images.num_views, data.num_nodes, dtype=torch.bool)
    i_offset = 0
    for im in images:
        mappings = im.mappings
        i_idx = mappings.images + i_offset
        j_idx = mappings.points.repeat_interleave(
            mappings.pointers[1:] - mappings.pointers[:-1])
        img_unseen_points[i_idx, j_idx] = True
        i_offset += im.num_views
    img_unseen_points = [x.numpy() for x in img_unseen_points]
    
    print([img_unseen_points[x].sum() for x in range(len(img_unseen_points))])
    
PickImages(seen_mm_data.data, seen_mm_data.modalities['image'])

### Compare mapping feature statistics between train/val/test proccessed_2d 

In [9]:
for split in ['train', 'val', 'test']:
    data_dir = f"/project/fsun/dvata/scannet-neucon-smallres-m2f/processed/processed_2d_{split}"

    mean = 0
    std = 0
    for name in os.listdir(data_dir):
        file = osp.join(data_dir, name)

        data = torch.load(file)
        data = data._mappings.values[2]

        # note that estimation is biased
        mean += data.mean(axis=0)
        std += data.std(axis=0)
        
    mean /= len(os.listdir(data_dir))
    std /= len(os.listdir(data_dir))


    print(f"{split} mapping feature statistics:")
    print(mean)
    print(std)    



train mapping feature statistics:
tensor([0.2711, 0.1401, 0.6778, 0.1822, 0.6568, 0.4669, 0.2979, 0.6933])
tensor([0.0811, 0.1083, 0.2375, 0.1710, 0.2294, 0.2735, 0.1378, 0.2283])
val mapping feature statistics:
tensor([0.2745, 0.1412, 0.6740, 0.1847, 0.6517, 0.4674, 0.3046, 0.6818])
tensor([0.0856, 0.1096, 0.2406, 0.1730, 0.2322, 0.2735, 0.1415, 0.2301])
test mapping feature statistics:
tensor([0.2839, 0.1362, 0.6823, 0.1815, 0.6587, 0.4749, 0.3009, 0.6832])
tensor([0.0886, 0.1055, 0.2369, 0.1726, 0.2329, 0.2752, 0.1408, 0.2319])


In [4]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/project/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f-partial-subsampled'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks

# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")

load_m2f_masks:  True
initialize train dataset
initialize val dataset
initialize test dataset
line 720 scannet.py: split == 'test'
Time = 4.8 sec.


In [23]:
# mm_data_list = [dataset.test_dataset[0][i] for i in range(1)]
# mm_data_list = [dataset.test_dataset[0][99]]

index = 0
mm_data_list = [dataset.val_dataset[index] for index in range(3)]

# mm_data_list = [dataset.train_dataset[index] for index in range(3)]

# mm_data_list = [dataset.test_dataset[0][index] for index in range(3)]


In [24]:
# for x in mm_data_list:
#     print(x)
    
batch = MMBatch.from_mm_data_list(mm_data_list)
print(batch)
print(batch.pos.mean(axis=0))
print(batch.modalities['image'][0].pos.mean(0))

MMBatch(
    data = Batch(batch=[80077], coords=[80077, 3], grid_size=[3], id_scan=[3], mapping_index=[80077], mvfusion_input=[58962, 9, 10], origin_id=[80077], pos=[80077, 3], ptr=[4], x=[80077, 3], y=[80077])
    image = ImageBatch(num_settings=1, num_views=300, num_points=80077, device=cpu)
)
tensor([-4.1350e-08,  1.1619e-07,  3.3875e-08])
tensor([-0.0588, -0.3896,  0.7278], dtype=torch.float64)


In [25]:
new_batch = batch.clone()
print(batch == new_batch)

[0, 100, 300]

False


[0, 100, 300]

In [26]:
new_batch = batch.clone()
print(new_batch.pos.mean(0), new_batch.x.mean(0))

batch_size = new_batch.ptr.shape[0] - 1
n_merge = 2

new_ptr = [0]

# Merge two individual point clouds to one MMData inside MMBatch. Uneven batches will leave the last sample untouched.
# NOTE: only N_MERGE == 2 is supported.
assert n_merge == 2
for i in range(0, len(new_batch.ptr)-2, n_merge):
    b1, e1 = new_batch.ptr[i], new_batch.ptr[i+1]
    b2, e2 = new_batch.ptr[i+1], new_batch.ptr[i+2]
    
    coords = new_batch.pos[b2:e2]
    r1 = coords.min(0)[0]
    r2 = coords.max(0)[0]
    offset = ( (r1 - r2) * torch.rand(1, 3) + r2 ) / 2
    
    # Slightly translate one of two point clouds
    new_batch.pos[b2:e2] = new_batch.pos[b2:e2] + offset
    new_batch.x[b2:e2] = new_batch.x[b2:e2] + offset

    new_ptr.append(new_batch.ptr[i+2])
    
# Add last pointer for uneven batches
if batch_size % n_merge == 1:
    new_ptr.append(new_batch.ptr[-1])
new_ptr = torch.LongTensor(new_ptr)

# Update batch identifiers
new_batch.ptr = new_ptr 
new_batch.batch = new_batch.batch // n_merge
new_batch.data.batch = new_batch.batch

# What is origin_id and do we need to adjust it when merging batches?
print(new_batch.pos.mean(0), new_batch.x.mean(0))


tensor([-4.1350e-08,  1.1619e-07,  3.3875e-08]) tensor([-4.1350e-08,  1.1619e-07,  3.3875e-08])
tensor([-0.5779,  0.5223,  0.2923]) tensor([-0.5779,  0.5223,  0.2923])


In [27]:
visualize_mm_data(new_batch, figsize=1000, pointsize=3, voxel=0.1, show_2d=False, back='m2f_pred_mask', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [90]:
def read_axis_align_matrix(filename):
    lines = open(filename).readlines()
    axis_align_matrix = None
    for line in lines:
        if "axisAlignment" in line:
            axis_align_matrix = torch.Tensor([float(x) for x in line.rstrip().strip("axisAlignment = ").split(" ")]).reshape((4, 4))
            break
    return axis_align_matrix

path = '/project/fsun/data/scannet/scans'
scene_id = dataset.val_dataset.MAPPING_IDX_TO_SCAN_TRAIN_NAMES[index]
path = osp.join(path, scene_id, scene_id + '.txt')
matrix = read_axis_align_matrix(path)
matrix

tensor([[ 0.9455,  0.3256,  0.0000, -5.3844],
        [-0.3256,  0.9455,  0.0000, -2.8718],
        [ 0.0000,  0.0000,  1.0000, -0.0644],
        [ 0.0000,  0.0000,  0.0000,  1.0000]])

In [91]:
inv = torch.linalg.inv(matrix.T)
print(inv)

# inv2 = torch.linalg.inv(matrix[:3, :3])
# inv2_t = (-inv2 @ matrix[:3, 3])

# inv_final = torch.zeros((4, 4))
# inv_final[:3, :3] = inv2
# inv_final[:3, 3] = inv2_t
# inv_final

tensor([[ 9.4552e-01,  3.2557e-01, -4.6566e-10,  1.8626e-09],
        [-3.2557e-01,  9.4552e-01,  0.0000e+00,  3.7253e-09],
        [ 0.0000e+00, -0.0000e+00,  1.0000e+00,  0.0000e+00],
        [ 4.1561e+00,  4.4683e+00,  6.4350e-02,  1.0000e+00]])


In [92]:
batch.pos = torch.concat((batch.pos, torch.ones((len(batch.pos), 1))), axis=-1) @ inv
batch.pos = batch.pos[:, :3]

print(batch.pos.shape)

# y = x @ M^T
# y @ M = x @ (M^T @ M)

torch.Size([26715, 3])


In [93]:
batch.data.pos = batch.pos

In [97]:
# # Reverses the inverse
# batch.data.pos = torch.concat((batch.pos, torch.ones((len(batch.pos), 1))), axis=-1) @ matrix.T
# batch.data.pos = batch.data.pos[:, :3]

In [11]:
visualize_mm_data(batch, figsize=1000, pointsize=3, voxel=0.1, show_2d=False, back='m2f_pred_mask', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

## Run inference from pretrained weights and visualize predictions
It is possible to visualize the pointwise predictions and errors from a model. 

To do so, we will use the pretrained weights made available with this project. See `README.md` to get the download links and manually place the `.pt` files locally. You will need to provide `checkpoint_dir` where you saved those files in the next cell.

In [13]:
from torch_points3d.models.model_factory import instantiate_model

# Set your parameters
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-10-31/19-54-07'  # MVFusion without 3d, fully trained on 30k seen points per pcd
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-11-09/12-51-37'
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-11-11/17-18-09' # MVFusion_3d 0.1 LR, 0.03 vox, miou 74.7

# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-11-04/15-51-33' # 3D Backbone, 68.04 miou
# checkpoint_dir = '/home/fsun/DeepViewAgg/model_checkpoints' # DVA best model

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Load the checkpoint and recover the 'best_miou' model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for inference
model = model.eval().cuda()
print('Model loaded')

In [9]:
# model

In [10]:
# model

Now we have loaded the model, we need to run a forward pass on a sample. Howver, if we want to be able to visualize the predictions, we need to pay special attention to which type of 3D and 2D transforms we apply on the data if we do not want to break the mappings. To do so, we will manually apply some sensitive transforms to be able to both infer on the data and visualize it.

In [33]:
i_room = 0

# Pick a room in the Train set
# mm_data = sample_real_data(dataset.train_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Val set
mm_data = sample_real_data(dataset.val_dataset, idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Pick a room in the Test set
# mm_data = sample_real_data(dataset.test_dataset[0], idx=i_room, exclude_3d=exclude_3d_viz, exclude_2d=exclude_2d_viz)

# Extract point cloud and images from MMData object
data = mm_data.data.clone()
images = mm_data.modalities['image'].clone()

data

In [9]:
# Run cell for validation sample with original validation transformations
mm_data = dataset.val_dataset[0]

# data = mm_data.data.clone()
# images = mm_data.modalities['image'].clone()
# data

In [10]:
mm_data

MMData(
    data = Data(coords=[97257, 3], grid_size=[1], id_scan=[1], mapping_index=[97257], mvfusion_input=[70209, 9, 10], origin_id=[97257], pos=[97257, 3], x=[97257, 3], y=[97257])
    image = ImageData(num_settings=1, num_views=100, num_points=97257, device=cpu)
)

In [27]:
### Select seen points
# csr_idx = mm_data.modalities['image'][0].view_csr_indexing
# n_seen = csr_idx[1:] - csr_idx[:-1]
# seen_mask = ( n_seen > 0 )
# print(seen_mask.shape)
# print(seen_mask.sum())

torch.Size([237360])
tensor(157279)


In [28]:
# # select only first N points
# N = 33260
# mm_data = mm_data[:N]
# csr_idx = mm_data.modalities['image'][0].view_csr_indexing
# n_seen = csr_idx[1:] - csr_idx[:-1]
# seen_mask = ( n_seen > 0 )

# print(seen_mask.shape, seen_mask.sum())

torch.Size([33260]) tensor(17904)


In [29]:
# # Run block to grab only seen points

# # csr_idx = images[0].view_csr_indexing
# # dense_idx_list = torch.arange(images.num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
# # # take subset of only seen points without re-indexing the same point
# # mm_data = mm_data[dense_idx_list.unique()]



# n_seen_points = seen_mask.sum()

# mm_data.mvfusion_input = mm_data.mvfusion_input[:n_seen_points]
# mm_data.data.mvfusion_input = mm_data.data.mvfusion_input[:n_seen_points]

# # Extract point cloud and images from MMData object
# data = mm_data.data.clone()
# images = mm_data.modalities['image'].clone()

# data

Data(id_scan=[1], linearity=[33260], mapping_index=[33260], mvfusion_input=[17904, 9, 10], norm=[33260, 3], origin_id=[33260], planarity=[33260], pos=[33260, 3], pos_x=[33260], pos_y=[33260], pos_z=[33260], rgb=[33260, 3], scattering=[33260], y=[33260])

In [18]:
# For voxel-based 3D backbones such as SparseConv3d and MinkowskiNet, points need to be 
# preprocessed with Center and GridSampling3D. Unfortunately, Center breaks relative 
# positions between points and images. Besides, the combination of Center and GridSampling3D
# may lead to some points being merged into the same voxels, so we must apply it to both the
# inference and visualization data to make sure we have the same voxels. The workaround here 
# is to manually run these while keeping track of the centering offset
center = data.pos.mean(dim=-2, keepdim=True)
data = AddFeatsByKeys(list_add_to_x=[True, True, True], feat_names=['pos_x', 'pos_y', 'pos_z'], delete_feats=[True, True, True])(data)          # add z-height to the features
data = Center()(data)                                                                                 # mean-center the data
data = GridSampling3D(cfg.data.resolution_3d, quantize_coords=True, mode='last')(data)                # quantization for volumetric models

# This last voxelization step with GridSampling3D might have removed some points, so we need
# to update the mappings usign SelectMappingFromPointId. To control the size of the batch, we
# use PickImagesFromMemoryCredit. Besides, 2D models expect normalized float images, which is
# why we call ToFloatImage and Normalize
data, images = SelectMappingFromPointId()(data, images)                                               # update mappings after GridSampling3D
data, images = PickImagesFromMemoryCredit(
    img_size=cfg.data.resolution_2d, 
    k_coverage=cfg.data.multimodal.settings.k_coverage, 
    n_img=cfg.data.multimodal.settings.test_pixel_credit)(data, images)                                      # select images to respect memory constraints
data, images_infer = ToFloatImage()(data, images.clone())                                             # convert uint8 images to float
data, images_infer = Normalize()(data, images_infer)                                                  # RGB normalization

# Create a MMData for inference
mm_data_infer = MMData(data, image=images_infer)
print(mm_data_infer)

# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data_infer])

with torch.no_grad():
    print("input batch: ", batch)
    model.set_input(batch, model.device)
    model(batch)

# Create a MMData for visualization
data.pos += center
mm_data = MMData(data, image=images)

# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

MMData(
    data = Data(coords=[96829, 3], grid_size=[1], id_scan=[1], linearity=[96829], mapping_index=[96829], mvfusion_input=[155050, 9, 10], norm=[96829, 3], origin_id=[96829], planarity=[96829], pos=[96829, 3], rgb=[96829, 3], scattering=[96829], x=[96829, 3], y=[96829])
    image = ImageData(num_settings=1, num_views=100, num_points=96829, device=cpu)
)
input batch:  MMBatch(
    data = Batch(batch=[96829], coords=[96829, 3], grid_size=[1], id_scan=[1], linearity=[96829], mapping_index=[96829], mvfusion_input=[155050, 9, 10], norm=[96829, 3], origin_id=[96829], planarity=[96829], pos=[96829, 3], ptr=[2], rgb=[96829, 3], scattering=[96829], x=[96829, 3], y=[96829])
    image = ImageBatch(num_settings=1, num_views=100, num_points=96829, device=cpu)
)


In [24]:
# Run inference with augmentations
# mm_data = dataset.train_dataset[0]
# mm_data = dataset.val_dataset[0]
mm_data = dataset.test_dataset[0][0]



# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data])

with torch.no_grad():
    print("input batch: ", batch)
    model.set_input(batch, model.device)
    model(batch)
    
# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

input batch:  MMBatch(
    data = Batch(batch=[55011], coords=[55011, 3], grid_size=[1], id_scan=[1], mapping_index=[55011], mvfusion_input=[39801, 9, 10], origin_id=[55011], pos=[55011, 3], ptr=[2], x=[55011, 3])
    image = ImageBatch(num_settings=1, num_views=90, num_points=55011, device=cpu)
)


In [18]:
def get_mode_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
        
    n_views = 9
    
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds

mode_preds = get_mode_pred(mm_data)

In [None]:
m2f_mm_data = mm_data.clone()
m2f_mm_data.data.x = None
m2f_mm_data.data.pred = mode_preds
# m2f_mm_data.data.pred = m2f_mm_data.data.pred[m2f_mm_data.data.y != -1]
m2f_mm_data = m2f_mm_data[m2f_mm_data.data.y != -1]

visualize_mm_data(m2f_mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_mask_pred', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [10]:
mm_data.data.x = None
mm_data.data.pred = mm_data.data.pred[mm_data.data.y != -1]
mm_data = mm_data[mm_data.data.y != -1]


print(mm_data.data.pred.unique())
mm_data.data.y.unique()

tensor([ 0,  1,  2,  4,  5,  6,  7,  8, 11, 14, 17, 19])


tensor([ 0,  1,  2,  4,  6,  7,  8, 11, 14, 17, 19])

In [12]:
visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_pred_mask', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
mm_data.pred

In [None]:
### Calculate number of correct predictions (accuracy)

print(sum(mm_data.y == mm_data.pred) / len(mm_data.y))

print(sum(m2f_mm_data.y == m2f_mm_data.pred) / len(m2f_mm_data.y))
