In [1]:
from PIL import Image
import os
import os.path as osp
import numpy as np
import torch
import matplotlib.pyplot as plt

# Select you GPU
I_GPU = 0

# Uncomment to use autoreload
%load_ext autoreload
%autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
from PIL import Image
start = time()
import warnings
warnings.filterwarnings('ignore')

# torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.visualization.multimodal_data import visualize_mm_data
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.scannet import ScannetDatasetMM
from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS, CLASS_NAMES, CLASS_LABELS
from torch_points3d.metrics.segmentation_tracker import SegmentationTracker

from pykeops.torch import LazyTensor

import plotly.io as pio

#pio.renderers.default = 'jupyterlab'        # for local notebook
pio.renderers.default = 'iframe_connected'  # for remote notebook. Other working (but seemingly slower) options are: 'sphinx_gallery' and 'iframe'

CLASS_COLORS[0] = (174.0, 199.0, 232.0)
CLASS_COLORS[-1] = (0, 0, 0)

# from torch_points3d.datasets.segmentation.scannet import CLASS_COLORS

MMData debug() function changed, please uncomment the 3rd assert line when doing inference without M2F features!


In [2]:
import os
import torchnet as tnt
import torch
from typing import Dict, Any
import wandb
from torch.utils.tensorboard import SummaryWriter
import logging
from torch_points3d.metrics.confusion_matrix import ConfusionMatrix
from torch_points3d.models import model_interface
from torch_points3d.metrics.base_tracker import BaseTracker, meter_value
from torch_points3d.metrics.meters import APMeter
from torch_points3d.datasets.segmentation import IGNORE_LABEL

from torch_geometric.nn.unpool import knn_interpolate
from torch_points3d.core.data_transform import SaveOriginalPosId

log = logging.getLogger(__name__)


def meter_value(meter, dim=0):
    return float(meter.value()[dim]) if meter.n > 0 else 0.0


class BaseTracker:
    def __init__(self, stage: str, wandb_log: bool, use_tensorboard: bool):
        self._wandb = wandb_log
        self._use_tensorboard = use_tensorboard
        self._tensorboard_dir = os.path.join(os.getcwd(), "tensorboard")
        self._n_iter = 0
        self._finalised = False
        self._conv_type = None

        if self._use_tensorboard:
            log.info(
                "Access tensorboard with the following command <tensorboard --logdir={}>".format(self._tensorboard_dir)
            )
            self._writer = SummaryWriter(log_dir=self._tensorboard_dir)

    def reset(self, stage="train"):
        self._stage = stage
        self._loss_meters = {}
        self._finalised = False

    def get_metrics(self, verbose=False) -> Dict[str, Any]:
        metrics = {}
        for key, loss_meter in self._loss_meters.items():
            value = meter_value(loss_meter, dim=0)
            if value:
                metrics[key] = meter_value(loss_meter, dim=0)
        return metrics

    @property
    def metric_func(self):
        self._metric_func = {"loss": min}
        return self._metric_func

    def track(self, model: model_interface.TrackerInterface, **kwargs):
        if self._finalised:
            raise RuntimeError("Cannot track new values with a finalised tracker, you need to reset it first")
            
        if model is not None:
            losses = self._convert(model.get_current_losses())
            self._append_losses(losses)

    def finalise(self, *args, **kwargs):
        """ Lifcycle method that is called at the end of an epoch. Use this to compute
        end of epoch metrics.
        """
        self._finalised = True

    def _append_losses(self, losses):
        for key, loss in losses.items():
            if loss is None:
                continue
            loss_key = "%s_%s" % (self._stage, key)
            if loss_key not in self._loss_meters:
                self._loss_meters[loss_key] = tnt.meter.AverageValueMeter()
            self._loss_meters[loss_key].add(loss)

    @staticmethod
    def _convert(x):
        if torch.is_tensor(x):
            return x.detach().cpu().numpy()
        else:
            return x

    def publish_to_tensorboard(self, metrics, step):
        for metric_name, metric_value in metrics.items():
            metric_name = "{}/{}".format(metric_name.replace(self._stage + "_", ""), self._stage)
            self._writer.add_scalar(metric_name, metric_value, step)

    @staticmethod
    def _remove_stage_from_metric_keys(stage, metrics):
        new_metrics = {}
        for metric_name, metric_value in metrics.items():
            new_metrics[metric_name.replace(stage + "_", "")] = metric_value
        return new_metrics

    def publish(self, step):
        """ Publishes the current metrics to wandb and tensorboard
        Arguments:
            step: current epoch
        """
        metrics = self.get_metrics()

        if self._wandb:
            wandb.log(metrics, step=step)

        if self._use_tensorboard:
            self.publish_to_tensorboard(metrics, step)

        # Some metrics may be intended for wandb or tensorboard
        # tracking but not for final final model selection. Those are
        # the metrics absent from self.metric_func and must be excluded
        # from the output of self.publish
        current_metrics = {
            k: v
            for k, v in self._remove_stage_from_metric_keys(self._stage, metrics).items()
            if k in self.metric_func.keys()}

        return {
            "stage": self._stage,
            "epoch": step,
            "current_metrics": current_metrics,
        }

    def print_summary(self):
        metrics = self.get_metrics(verbose=True)
        log.info("".join(["=" for i in range(50)]))
        for key, value in metrics.items():
            log.info("    {} = {}".format(key, value))
        log.info("".join(["=" for i in range(50)]))

    @staticmethod
    def _dict_to_str(dictionnary):
        string = "{"
        for key, value in dictionnary.items():
            string += "%s: %.2f," % (str(key), value)
        string += "}"
        return string


class SegmentationTracker(BaseTracker):
    def __init__(
        self, dataset, stage="train", wandb_log=False, use_tensorboard: bool = False, ignore_label: int = IGNORE_LABEL
    ):
        """ This is a generic tracker for multimodal tasks.
        It uses a confusion matrix in the back-end to track results.
        Use the tracker to track an epoch.
        You can use the reset function before you start a new epoch

        Arguments:
            dataset  -- dataset to track (used for the number of classes)

        Keyword Arguments:
            stage {str} -- current stage. (train, validation, test, etc...) (default: {"train"})
            wandb_log {str} --  Log using weight and biases
        """
        super(SegmentationTracker, self).__init__(stage, wandb_log, use_tensorboard)
        self._num_classes = dataset.num_classes
        self._ignore_label = ignore_label
        self._dataset = dataset
        self.reset(stage)
        self._metric_func = {
            "miou": max,
            "macc": max,
            "acc": max,
            "loss": min,
            "map": max,
        }  # Those map subsentences to their optimization functions

    def reset(self, stage="train"):
        super().reset(stage=stage)
        self._confusion_matrix = ConfusionMatrix(self._num_classes)
        self._acc = 0
        self._macc = 0
        self._miou = 0
        self._miou_per_class = {}

    @staticmethod
    def detach_tensor(tensor):
        if torch.torch.is_tensor(tensor):
            tensor = tensor.detach()
        return tensor

    @property
    def confusion_matrix(self):
        return self._confusion_matrix.confusion_matrix

    def track(self, model: model_interface.TrackerInterface, pred_labels=None, gt_labels=None, **kwargs):
        """ Add current model predictions (usually the result of a batch) to the tracking
        """
        if not self._dataset.has_labels(self._stage):
            return

        # Feng: to evaluate M2F predictions instead of model logits
        if pred_labels is not None and gt_labels is not None:
            outputs = pred_labels
            targets = gt_labels
        else:
            super().track(model)
            
            outputs = model.get_output()
            targets = model.get_labels()
        self._compute_metrics(outputs, targets)

    def _compute_metrics(self, outputs, labels):
        mask = labels != self._ignore_label
        outputs = outputs[mask]
        labels = labels[mask]

        outputs = self._convert(outputs)
        labels = self._convert(labels)

        if len(labels) == 0:
            return

        assert outputs.shape[0] == len(labels)
        
        # Check if output is predicted label or logits
        if len(outputs.shape) > 1:
            self._confusion_matrix.count_predicted_batch(labels, np.argmax(outputs, 1))
        else:
            
            self._confusion_matrix.count_predicted_batch(labels, outputs)

        self._acc = 100 * self._confusion_matrix.get_overall_accuracy()
        self._macc = 100 * self._confusion_matrix.get_mean_class_accuracy()
        self._miou = 100 * self._confusion_matrix.get_average_intersection_union()
        self._miou_per_class = {
            i: "{:.2f}".format(100 * v)
            for i, v in enumerate(self._confusion_matrix.get_intersection_union_per_class()[0])
        }

    def get_metrics(self, verbose=False) -> Dict[str, Any]:
        """ Returns a dictionnary of all metrics and losses being tracked
        """
        metrics = super().get_metrics(verbose)

        metrics["{}_acc".format(self._stage)] = self._acc
        metrics["{}_macc".format(self._stage)] = self._macc
        metrics["{}_miou".format(self._stage)] = self._miou

        if verbose:
            metrics["{}_miou_per_class".format(self._stage)] = self._miou_per_class
        return metrics

    @property
    def metric_func(self):
        return self._metric_func


class ScannetSegmentationTracker(SegmentationTracker):
    def reset(self, stage="train"):
        super().reset(stage=stage)
        self._full_confusion_matrix = ConfusionMatrix(self._num_classes)
        self._raw_datas = {}
        self._votes = {}
        self._vote_counts = {}
        self._full_preds = {}
        self._full_acc = None

    def track(self, model: model_interface.TrackerInterface, full_res=False, pred_labels=None, gt_labels=None, **kwargs):
        """ Add current model predictions (usually the result of a batch) to the tracking
        """
        if pred_labels is not None and gt_labels is not None:
            super().track(model=None, pred_labels=pred_labels, gt_labels=gt_labels)
        else:
            super().track(model)

            # Set conv type
            self._conv_type = model.conv_type

            # Train mode or low res, nothing special to do
            if not full_res or self._stage == "train" or kwargs.get("data") is None:
                return

            data = kwargs.get("data", model.get_input())
            data = data.data if model.is_multimodal else data
            self._vote(data, model.get_output())

    def get_metrics(self, verbose=False) -> Dict[str, Any]:
        """ Returns a dictionnary of all metrics and losses being tracked
        """
        metrics = super().get_metrics(verbose)
        if self._full_acc:
            metrics["{}_full_acc".format(self._stage)] = self._full_acc
            metrics["{}_full_macc".format(self._stage)] = self._full_macc
            metrics["{}_full_miou".format(self._stage)] = self._full_miou
        return metrics

    def finalise(self, full_res=False, make_submission=False, **kwargs):
        if not full_res and not make_submission:
            return
        
        self._predict_full_res()

        # Compute full res metrics
        if self._dataset.has_labels(self._stage):
            for scan_id in self._full_preds:
                full_labels = self._raw_datas[scan_id].y
                # Mask ignored labels
                mask = full_labels != self._ignore_label
                full_labels = full_labels[mask]
                full_preds = self._full_preds[scan_id].cpu()[mask].numpy()
                self._full_confusion_matrix.count_predicted_batch(full_labels, full_preds)

            self._full_acc = 100 * self._full_confusion_matrix.get_overall_accuracy()
            self._full_macc = 100 * self._full_confusion_matrix.get_mean_class_accuracy()
            self._full_miou = 100 * self._full_confusion_matrix.get_average_intersection_union()
            
        # Save files to disk
        if make_submission and self._stage == "test":
            self._make_submission()

    def _make_submission(self):
        orginal_class_ids = np.asarray(self._dataset.train_dataset.valid_class_idx)
        path_to_submission = self._dataset.path_to_submission
        for scan_id in self._full_preds:
            full_pred = self._full_preds[scan_id].cpu().numpy().astype(np.int8)
            full_pred = orginal_class_ids[full_pred]  # remap labels to original labels between 0 and 40
            scan_name = self._raw_datas[scan_id].scan_name
            path_file = osp.join(path_to_submission, "{}.txt".format(scan_name))
            np.savetxt(path_file, full_pred, delimiter="/n", fmt="%d")

    def _vote(self, data, output):
        """ Populates scores for the points in data

        Parameters
        ----------
        data : Data
            should contain `pos` and `SaveOriginalPosId.KEY` keys
        output : torch.Tensor
            probablities out of the model, shape: [N,nb_classes]
        """
        id_scans = data.id_scan
        if id_scans.dim() > 1:
            id_scans = id_scans.squeeze()
        if self._conv_type == "DENSE":
            batch_size = len(id_scans)
            output = output.view(batch_size, -1, output.shape[-1])

        for idx_batch, id_scan in enumerate(id_scans):
            # First time we see this scan
            if id_scan not in self._raw_datas:
                raw_data = self._dataset.get_raw_data(self._stage, id_scan, remap_labels=True)
                self._raw_datas[id_scan] = raw_data
                self._vote_counts[id_scan] = torch.zeros(raw_data.pos.shape[0], dtype=torch.int)
                self._votes[id_scan] = torch.zeros((raw_data.pos.shape[0], self._num_classes), dtype=torch.float)
            else:
                raw_data = self._raw_datas[id_scan]

            batch_mask = idx_batch
            if self._conv_type != "DENSE":
                batch_mask = data.batch == idx_batch
            idx = data[SaveOriginalPosId.KEY][batch_mask]

            self._votes[id_scan][idx] += output[batch_mask].cpu()
            self._vote_counts[id_scan][idx] += 1

    def _predict_full_res(self):
        """ Predict full resolution results based on votes """
        for id_scan in self._votes:
            has_prediction = self._vote_counts[id_scan] > 0
            self._votes[id_scan][has_prediction] /= self._vote_counts[id_scan][has_prediction].unsqueeze(-1)

            # Upsample and predict
            full_pred = knn_interpolate(
                self._votes[id_scan][has_prediction],
                self._raw_datas[id_scan].pos[has_prediction],
                self._raw_datas[id_scan].pos,
                k=1,
            )
            self._full_preds[id_scan] = full_pred.argmax(-1)


In [3]:
# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/scratch-shared/fsun/dvata'

dataset_config = 'segmentation/multimodal/Feng/scannet-neucon-smallres-m2f-TC'   
models_config = 'segmentation/multimodal/Feng/mvfusion'    # model family
model_name = 'MVFusion_3D_small_6views'                       # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
]

cfg = hydra_read(overrides)
OmegaConf.set_struct(cfg, False)  # This allows getattr and hasattr methods to function correctly
cfg.data.load_m2f_masks = True   # load Mask2Former predicted masks

cfg.data.m2f_preds_dirname = 'ViT_masks'
cfg.data.n_views = 6 #cfg.models[model_name].backbone.transformer.n_views
print(cfg.data.n_views)

# Dataset instantiation
start = time()
dataset = ScannetDatasetMM(cfg.data)
# print(dataset)
print(f"Time = {time() - start:0.1f} sec.")

6
Load predicted 2D semantic segmentation labels from directory  ViT_masks
initialize train dataset
initialize val dataset
Time = 8.1 sec.


In [89]:
tracker_mvfusion = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)
tracker_m2f = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)
tracker_gt = ScannetSegmentationTracker(dataset=dataset, stage='train', wandb_log=False, use_tensorboard=False, ignore_label=IGNORE_LABEL)


In [5]:
from torch_points3d.models.model_factory import instantiate_model

# Set your parameters
# checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/2022-12-07/12-07-34' # 3rd run
checkpoint_dir = '/home/fsun/DeepViewAgg/outputs/ViT_masks_3rd_run'

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)

# Load the checkpoint and recover the model weights
checkpoint = torch.load(f'{checkpoint_dir}/{model_name}.pt', map_location='cpu')
model.load_state_dict_with_same_shape(checkpoint['models']['latest'], strict=False)

# Prepare the model for inference
model = model.eval().cuda()
print('Model loaded')

Creating model: MVFusion_3D_small_6views
task:  segmentation.multimodal
tested_model_name:  MVFusion_3D_small_6views
class_name:  MVFusionAPIModel
model_module:  torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d
name, cls of chosen model_cls:  MVFusionAPIModel <class 'torch_points3d.models.segmentation.multimodal.Feng.mvfusion_3d.MVFusionAPIModel'>
x feature dim:  {'FEAT': 3}
nc_in:  67
nc_in:  64
nc_in:  32
nc_in:  64
nc_in:  128
nc_in:  256
nc_in:  128
nc_in:  128
nc_in:  96
nc_in:  96
Model loaded


# Temporal Consistency using optical flow

In [6]:
from torch_points3d.utils.multimodal import lexargsort
from torch_points3d.core.multimodal.csr import CSRData
import scipy.ndimage

In [7]:
from mmflow.apis import init_model, inference_model
from mmflow.datasets import visualize_flow, write_flow
import mmcv
from mmflow.ops import Warp

# Specify the path to model config and checkpoint file
# config_file = '/home/fsun/DeepViewAgg/flow/configs/pwcnet_ft_4x1_300k_sintel_final_384x768.py'
# checkpoint_file = '/home/fsun/DeepViewAgg/flow/pretrained/pwcnet_ft_4x1_300k_sintel_final_384x768.pth'

config_file = '/home/fsun/DeepViewAgg/flow/configs/flownet2_8x1_sfine_flyingthings3d_subset_384x768.py'
checkpoint_file = '/home/fsun/DeepViewAgg/flow/pretrained/flownet2_8x1_sfine_flyingthings3d_subset_384x768.pth'

# build the model from a config file and a checkpoint file
flow_model = init_model(config_file, checkpoint_file, device='cuda:0')

warp = Warp(mode='nearest',
            padding_mode='zeros',
            align_corners=False,
            use_mask=True).cuda()

2023-01-12 14:22:18,093 - mmflow - INFO - Freeze the parameters in FlowNetCSS
2023-01-12 14:22:19,032 - mmflow - INFO - Freeze the parameters in FlowNetS


load checkpoint from local path: /home/fsun/DeepViewAgg/flow/pretrained/flownet2_8x1_sfine_flyingthings3d_subset_384x768.pth


In [98]:
# mm_data = dataset.val_dataset[0]
# mm_data

In [97]:
# tracker_m2f.reset(stage='train')
# tracker_mvfusion.reset(stage='train')


# for sample_idx in range(len(dataset.val_dataset)):
    
# #     mm_data = dataset.val_dataset[sample_idx]
    
#     # Create a MMBatch and run inference
#     batch = MMBatch.from_mm_data_list([mm_data])

#     with torch.no_grad():
#         model.set_input(batch, model.device)
#         model(batch)

        
#     #################### ORIG CODE IN TRAINER.PY _track_2d_results
        
#     # Recover the predicted labels for visualization
#     mm_data.data.pred = model.output.detach().cpu().argmax(1)
    
#     mappings = mm_data.modalities['image'][0].mappings
#     point_ids = torch.arange(
#                     mappings.num_groups, device=mappings.device).repeat_interleave(
#                     mappings.pointers[1:] - mappings.pointers[:-1])
#     image_ids = mappings.images.repeat_interleave(
#                     mappings.values[1].pointers[1:] - mappings.values[1].pointers[:-1])    
#     pixels_full = mappings.pixels

#     # Sort point and image ids based on image_id
#     idx_sort = lexargsort(image_ids, point_ids)
#     image_ids = image_ids[idx_sort]
#     point_ids = point_ids[idx_sort]
#     pixels_full = pixels_full[idx_sort].long()

#     # Get pointers for easy indexing
#     pointers = CSRData._sorted_indices_to_pointers(image_ids)

    
#     im_paths = mm_data.modalities['image'][0].gt_mask_path
#     scan_dir = os.sep.join(im_paths[0].split(os.sep)[:-2])
    
#     color_im_dir = osp.join(scan_dir, 'color_resized')
    
#     input_mask_name = mm_data.modalities['image'][0].m2f_pred_mask_path[0].split(os.sep)[-2]
    
#     # Dirty workaround for masks in different directory
#     if input_mask_name == 'ViT_masks':
#         scan_id = scan_dir.split(os.sep)[-1]
#         mask_im_dir = osp.join("/home/fsun/data/scannet/scans", scan_id, input_mask_name)
#         refined_mask_im_dir = osp.join(scan_dir, input_mask_name + '_refined')
#     else:
#         mask_im_dir = osp.join(scan_dir, input_mask_name)
#         refined_mask_im_dir = osp.join(scan_dir, input_mask_name + '_refined')
#     print(refined_mask_im_dir)
#     os.makedirs(refined_mask_im_dir, exist_ok=True)
    
#     im_names = [p.split(os.sep)[-1] for p in im_paths]
#     # Image indices of sorted list
#     im_sort_indices = sorted(range(len(im_names)), key=lambda k: int(os.path.splitext(os.path.basename(im_names[k]))[0]))
#     # Sorted image names
# #     im_names = sorted(im_names, key=lambda i: int(os.path.splitext(os.path.basename(i))[0]))


#     # Loop over all N views
# #     for i, x in enumerate(mm_data.modalities['image'][0]):

#     # Skip last image since we grab pairs
#     for i in range(len(im_sort_indices[:-1])):
        
#         print(im_sort_indices[i], im_sort_indices[i+1])

#         im1_name, im2_name = im_names[im_sort_indices[i]], im_names[im_sort_indices[i+1]]
#         print(im1_name, im2_name)
#         if ( int(im2_name.split(".")[0]) - int(im1_name.split(".")[0]) ) > 50:
#             continue
        
#         x = mm_data.modalities['image'][0][i]

#         # Grab the 3D points corresponding to ith view
#         start, end = pointers[im_sort_indices[i]], pointers[im_sort_indices[i] + 1]    
#         points = point_ids[start:end]
#         pixels = pixels_full[start:end]
#         # Image (x, y) pixel index
#         w, h = pixels[:, 0], pixels[:, 1]

#         # Grab set of points visible in current view
#         mm_data_of_view = mm_data[points]

#         im_ref_w, im_ref_h = x.ref_size

#         # Get nearest neighbor interpolated projection image filled with 3D labels
#         pred_mask_2d = -1 * torch.ones((im_ref_h, im_ref_w), dtype=torch.long, device=mm_data_of_view.device)    
#         pred_mask_2d[h, w] = mm_data_of_view.data.pred.squeeze()

#         nearest_neighbor = scipy.ndimage.morphology.distance_transform_edt(
#             pred_mask_2d==-1, return_distances=False, return_indices=True)    
#         pred_mask_2d = pred_mask_2d[nearest_neighbor].numpy().astype(np.uint8)
#         pred_mask_2d = Image.fromarray(pred_mask_2d, 'L') 
        
#         pred_mask_2d = pred_mask_2d.resize((640, 480), resample=0)
        
# #         # SAVE REFINED MASK IN GIVEN DIR
# #         print(osp.join(scan_dir, input_mask_name + '_refined', im1_name))
# #         pred_mask_2d.save(osp.join(scan_dir, input_mask_name + '_refined', im1_name))

        
#         pred_mask_2d = np.asarray(pred_mask_2d)
        
    
            
#         im1_p = osp.join(color_im_dir, im1_name)
#         im2_p = osp.join(color_im_dir, im2_name)

#         # compute flow map from im1 to im2
#         result = inference_model(flow_model, im2_p, im1_p)
#         flow_map = torch.tensor(result).permute(2, 0, 1).unsqueeze(0)
    
#         # load 2D input mask 
#         seg_im_p1 = osp.join(mask_im_dir, im1_name)
#         seg_im_p2 = osp.join(mask_im_dir, im2_name)
#         seg_im1 = np.asarray(Image.open(seg_im_p1)) 
#         seg_im2 = np.asarray(Image.open(seg_im_p2)).astype(np.int) - 1   # Adjust labels


#         # warping im1 to im2
#         seg_im1_semantic = torch.tensor(seg_im1).unsqueeze(0).unsqueeze(0).float()
#         seg_im_warped = warp(seg_im1_semantic, flow_map).permute(0, 2, 3, 1)[0].squeeze() - 1    # Adjust labels

#         # take im2 as 'pseudo gt' for temporal consistency. Thus, invalidly warped pixels should be ignored
#         # by setting the corresponding gt label to the IGNORE_LABEL
#         seg_im2[seg_im_warped == -1] = -1
#         tracker_m2f.track(pred_labels=seg_im_warped.long(), gt_labels=seg_im2, model=None)
        
#         # warping refined im1 to im2
# #         pred_mask_2d
        
# #         tracker_mvfusion.track(pred_labels=mm_data.data.pred, gt_labels=mm_data.data.y, model=None)


# #         # Visualizations of warped images
# #         seg_im1_rgb = np.array(CLASS_COLORS)[seg_im_warped.long()]
# #         seg_im1_rgb = Image.fromarray(seg_im1_rgb.astype('uint8'))

# #         seg_im2_rgb = np.array(CLASS_COLORS)[seg_im2]
# #         seg_im2_rgb = Image.fromarray(seg_im2_rgb.astype('uint8'))


# #         plt.imshow(seg_im1_rgb)
# #         plt.show()
# #         plt.imshow(seg_im2_rgb)
# #         plt.show()
        
#     break
        
    
#     ###############################################################

# #     break
    
# #     # Uncomment
# #     for i in range(len(im_ids)-1):

# #         im1_p = osp.join(scene_dir, 'color_resized', im_ids[i])
# #         im2_p = osp.join(scene_dir, 'color_resized', im_ids[i+1])

# #         # compute flow map from im1 to im2
# #         result = inference_model(model, im2_p, im1_p)
# #         flow_map = torch.tensor(result).permute(2, 0, 1).unsqueeze(0)

# #     #     im1 = to_tensor(Image.open(im1_p)).unsqueeze(0)
# #     #     im2 = to_tensor(Image.open(im1_p)).unsqueeze(0)

# #     #     # warps im1 to im2
# #     #     im_warped = warp(im1, flow_map).permute(0, 2, 3, 1)[0]

# #     #     plt.imshow(Image.open(im1_p))
# #     #     plt.show()
# #     #     plt.imshow(Image.open(im2_p))
# #     #     plt.show()
# #     #     plt.imshow(im_warped)
# #     #     plt.show()


# #         # img loading
# #         seg_im_p1 = osp.join(mask_im_dir, im_ids[i])
# #         seg_im_p2 = osp.join(mask_im_dir, im_ids[i+1])
# #         seg_im1 = np.asarray(Image.open(seg_im_p1)) 
# #         seg_im2 = np.asarray(Image.open(seg_im_p2)).astype(np.int) - 1   # Adjust labels


# #         # warping im1 to im2
# #         seg_im1_semantic = torch.tensor(seg_im1).unsqueeze(0).unsqueeze(0).float()
# #         seg_im_warped = warp(seg_im1_semantic, flow_map).permute(0, 2, 3, 1)[0].squeeze() - 1    # Adjust labels

# #         # take im2 as 'pseudo gt' for temporal consistency. Thus, invalidly warped pixels should be ignored
# #         # by setting the corresponding gt label to the IGNORE_LABEL
# #         seg_im2[seg_im_warped == -1] = -1
# #         tracker_m2f.track(pred_labels=seg_im_warped.long(), gt_labels=seg_im2, model=None)
# #     #     tracker_mvfusion.track(pred_labels=mm_data.data.pred, gt_labels=mm_data.data.y, model=None)


# #         seg_im1_rgb = np.array(CLASS_COLORS)[seg_im_warped.long()]
# #         seg_im1_rgb = Image.fromarray(seg_im1_rgb.astype('uint8'))

# #         seg_im2_rgb = np.array(CLASS_COLORS)[seg_im2]
# #         seg_im2_rgb = Image.fromarray(seg_im2_rgb.astype('uint8'))


# #     #     plt.imshow(seg_im1_rgb)
# #     #     plt.show()
# #     #     plt.imshow(seg_im2_rgb)
# #     #     plt.show()


# #     #     if i == 3:
# #     #         break

# tracker_m2f.get_metrics()

In [103]:
import os.path as osp
import os

###########
# mask_foldername = 'ViT_masks'
# refined_mask_foldername = 'ViT_masks_refined'
# mask_scans_dir = '/home/fsun/data/scannet/scans'
###########
mask_foldername = 'm2f_masks'
refined_mask_foldername = 'm2f_masks_refined'
mask_scans_dir = '/scratch-shared/fsun/data/scannet/scans'
###########

tracker_m2f.reset(stage='train')
tracker_mvfusion.reset(stage='train')
tracker_gt.reset(stage='train')

with open("/scratch-shared/fsun/data/scannet/splits/scannetv2_val.txt", 'r') as f:
    scan_ids = [x.split()[0] for x in f.readlines()]    
    scan_ids = sorted(scan_ids)

scans_dir = "/scratch-shared/fsun/data/scannet/scans"

for scan_id in scan_ids:
    print(scan_id)
    scan_dir = osp.join(scans_dir, scan_id)
    
    mask_scan_dir = osp.join(mask_scans_dir, scan_id)
    
    mask_dir = osp.join(mask_scan_dir, mask_foldername)
    refined_mask_dir = osp.join(scan_dir, refined_mask_foldername)
    gt_mask_dir = osp.join(scan_dir, 'label-filt-scannet20')
    
    refined_mask_names = os.listdir(refined_mask_dir)
    refined_mask_names = sorted(refined_mask_names, key=lambda i: int(os.path.splitext(os.path.basename(i))[0]))
    
    color_im_dir = osp.join(scan_dir, 'color_resized')
    
    # Loop over all pairs of consecutive images
    for i in range(len(refined_mask_names) - 1):
        cur_frame_name, next_frame_name = refined_mask_names[i], refined_mask_names[i+1]
        cur_frame_id, next_frame_id = int(refined_mask_names[i].split(".")[0]), int(refined_mask_names[i+1].split(".")[0])
        
        if next_frame_id - cur_frame_id > 50:
            continue
                
        cur_color_p, next_color_p = osp.join(color_im_dir, cur_frame_name), osp.join(color_im_dir, next_frame_name)
        cur_color_im, next_color_im = Image.open(cur_color_p), Image.open(next_color_p)
        
        result = inference_model(flow_model, next_color_p, cur_color_p)
        flow_map = torch.tensor(result).permute(2, 0, 1).unsqueeze(0)
        
        ################# 2D input masks
        seg_im_p1 = osp.join(mask_dir, cur_frame_name)
        seg_im_p2 = osp.join(mask_dir, next_frame_name)
        seg_im1 = np.asarray(Image.open(seg_im_p1)) 
        seg_im2 = np.asarray(Image.open(seg_im_p2)).astype(np.int) - 1   # Adjust labels
        
        # warping im1 to im2
        seg_im1_semantic = torch.tensor(seg_im1).unsqueeze(0).unsqueeze(0).float()
        seg_im_warped = warp(seg_im1_semantic, flow_map).permute(0, 2, 3, 1)[0].squeeze() - 1    # Adjust labels

        # take im2 as 'pseudo gt' for temporal consistency. Thus, invalidly warped pixels should be ignored
        # by setting the corresponding gt label to the IGNORE_LABEL
        seg_im2[seg_im_warped == -1] = -1
        tracker_m2f.track(pred_labels=seg_im_warped.long(), gt_labels=seg_im2, model=None)

        ################ 2D refined masks
        seg_im_p1 = osp.join(refined_mask_dir, cur_frame_name)
        seg_im_p2 = osp.join(refined_mask_dir, next_frame_name)
        seg_im1 = np.asarray(Image.open(seg_im_p1))  + 1 # Adjust labels for flow warping (invalid pixels will be set to 0)
        
        # Resize in case
        seg_im2 = Image.open(seg_im_p2).resize((640, 480), resample=0)
        seg_im2 = np.asarray(seg_im2).astype(np.int) 
        
        # warping im1 to im2
        seg_im1_semantic = torch.tensor(seg_im1).unsqueeze(0).unsqueeze(0).float()
        seg_im_warped = warp(seg_im1_semantic, flow_map).permute(0, 2, 3, 1)[0].squeeze() - 1 # Adjust labels back to normal

        # take im2 as 'pseudo gt' for temporal consistency. Thus, invalidly warped pixels should be ignored
        # by setting the corresponding gt label to the IGNORE_LABEL
        seg_im2[seg_im_warped == -1] = -1
        tracker_mvfusion.track(pred_labels=seg_im_warped.long(), gt_labels=seg_im2, model=None)
        
        ################ 2D GT masks
        seg_im_p1 = osp.join(gt_mask_dir, cur_frame_name)
        seg_im_p2 = osp.join(gt_mask_dir, next_frame_name)
        seg_im1 = np.asarray(Image.open(seg_im_p1))  # invalid pixels will be set to 0
        
        # Resize in case
        seg_im2 = Image.open(seg_im_p2).resize((640, 480), resample=0)
        seg_im2 = np.asarray(seg_im2).astype(np.int) - 1  # adjust labels
        
        # warping im1 to im2
        seg_im1_semantic = torch.tensor(seg_im1).unsqueeze(0).unsqueeze(0).float()
        seg_im_warped = warp(seg_im1_semantic, flow_map).permute(0, 2, 3, 1)[0].squeeze() - 1  # adjust labels

        # take im2 as 'pseudo gt' for temporal consistency. Thus, invalidly warped pixels should be ignored
        # by setting the corresponding gt label to the IGNORE_LABEL
        seg_im2[seg_im_warped == -1] = -1
        tracker_gt.track(pred_labels=seg_im_warped.long(), gt_labels=seg_im2, model=None)       
        
#         seg_im1_rgb = np.array(CLASS_COLORS)[seg_im_warped.long()]
#         seg_im1_rgb = Image.fromarray(seg_im1_rgb.astype('uint8'))

#         seg_im2_rgb = np.array(CLASS_COLORS)[seg_im2]
#         seg_im2_rgb = Image.fromarray(seg_im2_rgb.astype('uint8'))
        
#         plt.imshow(seg_im1_rgb)
#         plt.show()
#         plt.imshow(seg_im2_rgb)
#         plt.show()


        
#     break

print("Temporal Consistency Scores measured over all validation scenes, using all image views")
print("Mask2Former and refined scores and gt scores:")
print("m2f ", tracker_m2f.get_metrics())
print("mvfusion ", tracker_mvfusion.get_metrics())
print("gt ", tracker_gt.get_metrics())

scene0011_00
Temporal Consistency Scores measured over all validation scenes, using all image views
Mask2Former and refined scores and gt scores:
m2f  {'train_acc': 91.79619787115168, 'train_macc': 73.6956810795494, 'train_miou': 47.85911383305274}
mvfusion  {'train_acc': 90.45814649800153, 'train_macc': 74.24482720161333, 'train_miou': 47.142153319771786}
gt  {'train_acc': 92.0258673582081, 'train_macc': 86.74310367516654, 'train_miou': 70.76243953631104}


In [105]:
print("m2f ", tracker_m2f.get_metrics())
print("mvfusion ", tracker_mvfusion.get_metrics())
print("gt ", tracker_gt.get_metrics())

m2f  {'train_acc': 90.5461629810143, 'train_macc': 87.48580362561194, 'train_miou': 78.02501976060385}
mvfusion  {'train_acc': 95.68884017721575, 'train_macc': 95.12010618343393, 'train_miou': 90.59642914159319}
gt  {'train_acc': 97.54413197503662, 'train_macc': 97.44599314771565, 'train_miou': 94.79914213901814}


In [87]:
print(" & ".join(list(tracker_m2f._miou_per_class.values())))


'90.48 & 92.49 & 74.11 & 89.10 & 84.05 & 85.09 & 84.04 & 80.20 & 80.28 & 86.60 & 68.65 & 75.91 & 76.75 & 82.47 & 75.49 & 81.56 & 91.85 & 82.37 & 85.53 & 71.37'

In [88]:
print(" & ".join(list(tracker_mvfusion._miou_per_class.values())))


'94.06 & 92.56 & 91.99 & 95.58 & 88.37 & 95.12 & 91.12 & 90.76 & 91.35 & 94.83 & 91.32 & 89.51 & 90.82 & 92.90 & 92.93 & 88.93 & 89.07 & 87.16 & 90.50 & 90.74'

In [None]:
print(" & ".join(list(tracker_mvfusion._miou_per_class.values())))
