In [None]:
I_GPU = 0

In [None]:
# %load_ext autoreload
# %autoreload 2
import os
import sys
import numpy as np
import torch
import glob
from matplotlib.colors import ListedColormap
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)
from torch_points3d.datasets.segmentation.multimodal.s3dis import S3DISFusedDataset
from torch_points3d.datasets.segmentation.multimodal import IGNORE_LABEL

_ = torch.cuda.is_available()
_ = torch.cuda.memory_allocated()
torch.cuda.set_device(I_GPU)

# Load config

In [None]:
from omegaconf import OmegaConf
from torch_points3d.utils.config import hydra_read

# Set root to the DATA drive, where the data was downloaded
# DATA_ROOT = '/mnt/fa444ffd-fdb4-4701-88e7-f00297a8e29b/projects/datasets/s3dis'  # ???
# DATA_ROOT = '/media/drobert-admin/DATA/datasets/s3dis'  # IGN DATA
DATA_ROOT = '/media/drobert-admin/DATA2/datasets/s3dis'  # IGN DATA2
# DATA_ROOT = '/var/data/drobert/datasets/s3dis'  # AI4GEO
# DATA_ROOT = '/home/qt/robertda/scratch/datasets/s3dis'  # CNES
# DATA_ROOT = '/raid/datasets/pointcloud/data/s3dis'  # ENGIE

overrides = [
    'task=segmentation',
    'data=segmentation/multimodal/s3disfused/no3d_5cm_768x384-exact',
    'models=segmentation/multimodal/no3d',
    'model_name=RGB_ResNet18PPM_mean-feat',
    'data.fold=5',
    'data.sample_per_epoch=2000',
    f"data.dataroot={os.path.join(DATA_ROOT, '5cm_exact_768x384')}",
]

cfg = hydra_read(overrides)

# print(OmegaConf.to_yaml(cfg))

# Load S3DIS dataset

In [None]:
CLASSES = [
    'ceiling',
    'floor',
    'wall',
    'beam',
    'column',
    'window',
    'door',
    'chair',
    'table',
    'bookcase',
    'sofa',
    'board',
    'clutter',
]
OBJECT_COLOR = [
    [180, 180, 80],  #'ceiling' .-> .yellow
    [95, 156, 196],  #'floor' .-> . blue
    [179, 116, 81],  #'wall'  ->  brown
    [241, 149, 131],  #'beam'  ->  salmon
    [81, 163, 148],  #'column'  ->  bluegreen
    [77, 174, 84],  #'window'  ->  bright green
    [108, 135, 75],  #'door'   ->  dark green
    [41, 49, 101],  #'chair'  ->  darkblue
    [79, 79, 76],  #'table'  ->  dark grey
    [223, 52, 52],  #'bookcase'  ->  red
    [89, 47, 95],  #'sofa'  ->  purple
    [81, 109, 114],  #'board'   ->  grey
    [125, 125, 125],  #'clutter'  ->  light grey
]

num_classes = len(CLASSES)

PROJ_FEATS = [
    'normalized depth',
    'linearity',
    'planarity',
    'scattering',
    'orientation to the surface',
    'normalized pixel height',
    'density',
    'occlusion'
]

In [None]:
from time import time
start = time()
    
dataset = S3DISFusedDataset(cfg.data)
# print(dataset)

print(f"Time = {time() - start:0.1f} sec.")

In [None]:
# Remove pixel memory credit transform
from torch_points3d.core.data_transform.multimodal.image import PickImagesFromMemoryCredit
from torch_points3d.datasets.base_dataset_multimodal import BaseDatasetMM
for x in [dataset.train_dataset, dataset.val_dataset, dataset.test_dataset[0]]:
    x.transform_image = BaseDatasetMM.remove_multimodal_transform(x.transform_image, [PickImagesFromMemoryCredit])

# Load model

### Instantiate model

In [None]:
# from torch_points3d.models.model_factory import instantiate_model
# 
# model = instantiate_model(cfg, dataset)
# model = model.train().cuda()

### Load model from checkpoint

In [None]:
from torch_points3d.metrics.model_checkpoint import ModelCheckpoint

# checkpoint_dir = '/home/ign.fr/drobert-admin/Bureau/benchmark_checkpoints/benchmark-Res16UNet21-15_ResImage3_light_1_a4_concatenation-20210304_210217'
# checkpoint_dir = '/home/ign.fr/drobert-admin/Bureau/benchmark_checkpoints/benchmark-Res16UNet21-15_ResImage3_light_1_mean_concatenation-20210301_230608'
# checkpoint_dir = '/home/ign.fr/drobert-admin/Bureau/benchmark_checkpoints/XYZ+RGB_a4-dim_cat-1'
# checkpoint_dir = '/home/ign.fr/drobert-admin/Bureau/benchmark_checkpoints/XYZ+RGB_mean_cat-1'
# checkpoint_dir = '/home/ign.fr/drobert-admin/Bureau/benchmark_checkpoints/XYZ+RGB_attention_debug_fast/files'
# checkpoint_dir = '/workspace/projects/torch-points3d/outputs/benchmark/benchmark-Res16UNet21-15_light-20210330_193749/wandb/run-20210330_193750-1ltttctz/files'
checkpoint_dir = '/home/ign.fr/drobert-admin/Bureau/benchmark_checkpoints/RGB_light_drop-50_view-loss_fold5'

# RGB light drop 50 trained on "5cm exact 512x256"
# checkpoint_dir = '/home/drobert/projects/torch-points3d/outputs/benchmark/benchmark-Res16UNet21-15_light_drop-50_image-view-loss-20210419_193538'

# RGB light drop 50 trained on "5cm exact 768x384"
checkpoint_dir = '/home/drobert/projects/torch-points3d/outputs/benchmark/benchmark-Res16UNet21-15_light_drop-50_image-view-loss-20210428_165425'
checkpoint_dir = '/home/drobert/projects/torch-points3d/outputs/benchmark/benchmark-RGB_D32-4_persistent-indrop-50_mean_view-20210518_141358'
checkpoint_dir = '/home/drobert/projects/torch-points3d/outputs/benchmark/benchmark-RGB_D32-4_persistent-indrop-50_mean_view-20210517_230329'

# Load model from checkpoint
selection_stage = 'val' # train, val, test
weight_name = 'loss_seg'  # miou, macc, acc, ..., latest
checkpoint = ModelCheckpoint(checkpoint_dir, cfg.model_name, selection_stage, run_config=cfg, resume=False)
model = checkpoint.create_model(dataset, weight_name=weight_name)
model = model.eval().cuda()

In [None]:
# Activate the save_last option to investigate AttentiveBimodalCSRPool module
i_pool_branch = 0
# i_pool_branch = 1
model.backbone.down_modules[i_pool_branch].image.view_pool.save_last = True

# Inference on TRAIN/VAL/TEST set

In [None]:
from torch_points3d.core.multimodal.data import MMBatch
batch = MMBatch.from_mm_data_list([dataset.test_dataset[0][2]])

if not model.is_multimodal:
    batch = batch.data

model.set_input(batch, model.device)
_ = model(batch)

# gt = model.labels.cpu().numpy()
# pred = model.output.argmax(dim=1).cpu().numpy()

batch.data.pred = model.output.argmax(dim=1).cpu()

In [None]:
# from torch_points3d.visualization.multimodal_data import visualize_mm_data

visualize_mm_data(batch, class_names=CLASSES, class_colors=OBJECT_COLOR, figsize=800, voxel=0.05, show_3d=True, show_2d=False, color_mode='y', alpha=3, pointsize=5)

In [None]:
from torch_points3d.core.multimodal.data import MMBatch
from torch_points3d.metrics.confusion_matrix import ConfusionMatrix
from tqdm import tqdm
from torch_points3d.modules.multimodal.pooling import BimodalCSRPool, AttentiveBimodalCSRPool, HeuristicBimodalCSRPool

def inference(model, dataset, set_name='TEST', n_infer=1000):
    if set_name.upper() == 'TRAIN':
        dataset_ = dataset.train_dataset
    elif set_name.upper() == 'VAL':
        dataset_ = dataset.val_dataset
    elif set_name.upper() == 'TEST':
        dataset_ = dataset.test_dataset[0]
    else:
        raise ValueError(f"Unknown set '{set_name.upper()}'")

    c = ConfusionMatrix(dataset.num_classes)

    attention = model.backbone.down_modules[i_pool_branch].image.view_pool

    idx = []
    group_size = []
    x_proj = []
    x_mod = []
    if isinstance(attention, AttentiveBimodalCSRPool):
        K = []
        Q = []
        C = []
        A = []
        G = []
    Y = []
    Y_pred = []
    count = 0

    for i in tqdm(np.random.choice(len(dataset_), n_infer)):

        # Skip these two sphere samples, they make the model crash
        if i in [903, 1470, 1471] and set_name.upper() == 'TEST':
            continue

        batch = MMBatch.from_mm_data_list([dataset_[int(i)]])

        # ------------------------------------------------------------------
    #     # TEMPORARY FIX TO DROP SOME LOCAL PROJECTION FEATURES
    #     for s in range(batch.modalities['image'].num_settings):
    #         batch.modalities['image'][s].mappings.features = batch.modalities['image'][s].mappings.features[:, :-2]
        # ------------------------------------------------------------------

        if not model.is_multimodal:
            batch = batch.data

        model.set_input(batch, model.device)
        _ = model(batch)

        gt = model.labels.cpu().numpy()
        pred = model.output.argmax(dim=1).cpu().numpy()
        c.count_predicted_batch(gt, pred)

        # ------------------------------------------------------------------
        # POOLING DATA
        idx.append(attention._last_idx.detach().cpu() + count)
        group_size.append(attention._last_view_num.detach().cpu())
        x_proj.append(attention._last_x_proj.detach().cpu())
        x_mod.append(attention._last_x_mod.detach().cpu())
        if isinstance(attention, AttentiveBimodalCSRPool):
            K.append(attention._last_K.detach().cpu())
            Q.append(attention._last_Q.detach().cpu())
            C.append(attention._last_C.detach().cpu())
            A.append(attention._last_A.detach().cpu())
            G.append(attention._last_G.detach().cpu())
        Y.append(gt)
        Y_pred.append(pred)
#         count += attention._last_idx.detach().cpu().max() + 1
        count += gt.shape[0]
        # ------------------------------------------------------------------

    oa = np.round(c.get_overall_accuracy() * 100, decimals=2)
    macc = np.round(c.get_mean_class_accuracy() * 100, decimals=2)
    miou = np.round(c.get_average_intersection_union() * 100, decimals=2)
    iou_dict = {dataset.INV_OBJECT_LABEL[k]: np.round(v * 100, decimals=1) 
                for k, v in enumerate(c.get_intersection_union_per_class()[0])}

    print(f"OA : {oa}")
    print(f"macc : {macc}")
    print(f"mIoU : {miou}")
    print("Per class IoU")
    for k, v in iou_dict.items():
        print(f"    {k:<9}: {v}")
    print()

    idx = torch.cat(idx, dim=0)
    group_size = torch.cat(group_size, dim=0)
    x_proj = torch.cat(x_proj, dim=0)
    x_mod = torch.cat(x_mod, dim=0)
    if isinstance(attention, AttentiveBimodalCSRPool):
        K = torch.cat(K, dim=0)
        Q = torch.cat(Q, dim=0)
        C = torch.cat(C, dim=0)
        A = torch.cat(A, dim=0)
        G = torch.cat(G, dim=0)
    Y = torch.cat([torch.from_numpy(y) for y in Y], dim=0)
    Y_pred = torch.cat([torch.from_numpy(y) for y in Y_pred], dim=0)

    group_size_view = torch.repeat_interleave(group_size, group_size)
    Y_view = torch.repeat_interleave(Y, group_size)
    Y_pred_view = torch.repeat_interleave(Y_pred, group_size)
    if isinstance(attention, AttentiveBimodalCSRPool):
        G_view = torch.repeat_interleave(G, group_size)

    # For pure-2d mean logit fusion models, x_mod carries the logits
    Y_pred_view_indiv = torch.max(x_mod, dim=1).indices
    
    # Row-normalized confusion matrix - rows sum up to 1. Each row carries 
    # the distribution of predictions for an expected label. This 
    # illustrates how classes are misclassified.
    confusion = np.zeros((dataset.num_classes, dataset.num_classes))
    for i, j in zip(Y.numpy(), Y_pred.numpy()):
        confusion[i, j] += 1
    sns.heatmap(confusion / confusion.max(axis=1).reshape(-1, 1), xticklabels=CLASSES, yticklabels=CLASSES)
    plt.show()
    print()
    
    # Wrap up everything in an output dictionary
    out  = {
        'oa': oa,
        'macc': macc,
        'miou': miou,
        'iou_dict': iou_dict,
        'idx': idx,
        'group_size': group_size,
        'x_proj': x_proj,
        'x_mod': x_mod,
        'Y': Y,
        'Y_pred': Y_pred,
        'group_size_view': group_size_view,
        'Y_view': Y_view,
        'Y_pred_view': Y_pred_view,
        'Y_pred_view_indiv': Y_pred_view_indiv,
    }

    if isinstance(attention, AttentiveBimodalCSRPool):
        out['K'] = K
        out['Q'] = Q
        out['C'] = C
        out['A'] = A
        out['G'] = G
        out['G_view'] = G_view
        
    return out

In [None]:
# Inference on TEST for average-pooled logit model
out = inference(model, dataset, set_name='TEST', n_infer=500)

Computed <span style="color:red">**logits mean-pool score: 51.2 mIoU**</span> (RGB drop 50 on exact 768x384 Area 5 5cm, with $n_{infer}=1000$)

In [None]:
# Inference on TEST for min-depth heuristic
model.backbone.down_modules[i_pool_branch].image.view_pool = HeuristicBimodalCSRPool(mode='min', feat='normalized_depth', save_last=True)
out = inference(model, dataset, set_name='TEST', n_infer=1000)

In [None]:
# Inference on TEST for max-occlusion heuristic
model.backbone.down_modules[i_pool_branch].image.view_pool = HeuristicBimodalCSRPool(mode='max', feat='occlusion', save_last=True)
out = inference(model, dataset, set_name='TEST', n_infer=100)

In [None]:
# Inference on TEST for max-orientation heuristic
model.backbone.down_modules[i_pool_branch].image.view_pool = HeuristicBimodalCSRPool(mode='max', feat='orientation_to_the_surface', save_last=True)
out = inference(model, dataset, set_name='TEST', n_infer=100)

In [None]:
# Inference on TEST for max-planarity heuristic
model.backbone.down_modules[i_pool_branch].image.view_pool = HeuristicBimodalCSRPool(mode='max', feat='planarity', save_last=True)
out = inference(model, dataset, set_name='TEST', n_infer=100)

In [None]:
# Inference on TEST for depth < 0.6 heuristic
from torch_points3d.core.data_transform.multimodal.image import PickMappingsFromProjectionFeatures
t = PickMappingsFromProjectionFeatures(feat=PROJ_FEATS.index('normalized depth'), lower=None, upper=0.6)
for x in [dataset.train_dataset, dataset.val_dataset, dataset.test_dataset[0]]:
    x.transform_image.transforms.append(t)
    
out = inference(model, dataset, set_name='TEST', n_infer=500)

for x in [dataset.train_dataset, dataset.val_dataset, dataset.test_dataset[0]]:
    x.transform_image.transforms.pop()

Column, window, door, chair benefit from this.

In [None]:
# Inference on TEST for planarity > 0.5 heuristic
from torch_points3d.core.data_transform.multimodal.image import PickMappingsFromProjectionFeatures
t = PickMappingsFromProjectionFeatures(feat=PROJ_FEATS.index('planarity'), lower=0.5, upper=None)
for x in [dataset.train_dataset, dataset.val_dataset, dataset.test_dataset[0]]:
    x.transform_image.transforms.append(t)
    
out = inference(model, dataset, set_name='TEST', n_infer=500)

for x in [dataset.train_dataset, dataset.val_dataset, dataset.test_dataset[0]]:
    x.transform_image.transforms.pop()

In [None]:
# Inference on TEST for orientation > 0.25 heuristic
from torch_points3d.core.data_transform.multimodal.image import PickMappingsFromProjectionFeatures
t = PickMappingsFromProjectionFeatures(feat=PROJ_FEATS.index('orientation to the surface'), lower=0.25, upper=None)
for x in [dataset.train_dataset, dataset.val_dataset, dataset.test_dataset[0]]:
    x.transform_image.transforms.append(t)
    
out = inference(model, dataset, set_name='TEST', n_infer=500)

for x in [dataset.train_dataset, dataset.val_dataset, dataset.test_dataset[0]]:
    x.transform_image.transforms.pop()

As projection features could predict it, window and column benefit from this.

# Assess bounds on the multi-view pooling performance
Here we take a closer look at the individual and multi-view predictions. We want to estimate an upper bound on the multi-view prediction performance, the probability of a single view of being right, the probability of views of a group to be in agreement, etc.

In [None]:
import torch_scatter

idx_seen = torch.unique(idx)
n_points = Y.shape[0]

confusion = ConfusionMatrix(dataset.num_classes)
idx_best_of_group = torch_scatter.scatter_max((Y_view == Y_pred_view_indiv).float(), idx, dim_size=n_points)[1][idx_seen]
confusion.count_predicted_batch(Y[idx_seen].numpy(), Y_pred_view_indiv[idx_best_of_group].numpy())

oa = np.round(confusion.get_overall_accuracy() * 100, decimals=2)
macc = np.round(confusion.get_mean_class_accuracy() * 100, decimals=2)
miou = np.round(confusion.get_average_intersection_union() * 100, decimals=2)
iou_dict = {dataset.INV_OBJECT_LABEL[k]: np.round(v * 100, decimals=1) 
            for k, v in enumerate(confusion.get_intersection_union_per_class()[0])}

print(f"Unseen points: {100 * (1 - idx_seen.shape[0] / n_points):0.1f}%")
print(f"OA : {oa}")
print(f"macc : {macc}")
print(f"mIoU : {miou}")
print("Per class IoU")
for k, v in iou_dict.items():
    print(f"    {k:<9}: {v}")
del confusion

Computed <span style="color:red">**multi-view upper bound: ~64.9 mIoU**</span> (RGB drop 50 on exact 768x384 Area 5 5cm, with 2 runs of $n_{infer}=1000$). 

When compared to XYZRGB*, this would hypothetically bring the following improvements: 
- ceiling  : $-2.00$
- floor    : $-6.10$
- wall     : $3.90$
- beam     : $1.60$
- column   : $10.90$
- window   : $21.80$
- door     : $14.40$
- chair    : $-16.30$
- table    : $-2.40$
- bookcase : $7.70$
- sofa     : $-31.90$
- board    : $28.10$
- clutter  : $4.80$

That is to say, there is a lot to hope for classes such as window, door and board.

In [None]:
idx_seen = torch.unique(idx)
n_points = Y.shape[0]

print(f"N_points      : {n_points}")
print(f"N_seen        : {idx_seen.shape[0]}")
print(f"Unseen points : {100 * (1 - idx_seen.shape[0] / n_points):0.1f}%")
print(f"N_views       : {Y_view.shape[0]}")
print()
print(f"CHECK - idx    increases           : {torch.all(idx[1:] >= idx[:-1])}")
print(f"CHECK - idx    matches group_size  : {torch.all(torch.arange(n_points).repeat_interleave(group_size) == idx)}")
print(f"CHECK - Y      matches Y_view      : {torch.all(torch.repeat_interleave(Y, group_size) == Y_view)}")
print(f"CHECK - Y_pred matches Y_pred_view : {torch.all(torch.repeat_interleave(Y_pred, group_size) == Y_pred_view)}")

Remark on Y_pred_indiv: although not often, <span style="color:red">there may be a discrepancy in the Y_view_pred (view-expanded mean-pooled prediction) and the Y_view_pred_indiv (view-wise individual predictions)</span>. It typically happens when the second highest logit is the same in the individual predictions, and they disagree on the first one.

In [None]:
confusion = ConfusionMatrix(dataset.num_classes)
confusion.count_predicted_batch(Y_pred_view.numpy(), Y_pred_view_indiv.numpy())

oa = np.round(confusion.get_overall_accuracy() * 100, decimals=2)
macc = np.round(confusion.get_mean_class_accuracy() * 100, decimals=2)
miou = np.round(confusion.get_average_intersection_union() * 100, decimals=2)
iou_dict = {dataset.INV_OBJECT_LABEL[k]: np.round(v * 100, decimals=1) 
            for k, v in enumerate(confusion.get_intersection_union_per_class()[0])}

print(f"Unseen points: {100 * (1 - idx_seen.shape[0] / n_points):0.1f}%")
print(f"OA : {oa}")
del confusion

The OA here is interesting: it gives an idea of <span style="color:red">how often the individual image predictions agree with the group's final prediction</span>. This, however is no garantee that the group's prediction is these cases is any bettter...

In [None]:
confusion = ConfusionMatrix(dataset.num_classes)
confusion.count_predicted_batch(Y_view.numpy(), Y_pred_view_indiv.numpy())

oa = np.round(confusion.get_overall_accuracy() * 100, decimals=2)
macc = np.round(confusion.get_mean_class_accuracy() * 100, decimals=2)
miou = np.round(confusion.get_average_intersection_union() * 100, decimals=2)
iou_dict = {dataset.INV_OBJECT_LABEL[k]: np.round(v * 100, decimals=1) 
            for k, v in enumerate(confusion.get_intersection_union_per_class()[0])}

print(f"Unseen points: {100 * (1 - idx_seen.shape[0] / n_points):0.1f}%")
print(f"OA : {oa}")
print(f"macc : {macc}")
print(f"mIoU : {miou}")
print("Per class IoU")
for k, v in iou_dict.items():
    print(f"    {k:<9}: {v}")
del confusion

These measures give an idea of <span style="color:red">how individual image predictions perform wrt the ground truth</span>. This illustrates how <span style="color:red">30% of the time, what the images say is wrong</span>. And not images are equal. Would be good to link these erroneous predictions with any of the projection features...

In [None]:
confusion = ConfusionMatrix(dataset.num_classes)
confusion.count_predicted_batch(Y.numpy(), Y_pred.numpy())

oa = np.round(confusion.get_overall_accuracy() * 100, decimals=2)
macc = np.round(confusion.get_mean_class_accuracy() * 100, decimals=2)
miou = np.round(confusion.get_average_intersection_union() * 100, decimals=2)
iou_dict = {dataset.INV_OBJECT_LABEL[k]: np.round(v * 100, decimals=1) 
            for k, v in enumerate(confusion.get_intersection_union_per_class()[0])}

print(f"Unseen points: {100 * (1 - idx_seen.shape[0] / n_points):0.1f}%")
print(f"OA : {oa}")
print(f"macc : {macc}")
print(f"mIoU : {miou}")
print("Per class IoU")
for k, v in iou_dict.items():
    print(f"    {k:<9}: {v}")
del confusion

In [None]:
import torch_scatter

idx_seen = torch.unique(idx)
n_points = Y.shape[0]

confusion = ConfusionMatrix(dataset.num_classes)
idx_best_of_group = torch_scatter.scatter_min((Y_view == Y_pred_view_indiv).float(), idx, dim_size=n_points)[1][idx_seen]
confusion.count_predicted_batch(Y[idx_seen].numpy(), Y_pred_view_indiv[idx_best_of_group].numpy())

oa = np.round(confusion.get_overall_accuracy() * 100, decimals=2)
macc = np.round(confusion.get_mean_class_accuracy() * 100, decimals=2)
miou = np.round(confusion.get_average_intersection_union() * 100, decimals=2)
iou_dict = {dataset.INV_OBJECT_LABEL[k]: np.round(v * 100, decimals=1) 
            for k, v in enumerate(confusion.get_intersection_union_per_class()[0])}

print(f"Unseen points: {100 * (1 - idx_seen.shape[0] / n_points):0.1f}%")
print(f"OA : {oa}")
print(f"macc : {macc}")
print(f"mIoU : {miou}")
print("Per class IoU")
for k, v in iou_dict.items():
    print(f"    {k:<9}: {v}")
del confusion

# Search worst-case examples for each class

In [None]:
from torch_points3d.core.multimodal.data import MMBatch
from torch_points3d.metrics.confusion_matrix import ConfusionMatrix
from tqdm import tqdm

if SET.upper() == 'TRAIN':
    dataset_ = dataset.train_dataset
elif SET.upper() == 'VAL':
    dataset_ = dataset.val_dataset
elif SET.upper() == 'TEST':
    dataset_ = dataset.test_dataset[0]
else:
    raise ValueError(f"Unknown set '{SET.upper()}'")
    
idx_worst_sample = np.zeros(dataset.num_classes, dtype='int')
iou_worst_sample = np.ones(dataset.num_classes)
idx_best_sample = np.zeros(dataset.num_classes, dtype='int')
iou_best_sample = np.zeros(dataset.num_classes)

n_points_min = 100

for i in tqdm(range(len(dataset_))):
    
    if i in [903, 1470, 1471]:
        continue
    
    batch = MMBatch.from_mm_data_list([dataset_[int(i)]])
        
    if not model.is_multimodal:
        batch = batch.data
    
    model.set_input(batch, model.device)
    _ = model(batch)
    
    gt = model.labels.cpu().numpy()
    pred = model.output.argmax(dim=1).cpu().numpy()
    
    # Compute the local miou
    c = ConfusionMatrix(dataset.num_classes)
    c.count_predicted_batch(gt, pred)
    
    # Check whether the sampling contains the class in the GT
    is_class_seen = c.confusion_matrix.sum(axis=1) > n_points_min
    
    # Compute the per-class IoU
    iou = c.get_intersection_union_per_class()[0]
    
    # Update worse values
    idx_worst_update = np.logical_and(iou < iou_worst_sample, is_class_seen)
    iou_worst_sample[idx_worst_update] = iou[idx_worst_update]
    idx_worst_sample[idx_worst_update] = i
    
    # Update best values
    idx_best_update = np.logical_and(iou > iou_best_sample, is_class_seen)
    iou_best_sample[idx_best_update] = iou[idx_best_update]
    idx_best_sample[idx_best_update] = i
    
    print(f"Best - Worst samples")
    for c, idx_b, idx_w in zip(CLASSES, idx_best_sample, idx_worst_sample):
        print(f"    {c:<9}: {idx_b} - {idx_w}")
    print()

In [None]:
from torch_geometric.transforms import Center, RandomRotate
from torch_points3d.core.data_transform import RandomNoise, RandomScaleAnisotropic, RandomSymmetry, \
    DropFeature, XYZFeature, AddFeatsByKeys
from torch_points3d.core.data_transform.multimodal.image import ToFloatImage, AddPixelHeightFeature, \
    PickImagesFromMappingArea, PickImagesFromMemoryCredit, CropImageGroups
from torch_points3d.datasets.base_dataset import BaseDataset
from torch_points3d.datasets.base_dataset_multimodal import BaseDatasetMM

def sample_real_data(tg_dataset, idx=0):
    """
    Temporarily remove the 3D and 2D transforms affecting the point 
    positions and images from the dataset to better visualize points 
    and images relative positions.
    """
    transform = tg_dataset.transform
    tg_dataset.transform = BaseDataset.remove_transform(transform, [Center, RandomNoise,
        RandomRotate, RandomScaleAnisotropic, RandomSymmetry, DropFeature, AddFeatsByKeys])
    
    transform_image = tg_dataset.transform_image
    tg_dataset.transform_image = BaseDatasetMM.remove_multimodal_transform(transform_image, [ToFloatImage, AddPixelHeightFeature])

    out = tg_dataset[idx]
    
    tg_dataset.transform = transform
    tg_dataset.transform_image = transform_image
    return out

In [None]:
from torch_points3d.visualization.multimodal_data import visualize_mm_data

idx = 82

# Compute the predicted labels
batch = MMBatch.from_mm_data_list([dataset_[idx]])   
if not model.is_multimodal:
    batch = batch.data
model.set_input(batch, model.device)
_ = model(batch)
y_pred = model.output.argmax(dim=1).cpu()

# Visualize the ground truth
mm_data = sample_real_data(dataset.test_dataset[0], idx=idx)
visualize_mm_data(mm_data, class_names=CLASSES, class_colors=OBJECT_COLOR, figsize=800, voxel=0.05, show_3d=True, show_2d=True, color_mode='light', alpha=3, pointsize=5)

# Visualize the predictions
mm_data.data.y = y_pred
visualize_mm_data(mm_data, class_names=CLASSES, class_colors=OBJECT_COLOR, figsize=800, voxel=0.05, show_3d=True, show_2d=True, color_mode='y', alpha=3, pointsize=5)

In [None]:
# # Scattering unskewed
# i_feat = 3
# fig = plt.figure()
# sns.distplot(np.power(np.random.choice(x_proj[:, i_feat][select].numpy(), N_PRINT), 1/3), label='True',)
# sns.distplot(np.power(np.random.choice(x_proj[:, i_feat][~select].numpy(), N_PRINT), 1/3), label='False', )
# plt.title(PROJ_FEATS[i_feat])
# plt.legend()
# plt.show()

In [None]:
# # Occlusion unskewed
# i_feat = 7
# fig = plt.figure()
# sns.distplot(np.power(np.random.choice(x_proj[:, i_feat][select].numpy(), N_PRINT), 4), label='True',)
# sns.distplot(np.power(np.random.choice(x_proj[:, i_feat][~select].numpy(), N_PRINT), 4), label='False', )
# ax.set_title(PROJ_FEATS[i_feat])
# ax.legend()
# ax.set_ylabel('')
# plt.show()

In [None]:
N_PRINT = 5000

# Define the plots
fig = plt.figure(figsize=(4*len(PROJ_FEATS), 4*(len(CLASSES)+1)))
axes = fig.subplots(len(CLASSES)+1, len(PROJ_FEATS))

for ax in axes.flatten():
    ax.set_yticks([])

for ax, col_name in zip(axes[0], PROJ_FEATS):
    ax.set_title(col_name.upper(), fontsize=12)

for ax, row_name in zip(axes[:,0], ['GLOBAL'] + [f"{class_name.upper()} - {list(iou_dict.values())[i_class]} IoU" for i_class, class_name in enumerate(CLASSES)]):
    ax.set_ylabel(row_name.upper(), fontsize=12)

# Plot the global distribution of projection features wrt prediction errors
select = Y_pred_view_indiv == Y_view
for i_feat, (ax, feat_name) in enumerate(zip(axes[0], PROJ_FEATS)):
    sns.distplot(np.random.choice(x_proj[:, i_feat][select].numpy(), N_PRINT), label='True', ax=ax, color='black')
    sns.distplot(np.random.choice(x_proj[:, i_feat][~select].numpy(), N_PRINT), label='False', ax=ax, color='red')
    ax.legend()

# Plot the per-class distribution of projection features wrt prediction errors
for i_class, class_name in enumerate(CLASSES):
    t = Y_view == i_class
    p = Y_pred_view_indiv == i_class
    tp = torch.logical_and(t, p)
    fp = torch.logical_and(~t, p)
    fn = torch.logical_and(t, ~p)

    for i_feat, (ax, feat_name) in enumerate(zip(axes[i_class+1], PROJ_FEATS)):
        for select, label, color in zip((tp, fp, fn), ('tp', 'fp', 'fn'), ('black', 'green', 'red')):
            if select.sum() > 0:
                sns.distplot(np.random.choice(x_proj[:, i_feat][select].numpy(), N_PRINT), label=label.upper(), ax=ax, color=color)
        ax.legend()
plt.show()

### Thoughts
The distribution of projection features wrt pure-2D model predictions suggest the former should be able to help disambiguate some cases. 

However, this intuition means we would build an **_attention/selection mechanism_** able to attend to the views, based on their **_projection features AND their class_**. Since we don't know the class, the mechanism should make use of the features. 

- But will the network be able to learn this ? Are we not thinking in reverse, expecting the network to already know what we would like it to learn ?
- And what should the queries and keys be built upon ? Shouldn't the 2D features be taken into account in the queries ? For the pure-2D model, we can think of queries built from the 2D features, much like in a self-attention mechanism. Why shouldn't this apply to the multimodal architecture too then ?
- Besides, since the geometric information - some of which is carried in the local geometric features of the projection - seems to help disambiguate some 2D features, w wouldn't the 2D encoder benefit from receiveing some of it ? Otherwise said: can the 2D features also be fed contracted 3D features / multimodal attention outputs ? 
- Since in Transformer networks, the self-attention Q and K are all computed from the same data, shouldn't we do the same ? Rather than 'hiding' the 2D feats and projection feats from the Queries and 'hiding' the 3D feats from the keys, shouldn't they all benefit from all ? Maybe right now either the Keys or the Queries just don't know enough to actually do anything useful ?

In [None]:
from sklearn.manifold import TSNE

N_PRINT = 1000

fig = plt.figure(figsize=((len(CLASSES)+1)*5, 5))
axes = fig.subplots(1, len(CLASSES)+1)

# Plot the global t-SNE distribution of projection features wrt prediction errors
select = Y_pred_view_indiv == Y_view
rng = np.random.default_rng()
x_tsne = TSNE(n_components=2).fit_transform(np.vstack((
    rng.choice(x_proj[select].numpy(), int(N_PRINT/2)),
    rng.choice(x_proj[~select].numpy(), int(N_PRINT/2)))))

axes[0].scatter(x_tsne[:int(N_PRINT/2), 0], x_tsne[:int(N_PRINT/2), 1], color="blue", alpha=0.5, label='True')
axes[0].scatter(x_tsne[int(N_PRINT/2):, 0], x_tsne[int(N_PRINT/2):, 1], color="red", alpha=0.5, label='False')
axes[0].legend()
axes[0].axis('off')
axes[0].set_title('GLOBAL'.upper(), fontsize=20)

# Plot the per-class t-SNE distribution of projection features wrt prediction errors
for i_class, class_name in enumerate(CLASSES):
    
    t = Y_view == i_class
    p = Y_pred_view_indiv == i_class
    tp = torch.logical_and(t, p)
    fp = torch.logical_and(~t, p)
    fn = torch.logical_and(t, ~p)
    
    x_tsne = TSNE(n_components=2).fit_transform(np.vstack(
        [rng.choice(x_proj[x].numpy(), int(N_PRINT/3)) for x in (tp, fp, fn) if x.sum() > 0]
        ))
    
    pointers = np.cumsum([0] + [int(N_PRINT/3) * (x.sum() > 0) for x in (tp, fp, fn)])
        
    for i_select, (label, color) in enumerate(zip(('tp', 'fp', 'fn'), ('black', 'green', 'red'))):
        a = pointers[i_select]
        b = pointers[i_select+1]
        if a < b:
            axes[i_class+1].scatter(x_tsne[a:b, 0], x_tsne[a:b, 1], color=color, alpha=0.3, label=label)
    axes[i_class+1].legend()
    axes[i_class+1].axis('off')
    axes[i_class+1].set_title(f"{class_name.upper()} - {list(iou_dict.values())[i_class]} IoU", fontsize=20)
plt.show()

Judging from the t-SNE distribution of projection features and error cases (**_on the test data_**), we can conclude the following:
- GLOBAL: there is little to hope from a class-agnostic (or 2D-feature) attention mechanism on the projection features. That means <span style="color:red">keys should be computed based on the projection features AND the 2D features</span>. What about <span style="color:red">queries</span> ?
- COLUMN, WINDOW, DOOR, CHAIR, TABLE, BOARD: may benefit from projection features
- CEILING, FLOOR: maybe but unsure
- WALL, BOOKCASE, CLUTTER: unlikely to benefit from the projection features
- Would the classes with no correlation with projection features still benefit from a self-attention mechanism ?
- Would the projeciton features disambiguation help special cases where the XYZ and XYZRGB* models fail, or would it just bring the 2D model closer to XYZRGB* ?

In [None]:
confusion = np.zeros((num_classes, num_classes))
for i, j in zip(Y.numpy(), Y_pred.numpy()):
    confusion[i, j] += 1

# sns.heatmap(confusion / confusion.max(axis=0).reshape(1, -1), xticklabels=CLASSES, yticklabels=CLASSES)
# plt.show()

# Row-normalized confusion matrix - rows sum up to 1. Each row carries 
# the distribution of predictions for an expected label. This 
# illustrates how classes are misclassified.
sns.heatmap(confusion / confusion.max(axis=1).reshape(-1, 1), xticklabels=CLASSES, yticklabels=CLASSES)
plt.show()

In [None]:
N_PRINT = 10000

fig = plt.figure(figsize=(32, 4))
(ax0, ax1, ax2, ax3, ax4, ax5) = fig.subplots(1, 6)

sns.distplot(group_size.numpy()[:N_PRINT], kde_kws={'bw': 0.3}, bins=np.arange(21), label="Group Size", color='purple', ax=ax0)
ax0.axvline(x=group_size.numpy()[:N_PRINT].mean(), ls=':', c='red')

sns.distplot(torch.clamp(C, -1, 1.5).numpy()[:N_PRINT], label="Compatibilities", color='b', ax=ax1)

sns.distplot(torch.clamp(A, -1, 1.5).numpy()[:N_PRINT], label="Attention", ax=ax2)
for k in range(1, 10):
    ax2.axvline(x=1/k, ls=':', c='black')

sns.distplot((A * torch.repeat_interleave(group_size, group_size)).numpy()[:N_PRINT], label="Attention * Group Size", ax=ax3)
    
sns.distplot(torch.clamp(G, -1, 1.5).numpy()[:N_PRINT], label="Gating", color='r', ax=ax4)

sns.distplot(Y.numpy()[:N_PRINT], kde=False, label="Y", ax=ax5)
sns.distplot(Y_pred.numpy()[:N_PRINT], kde=False, label="Y_pred", ax=ax5)

for ax in (ax0, ax1, ax2, ax3, ax4, ax5):
    ax.legend()
    ax.set_ylabel('')
plt.show()

In [None]:
fig = plt.figure(figsize=(32, 2))
axes = fig.subplots(1, num_classes)
for i_class, ax in enumerate(axes):
    idx_class = torch.where(Y == i_class)
    group_size_class = group_size[idx_class]
    sns.distplot(
        torch.clamp(group_size_class, 0, 10).numpy()[:N_PRINT],
        kde_kws={'bw': 0.3}, 
        bins=np.arange(11),
        label=f"{CLASSES[i_class]}", 
        color=tuple([x / 255. for x in OBJECT_COLOR[i_class]]),
        ax=ax)
    ax.legend()
    ax.set_ylabel('')
    ax.axvline(x=group_size.numpy()[:N_PRINT].mean(), ls=':', c='black')
    ax.axvline(x=group_size_class.numpy()[:N_PRINT].mean(), ls=':', c='red')
plt.suptitle("Per-class group sizes")
plt.show()

In [None]:
fig = plt.figure(figsize=(32, 4))
(ax0, ax1, ax2, ax3, ax4, ax5) = fig.subplots(1, 6)

sns.distplot(K[(x_proj < 0.3).squeeze()].numpy()[:N_PRINT], label="K toxic", ax=ax0)
sns.distplot(K[(x_proj >= 0.3).squeeze()].numpy()[:N_PRINT], label="K safe", ax=ax0)

sns.distplot(Q[(x_proj < 0.3).squeeze()].numpy()[:N_PRINT], label="Q toxic", ax=ax1)
sns.distplot(Q[(x_proj >= 0.3).squeeze()].numpy()[:N_PRINT], label="Q safe", ax=ax1)

sns.distplot(C[(x_proj < 0.3).squeeze()].numpy()[:N_PRINT], label="Compatibility toxic", ax=ax2)
sns.distplot(C[(x_proj >= 0.3).squeeze()].numpy()[:N_PRINT], label="Compatibility safe", ax=ax2)

sns.distplot(A[(x_proj < 0.3).squeeze()].numpy()[:N_PRINT], label="Attention toxic", ax=ax3)
sns.distplot(A[(x_proj >= 0.3).squeeze()].numpy()[:N_PRINT], label="Attention safe", ax=ax3)

for k in range(2, 10):
    idx_group = group_size_view == k
    sns.distplot(A[idx_group][(x_proj[idx_group] < 0.3).squeeze()].numpy()[:N_PRINT], label=f"Attention toxic k={k}", ax=ax4)
    
for k in range(2, 10):
    idx_group = group_size_view == k
    sns.distplot(A[idx_group][(x_proj[idx_group] >= 0.3).squeeze()].numpy()[:N_PRINT], label=f"Attention safe k={k}", ax=ax5)

for ax in (ax0, ax1, ax2, ax3, ax4, ax5):
    ax.legend()
    ax.set_ylabel('')
plt.show()