In [13]:
import pickle
import os
import numpy as np
try:
    import xml.etree.cElementTree as ET
except ImportError:
    import xml.etree.ElementTree as ET

In [14]:
class XML_preprocessor(object):

    def __init__(self, data_path):
        self.path_prefix = data_path
        self.num_classes = 21
        self.data = dict()
        self._preprocess_XML()

    def _preprocess_XML(self):
        filenames = os.listdir(self.path_prefix)
        for filename in filenames:
            tree = ET.parse(self.path_prefix + filename)
            root = tree.getroot()
            bounding_boxes = []
            one_hot_classes = []
            size_tree = root.find('size')
            width = float(size_tree.find('width').text)
            height = float(size_tree.find('height').text)
            for object_tree in root.findall('object'):
                for bounding_box in object_tree.iter('bndbox'):
                    xmin = float(bounding_box.find('xmin').text)/width
                    ymin = float(bounding_box.find('ymin').text)/height
                    xmax = float(bounding_box.find('xmax').text)/width
                    ymax = float(bounding_box.find('ymax').text)/height
                bounding_box = [xmin,ymin,xmax,ymax]
                bounding_boxes.append(bounding_box)
                class_name = object_tree.find('name').text
                one_hot_class = self._to_one_hot(class_name)
                one_hot_classes.append(one_hot_class)
            image_name = root.find('filename').text
            bounding_boxes = np.asarray(bounding_boxes)
            one_hot_classes = np.asarray(one_hot_classes)
            image_data = np.hstack((bounding_boxes, one_hot_classes))
            self.data[image_name] = image_data

    def _to_one_hot(self,name):
        one_hot_vector = [0] * self.num_classes
        if name == 'stop_sign':
            one_hot_vector[0] = 1
        elif name == 'aeroplane':
            one_hot_vector[1] = 1
        elif name == 'bicycle':
            one_hot_vector[2] = 1
        elif name == 'bird':
            one_hot_vector[3] = 1
        elif name == 'boat':
            one_hot_vector[4] = 1
        elif name == 'bottle':
            one_hot_vector[5] = 1
        elif name == 'bus':
            one_hot_vector[6] = 1
        elif name == 'car':
            one_hot_vector[7] = 1
        elif name == 'cat':
            one_hot_vector[8] = 1
        elif name == 'chair':
            one_hot_vector[9] = 1
        elif name == 'cow':
            one_hot_vector[10] = 1
        elif name == 'diningtable':
            one_hot_vector[11] = 1
        elif name == 'dog':
            one_hot_vector[12] = 1
        elif name == 'horse':
            one_hot_vector[13] = 1
        elif name == 'motorbike':
            one_hot_vector[14] = 1
        elif name == 'person':
            one_hot_vector[15] = 1
        elif name == 'pottedplant':
            one_hot_vector[16] = 1
        elif name == 'sheep':
            one_hot_vector[17] = 1
        elif name == 'sofa':
            one_hot_vector[18] = 1
        elif name == 'train':
            one_hot_vector[19] = 1
        elif name == 'tvmonitor':
            one_hot_vector[20] = 1
        else:
            print('unknown label: %s' %name)
        return one_hot_vector

In [15]:
data = XML_preprocessor('xml/').data
pickle.dump(data,open('data.p','wb'))