# Mask R-CNN

This is a Pytorch implimentation of Mask R-CNN that follows [this tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html). This is useful for any kind of custom dataset with which you would like to perform image segmentation. The code below allows you to easily perform transfer learning using the model pre-trained on COCO train2017 provided by pytorch.

For my personal usecase, I had small dataset of 10,000 images aggregated using the reddit api wrapper, [praw](https://github.com/praw-dev/praw). I labelled 800 of them, created this entire notebook, and was making accurate out of sample predictions within one sitting. The Dataset class below takes project files (json) generated by the [VGG image annotator](https://www.robots.ox.ac.uk/~vgg/software/via/). For the image datset you can tweak the load_data() method as needed in order to load in your images of choice. Other than that, this notebook is plug and play and will allow you to quickly get up and running with Mask R-CNN. 

In [1]:
import os
import json
import numpy as np
from PIL import Image, ImageDraw, ImageOps, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import torch
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as T

In [2]:
if torch.cuda.is_available():
  device = torch.device("cuda")

In [3]:
root = '/content/drive/MyDrive/Data/RCNN Dataset'

In [None]:
class Dataset:
  def __init__(self, root, train = True):
    self.root = root
    self.train = train

    self.mask_path = os.path.join(root, 'Labels/labels_json.json')
    self.masks = self.parse_annotations(self.mask_path)

    self.img_path = os.path.join(root, "Images")
    self.file_names = list(self.masks.keys())
    self.images = self.load_images()
     
  def load_images(self):
    images = {}
    for f in self.file_names:
      images[f] = Image.open(os.path.join(self.img_path, f))
    return images

  def transform(self):
    operations = [T.ToTensor()]
    if self.train == True:
      if np.random.random() < 0.5:
        flip = T.RandomHorizontalFlip(p=0.9999)
        operations.append(flip)
      deg = np.random.uniform(0, 360)
      rot = T.RandomRotation([deg-0.0001, deg+0.0001])
      operations.append(rot)
    return T.Compose(operations)

  def clean_coords(self, coords):
    x = [int(i) for i in coords[::2]]
    y = [int(i) for i in coords[1::2]]
    return [x,y]

  def create_mask(self, file_name):
    img = self.images[file_name]
    coords = self.masks[file_name]
    poly = Image.new('RGBA', img.size)
    draw = ImageDraw.Draw(poly)
    draw.polygon(coords, fill = 'black')

    x,y = self.clean_coords(self.masks[file_name])
    box =  [min(x), min(y), max(x), max(y)]
    return poly, box

  def apply_mask(self, file_name):
    transform = self.transform()
    img = T.ToPILImage(mode='RGB')(transform(self.images[file_name]))
    mask, _ = self.create_mask(file_name)
    transformed_mask = T.ToPILImage(mode='RGBA')(transform(mask))
    new_image = Image.new('RGBA', img.size)
    new_image.paste(img, (0,0), mask = mask)
    return new_image.convert('RGB')

  def parse_annotations(self, mask_path):
    via_json = json.load(open(mask_path, 'r'))
    file_map = {key: {'fname':value['fname']} for key,value in via_json['file'].items()}
    img_mask = {}
    for value in via_json['metadata'].values():
      # if coords not empty and is polygon (7)
      coords = value['xy']
      if coords != [] and coords[0] == 7:
        file_name = file_map[value['vid']]['fname']
        img_mask[file_name] = list(map(int,coords[1:]))
    return img_mask

class DataSetLoader(Dataset):
  def __init__(self, root, train = True):
    super().__init__(root) 
    self.train = train

  def __len__(self):
    return len(self.file_names)

  def __getitem__(self, idx):
    transform = self.transform()
    # loading image and mask
    file_name = self.file_names[idx]
    img = transform(self.images[file_name])
    raw_mask, box = self.create_mask(file_name)
    mask = transform(raw_mask)

    # vectorization 
    mask = np.array(mask)
    obj_ids = np.unique(mask)[1:]
    num_objs = len(obj_ids)
    masks = mask == obj_ids[:, None, None]

    # calculating bounding box
    pos = np.where(mask)
    boxes = torch.as_tensor([box], dtype=torch.float32)
    area = (boxes[:,3]-boxes[:,1])*(boxes[:,2]-boxes[:,0])
    
    # torch conversions
    image_id = torch.tensor([idx])
    labels = torch.ones((num_objs,), dtype=torch.int64)
    masks = torch.as_tensor(masks, dtype=torch.uint8)
    iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
    target = {'boxes': boxes, 'labels': labels, 
              'masks': masks, ' image_id': image_id, 
              'area': area, 'iscrowd' : iscrowd}
    return img, target

In [None]:
os.chdir(os.path.join(root, 'ExtraTools')) 
# https://github.com/pytorch/vision/tree/master/references/detection
from utils import * 
from engine import *

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

def get_model_instance_segmentation(num_classes):
  model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
  in_features = model.roi_heads.box_predictor.cls_score.in_features
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
  in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
  hidden_layer = 256
  model.roi_heads.mask_predictor = MaskRCNNPredictor(
      in_features_mask, hidden_layer,num_classes)
  return model

In [None]:
# object or background
num_classes = 2

# splitting data
train = DataSetLoader(root, train = True)
test = DataSetLoader(root, train = False)
indices = torch.randperm(len(train)).tolist()
train = torch.utils.data.Subset(train, indices[:-50])
test = torch.utils.data.Subset(test, indices[-50:])

# loader init 
train_loader = DataLoader(train, batch_size=2, 
                         shuffle=True, num_workers=4,
                         collate_fn= collate_fn)
test_loader = DataLoader(test, batch_size=1, 
                         shuffle=True, num_workers=4,
                         collate_fn= collate_fn)

In [None]:
# model init

device = torch.device('cuda') 
model = get_model_instance_segmentation(num_classes)
model.to(device)

# optimizaer init
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=3, gamma=0.1)

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth


HBox(children=(FloatProgress(value=0.0, max=178090079.0), HTML(value='')))




In [3]:
num_epochs = 10
for epoch in range(num_epochs):
  train_one_epoch(model, optimizer, train_loader, device, epoch, print_freq=10)
  lr_scheduler.step()