In [None]:
import os, csv
import tensorflow as tf
from tqdm import tqdm
from lxml import etree
from PIL import Image, ImageDraw

from object_detection.utils import dataset_util
from object_detection.utils import label_map_util

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
path = '/media/data/LocalizationData/Validation'
label_map_path = '/media/data/LocalizationData/TFObjectDetection/data/copter_label_map.pbtxt'
label_map_dict = label_map_util.get_label_map_dict(label_map_path)
exclude_classes = [2,3,5,6]

In [None]:
def parseXML(file):
    with tf.gfile.GFile(file, 'r') as fid:
        xml_str = fid.read()
    xml = etree.fromstring(xml_str)
    return dataset_util.recursive_parse_xml_to_dict(xml)

def areClassesSupported(data):
    for obj in data['object']:
        if obj['name'] not in label_map_dict.keys() or \
           label_map_dict[obj['name']] in exclude_classes:
            return False
    return True

In [None]:
img_out_path = "/media/data/LocalizationData/Output/ValidationSecondStage/c1478p/"
os.makedirs(img_out_path, exist_ok=True)

jpgs = []
for root, dirs, files in os.walk(path):
    for f in files:
        if f.endswith(".jpg"):
            jpgs.append(root+'/'+f)
jpgs.sort()
i = 0
csv_rows = []
for jpg in tqdm(jpgs):
    xml = jpg.replace(".jpg",".xml")
    data = parseXML(xml)['annotation']
    if not areClassesSupported(data):
        continue
        
    image = Image.open(jpg)
    j = 0
    for obj in data['object']:
        xmin = int(obj['bndbox']['xmin'])
        ymin = int(obj['bndbox']['ymin'])
        xmax = int(obj['bndbox']['xmax'])
        ymax = int(obj['bndbox']['ymax'])
        subclass_text = obj['name'].encode('utf8')
        subclass = label_map_dict[obj['name']]
        rot = float(obj['pose'])
        
        roi = image.crop((xmin, ymin, xmax, ymax))
        roiname = "img"+str(i)+"-obj"+str(j)+".png"
        roi.save(img_out_path+roiname)
               
        csv_rows.append([roiname,subclass_text,subclass,rot])
        j+=1
    i+=1
        
with open(img_out_path+"0groundtruth.csv", 'w') as csvfile:
    spamwriter = csv.writer(csvfile, delimiter=',',
                            quotechar='"', quoting=csv.QUOTE_MINIMAL)
    for row in csv_rows:
        spamwriter.writerow(row)
print("Number of objects: " + str(len(csv_rows)))