## Instance Segmentation Using Habitat-Sim and Mask-R-CNN

author: Michael Piseno (mpiseno@fb.com)

This notebook will demonstrate how to set up an efficient datapipline for the purpose of instance segmentation using PyTorch, Mask-R-CNN, and Habitat-Sim as a data generator.

Other resources:
* [Mask-R-CNN paper](https://arxiv.org/pdf/1703.06870.pdf)
* [PyTorch instance segmentation tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)

In [7]:
import os
import time
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torch.utils.data import Dataset

from habitat_sim.utils.data.dataextractor import ImageExtractor

### Data preparation and proprocessing

Below we will define the data extraction and preprocessing steps. Habitat-Sim's image data extraction API will be used to gather images from within the simulator for use inside a PyTorch Dataset subclass, which is subsequently fed into a PyTorch dataloader.

In [3]:
def collate_fn(batch):
    return tuple(zip(*batch))
    
    
class HabitatDataset(Dataset):
    def __init__(self, extractor, transform=None):
        self.extractor = extractor
        self.transform = transform
        
    def __len__(self):
        return len(self.extractor.poses)
    
    def __getitem__(self, idx):
        sample = self.extractor[idx]
        img, mask = sample["rgb"][:, :, :3], sample["semantic"]
        obj_ids = np.unique(mask)
        masks = np.array([mask == obj_id for obj_id in obj_ids])
        
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
        
        if self.transform:
            img = self.transform(img)
        
        return img, target
    

try:
    extractor.close()
except:
    pass
    
scene_filepath = "../../data/scene_datasets/habitat-test-scenes/skokloster-castle.glb" # Replace with your filepath
extractor = ImageExtractor(scene_filepath, output=["rgb", "semantic"])


I0204 15:12:17.414058 38362 simulator.py:131] Loaded navmesh ../../data/scene_datasets/habitat-test-scenes/skokloster-castle.navmesh


sim_cfg.physics_config_file = ./data/default.phys_scene_config.json


First we instantiate an ImageExtractor from Habitat-Sim. This requires that we previde a the filepath to a scene from which we will extract images. Optionally, we can specify the type of output we would like from the extractor. The default is just RGB images.

```python
scene_filepath = "../../data/scene_datasets/habitat-test-scenes/skokloster-castle.glb"
extractor = ImageExtractor(scene_filepath, output=["rgb", "semantic"])
```

We then create a custom class that subclasses PyTorch's dataset and override the __len__ and __getitem__ methods. Mask-R-CNN requires that we provide the image, bounding boxes, semantic masks, and class labels for each example, so we have implemented functionality for that in the __getitem__ method.

In [5]:
# Specify which transforms to apply to the data in preprocessing
transform = T.Compose([
    T.ToTensor()
])

dataset = HabitatDataset(extractor, transform=transform)

data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True,
                                          collate_fn=collate_fn)

### Model and Training Setup

In [6]:
# Credit: https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
def build_model(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model


num_epochs = 100
num_classes = 10
model_weights = "maskrcnn-weights"
load_weights = True

model = build_model(num_classes)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)


### Training and Evaluation

In [9]:
# Example of training
from examples.instance_segmentation.engine import train_one_epoch
import examples.instance_segmentation.utils


if load_weights:
    # Load model weights
    pass

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    lr_scheduler.step()
    
    if epoch % 100 == 0:
        # Save model weights
        pass
    
    
# # For Training
# images, targets = next(iter(data_loader))
# images = list(image for image in images)
# targets = [{k: v for k, v in t.items()} for t in targets]
# output = model(images,targets)   # Returns losses and detections
# # For inference
# model.eval()
# x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
# predictions = model(x)           # Returns predictions

Epoch: [0]  [0/3]  eta: 0:01:04  lr: 0.002502  loss: 4.6411 (4.6411)  loss_classifier: 2.3846 (2.3846)  loss_box_reg: 0.0750 (0.0750)  loss_mask: 2.0145 (2.0145)  loss_objectness: 0.0922 (0.0922)  loss_rpn_box_reg: 0.0749 (0.0749)  time: 21.5655  data: 0.2227
Epoch: [0]  [2/3]  eta: 0:00:17  lr: 0.005000  loss: 4.9862 (6.0747)  loss_classifier: 2.3846 (1.9543)  loss_box_reg: 0.0750 (0.0649)  loss_mask: 2.0145 (2.0435)  loss_objectness: 0.0922 (1.4433)  loss_rpn_box_reg: 0.0749 (0.5687)  time: 17.2008  data: 0.0959
Epoch: [0] Total time: 0:00:51 (17.2013 s / it)
Epoch: [1]  [0/3]  eta: 0:00:58  lr: 0.005000  loss: 1.8169 (1.8169)  loss_classifier: 0.1961 (0.1961)  loss_box_reg: 0.0446 (0.0446)  loss_mask: 0.8425 (0.8425)  loss_objectness: 0.2259 (0.2259)  loss_rpn_box_reg: 0.5079 (0.5079)  time: 19.5119  data: 0.0408
Epoch: [1]  [2/3]  eta: 0:00:16  lr: 0.005000  loss: 1.8169 (1.8073)  loss_classifier: 0.1961 (0.1966)  loss_box_reg: 0.0492 (0.0509)  loss_mask: 0.9531 (1.0600)  loss_obje