In [1]:
# Uncomment to use autoreload
%load_ext autoreload
%autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES, CLASS_LABELS
from torch_points3d.metrics.segmentation_tracker import SegmentationTracker
from torch_points3d.datasets.segmentation import IGNORE_LABEL
from torch_points3d.metrics.scannet_segmentation_tracker import ScannetSegmentationTracker
from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq


from PIL import Image

import matplotlib.pyplot as plt 

%matplotlib inline

CLASS_COLORS[0] = (174.0, 199.0, 232.0)
CLASS_COLORS[-1] = (0, 0, 0)
import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

MMData debug() function changed, please uncomment the 3rd assert line when doing inference without M2F features!


In [2]:
# Functions for evaluation

def get_seen_points(mm_data, return_mask=False):
    ### Select seen points
    csr_idx = mm_data.modalities['image'][0].view_csr_indexing
    
    seen_mask = (csr_idx[1:] > csr_idx[:-1])
    if return_mask:
        return mm_data[seen_mask], seen_mask
    else:
        return mm_data[seen_mask]

def get_mode_pred(data):
    pixel_validity = data.data.mvfusion_input[:, :, 0].bool()
    mv_preds = data.data.mvfusion_input[:, :, -1].long()
            
    valid_m2f_feats = []
    for i in range(len(mv_preds)):
        valid_m2f_feats.append(mv_preds[i][pixel_validity[i]])

    mode_preds = []
    for m2feats_of_seen_point in valid_m2f_feats:
        mode_preds.append(torch.mode(m2feats_of_seen_point.squeeze(), dim=0)[0])
    mode_preds = torch.stack(mode_preds, dim=0)
        
    return mode_preds

def get_normalized_entropy(labels):
    counts = torch.unique(labels, return_counts=True)[1]
    
    pk = counts / counts.sum()
    len_pk = torch.tensor(len(pk))
    if len_pk == 1:
        normalized_entropy = 0.
    else:
        normalized_entropy = -sum(pk * torch.log2(pk)) / torch.log2(len_pk)
    return normalized_entropy
        
def get_semantic_image_from_camera(dataset, scene, mesh_triangles, intrinsic, extrinsic, class_id_faces, im_size=(480, 640)):
    """
    Returns the back-projected semantic label image given camera parameters and (semantic) mesh.
    """
    
    # Initialize rays for given camera
    rays = o3d.t.geometry.RaycastingScene.create_rays_pinhole(
        intrinsic_matrix=intrinsic,
        extrinsic_matrix=extrinsic,
        width_px=im_size[1],
        height_px=im_size[0],
    )

    # Get result
    ans = scene.cast_rays(rays)

    primitive_ids = ans['primitive_ids'].numpy()
    primitive_uvs = ans['primitive_uvs'].numpy()

    # Select the closest vertex for each valid face in the projected mesh
    valid_mask = primitive_ids != scene.INVALID_ID

    # https://stackoverflow.com/questions/45212949/vertex-of-a-3d-triangle-that-is-closest-to-a-point-given-barycentric-parameter-o
    w_coords = (1 - primitive_uvs[:, :, 0][valid_mask] - primitive_uvs[:, :, 1][valid_mask])
    barycentric_coords = np.concatenate((w_coords[:, None], primitive_uvs[valid_mask]), axis=-1)

    selected_vertex_idx = np.argmax(barycentric_coords, axis=-1)

    contained_mesh_triangles = mesh_triangles[primitive_ids[valid_mask]]
    closest_mesh_vertices = contained_mesh_triangles[range(len(barycentric_coords)), selected_vertex_idx]
    
    # Map mesh vertices to semantic label
    labels = class_id_faces[closest_mesh_vertices]
    # Remap to [0 ; num_labels - 1]
    labels = dataset.val_dataset._remap_labels(torch.tensor(labels))

    # Visualize back-projection
    image = torch.ones(im_size, dtype=torch.long) * -1
    image[valid_mask] = labels


    # NN interpolation at invalid pixels          
    nearest_neighbor = scipy.ndimage.morphology.distance_transform_edt(
        image==-1, return_distances=False, return_indices=True)    

    image = image[nearest_neighbor].numpy()
    return image

def read_axis_align_matrix(filename):
    lines = open(filename).readlines()
    axis_align_matrix = None
    for line in lines:
        if "axisAlignment" in line:
            axis_align_matrix = torch.Tensor([float(x) for x in line.rstrip().strip("axisAlignment = ").split(" ")]).reshape((4, 4))
            break
    return axis_align_matrix

def save_semantic_prediction_as_txt(tracker, model_name, mask_name):
    orginal_class_ids = np.asarray(tracker._dataset.train_dataset.valid_class_idx)
    path_to_submission = tracker._dataset.path_to_submission
    
    path_to_submission = osp.join(path_to_submission, model_name, mask_name)
    if not osp.exists(path_to_submission):
        os.makedirs(path_to_submission)
    
    for scan_id in tracker._full_preds:
        full_pred = tracker._full_preds[scan_id].cpu().numpy().astype(np.int8)
        full_pred = orginal_class_ids[full_pred]  # remap labels to original labels between 0 and 40
        scan_name = tracker._raw_datas[scan_id].scan_name
        path_file = osp.join(path_to_submission, "{}.txt".format(scan_name))
        
        np.savetxt(path_file, full_pred, delimiter="/n", fmt="%d")
        
    return path_to_submission
        
        


In [3]:
### Note: set over to Python file with argparse from script file
MASK_NAME = 'ViT_masks'


# # # DeepSet - DeepSetAttention (check model/config correctness)
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2023-02-11/10-52-09'   # ViT_masks
# # checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2023-02-11/10-54-19'   # m2f_masks
# dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f-allviews.yaml'   
# models_config = 'segmentation/multimodal/Feng/view_selection_experiment'    # model family
# model_name = 'DeepSetAttention'                       # specific model



# # Transformer - MVFusion_orig
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/MVFusion_orig'   # ViT_masks
# checkpoint_dir = '/home/fsun/DeepViewAgg_31-10-22/DeepViewAgg/outputs/2023-02-11/22-17-12'   # m2f_masks
dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f-allviews.yaml'   
models_config = 'segmentation/multimodal/Feng/mvfusion_orig'    # model family
model_name = 'MVFusion_orig'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'data.dataroot=/scratch-shared/fsun/dvata',
    f'models={models_config}',
    f'model_name={model_name}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load input masks
cfg.data.m2f_preds_dirname = MASK_NAME
cfg.data.n_views = cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
print(f"Dataset Creation Time = {time() - start:0.1f} sec.")



9
Load predicted 2D semantic segmentation labels from directory  ViT_masks
initialize train dataset
initialize val dataset
Dataset Creation Time = 7.8 sec.


In [4]:
from torch_points3d.models.model_factory import instantiate_model

# Set your parameters
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-12-07/12-07-34' # 3rd run
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/ViT_masks_3rd_run'

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)

# Load the checkpoint and recover the model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for inference
model = model.eval().cuda()
print('Model loaded')

Creating model: MVFusion_orig
task:  segmentation.multimodal
tested_model_name:  MVFusion_orig
loading gt mask from :  label-filt-scannet20
class_name:  MVFusion_model_orig
model_module:  torch_points3d.models.segmentation.multimodal.Feng.mvfusion_orig
name, cls of chosen model_cls:  MVFusion_model_orig <class 'torch_points3d.models.segmentation.multimodal.Feng.mvfusion_orig.MVFusion_model_orig'>
opt:   {'class': 'Feng.mvfusion_orig.MVFusion_model_orig', 'down_conv': {'image': {'down_conv': {'module_name': 'ADE20KResNet18PPM', 'frozen': False}, 'atomic_pooling': {'module_name': 'BimodalCSRPool', 'mode': 'max'}, 'view_pooling': {'module_name': 'GroupBimodalCSRPool', 'in_map': 8, 'in_mod': 512, 'num_groups': 4, 'use_mod': False, 'gating': True, 'group_scaling': True, 'map_encoder': 'DeepSetFeat', 'use_num': True, 'pool': 'max', 'fusion': 'concatenation'}, 'fusion': {'module_name': 'BimodalFusion', 'mode': 'concatenation'}, 'drop_mod': 0.0, 'branching_index': 0}}, 'backbone': {'transforme

In [5]:
# Set dataloaders
dataset.create_dataloaders(
    model,
    1,
    True,
    17,
    False,
    train_only=False,
    val_only=True,
    test_batch_size=1
)

In [6]:
tracker = dataset.get_tracker(False, False)

mm_data = dataset.val_dataset[0]

loading gt mask from :  label-filt-scannet20


In [8]:
tracker.reset(stage='val')
# Create a MMBatch and run inference
batch = MMBatch.from_mm_data_list([mm_data])

batch = get_seen_points(batch)

with torch.no_grad():
    print("input batch: ", batch)
    
#     batch.data.mvfusion_input[:, :, -1] = 0
    
    model.set_input(batch, model.device)
    model.forward(epoch=1)
    
# Recover the predicted labels for visualization
mm_data.data.pred = model.output.detach().cpu().argmax(1)

tracker.track(model)

tracker.get_metrics()

input batch:  MMData(
    data = Batch(batch=[79550], coords=[79550, 3], grid_size=[1], id_scan=[1], mapping_index=[79550], mvfusion_input=[79550, 9, 10], origin_id=[79550], pos=[79550, 3], pred=[79550], ptr=[2], rgb=[79550, 3], x=[79550, 3], y=[79550])
    image = ImageBatch(num_settings=1, num_views=342, num_points=79550, device=cpu)
)


{'val_loss_seg': 0.22071437537670135,
 'val_acc': 93.77695104422085,
 'val_macc': 91.51102796521786,
 'val_miou': 64.89498269639637}

In [None]:
# deepset
{'val_loss_seg': 0.21620771288871765,
 'val_acc': 93.81218116738536,
 'val_macc': 91.32494575166828,
 'val_miou': 64.8205182841954}

# trans
{'val_loss_seg': 0.22071437537670135,
 'val_acc': 93.77695104422085,
 'val_macc': 91.51102796521786,
 'val_miou': 64.89498269639637}

# Viewing Conditions Ablation Study
- measure mIoU difference when setting each feature to its mean value

- Do this for MVFusion_3D, MVFusion and MVAttention!

    - First one because it is the general model,
    
      Second one because we remove the dependency of 3D Network on those features
      
      Third one because that truly shows the ability of each view condition to predict the quality of a view.

in 'mvfusion_input', this is the 2nd till 9th feature 

1. normalized depth
2. linearity
3. planarity
4. scattering
5. orientation to the surface
6. normalized pixel height
7. density
8. occlusion

In [9]:
# train mapping feature statistics:
# tensor([0.2711, 0.1401, 0.6778, 0.1822, 0.6568, 0.4669, 0.2979, 0.6933])
# tensor([0.0811, 0.1083, 0.2375, 0.1710, 0.2294, 0.2735, 0.1378, 0.2283])

# val mapping feature statistics:
# tensor([0.2745, 0.1412, 0.6740, 0.1847, 0.6517, 0.4674, 0.3046, 0.6818])
# tensor([0.0856, 0.1096, 0.2406, 0.1730, 0.2322, 0.2735, 0.1415, 0.2301])

weighted_mean = torch.tensor([0.2711, 0.1401, 0.6778, 0.1822, 0.6568, 0.4669, 0.2979, 0.6933])

In [17]:
metrics = []

for idx in range(8):
    tracker.reset(stage='val')
    
    for b in [batch]:

        # Change feature on idx to mean value
#         b.data.mvfusion_input[:, :, idx+1] = weighted_mean[idx]
        
        b.data.mvfusion_input[:, :, 1:-1] = weighted_mean

        
#         print(b.modalities['image'][0]._mappings.values[-1])
    
#         b.modalities['image'][0]._mappings.values[-1][:, idx] = weighted_mean[idx]

#         print(b.modalities['image'][0]._mappings.values[-1])

        # Inference
        model.set_input(b, model.device)
        model.forward(epoch=1)

        b.data.pred = model.output.detach().cpu().argmax(1)

        tracker.track(model, full_res=False)

        
    metrics.append(tracker.get_metrics())
    
    break

print(metrics)

[{'val_loss_seg': 0.6055614948272705, 'val_acc': 83.4291592683408, 'val_macc': 80.53872289999802, 'val_miou': 34.64399044593892}]


In [11]:
# res = {}
# for l in metrics:
#     for k, v in l.items():
#         if k not in res:
#             res[k] = [v]
#         else:
#             res[k].append(v)

In [12]:
# res

In [18]:
tracker.reset(stage='val')

In [19]:
weighted_mean = torch.tensor([0.2711, 0.1401, 0.6778, 0.1822, 0.6568, 0.4669, 0.2979, 0.6933])

for idx in range(1):
    tracker.reset(stage='val')

    with Ctq(dataset._val_loader) as tq_loader:
        for data in tq_loader:
            
            with torch.no_grad():
                data = get_seen_points(data)

                data.data.mvfusion_input[:, :, idx+1] = weighted_mean[idx]

                model.set_input(data, model.device)

                with torch.cuda.amp.autocast(enabled=model.is_mixed_precision()):
                    model.forward(epoch=1)


                # 3D mIoU, all points
                tracker.track(model, full_res=False, data=data)

    #             # 3D mIoU, seen points
    #             data = get_seen_points(data)
    #             tracker.track(self._model, full_res=False, data=data)


  0%|          | 0/312 [00:01<?, ?it/s]

loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt mask from :  label-filt-scannet20
loading gt

In [21]:
tracker.get_metrics()

{'val_loss_seg': 0.33554776523930907,
 'val_acc': 91.66327335990314,
 'val_macc': 84.60007118345999,
 'val_miou': 75.20456407483012}