## Instance Segmentation Using Habitat-Sim and Mask-R-CNN

author: Michael Piseno (mpiseno@fb.com)

This notebook will demonstrate how to set up an efficient datapipline for the purpose of instance segmentation using PyTorch, Mask-R-CNN, and Habitat-Sim as a data generator.

Other resources:
* [Mask-R-CNN paper](https://arxiv.org/pdf/1703.06870.pdf)
* [PyTorch instance segmentation tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)

In [1]:
%load_ext autoreload
%autoreload 2

import os
import time
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import Dataset

from habitat_sim.utils.data.dataextractor import ImageExtractor

### Data preparation and proprocessing

Below we will define the data extraction and preprocessing steps. Habitat-Sim's image data extraction API will be used to gather images from within the simulator for use inside a PyTorch Dataset subclass, which is subsequently fed into a PyTorch dataloader.

In [25]:
def collate_fn(batch):
    return tuple(zip(*batch))
    
    
class HabitatDataset(Dataset):
    def __init__(self, extractor, transform=None):
        self.extractor = extractor
        self.transform = transform
        
    def __len__(self):
        return len(self.extractor.poses)
    
    def __getitem__(self, idx):
        sample = self.extractor[idx]
        img, mask = sample["rgba"][:, :, :3], sample["semantic"]
        H, W = mask.shape
        
        instance_ids = np.unique(mask)
        #instance_ids = instance_ids[1:] # We don't care about background
        
        # get bounding box coordinates, mask, and label for each instance_id
        masks = []
        labels = []
        boxes = []
        areas = []
        num_instances = len(instance_ids)
        
        # There are much more efficient ways to create the data involving caching and
        # preprocessing but efficiency is not the focus of this example
        for i in range(num_instances):
            cur_mask = mask == instance_ids[i]
            pos = np.where(cur_mask)
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            
            # Avoid zero area boxes
            if xmin == xmax:
                xmin = max(0, xmin - 1)
                xmax = min(W, xmax + 1)
            if ymin == ymax:
                ymin = max(0, ymin - 1)
                ymax = min(H, ymax + 1)
            
            box = (xmin, ymin, xmax, ymax)
            boxes.append(list(box))
            masks.append(cur_mask)
            name = 'hi'
            labels.append(1)
            areas.append((ymax - ymin) * (xmax - xmin))

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_instances,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_instances,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
        
        if self.transform:
            img = self.transform(img)
        
        return img, target
    

try:
    extractor.close()
except:
    pass
    
scene_filepath = "/private/home/mpiseno/Documents/Habitat/data/scene1/17DRP5sb8fy.glb" # Replace with your filepath
extractor = ImageExtractor(scene_filepath, output=["rgba", "semantic"])


I0221 17:04:29.392554 40787 simulator.py:142] Loaded navmesh /private/home/mpiseno/Documents/Habitat/data/scene1/17DRP5sb8fy.navmesh


First we instantiate an ImageExtractor from Habitat-Sim. This requires that we previde a the filepath to a scene from which we will extract images. Optionally, we can specify the type of output we would like from the extractor. The default is just RGB images.

```python
scene_filepath = "../../data/scene_datasets/habitat-test-scenes/skokloster-castle.glb"
extractor = ImageExtractor(scene_filepath, output=["rgb", "semantic"])
```

We then create a custom class that subclasses PyTorch's dataset and override the __len__ and __getitem__ methods. Mask-R-CNN requires that we provide the image, bounding boxes, semantic masks, and class labels for each example, so we have implemented functionality for that in the __getitem__ method.

In [26]:
# Specify which transforms to apply to the data in preprocessing
transform = T.Compose([
    T.ToTensor()
])

dataset = HabitatDataset(extractor, transform=transform)

data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True,
                                          collate_fn=collate_fn)

### Model and Training Setup

In [27]:
# Credit: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
def build_model(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model


num_epochs = 100
num_classes = 10
model_weights = "maskrcnn-weights"
load_weights = True

model = build_model(num_classes)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.00005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)


### Training and Evaluation

In [28]:
# Example of training
from examples.instance_segmentation.engine import train_one_epoch
import examples.instance_segmentation.utils


if load_weights:
    # Load model weights
    pass

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    lr_scheduler.step()
    
    if epoch % 100 == 0:
        # Save model weights
        pass


Epoch: [0]  [ 0/11]  eta: 0:00:49  lr: 0.000005  loss: 8.7761 (8.7761)  loss_classifier: 2.3126 (2.3126)  loss_box_reg: 0.6126 (0.6126)  loss_mask: 1.0407 (1.0407)  loss_objectness: 3.9861 (3.9861)  loss_rpn_box_reg: 0.8241 (0.8241)  time: 4.4937  data: 4.0778  max mem: 2639
Epoch: [0]  [10/11]  eta: 0:00:03  lr: 0.000050  loss: 7.9588 (8.4265)  loss_classifier: 2.3126 (2.2926)  loss_box_reg: 0.6132 (0.5983)  loss_mask: 1.1804 (1.1773)  loss_objectness: 2.6987 (3.4175)  loss_rpn_box_reg: 0.8421 (0.9408)  time: 3.4395  data: 3.1233  max mem: 2991
Epoch: [0] Total time: 0:00:37 (3.4397 s / it)
Epoch: [1]  [ 0/11]  eta: 0:00:33  lr: 0.000050  loss: 5.4564 (5.4564)  loss_classifier: 1.9528 (1.9528)  loss_box_reg: 0.6485 (0.6485)  loss_mask: 1.0494 (1.0494)  loss_objectness: 1.2651 (1.2651)  loss_rpn_box_reg: 0.5405 (0.5405)  time: 3.0356  data: 2.7274  max mem: 2991


KeyboardInterrupt: 