# Train Mask RCNN

In [1]:
import random
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import numpy as np
import torch.utils.data
import cv2
import torchvision.models.segmentation
import torch
import os
from matplotlib import pyplot as plt

In [2]:
# Get the training parameters.
BATCH_SIZE  = 4
IMG_SIZE    = [1024, 1024]

In [3]:
# Get the dataset sources.
DATASET_IMGS = "../../Data/Images"
DATASET_SEGM = "../../Data/instance_segmentation"
LABELS_NAMES = "../../Data/bounding_box"

DATASET_SEPA = "../../Data/isolated_strawberrys"
DATASET_CROP = "../../Data/isolated_strawberrys_cropped"

In [4]:
# Get the dataset into variables.
base_images = []
segm_images = []

# Get all the images within the path.
for pth in os.listdir(DATASET_IMGS):

    base_images.append(DATASET_IMGS + "/" + pth)
    segm_images.append(DATASET_SEGM + "/" + pth)

In [5]:
# Euclidean distance.
def eucl_box(x1, y1, x2, y2):

    dist_1 = np.sqrt(np.sum((np.array([x1, y1]) - np.array([x2, y2])) ** 2))

    return dist_1

def load_bounding_boxes(filename):

    f = open(filename, "r")
    lines = f.readlines()

    boxes = []

    for line in lines:

        elements = line.split(' ')
    
        # Transform them.
        id, cent_x, cent_y, width, height = elements
        id = int(id)
        cent_x, cent_y, width, height = [float(x) for x in [cent_x, cent_y, width, height]]

        boxes.append([id, cent_x, cent_y, width, height])

    return boxes

In [6]:
for image, segmentation in zip(base_images, segm_images):

    # Load the image.
    img = cv2.imread(image)
    mask = cv2.imread(segmentation, 0)

    # Get the masks.
    masks_labels = [l for l in np.unique(mask) if l != 0]

    # get the number ID and the height and width.
    height, width = mask.shape
    img_id = image.split('/')[-1].split('.')[-2]

    # Iterate through the labels to get the masks.
    for num, label_id in enumerate(masks_labels):

        # Get that label as an isolated mask.
        mask_i = (mask == label_id).astype(np.uint8)  # Read vesse instance mask

        # Get the bounding box.
        bb_coords = cv2.findNonZero(mask_i)
        mask_bb_x, mask_bb_y, mask_bb_w, mask_bb_h = cv2.boundingRect(bb_coords)
        bb_cent = (mask_bb_x + mask_bb_w / 2, mask_bb_y + mask_bb_h / 2)

        min_distance = height * height
        best_box = None
        for box in load_bounding_boxes(LABELS_NAMES + '/' + img_id + '.txt'):

            # Transform them into the good gormat.
            box_id, box_cent_x, box_cent_y, box_width, box_height = box
            box_cent_x *= width
            box_cent_y *= height
            box_width *= width
            box_height *= height

            dist = eucl_box(box_cent_x, box_cent_y, bb_cent[0], bb_cent[1])

            if dist < min_distance:
                min_distance = dist
                best_box = (box_id, int(box_cent_x), int(box_cent_y))

        # Display this image.
        label = ['unripe', 'partially_ripe', 'fully_ripe'][best_box[0]]
        label = str(best_box[0])
        new_img = img * cv2.merge([mask_i, mask_i, mask_i])

        new_filename = img_id + '-' + str(num) + '-' + label + '.png'

        cv2.imwrite(DATASET_SEPA + '/' + new_filename, new_img)

        # Crop the image.
        crop_img = new_img[mask_bb_y:mask_bb_y + mask_bb_h, mask_bb_x:mask_bb_x + mask_bb_w]
        
        cv2.imwrite(DATASET_CROP + '/' + img_id + '_' + str(num) + '_' + label + '.png', crop_img)