In [1]:
import sys
from os import listdir
from xml.etree import ElementTree
from numpy import zeros
from numpy import asarray
from Mask_RCNN.mrcnn.utils import Dataset
import skimage
import re

In [2]:
from mrcnn.utils import Dataset
from mrcnn.config import Config
from mrcnn.model import MaskRCNN
import matplotlib.pyplot as pyplot

In [24]:
class ObjectDataset(Dataset):
    
    def load_dataset(self, dataset_dir, is_train=True):
        self.add_class("boat_dataset", 1, "boat")
        self.add_class("building_dataset", 2, "building")
        self.add_class("car_dataset", 3, "car")
        self.add_class("drone_dataset", 4, "drone")
        self.add_class("person_dataset", 5, "person")
        self.add_class("horseride_dataset", 6, "horseride")
        self.add_class("paraglider_dataset", 7, "paraglider")
        self.add_class("riding_dataset", 8, "riding")
        self.add_class("truck_dataset", 9, "truck")
        self.add_class("whale_dataset", 10, "whale")
        self.add_class("wakeboard_dataset", 11, "wakeboard")
        self.add_class("group_dataset", 12, "group")
        
        # images 
        # folders to skip over
        dir_file_c = {}
        
        for d in listdir(dataset_dir):
            # create count for each file
            if d == '.DS_Store':
                continue
            dir_file_c[d] = 0
            for f in listdir(dataset_dir+'/'+d):
                dir_file_c[d]+=1
            dir_file_c[d]/=2
    
        im_id = 0
        for thing in listdir(dataset_dir):
            if thing == '.DS_Store':
                continue
            # gets class by removing the nunmbers form name of dir    
            image_class = re.sub('[^a-zA-Z]+', '', str(thing))
            
            for i, file in enumerate(listdir(dataset_dir + '/' + thing)):
                if file[-1:] == '.xml':
                    continue
                # find what percentage of a folder we're through
                percent = float(i)/float(dir_file_c[thing])*50
                # make sure that we aare only gethering trainig ot not traing images
                if is_train and (percent >= 65):
                    continue
                if not is_train and (percent < 65):
                    continue
                
                image_id = im_id
                im_id +=1
                # locations for images and annotations
                im = dataset_dir + '/' + thing + '/' + file
                an = dataset_dir + '/' + thing + '/' + file[:-4] + '.xml'
                
                self.add_image(image_class+'_dataset', image_id=image_id, path=im, annotation=an)


    def extract_boxes(self, file):
        tree = ElementTree.parse(file)
        root = tree.getroot()
        boxes = list()
        # get box coordinates
        for box in root.findall('.//bndbox'):
            xmax = int(box.find('xmax').text)
            xmin = int(box.find('xmin').text)
            ymax = int(box.find('ymax').text)
            ymin = int(box.find('ymin').text)
            box_coords = [xmax, xmin, ymax, ymin]
            boxes.append(box_coords)

        height = int(root.find('.//size/height').text)
        width = int(root.find('.//size/width').text)
        # return value is a list with box coordinates and the dimensions of pic overall
        return boxes, width, height
    
        
    def load_masks(self, image_id):
        # get details of image
        info = self.image_info[image_id]
        # define box file location
        path = info['annotation']
        print(path)
        ind = re.sub('[^a-zA-Z]+', '', path[14:].split('/')[0])
        # load XML
        boxes, w, h = self.extract_boxes(path)
        # create one array for all masks, each on a different channel
        masks = zeros([h, w, len(boxes)], dtype='uint8')
        # create masks
        class_ids = list()
        for i in range(len(boxes)):
            box = boxes[i]
            row_s, row_e = box[3], box[2]
            col_s, col_e = box[1], box[0]
            masks[row_s:row_e, col_s:col_e, i] = 1
            for c in self.class_info:
                if c['name'] == ind:
                    class_ids.append(c['id'])
                    break
        return masks, asarray(class_ids, dtype='int32')
        
    def image_reference(self, image_id):
        info = self.image_info[image_id]
        return info ['path']

In [25]:
train_set = ObjectDataset()
train_set.load_dataset('data_training', is_train=True)
train_set.prepare()
print('Train: %d' % len(train_set.image_ids))
 
# test/val set
test_set = ObjectDataset()
test_set.load_dataset('data_training', is_train=False)
test_set.prepare()
print('Test: %d' % len(test_set.image_ids))

Train: 121620
Test: 65420


In [26]:
class ObjectConfg(Config):
    NAME = "object_cfg"
    NUM_CLASSES = 1 + 12
    STEPS_PER_EPOCH = 121620

In [23]:
config = ObjectDataset()
config.display()

AttributeError: 'ObjectDataset' object has no attribute 'display'