In [1]:
# Reference
# https://linuxhint.com/python_xml_to_dictionary/

In [97]:
import os
import xmltodict
import math
import random
from copy import deepcopy

In [2]:
base = './data/ori'

In [7]:
def convert2coco(base):
    
    # Open an xml file and convert it into the dictionary format
    def xml_to_dict(base, filename):
        with open(os.path.join(base, filename), "r") as xml_obj:
            my_dict = xmltodict.parse(xml_obj.read())
            xml_obj.close()
        return my_dict
    
    # Get filenames of xml files
    filenames = [filename for filename in os.listdir(base) if os.path.splitext(filename)[-1] == '.xml']
    
    # Initialize the annotations dictionary
    annotations = {
        "type": "instances",
        "images": [],
        "categories": [],
        "annotations": []
    }
    
    img2annots = {}

    cls_dict = {}
       
    # Get all classes in the dataset
    for filename in filenames:
        ann  = xml_to_dict(base, filename)['annotation']
        for obj in ann['object']:
            cls_name = obj['name']
            if cls_name not in cls_dict:
                cls_dict[cls_name] = len(cls_dict) + 1
    
    img_id = 1 # for img counting
    ann_id = 1 # for annotation counting
    for filename in filenames:
        img2annots[filename] = {
            'data': None,
            'annotations': [],
            'num_objects': None
        }
        
        # Get annotations of an xml file
        ann  = xml_to_dict(base, filename)['annotation']
            
        for ann_obj in ann['object']:
            bbx = ann_obj['bndbox']
            xmin, ymin = int(bbx['xmin']), int(bbx['ymin'])
            xmax, ymax = int(bbx['xmax']), int(bbx['ymax'])
            dx = xmax - xmin
            dy = ymax - ymin
            
            annot = {
                "id": ann_id,
                "bbox": [xmin, ymin, dx, dy],
                "image_id": img_id,
                "category_id": cls_dict[ann_obj['name']],
                "segmentation": [],
                "area": dx*dy,
                "iscrowd": 0
            }
            annotations["annotations"].append(annot)
            ann_id = ann_id + 1
            
            img2annots[filename]['annotations'].append(annot)
        
        size = ann['size']
        image = {
            "file_name": ann['filename'],
            "height":size['height'] ,
            "width": size['width'],
            "id": img_id
        }
        annotations["images"].append(image)
        img_id = img_id + 1
        
        img2annots[filename]['data'] = image       

    for cls_name, cls_id in cls_dict.items():
        annotations["categories"].append({
            "supercategory": "none",
            "name": cls_name,
            "id": cls_id
        })
        
    for filename, val in img2annots.items():
        img2annots[filename]['num_objects'] = {cls_id: 0 for _, cls_id in cls_dict.items()}
        for obj in val['annotations']:
            category_id = obj['category_id']
            img2annots[filename]['num_objects'][category_id] = img2annots[filename]['num_objects'][category_id] + 1
        
    return {
        'type': annotations['type'],
        'categories': annotations['categories'],
        'img2annots': img2annots
    }

In [8]:
annotations = convert2coco(base)

In [204]:
def train_test_split(input_annotations, split_dictionary, max_iter=100):
    img2annots = input_annotations['img2annots']
    
    total = sum([val for _, val in split_dictionary.items()])
    split_dict = {key: val/total for key, val in split_dictionary.items()}
    
    # Get the number of objects
    categories = img2annots[list(img2annots.keys())[0]]['num_objects']
    total_objects = {key: 0 for key, _ in categories.items()}
    for key, val in img2annots.items():
        for cat_id, cat_n in val['num_objects'].items():
            total_objects[cat_id] = total_objects[cat_id] + cat_n
    
    # Get the spit_size for every set
    total_img = len(img2annots.keys())
    split_size_img = {}
    for key1, val1 in split_dict.items():
        split_size_img[key1] = math.ceil(val1*total_img)
            
    # Get the index_mapping for each set
    split_img_dict = {}
    start_idx = 0
    for key, val in split_size_img.items():
        split_img_dict[key] = [start_idx, min(start_idx + val, total_img)]
        start_idx = start_idx + val
        
    # Calculate the percentage of objects w.r.t to total objects
    def calculate_object(data_dict, total_objects):
        count = {key: 0 for key, _ in total_objects.items()}
        for key, val in data_dict.items():
            for ann in val['annotations']:
                category_id = ann['category_id']
                count[category_id] = count[category_id] + 1
        for key, val in total_objects.items():
            count[key] = count[key] / val
        return count
            
    # Optimization
    img_name = list(img2annots.keys())
    obj_counts = {}
    best_error = 1.
    for i in range(max_iter):
        random.shuffle(img_name)
        
        for key, val in split_img_dict.items():
            obj_dict = {name: img2annots[name] for name in img_name[val[0]:val[1]]}
            obj_counts[key] = calculate_object(obj_dict, total_objects)
            
        error = 0
        for key1, val1 in split_dictionary.items():
            for key2, val2 in obj_counts[key1].items():
                error = error + (val1-val2)**2
        
        if error < best_error:
            best_error = deepcopy(error)
            best_img_name_seq = deepcopy(img_name)
    
    # Split the dataset
    annotations = {
        
    }
    for key1, val1 in split_img_dict.items():
        obj_dict = {name: img2annots[name] for name in best_img_name_seq[val1[0]:val1[1]]}
        annotations[key1] = {
            'type': input_annotations['type'],
            'categories': input_annotations['categories'],
            'images': [],
            'annotations': []
        }
        for key2, val2 in obj_dict.items():
            annotations[key1]['images'].append(val2['data'])
            annotations[key1]['annotations'].extend(val2['annotations'])
    return annotations

In [205]:
split_dictionary = {
    'train': 0.60,
    'val': 0.20,
    'test': 0.20
}

ann_split = train_test_split(annotations, split_dictionary)

In [213]:
print(len(ann_split['test']['annotations']))
print(len(ann_split['train']['annotations']))
print(len(ann_split['val']['annotations']))

4558
13897
4684
