In [1]:
# Uncomment to use autoreload
%load_ext autoreload
%autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES, CLASS_LABELS
from torch_points3d.metrics.segmentation_tracker import SegmentationTracker
from torch_points3d.datasets.segmentation import IGNORE_LABEL
from torch_points3d.metrics.scannet_segmentation_tracker import ScannetSegmentationTracker

CLASS_COLORS[0] = (174.0, 199.0, 232.0)
CLASS_COLORS[-1] = (0, 0, 0)
import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

import matplotlib.pyplot as plt 
%matplotlib inline

MMData debug() function changed, please uncomment the 3rd assert line when doing inference without M2F features!


In [2]:
def get_seen_points(mm_data):
    ### Select seen points
    csr_idx = mm_data.modalities['image'][0].view_csr_indexing
    dense_idx_list = torch.arange(mm_data.modalities['image'][0].num_points).repeat_interleave(csr_idx[1:] - csr_idx[:-1])
    # take subset of only seen points without re-indexing the same point
    mm_data = mm_data[dense_idx_list.unique()]
    return mm_data

def get_mode_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds

def get_random_view_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    selected_view_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        selected_idx = torch.randint(low=0, high=m2feats_of_seen_point.shape[0], size=(1,))
        selected_pred = m2feats_of_seen_point[selected_idx].squeeze(0)
        selected_view_preds.append(selected_pred)
    selected_view_preds = torch.stack(selected_view_preds, dim=0)
        
    return selected_view_preds


In [3]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/scratch-shared/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D_small_6views'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks
cfg.data.m2f_preds_dirname = 'ViT_masks'
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")

6
Load predicted 2D semantic segmentation labels from directory  ViT_masks
initialize train dataset
initialize val dataset
Time = 7.6 sec.


In [4]:
from torch_points3d.models.model_factory import instantiate_model

# ViT_masks 3rd run
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/ViT_masks_3rd_run' # 3rd run
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/MVFusion_3D_6_views_m2f_masks'

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Load the checkpoint and recover the 'best_miou' model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for training
model = model.cuda()
print('Model loaded')

Creating model: MVFusion_3D_small_6views
task:  segmentation.multimodal
tested_model_name:  MVFusion_3D_small_6views
class_name:  MVFusionAPIModel
model_module:  torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d
name, cls of chosen model_cls:  MVFusionAPIModel <class 'torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d.MVFusionAPIModel'>
x feature dim:  {'FEAT': 3}
nc_in:  67
nc_in:  64
nc_in:  32
nc_in:  64
nc_in:  128
nc_in:  256
nc_in:  128
nc_in:  128
nc_in:  96
nc_in:  96
Model loaded


In [5]:
import pandas as pd


# Create validation loader
dataset.create_dataloaders(
    model,
    1,
    False,
    17,
    False,
    train_only=False,
    val_only=True,
    test_batch_size=1
)
tracker = ScannetSegmentationTracker(dataset=dataset, stage='val', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)


macc = TP / (TP + TN) $\rightarrow$ tells performance without accounting for false predictions

In [6]:
mapping_idx_to_names = dataset.val_dataset.MAPPING_IDX_TO_SCAN_VAL_NAMES

rows_list = []
for batch in dataset._val_loader:
    
    # Inference
    model.set_input(batch, model.device)
    model(batch)
    batch.data.pred = model.output.detach().cpu().argmax(1)

    # Evaluate on seen points and valid gt only
    batch = get_seen_points(batch)
    batch = batch[batch.y != -1]

    tracker.reset(stage='val')
    tracker.track(pred_labels=batch.data.pred, gt_labels=batch.y, model=None)
    metrics = tracker.get_metrics()

    print(batch)
    print(metrics)
    # Save for dataframe
    dict1 = {}
    dict1['scan_id'] = mapping_idx_to_names[batch.id_scan.item()]
    dict1['num_points'] = batch.num_points
    dict1['miou'] = metrics['val_miou']
    dict1['macc'] = metrics['val_macc']
    dict1['acc'] = metrics['val_acc']
    
    rows_list.append(dict1)

    break
    

df = pd.DataFrame(rows_list)
df

MMData(
    data = Batch(batch=[65475], coords=[65475, 3], grid_size=[1], id_scan=[1], mapping_index=[65475], mvfusion_input=[65475, 6, 10], origin_id=[65475], pos=[65475, 3], pred=[65475], ptr=[2], rgb=[65475, 3], x=[65475, 3], y=[65475])
    image = ImageBatch(num_settings=1, num_views=100, num_points=65475, device=cpu)
)
{'val_acc': 96.07025582283313, 'val_macc': 94.34477011094998, 'val_miou': 74.08247078402268}


Unnamed: 0,scan_id,num_points,miou,macc,acc
0,scene0011_00,65475,74.082471,94.34477,96.070256


In [8]:
input_labels = batch.modalities['image'][0].get_mapped_m2f_features()


csr_idx = batch.modalities['image'][0].view_csr_indexing


# Set invalid views to NaN for analysis purposes
pixel_validity = batch.data.mvfusion_input[:, :, 0] == 1

input_labels = batch.data.mvfusion_input[:, :, -1]
input_labels[~pixel_validity] = float('nan')

num_views = pixel_validity.sum(dim=-1)

input_labels_correct = torch.eq(input_labels, batch.y.unsqueeze(1))

In [9]:
# x = np.unique(input_labels.numpy(), axis=-1)
# x = x[~np.isnan(x)]
# x.shape, num_views.sum()

# x[60:70], input_labels[10:12]

In [27]:
num_correct = input_labels_correct.sum(dim=1)
num_incorrect = num_views - input_labels_correct.sum(dim=1)


d = {}
d['num_views'] = num_views
d['num_correct'] = num_correct
d['num_incorrect'] = num_incorrect
# d['num_distinct'] = torch.unique()
d['refined'] = batch.data.pred
d['gt_label'] = batch.y
d['correct_output'] = d['refined'] == d['gt_label']

for i in range(6):
    d[str(i+1)] = input_labels[:, i]


df = pd.DataFrame(d)
df.sort_values(['num_incorrect', 'correct_output', 'refined'], ascending=False)[:20]

Unnamed: 0,num_views,num_correct,num_incorrect,refined,gt_label,correct_output,1,2,3,4,5,6
50700,6,0,6,19,19,True,0.0,0.0,0.0,0.0,0.0,0.0
46073,6,0,6,17,17,True,11.0,11.0,11.0,11.0,11.0,11.0
41915,6,0,6,11,11,True,2.0,2.0,2.0,2.0,2.0,2.0
41926,6,0,6,11,11,True,2.0,19.0,2.0,4.0,19.0,19.0
42696,6,0,6,8,8,True,17.0,11.0,17.0,11.0,11.0,11.0
42749,6,0,6,8,8,True,17.0,11.0,11.0,11.0,11.0,11.0
38303,6,0,6,6,6,True,0.0,0.0,0.0,0.0,0.0,0.0
38304,6,0,6,6,6,True,0.0,0.0,0.0,0.0,4.0,0.0
50531,6,0,6,6,6,True,5.0,4.0,5.0,5.0,19.0,0.0
28901,6,0,6,4,4,True,1.0,6.0,1.0,1.0,1.0,6.0


#### Descriptors for problem type
divide all points across class types
 or bucket number of views?

- number of views
- number of distinct input label clases
- number of correct label estimates
- number of incorrect label estimates
- refined outut label
- ground-truth label


- confusion matrix that accumulates all view preds, and one that accumulates refined pred x n views?

In [72]:
dict1 = {}
dict1['scan_id'] = mapping_idx_to_names[batch.id_scan.item()]
dict1['num_points'] = batch.num_points
dict1['miou'] = metrics['val_miou']
dict1['macc'] = metrics['val_macc']
dict1['acc'] = metrics['val_acc']
dict1[]

df = pandas.DataFrame(dict1)
df

KeyError: ''

In [24]:
from IPython.display import Image
from IPython.core.display import HTML 
Image(url="https://kaldir.vc.in.tum.de/scannet_benchmark/img/legend.jpg")