## Instance Segmentation Using Habitat-Sim and Mask-R-CNN

author: Michael Piseno (mpiseno@gatech.edu)

This notebook will demonstrate how to set up an efficient datapipline for the purpose of instance segmentation using PyTorch, Mask-R-CNN, and Habitat-Sim as a data generator.

Other resources:
* [Mask-R-CNN paper](https://arxiv.org/pdf/1703.06870.pdf)
* [PyTorch instance segmentation tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)

In [None]:
%load_ext autoreload
%autoreload 2

import os
import time
import math
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import Dataset

from habitat_sim.utils.data.dataextractor import ImageExtractor

### Data preparation and proprocessing

Below we will define the data extraction and preprocessing steps. Habitat-Sim's image data extraction API will be used to gather images from within the simulator for use inside a PyTorch Dataset subclass, which is subsequently fed into a PyTorch dataloader. Also, we will define a function to filter our semantic mask output from the extractor to hide instances that we don't want. For example, if the semantic output has instances of wall, chair, table, pillow, and background, but we only want to do instance segmentation on chairs and tables, we will simply set the semantic mask pixel values for wall and pillow to be the same as background.

In [None]:
try:
    extractor.close()
except:
    pass

 
#scene = "/private/home/mpiseno/Documents/Habitat/data/scene1/17DRP5sb8fy.glb" # mp3d
#scene = "/private/home/mpiseno/Documents/Habitat/sorted_faces/18_scenes/apartment_1/mesh.ply" # Replica
scene_dir = "/private/home/mpiseno/Documents/Habitat/data/" # Replace with your scene directory
extractor = ImageExtractor(scene, output=["rgba", "depth", "semantic"],
                           img_size=(1080, 1080), shuffle=False)

In [None]:
# Get the instance ids associated with the labels we specified and make a set from them
from examples.instance_segmentation.common import create_mask_filter, area_filter

# labels = ['bed', 'cushion', 'table', 'chair', 'sofa', 'tv_monitor',
#           'floor', 'door', 'cabinet', 'counter', 'stool', 'blinds']

labels = extractor.get_semantic_class_names()
# Make sure background is class 0 so the mask_filter works properly
labels = ['background'] + [name for name in labels if name not in ['background', 'void', '', 'objects']]

mask_filter = create_mask_filter(labels, extractor)

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))   
    
class HabitatDataset(Dataset):
    def __init__(self, extractor, labels_we_care_about, transform=None):
        self.extractor = extractor
        self.transform = transform
        self.instance_id_to_name = extractor.instance_id_to_name
        if 0 not in self.instance_id_to_name:
            self.instance_id_to_name[0] = 'background'
        
        # Create a mapping from class name to semantic ID
        self.name_to_sem_id = {
            name: id_val for id_val, name in enumerate(labels_we_care_about)
        }
        # And create the reverse mapping for convenience
        self.sem_id_to_name = {
            id_val: name for name, id_val in self.name_to_sem_id.items()
        }
        
    def __len__(self):
        return len(self.extractor)
    
    def __getitem__(self, idx):
        sample = self.extractor[idx]
        img, mask = sample["rgba"][:, :, :3], sample["semantic"]
        mask = mask_filter(mask)
        H, W = mask.shape
        
        instance_ids = np.unique(mask)
        #instance_ids = instance_ids[1:] # We don't care about background
        
        # get bounding box coordinates, mask, and label for each instance_id
        masks = []
        labels = []
        boxes = []
        areas = []
        num_instances = len(instance_ids)
        
        # There are much more efficient ways to create the data involving caching and
        # preprocessing but efficiency is not the focus of this example
        for i in range(num_instances):
            cur_mask = mask == instance_ids[i]
            pos = np.where(cur_mask)
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            
            # Avoid zero area boxes
            if xmin == xmax:
                xmin = max(0, xmin - 1)
                xmax = min(W, xmax + 1)
            if ymin == ymax:
                ymin = max(0, ymin - 1)
                ymax = min(H, ymax + 1)
            
            box = (xmin, ymin, xmax, ymax)
            if area_filter(cur_mask, box, H, W):
                boxes.append(list(box))
                masks.append(cur_mask)
                name = self.instance_id_to_name[instance_ids[i]]
                labels.append(self.name_to_sem_id[name])
                areas.append((ymax - ymin) * (xmax - xmin))

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image_id = torch.tensor([idx])
        areas = torch.as_tensor(areas, dtype=torch.float32)
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_instances,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = areas
        target["iscrowd"] = iscrowd
        
        if self.transform:
            img = self.transform(img)
        
        return img, target

First we instantiate an ImageExtractor from Habitat-Sim. This requires that we previde either a the filepath to a scene or a filepath to a diretory which contains several scene files. Optionally, we can specify the type of output we would like from the extractor. The default is just RGBA images.For details on the ImageExtractor, refer to the "Image-Data_Extraction-API" notebook.

Example
```python
scene_filepath = "./data/scene1/skokloster-castle.glb"
extractor = ImageExtractor(scene_filepath, output=["rgb", "semantic"])
```

We then create a custom class that subclasses PyTorch's dataset and override the __len__ and __getitem__ methods. Mask-R-CNN requires that we provide the image, bounding boxes, semantic masks, and class labels for each example, so we have implemented functionality for that in the __getitem__ method. The area and iscrowd keys are required for the evaluation metrics we use in this notebook.

#### Setting up the datasets and dataloaders

In [None]:
# Specify which transforms to apply to the data in preprocessing
transform = T.Compose([
    T.ToTensor()
])

dataset_train = HabitatDataset(extractor, labels, transform=transform)
dataset_test = HabitatDataset(extractor, labels, transform=transform)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=2, shuffle=False,
                                          collate_fn=collate_fn)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=2, shuffle=False,
                                          collate_fn=collate_fn)

#### Visualize the data

In [None]:
import matplotlib.patches as patches


def show_data(data_idx, show_masks=False):
    img, target = dataset_train[data_idx]
    img = img.permute(1, 2, 0).numpy()
    masks = target['masks'].numpy()
    boxes = target['boxes'].numpy()
    labels = target['labels']
    areas = target['area']
    fig, ax = plt.subplots(1)
    ax.imshow(img)
    cmap = {0: 'r', 1: 'g', 2: 'b', 3: 'c', 4: 'm', 5: 'y', 6: 'k', 7: 'w'}
    for i, box in enumerate(boxes):
        h, w = box[3] - box[1], box[2] - box[0]
        rect = patches.Rectangle((box[0], box[1]) , w, h, linewidth=2,
                                 edgecolor=cmap[i % len(cmap)], facecolor='none')
        ax.add_patch(rect)

    
    if show_masks:
        plt.show()
        fig=plt.figure(figsize=(8, 8))
        columns = 4
        rows = math.ceil(len(target['masks']) / columns)
        for i in range(1, columns * rows + 1):
            if i > len(target['masks']):
                break
            mask = masks[i - 1]
            ax = fig.add_subplot(rows, columns, i)
            sem_id = int(labels[i - 1].numpy())
            ax.title.set_text(dataset_train.sem_id_to_name[sem_id])
            plt.xticks([])
            plt.yticks([])
            plt.imshow(mask)
        
show_data(0, show_masks=True)


### Model and Training Setup

In [None]:
from torch.utils.tensorboard import SummaryWriter

# Credit: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
def build_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model


num_epochs = 160000
num_classes = len(labels) + 1 # Number of labels we care about + background
model_state_path = "examples/instance_segmentation/runs/maskrcnn-example-state.pt"
load_state = False

model = build_model(num_classes)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.002, momentum=0.9, weight_decay=0.0001)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=4)
writer = SummaryWriter("examples/instance_segmentation/runs/")


### Training and Evaluation

In [None]:
from examples.instance_segmentation.engine import train_one_epoch, load_model_state, save_model_state, evaluate

epoch = 1
if load_state:
    params = {'lr': 0.002}
    epoch = load_model_state(model, optimizer, model_state_path, params)
    
optimizer.lr = 0.002

while epoch < num_epochs:
    # We have to explicitly set the extractor mode because there can only be one instance of an extractor at a time,
    # so the dataset_train and dataset_test must share the same extractor
    extractor.set_mode('train')
    train_one_epoch(model, optimizer, dataloader_train, device, epoch, print_freq=10, 
                    writer=writer, grad_clip=0, lr_scheduler=lr_scheduler)
    
    if epoch % 50 == 0:
        save_model_state(model, optimizer, epoch, model_state_path)
        extractor.set_mode('test')
        evaluate(model, dataloader_test, device=device)
    
    epoch += 1

In [None]:
# ==== Testing ====
from examples.instance_segmentation.engine import evaluate

# Put the extractor into test mode so that it will use the test data
extractor.set_mode('test')
evaluator = evaluate(model, dataloader_test, device=device)
print(dir(evaluator))

In [None]:
from examples.instance_segmentation.common import InstanceVisualizer

visualizer = InstanceVisualizer(model, dataloader_train, device=device, num_classes=num_classes)
visuals = visualizer.visualize_instance_segmentation_output(max_num_outputs=12)
# Visualize the output
for idx in visuals.keys():
    vis = visuals[idx]
    #values = np.unique(vis.ravel())
    im = plt.imshow(vis)
    # colors = [im.cmap(im.norm(value)) for value in values]
    # patches = [patches.Patch(color=colors[i], label=dataset_train.sem_id_to_name[value]) for i, value in enumerate(values)]
    # # put those patched as legend-handles into the legend
    # plt.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0. )
    plt.show()

In [None]:
for img, target in dataloader_train:
    img1, img2 = img
    img1 = img1.permute(1, 2, 0).numpy()
    img2 = img2.permute(1, 2, 0).numpy()
    plt.imshow(img1)
    plt.show()
    plt.imshow(img2)
    plt.show()