In [1]:
# Reference
# https://linuxhint.com/python_xml_to_dictionary/

In [80]:
import os
import xmltodict
import math

In [6]:
base = './data/ori'

In [38]:
def convert2coco(base):
    
    def xml_to_dict(base, filename):
        with open(os.path.join(base, filename), "r") as xml_obj:
            my_dict = xmltodict.parse(xml_obj.read())
            xml_obj.close()
        return my_dict
    
    filenames = [filename for filename in os.listdir(base) if os.path.splitext(filename)[-1] == '.xml']
    
    annotations = {
        "type": "instances",
        "images": [],
        "categories": [],
        "annotations": []
    }
    
    img2annots = {}

    cls_dict = {}
       
    # Get all classes in the dataset
    for filename in filenames:
        ann  = xml_to_dict(base, filename)['annotation']
        for obj in ann['object']:
            cls_name = obj['name']
            if cls_name not in cls_dict:
                cls_dict[cls_name] = len(cls_dict) + 1
    
    img_id = 1
    ann_id = 1
    for filename in filenames:
        img2annots[filename] = {
            'data': None,
            'annotations': [],
            'num_objects': None
        }
        
        ann  = xml_to_dict(base, filename)['annotation']
            
        for ann_obj in ann['object']:
            bbx = ann_obj['bndbox']
            xmin, ymin = int(bbx['xmin']), int(bbx['ymin'])
            xmax, ymax = int(bbx['xmax']), int(bbx['ymax'])
            dx = xmax - xmin
            dy = ymax - ymin
            
            annot = {
                "id": ann_id,
                "bbox": [xmin, ymin, dx, dy],
                "image_id": img_id,
                "category_id": cls_dict[ann_obj['name']],
                "segmentation": [],
                "area": dx*dy,
                "iscrowd": 0
            }
            annotations["annotations"].append(annot)
            ann_id = ann_id + 1
            
            img2annots[filename]['annotations'].append(annot)
        
        size = ann['size']
        image = {
            "file_name": ann['filename'],
            "height":size['height'] ,
            "width": size['width'],
            "id": img_id
        }
        annotations["images"].append(image)
        img_id = img_id + 1
        
        img2annots[filename]['data'] = image       

    for cls_name, cls_id in cls_dict.items():
        annotations["categories"].append({
            "supercategory": "none",
            "name": cls_name,
            "id": cls_id
        })
        
    for filename, val in img2annots.items():
        img2annots[filename]['num_objects'] = {cls_id: 0 for _, cls_id in cls_dict.items()}
        for obj in val['annotations']:
            category_id = obj['category_id']
            img2annots[filename]['num_objects'][category_id] = img2annots[filename]['num_objects'][category_id] + 1
        
    return annotations, img2annots

In [39]:
annotations, img2annots = convert2coco(base)

In [127]:
def train_test_split(img2annots, split_dictionary, max_iter=10):
    total = sum([val for _, val in split_dictionary.items()])
    split_dict = {key: val/total for key, val in split_dictionary.items()}

    # Get the number of objects
    temp = img2annots[list(img2annots.keys())[0]]['num_objects']
    total_objects = {key: 0 for key, _ in temp.items()}
    for key, val in img2annots.items():
        for cat_id, cat_n in val['num_objects'].items():
            total_objects[cat_id] = total_objects[cat_id] + cat_n
    
    # Get the spit_size for every set
    total_img = len(img2annots.keys())
    split_size = {}
    split_size_img = {}
    for key1, val1 in split_dict.items():
        # annotations
        split_size[key1] = {}
        for key2, val2 in total_objects.items():
            split_size[key1][key2] = math.ceil(val1*val2)
        
        # images
        split_size_img[key1] = math.ceil(val1*total_img)
            
    split_img_dict = {}
    start_idx = 0
    for key, val in split_size_img.items():
        split_img_dict[key] = [start_idx, min(start_idx + val, total_img)]
        start_idx = start_idx + val
        
    print(split_img_dict)

In [128]:
split_dictionary = {
    'train': 0.60,
    'val': 0.20,
    'test': 0.20
}

train_test_split(img2annots, split_dictionary)

{'train': [0, 108], 'val': [108, 144], 'test': [144, 180]}
