###### ScanNet

This notebook lets you instantiate the **[ScanNet](http://www.scan-net.org/)** dataset from scratch and visualize **3D+2D room samples**.

Note that you will need **at least 1.2T** available for the SanNet raw dataset and **at least 64G** for the processed files at **5cm voxel resolution** and **320x240 image resolution**. 

The ScanNet dataset is composed of **rooms** of video acquisitions of indoor scenes. Thes video streams were used to produce a point cloud and images.

Each room is small enough to be loaded at once into a **64G RAM** memory. The `ScannetDatasetMM` class from `torch_points3d.datasets.segmentation.multimodal.scannet` deals with loading the room and part of the images of the associated video stream.

In [1]:
# Select you GPU
I_GPU = 0


In [2]:
# Uncomment to use autoreload
%load_ext autoreload
%autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES, CLASS_LABELS
from torch_points3d.metrics.segmentation_tracker import SegmentationTracker
from torch_points3d.datasets.segmentation import IGNORE_LABEL

from pykeops.torch import LazyTensor

import matplotlib.pyplot as plt 

%matplotlib inline

MMData debug() function changed, please uncomment the 3rd assert line when doing inference without M2F features!


In [3]:
CLASS_COLORS[0] = (174.0, 199.0, 232.0)
CLASS_COLORS[-1] = (0, 0, 0)

If `visualize_mm_data` does not throw any error but the visualization does not appear, you may need to change your plotly renderer below.

In [4]:
import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

In [5]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/scratch-shared/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D_small_6views'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks
cfg.data.m2f_preds_dirname = 'ViT_masks'
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")

6
Load predicted 2D semantic segmentation labels from directory  ViT_masks
initialize train dataset
initialize val dataset
Time = 7.3 sec.


In [8]:
from torch_points3d.models.model_factory import instantiate_model

# ViT_masks 3rd run
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/ViT_masks_3rd_run' # 3rd run
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/MVFusion_3D_6_views_m2f_masks'

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Load the checkpoint and recover the 'best_miou' model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for training
model = model.cuda()
print('Model loaded')

Creating model: MVFusion_3D_small_6views
task:  segmentation.multimodal
tested_model_name:  MVFusion_3D_small_6views
class_name:  MVFusionAPIModel
model_module:  torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d
name, cls of chosen model_cls:  MVFusionAPIModel <class 'torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d.MVFusionAPIModel'>
x feature dim:  {'FEAT': 3}
nc_in:  67
nc_in:  64
nc_in:  32
nc_in:  64
nc_in:  128
nc_in:  256
nc_in:  128
nc_in:  128
nc_in:  96
nc_in:  96
Model loaded


In [12]:
mm_data = dataset.val_dataset[0]

In [16]:
def get_seen_points(mm_data):
    ### Select seen points
    csr_idx = mm_data.modalities['image'][0].view_csr_indexing
    dense_idx_list = torch.arange(mm_data.modalities['image'][0].num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
    # take subset of only seen points without re-indexing the same point
    mm_data = mm_data[dense_idx_list.unique()]
    return mm_data

def get_mode_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds

def get_random_view_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    selected_view_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
#         print(m2feats_of_seen_point)
        selected_idx = torch.randint(low=0, high=m2feats_of_seen_point.shape[0], size=(1,))
        selected_pred = m2feats_of_seen_point[selected_idx].squeeze(0)
        
        
#         print(m2feats_of_seen_point.shape[0])
#         print(selected_idx)
#         print(selected_pred)
        selected_view_preds.append(selected_pred)
    selected_view_preds = torch.stack(selected_view_preds, dim=0)
        
    return selected_view_preds

get_random_view_pred(mm_data)

tensor([4, 1, 1,  ..., 0, 2, 2])

In [38]:
from torch_points3d.metrics.scannet_segmentation_tracker import ScannetSegmentationTracker

In [39]:
baseline_tracker = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)
mvfusion_tracker = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)


In [40]:
a = {0: '89.67', 1: '98.16', 2: '72.49', 3: '85.95', 4: '90.83', 5: '81.63', 6: '82.44', 7: '70.54', 8: '66.33', 9: '79.68', 10: '44.70', 11: '69.93', 12: '73.52', 13: '81.89', 14: '73.40', 15: '78.48', 16: '94.22', 17: '71.27', 18: '88.69', 19: '67.45'}
" & ".join(list(a.values()))

'89.67 & 98.16 & 72.49 & 85.95 & 90.83 & 81.63 & 82.44 & 70.54 & 66.33 & 79.68 & 44.70 & 69.93 & 73.52 & 81.89 & 73.40 & 78.48 & 94.22 & 71.27 & 88.69 & 67.45'

In [49]:
baseline_tracker.reset(stage='train')
mvfusion_tracker.reset(stage='train')


for mm_data in dataset.val_dataset:
    print(mm_data.id_scan)

    # Create a MMBatch and run inference
    batch = MMBatch.from_mm_data_list([mm_data])

    model.set_input(batch, model.device)
    model(batch)

    # Recover the predicted labels for visualization
    mm_data.data.pred = model.output.detach().cpu().argmax(1)
        
    mm_data = get_seen_points(mm_data)
    baseline_preds = get_mode_pred(mm_data)
    
    mvfusion_tracker.track(pred_labels=mm_data.data.pred, gt_labels=mm_data.data.y, model=None)       
    baseline_tracker.track(pred_labels=baseline_preds, gt_labels=mm_data.data.y, model=None)       

    break
    
print("3D validation IoU of seen points")
print("baseline ", baseline_tracker.get_metrics())
print(" & ".join(list(baseline_tracker._miou_per_class.values())))

print("mvfusion ", mvfusion_tracker.get_metrics())
print(" & ".join(list(mvfusion_tracker._miou_per_class.values())))


tensor([0])
tensor([1, 4, 1,  ..., 2, 0, 0])
tensor([1, 4, 1,  ..., 2, 0, 0])
3D validation IoU of seen points
baseline  {'train_acc': 88.69502831585535, 'train_macc': 85.95748835062953, 'train_miou': 61.41866356445035}
81.45 & 90.51 & 74.15 & 0.00 & 71.23 & 0.00 & 69.20 & 79.96 & 63.60 & 0.00 & 0.00 & 68.48 & 0.00 & 0.00 & 75.98 & 0.00 & 0.00 & 55.03 & 0.00 & 68.84
mvfusion  {'train_acc': 95.88466059135106, 'train_macc': 93.89504929024365, 'train_miou': 79.23063734957574}
87.83 & 99.13 & 83.22 & 0.00 & 93.52 & 0.00 & 91.54 & 96.12 & 71.81 & 0.00 & 0.00 & 81.49 & 0.00 & 0.00 & 83.65 & 0.00 & 0.00 & 68.84 & 0.00 & 93.62


In [None]:
print(mode_preds.shape)
print(mm_data.data.pred.shape)

In [None]:
m2f_mm_data = mm_data.clone()
m2f_mm_data.data.x = None
m2f_mm_data.data.pred = mode_preds
# m2f_mm_data.data.pred = m2f_mm_data.data.pred[m2f_mm_data.data.y != -1]
m2f_mm_data = m2f_mm_data[m2f_mm_data.data.y != -1]

visualize_mm_data(m2f_mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_mask_pred', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

In [None]:
mm_data.data.x = None
mm_data.data.pred = mm_data.data.pred[mm_data.data.y != -1]
mm_data = mm_data[mm_data.data.y != -1]


print(mm_data.data.pred.unique())
mm_data.data.y.unique()

In [None]:
visualize_mm_data(mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_pred_mask', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)

# View Selection comparison

In [50]:
from torch_points3d.metrics.scannet_segmentation_tracker import ScannetSegmentationTracker

def get_seen_points(mm_data):
    ### Select seen points
    csr_idx = mm_data.modalities['image'][0].view_csr_indexing
    dense_idx_list = torch.arange(mm_data.modalities['image'][0].num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
    # take subset of only seen points without re-indexing the same point
    mm_data = mm_data[dense_idx_list.unique()]
    return mm_data

def get_mode_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds

def get_random_view_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    selected_view_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        selected_idx = torch.randint(low=0, high=m2feats_of_seen_point.shape[0], size=(1,))
        selected_pred = m2feats_of_seen_point[selected_idx].squeeze(0)
        selected_view_preds.append(selected_pred)
    selected_view_preds = torch.stack(selected_view_preds, dim=0)
        
    return selected_view_preds

get_random_view_pred(mm_data)

In [51]:
random_selection_tracker = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)
average_fusion_tracker = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)
mvfusion_tracker = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)


In [52]:
random_selection_tracker.reset(stage='train')
average_fusion_tracker.reset(stage='train')
mvfusion_tracker.reset(stage='train')


for mm_data in dataset.val_dataset:
    print(mm_data.id_scan)

    # Create a MMBatch and run inference
    batch = MMBatch.from_mm_data_list([mm_data])

    model.set_input(batch, model.device)
    model(batch)

    # Recover the predicted labels for visualization
    mm_data.data.pred = model.output.detach().cpu().argmax(1)
        
    mm_data = get_seen_points(mm_data)
    random_selection_pred = get_random_view_pred(mm_data)
    average_fusion_pred = get_mode_pred(mm_data)
    
    random_selection_tracker.track(pred_labels=random_selection_pred, gt_labels=mm_data.data.y, model=None)  
    average_fusion_tracker.track(pred_labels=average_fusion_pred, gt_labels=mm_data.data.y, model=None)       
    mvfusion_tracker.track(pred_labels=mm_data.data.pred, gt_labels=mm_data.data.y, model=None)       


    break
    
print("3D validation IoU of seen points")
print("random_selection_tracker ", random_selection_tracker.get_metrics())
print(" & ".join(list(baseline_tracker._miou_per_class.values())))

print("average_fusion_tracker ", average_fusion_tracker.get_metrics())
print(" & ".join(list(baseline_tracker._miou_per_class.values())))

print("mvfusion_tracker ", mvfusion_tracker.get_metrics())
print(" & ".join(list(mvfusion_tracker._miou_per_class.values())))


tensor([0])
3D validation IoU of seen points
random_selection_tracker  {'train_acc': 87.91855478847967, 'train_macc': 85.03762311878698, 'train_miou': 55.869509912625446}
81.45 & 90.51 & 74.15 & 0.00 & 71.23 & 0.00 & 69.20 & 79.96 & 63.60 & 0.00 & 0.00 & 68.48 & 0.00 & 0.00 & 75.98 & 0.00 & 0.00 & 55.03 & 0.00 & 68.84
average_fusion_tracker  {'train_acc': 91.44490739332745, 'train_macc': 88.73124250433506, 'train_miou': 62.13647460620422}
81.45 & 90.51 & 74.15 & 0.00 & 71.23 & 0.00 & 69.20 & 79.96 & 63.60 & 0.00 & 0.00 & 68.48 & 0.00 & 0.00 & 75.98 & 0.00 & 0.00 & 55.03 & 0.00 & 68.84
mvfusion_tracker  {'train_acc': 95.82585687783218, 'train_macc': 94.01124190410147, 'train_miou': 73.3844437401989}
88.67 & 99.25 & 82.61 & 0.00 & 93.29 & 0.00 & 90.32 & 94.76 & 71.95 & 0.00 & 0.00 & 82.14 & 0.00 & 0.00 & 82.26 & 0.00 & 0.00 & 74.63 & 0.00 & 94.13


In [42]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/scratch-shared/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f'   
models_config = 'segmentation/multimodal/Feng/view_selection_experiment'    # model family
model_name = 'Average_Fusion'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks
cfg.data.m2f_preds_dirname = 'ViT_masks'
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")

ConfigKeyError: Missing key Average_Fusion
    full_key: models.Average_Fusion
    object_type=dict

In [43]:
from torch_points3d.models.model_factory import instantiate_model

# ViT_masks 3rd run
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/ViT_masks_3rd_run' # 3rd run
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/MVFusion_3D_6_views_m2f_masks'

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
print(model)

# # Load the checkpoint and recover the 'best_miou' model weights
# checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
# model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# # Prepare the model for training
model = model.cuda()
print('Model loaded')

Creating model: Average_Fusion
task:  segmentation.multimodal
tested_model_name:  Average_Fusion


Exception: The model_name Average_Fusion isn t within ['MVFusion_small_6views', 'DeepSetAttention', 'MVFusion']

In [8]:
mm_data = dataset.val_dataset[1]
mm_data

MMData(
    data = Data(coords=[96882, 3], grid_size=[1], id_scan=[1], mapping_index=[96882], mvfusion_input=[70473, 6, 10], origin_id=[96882], pos=[96882, 3], rgb=[96882, 3], x=[96882, 3], y=[96882])
    image = ImageData(num_settings=1, num_views=100, num_points=96882, device=cpu)
)

In [13]:
# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data])

model.set_input(batch, model.device)
model(batch)

# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

tensor([1, 1, 1,  ..., 0, 2, 0], device='cuda:0') torch.Size([70473])
tensor([[0, 1, 0,  ..., 0, 0, 0],
        [0, 1, 0,  ..., 0, 0, 0],
        [0, 1, 0,  ..., 0, 0, 0],
        ...,
        [1, 0, 0,  ..., 0, 0, 0],
        [0, 0, 1,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0]], device='cuda:0') torch.Size([70473, 20])
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0') torch.Size([96882, 20])


In [11]:
mm_data.data.pred.unique(return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 12, 13, 14, 15, 16, 18, 19]),
 tensor([ 4502,  4263, 17765, 16889,  1407,  2851, 29363,    11,  2472,   682,
           462,  2723,  1017,  9667,   127,  2332,     8,   341]))

In [20]:
from torch_points3d.metrics.scannet_segmentation_tracker import ScannetSegmentationTracker
tracker = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)

seen_data = get_seen_points(mm_data)
tracker.track(pred_labels=mm_data.data.pred, gt_labels=seen_data.data.y, model=None)       

tracker.get_metrics()

{'train_acc': 0.0015302921327681454,
 'train_macc': 0.0032055391716886783,
 'train_miou': 0.002188944426204394}

In [51]:
viewing_conditions = mm_data.modalities['image'][0].mappings.values[2]

input_preds = mm_data.modalities['image'][0].get_mapped_m2f_features()
input_preds_one_hot = torch.nn.functional.one_hot(input_preds.long().squeeze(), 20)
attention_input = torch.concat((viewing_conditions, input_preds_one_hot), dim=1)


csr_idx = mm_data.modalities['image'][0].view_csr_indexing
csr_idx

tensor([     0,      0,      0,  ..., 513984, 513985, 513986])

In [6]:

def get_mode_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds

def get_random_view_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    selected_view_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        selected_idx = torch.randint(low=0, high=m2feats_of_seen_point.shape[0], size=(1,))
        selected_pred = m2feats_of_seen_point[selected_idx].squeeze(0)
        selected_view_preds.append(selected_pred)
    selected_view_preds = torch.stack(selected_view_preds, dim=0)
        
    return selected_view_preds

In [5]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/scratch-shared/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f.yaml'   
models_config = 'segmentation/multimodal/Feng/view_selection_experiment.yaml'    # model family
model_name = 'Deepset_3D'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks
cfg.data.m2f_preds_dirname = 'ViT_masks'
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

# cfg.models.MVFusion_small_6views.backbone.transformer.max_n_points = 10000
# print(cfg.models.MVFusion_small_6views.backbone.transformer.max_n_points)


# cfg.data.store_random_pred = True
# cfg.data.store_mode_pred = True
# print(cfg.data.store_mode_pred)


# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")

6
Load predicted 2D semantic segmentation labels from directory  ViT_masks
initialize train dataset
initialize val dataset
Time = 7.8 sec.


In [6]:
from torch_points3d.models.model_factory import instantiate_model

# ViT_masks 3rd run
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/ViT_masks_3rd_run' # 3rd run
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/MVFusion_3D_6_views_m2f_masks'

# # MVFusion 100 epochs (no superconvergence) old version (31-10-2022)
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/m2f_masks_MVFusion_medium_9views'

# MVFusion_orig 100 epochs (18-1-2023 using old code, no superconvergence)
# checkpoint_dir ='/home/fsun/DeepViewAgg/outputs/MVFusion_orig'


# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
print(model)

# # Load the checkpoint and recover the 'best_miou' model weights
# checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
# model.load_state_dict_with_same_shape(checkpoint['models']['best_miou'], strict=False)

# # Prepare the model for training
model = model.cuda()
print('Model loaded')

Creating model: Deepset_3D
task:  segmentation.multimodal
tested_model_name:  Deepset_3D
class_name:  MVAttentionAPIModel
model_module:  torch_points3d.models.segmentation.multimodal.Feng.mvattention_attention_weighted_m2f_pred
name, cls of chosen model_cls:  MVAttentionAPIModel <class 'torch_points3d.models.segmentation.multimodal.Feng.mvattention_attention_weighted_m2f_pred.MVAttentionAPIModel'>
x feature dim:  {'FEAT': 3}
nc_in:  35
nc_in:  64
nc_in:  32
nc_in:  64
nc_in:  128
nc_in:  256
nc_in:  128
nc_in:  128
nc_in:  96
nc_in:  96
MVAttentionAPIModel(
  (backbone): MVAttentionSparseConv3dUnet(
    (inner_modules): ModuleList(
      (0): Identity()
    )
    (down_modules): ModuleList(
      (0): MultimodalBlockDown(
        (block_1): Identity()
        (block_2): Identity()
        (image): MVAttentionUnimodalBranch(
          drop_3d=None
          drop_mod=None
          keep_last_view=False
          checkpointing=c
          (attn_fusion): DeepSetFeat_ViewFusion(
           

Model loaded


In [7]:
mm_data = dataset.val_dataset[0]

In [8]:
mm_data

MMData(
    data = Data(coords=[97387, 3], grid_size=[1], id_scan=[1], mapping_index=[97387], mvfusion_input=[73331, 6, 10], origin_id=[97387], pos=[97387, 3], rgb=[97387, 3], x=[97387, 3], y=[97387])
    image = ImageData(num_settings=1, num_views=100, num_points=97387, device=cpu)
)

In [24]:
# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data])

model.set_input(batch, model.device)
model(batch)

# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

mm_data

MMData(
    data = Data(coords=[97387, 3], grid_size=[1], id_scan=[1], mapping_index=[97387], mvfusion_input=[73331, 6, 10], origin_id=[97387], pos=[97387, 3], pred=[97387], rgb=[97387, 3], x=[97387, 3], y=[97387])
    image = ImageData(num_settings=1, num_views=100, num_points=97387, device=cpu)
)

In [19]:
def get_seen_points(mm_data):
    ### Select seen points
    csr_idx = mm_data.modalities['image'][0].view_csr_indexing
    dense_idx_list = torch.arange(mm_data.modalities['image'][0].num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
    # take subset of only seen points without re-indexing the same point
    mm_data = mm_data[dense_idx_list.unique()]
    return mm_data

seen_mm_data = get_seen_points(mm_data)



In [20]:
visualize_mm_data(seen_mm_data, figsize=1000, pointsize=3, voxel=0.05, show_2d=False, back='m2f_pred_mask', front='y', class_names=CLASS_NAMES, class_colors=CLASS_COLORS, alpha=0.3)