# Code for running Torchvision Mask R-CNN on BRACOT dataset 
*Version for Google Colab*

*Based on [this](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)*

***Connect to drive (if your data is on GDrive)***

In [None]:
from google.colab import drive
drive.mount('/content/drive')

***Importing some libs***

In [None]:
import sys
import os
sys.path.insert(0, '/content/drive/My Drive/Colab Notebooks/torchvision_utils') # Insert path to utils/
import numpy as np
import torch
import torch.utils.data
from PIL import Image
from pycocotools.coco import COCO

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from engine import train_one_epoch, evaluate
from coco_utils import convert_coco_poly_to_mask
import utils
import transforms as T
import matplotlib.pyplot as plt
import cv2
from google.colab.patches import cv2_imshow
import visualize_maskrcnn_predictions as vis

***Defining the dataset class***

In [None]:
class LeafDataset(torch.utils.data.Dataset):
    def __init__(self, root, annotation, transforms=None):
        self.root = root
        self.transforms = transforms
        self.coco = COCO(annotation)
        self.ids = list(sorted(self.coco.imgs.keys()))

    def __getitem__(self, index):
        # Own coco file
        coco = self.coco
        # Image ID
        img_id = self.ids[index]
        # List: get annotation id from coco
        ann_ids = coco.getAnnIds(imgIds=img_id)
        # Dictionary: target coco_annotation file for an image
        coco_annotation = coco.loadAnns(ann_ids)
        # path for input image
        path = coco.loadImgs(img_id)[0]['file_name']
        # open the input image
        img = Image.open(os.path.join(self.root, path))
        w, h = img.size

        # number of objects in the image
        num_objs = len(coco_annotation)

        segmentation = []
        boxes = []

        for i in range(num_objs):
            xmin = coco_annotation[i]['bbox'][0]
            ymin = coco_annotation[i]['bbox'][1]
            xmax = xmin + coco_annotation[i]['bbox'][2]
            ymax = ymin + coco_annotation[i]['bbox'][3]
            boxes.append([xmin, ymin, xmax, ymax])
            segmentation.append(coco_annotation[i]['segmentation'])


        masks = convert_coco_poly_to_mask(segmentation, h, w)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # Labels (In my case, I only one class: target class or background)
        labels = torch.ones((num_objs,), dtype=torch.int64)
        # Tensorise img_id
        img_id = torch.tensor([img_id])
        # Size of bbox (Rectangular)
        areas = []
        for i in range(num_objs):
            areas.append(coco_annotation[i]['area'])
        areas = torch.as_tensor(areas, dtype=torch.float32)
        # Iscrowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        # Annotation is in dictionary format
        my_annotation = {}
        my_annotation["masks"] = masks
        my_annotation["boxes"] = boxes
        my_annotation["labels"] = labels
        my_annotation["image_id"] = img_id
        my_annotation["area"] = areas
        my_annotation["iscrowd"] = iscrowd

        if self.transforms is not None:
            img = self.transforms(img)

        return img, my_annotation

    def __len__(self):
        return len(self.ids)

***Defining the model***

In [None]:
def get_transform():
    custom_transforms = []
    custom_transforms.append(torchvision.transforms.ToTensor())
    return torchvision.transforms.Compose(custom_transforms)

def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True, box_score_thresh=0.7) # box_score_thresh=0.05

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

***Defining train, eval on dataset and eval on single image functions***

In [None]:
'''
Function for training on dataset
'''
def train(path_imgs_train, path_annotation_train, batch_size=1, lr=0.005, epochs=100):
    # train on the GPU or on the CPU, if a GPU is not available
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    # our dataset has two classes only - background and leaf
    num_classes = 2

    dataset = LeafDataset(root=path_imgs_train, 
                          annotation=path_annotation_train, 
                          transforms=get_transform())
    
    # define training and validation data loaders
    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=utils.collate_fn)

    print(len(data_loader.dataset))
    # get the model using our helper function
    model = get_model_instance_segmentation(num_classes)

    # move model to the right device
    model.to(device)

    # construct an optimizer
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=lr,
                                momentum=0.9, weight_decay=0.0005)
    # and a learning rate scheduler
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=3,
                                                   gamma=0.1)

    num_epochs = epochs

    for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
        # update the learning rate
        lr_scheduler.step()
    torch.save(model.state_dict(), "model_maskrcnn_final.pth")


'''
Function for evaluation on dataset
'''
def eval_on_dataset(weights_path, path_imgs_test, path_annotation_test, batch_size=1):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    
    model = get_model_instance_segmentation(num_classes=2)
    
    model.load_state_dict(torch.load(weights_path))

    # Define test dataset and dataset loader
    dataset_test = LeafDataset(root=path_imgs_test, 
                               annotation=path_annotation_test, 
                               transforms=get_transform())
    
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=batch_size, shuffle=False, num_workers=4,
        collate_fn=utils.collate_fn)
    
    model.to(device)
    
    evaluate(model, data_loader_test, device=device)


'''
Function for evaluating on single image and create outputs for next framework stage
'''
# Auxiliary function for creating masks
# TODO: implement it in a better way
def create_fig (img, mask):
    temp = img.copy()
    i_, j_, _ = img.shape
    for i in range (0, i_):
        for j in range (0, j_):
            if mask[i][j] < 0.1:
                temp[i][j][0] = 1 
                temp[i][j][1] = 1 
                temp[i][j][2] = 1 
    return temp

def eval_single_img(weights_path, path_to_img):
    # Load model
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = get_model_instance_segmentation(num_classes=2)
    model.load_state_dict(torch.load(weights_path))
    model.eval()
    t = get_transform()

    # Open image and run through model
    img = Image.open(path_to_img)
    outputs = model([t(img)])

    # Visualize predictions
    result, _, _ = vis.predict(img, model.to(device))
    cv2_imshow(result)
    
    # Save cropped masks on files
    transpose_fig = np.transpose(t(img).numpy(), (1,2,0))
    masks = outputs[0]['masks'].ge(0.5).mul(255).byte().cpu().numpy()
    
    if not os.path.exists('/content/outputs/'):
        os.makedirs('/content/outputs/instance_seg_result')
    cv2.imwrite(f'/content/outputs/instance_seg_result/instanceseg.png', result)
    
    i = 0
    for mask in masks:
        temp = create_fig(transpose_fig, mask[0])
        toTensor = torchvision.transforms.ToTensor()
        toPIL = torchvision.transforms.ToPILImage()
        temp = toPIL(toTensor(temp))
        temp.save(f'/content/outputs/instance_seg_result/{i}.jpg')
        i += 1
        plt.clf()

### Example

***Defining some params***

In [None]:
# path to root folder containing train imgs and annotations
path_imgs_train = '/content/drive/My Drive/ic2-dataset/train/'
path_annotation_train = '/content/drive/My Drive/ic2-dataset/train/train_annotation.json'

# path to root folder containing test imgs and annotations
path_imgs_test =  '/content/drive/My Drive/ic2-dataset/test/'
path_annotation_test = '/content/drive/My Drive/ic2-dataset/test/test_annotation.json'

# batch size
batch_size = 1
# learning rate
lr = 0.005
# num of epochs
epochs = 100

***Training on dataset***

In [None]:
# train
train(path_imgs_train=path_imgs_train, 
      path_annotation_train=path_annotation_train, 
      batch_size=batch_size, 
      lr=lr, 
      epochs=epochs)

***Evaluating on dataset***

In [None]:
# path to model weights
weights_path = '/content/drive/My Drive/ic2-dataset/model_final_torchvision_100epochs.pth'

# eval
eval_on_dataset(weights_path=weights_path, 
                path_imgs_test=path_imgs_test, 
                path_annotation_test=path_annotation_test, 
                batch_size=batch_size)

***Evaluating on single image***

In [None]:
# define img path and weights path
weights_path = '/content/drive/My Drive/ic2-dataset/model_final_torchvision_100epochs.pth'
img_path = '/content/drive/MyDrive/ic2-dataset/test/20190831_163222.jpg'

# eval
eval_single_img(weights_path = weights_path, 
                path_to_img = img_path)