## Acknowledgements

* Data set creation and learning code adapted from the PyTorch tutorial [Torchvision Object Detection Finetuning Tutorial](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
* Code for plotting detection bounding boxes taken from the tutorial [Faster R-CNN Object Detection with PyTorch](https://www.learnopencv.com/faster-r-cnn-object-detection-with-pytorch/)

In [None]:
from typing import Sequence, Dict

import os
import random
import yaml

from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from engine import train_one_epoch, evaluate
import utils
import transforms as T

In [None]:
ycb_label_data = {}
with open('ycb_labels.yaml') as ycb_label_file:
    ycb_label_data = yaml.load(ycb_label_file)
category_names = ycb_label_data['category_names']
label_map = ycb_label_data['label_map']
num_classes = len(category_names)
num_epochs = 3

In [None]:
class YCBDataset(object):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms

        self.data_list = []
        for x in os.listdir(self.root):
            dir_path = os.path.join(self.root, x)
            if os.path.isdir(dir_path):
                self.data_list.extend(self.__get_image_data(x, os.path.join(dir_path, 'images'),
                                                            os.path.join(dir_path, 'masks')))

    def __get_image_data(self, class_name: str, image_dir: str,
                         mask_dir: str) -> Sequence[Dict[str, str]]:
        '''Returns a list of dictionaries containing image-specific
        data for all images from the specified class.  Each dictionary
        in the resulting list is of the following format:
        {
            'img': path to an image (starting from self.root),
            'mask': path to the image segmentation mask (starting from self.root),
            'label': integer label of the class
        }

        Keyword arguments:
        class_name: str -- name of the image class
        image_dir: str -- name of a directory with RGB images
        mask_dir: str -- name of a directory with image masks

        '''
        class_image_list = []
        for x in os.listdir(image_dir):
            name, _ = x.split('.')
            image_data = {}
            image_data['img'] = os.path.join(image_dir, name + '.jpg')
            image_data['mask'] = os.path.join(mask_dir, name + '_mask.pbm')
            image_data['label'] = label_map[class_name]
            class_image_list.append(image_data)
        return class_image_list

    def __getitem__(self, idx):
        # load images ad masks
        img_path = self.data_list[idx]['img']
        mask_path = self.data_list[idx]['mask']

        img = Image.open(img_path).convert("RGB")
        img.thumbnail((600,600), Image.ANTIALIAS)

        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = Image.open(mask_path)
        mask.thumbnail((600,600), Image.ANTIALIAS)

        # convert the PIL Image into a numpy array
        mask = np.array(mask)

        # instances are encoded as different colors
        obj_ids = np.unique(mask)

        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(1 - masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)

        # we assume that all objects in the image are of the same class
        labels = self.data_list[idx]['label'] * torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.data_list)

In [None]:
def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

def get_model_instance_segmentation(num_classes):
    # load an instance segmentation model pre-trained pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

def get_prediction(img_path, threshold):
    img = Image.open(img_path) # Load the image
    transform = T.ToTensor() # Defing PyTorch Transform
    img, _ = transform(img, None) # Apply the transform to the image
    with torch.no_grad():
        pred = model([img.to(device)]) # Pass the image to the model
    pred_class = [category_names[i] for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score
    pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].cpu().detach().numpy())] # Bounding boxes
    pred_score = list(pred[0]['scores'].cpu().detach().numpy())
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1] # Get list of index with score greater than threshold.
    pred_boxes = pred_boxes[:pred_t+1]
    pred_class = pred_class[:pred_t+1]
    return pred_boxes, pred_class

def object_detection_api(img_path, threshold=0.5, rect_th=10, text_size=0.5, text_th=1):
    boxes, pred_cls = get_prediction(img_path, threshold) # Get predictions
    img = cv2.imread(img_path) # Read image with cv2
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Convert to RGB
    for i in range(len(boxes)):
        cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th) # Draw Rectangle with the coordinates
        cv2.putText(img,pred_cls[i], boxes[i][0],  cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th) # Write the prediction class
    plt.figure(figsize=(20,30)) # display the output image
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    plt.show()

In [None]:
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# use our dataset and defined transformations
# dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))
# dataset_test = PennFudanDataset('PennFudanPed', get_transform(train=False))
dataset = YCBDataset('ycb_dataset', get_transform(train=True))
dataset_test = YCBDataset('ycb_dataset', get_transform(train=False))


# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

# get the model using our helper function
model = get_model_instance_segmentation(num_classes)

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)

    # update the learning rate
    lr_scheduler.step()

    # evaluate on the test dataset
    # evaluate(model, data_loader_test, device=device)

print('Training over')

In [None]:
# sanity check - evaluating the detection on an image from the data set
object_detection_api('/home/lucy/ycb_dataset/sponge/images/N1_252.jpg', threshold=0.4)