In [1]:
%load_ext autoreload
%autoreload 2


import pandas as pd
import numpy as np
import torch
import os
import os.path as osp
import sys
from time import time
from omegaconf import OmegaConf
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES, CLASS_LABELS
from torch_points3d.metrics.scannet_segmentation_tracker import ScannetSegmentationTracker
from torch_points3d.datasets.segmentation import IGNORE_LABEL
from torch_points3d.models.model_factory import instantiate_model
from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq


from pykeops.torch import LazyTensor

import matplotlib.pyplot as plt 

%matplotlib inline

CLASS_COLORS[0] = (174.0, 199.0, 232.0)
CLASS_COLORS[-1] = (0, 0, 0)

import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

MMData debug() function changed, please uncomment the 3rd assert line when doing inference without M2F features!


In [2]:
DATA_ROOT = '/scratch-shared/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D_small_6views'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks
cfg.data.m2f_preds_dirname = 'ViT_masks'
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)|
print(f"Time = {time() - start:0.1f} sec.")



# ViT_masks 3rd run
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/ViT_masks_3rd_run' # 3rd run


# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Load the checkpoint and recover the 'best_miou' model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['best_miou'], strict=False)

# Prepare the model for training
model = model.cuda()
print('Model loaded')

# Create validation loader
dataset.create_dataloaders(
    model,
    1,
    False,
    17,
    False,
    train_only=False,
    val_only=True,
    test_batch_size=1
)

tracker = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)

6
Load predicted 2D semantic segmentation labels from directory  ViT_masks
initialize train dataset
initialize val dataset
Time = 8.3 sec.
Creating model: MVFusion_3D_small_6views
task:  segmentation.multimodal
tested_model_name:  MVFusion_3D_small_6views
class_name:  MVFusionAPIModel
model_module:  torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d
name, cls of chosen model_cls:  MVFusionAPIModel <class 'torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d.MVFusionAPIModel'>
x feature dim:  {'FEAT': 3}
nc_in:  67
nc_in:  64
nc_in:  32
nc_in:  64
nc_in:  128
Return attention maps!
Enabling checkpointing for:  c
nc_in:  256
nc_in:  128
nc_in:  128
nc_in:  96
nc_in:  96
Model loaded


In [3]:
data_list = []
input_labels_list = []
gt_labels_list = []

with Ctq(dataset._val_loader) as loader:
    for mm_data in loader:

        data = mm_data.modalities['image'][0]._mappings.values[-1]
        input_labels = mm_data.modalities['image'][0].get_mapped_m2f_features()
        gt_labels = mm_data.modalities['image'][0].get_mapped_gt_labels()

        data_list.append(data)
        input_labels_list.append(input_labels)
        gt_labels_list.append(gt_labels)

  0%|          | 0/312 [00:01<?, ?it/s]

In [4]:
data = torch.vstack(data_list)
input_labels = torch.vstack(input_labels_list)
gt_labels = torch.vstack(gt_labels_list)

In [5]:
columns = ['feat1', 'feat2', 'feat3', 'feat4', 'feat5', 'feat6', 'feat7', 'feat8', 'input_label', 'gt_label']