In [None]:
import os
import tensorflow as tf
from tqdm import tqdm
from lxml import etree
from PIL import Image, ImageDraw

from object_detection.utils import dataset_util
from object_detection.utils import label_map_util

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
path = '/media/data/LocalizationData/Validation'
label_map_path = '/media/data/LocalizationData/TF/data/copter_label_map.pbtxt'
out_file = '/media/data/LocalizationData/TF/data/test-coptersflying-c1478.record'
label_map_dict = label_map_util.get_label_map_dict(label_map_path)
exclude_classes = [2,3,5,6]

In [None]:
def parseXML(file):
    with tf.gfile.GFile(file, 'r') as fid:
        xml_str = fid.read()
    xml = etree.fromstring(xml_str)
    return dataset_util.recursive_parse_xml_to_dict(xml)

def areClassesSupported(data):
    for obj in data['object']:
        if obj['name'] not in label_map_dict.keys() or \
           label_map_dict[obj['name']] in exclude_classes:
            return False
    return True

In [None]:
tf_examples = []
jpgs = []
for root, dirs, files in os.walk(path):
    for f in files:
        if f.endswith(".jpg"):
            jpgs.append(root+'/'+f)
jpgs.sort()
for jpg in tqdm(jpgs):
    xml = jpg.replace(".jpg",".xml")
    data = parseXML(xml)['annotation']
    if not areClassesSupported(data):
        #print("skip unsupported cats")
        continue
        
    image = Image.open(jpg)
    draw = ImageDraw.Draw(image)
    draw.rectangle(((0, 0), (100, 100)), fill=data['object'][0]['name'].replace("copter_",""))
    image.save("/tmp/img.jpg", "JPEG")
    with tf.gfile.GFile("/tmp/img.jpg", 'rb') as fid:
        encoded_jpg = fid.read()
    width = int(data['size']['width'])
    height = int(data['size']['height'])

    xmin, ymin, xmax, ymax = [],[],[],[]
    classes, classes_text = [],[]
    for obj in data['object']:
        xmin.append(float(obj['bndbox']['xmin']) / width)
        ymin.append(float(obj['bndbox']['ymin']) / height)
        xmax.append(float(obj['bndbox']['xmax']) / width)
        ymax.append(float(obj['bndbox']['ymax']) / height)
        classes_text.append(obj['name'].encode('utf8'))
        classes.append(label_map_dict[obj['name']])

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(
          data['filename'].encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    tf_examples.append(example)
            
print("Number of tf examples: " + len(tf_examples))
writer = tf.python_io.TFRecordWriter(out_file)
for example in tf_examples:
    writer.write(example.SerializeToString())
writer.close()