In [14]:
import json
import funcy
import os
import shutil
from sklearn.model_selection import train_test_split
from skmultilearn.model_selection import iterative_train_test_split
import numpy as np

# Define constants
ANNOTATIONS = '/home/guthix/Projects/mail-detector/neuralnet/dataset/COCO/labels.json'
TRAIN = '/home/guthix/Projects/mail-detector/neuralnet/dataset/COCO/training/labels.json'
TEST = '/home/guthix/Projects/mail-detector/neuralnet/dataset/COCO/testing/labels.json'
SPLIT = 0.85
HAVING_ANNOTATIONS = True 
MULTI_CLASS = True  
IMAGE_DIR = '/home/guthix/Projects/mail-detector/neuralnet/dataset/COCO/data'  
TRAIN_DIR = '/home/guthix/Projects/mail-detector/neuralnet/dataset/COCO/training/data'  
TEST_DIR = '/home/guthix/Projects/mail-detector/neuralnet/dataset/COCO/testing/data'

def save_coco(file, info, licenses, images, annotations, categories):
    with open(file, 'wt', encoding='UTF-8') as coco:
        json.dump({ 'info': info, 'licenses': licenses, 'images': images, 
            'annotations': annotations, 'categories': categories}, coco, indent=2, sort_keys=True)

def filter_annotations(annotations, images):
    image_ids = funcy.lmap(lambda i: int(i['id']), images)
    return funcy.lfilter(lambda a: int(a['image_id']) in image_ids, annotations)

def filter_images(images, annotations):
    annotation_ids = funcy.lmap(lambda i: int(i['image_id']), annotations)
    return funcy.lfilter(lambda a: int(a['id']) in annotation_ids, images)

def move_images(images, target_dir):
    for img in images:
        img_name = img['file_name']
        imgname = img_name.split('/')[-1]
        shutil.copy(os.path.join(IMAGE_DIR, img_name), os.path.join(target_dir, img_name))

def main():
    with open(ANNOTATIONS, 'rt', encoding='UTF-8') as annotations:
        coco = json.load(annotations)
        info = coco['info']
        licenses = coco['licenses']
        images = coco['images']
        annotations = coco['annotations']
        categories = coco['categories']

        number_of_images = len(images)

        images_with_annotations = funcy.lmap(lambda a: int(a['image_id']), annotations)

        if HAVING_ANNOTATIONS:
            images = funcy.lremove(lambda i: i['id'] not in images_with_annotations, images)

        if MULTI_CLASS:
            annotation_categories = funcy.lmap(lambda a: int(a['category_id']), annotations)
            annotation_categories =  funcy.lremove(lambda i: annotation_categories.count(i) <=1  , annotation_categories)
            annotations =  funcy.lremove(lambda i: i['category_id'] not in annotation_categories  , annotations)

            X_train, y_train, X_test, y_test = iterative_train_test_split(np.array([annotations]).T,np.array([ annotation_categories]).T, test_size = 1-SPLIT)

            save_coco(TRAIN, info, licenses, filter_images(images, X_train.reshape(-1)), X_train.reshape(-1).tolist(), categories)
            save_coco(TEST, info, licenses,  filter_images(images, X_test.reshape(-1)), X_test.reshape(-1).tolist(), categories)

            move_images(filter_images(images, X_train.reshape(-1)), TRAIN_DIR)
            move_images(filter_images(images, X_test.reshape(-1)), TEST_DIR)

            print("Saved {} entries in {} and {} in {}".format(len(X_train), TRAIN, len(X_test), TEST))
        else:
            X_train, X_test = train_test_split(images, train_size=SPLIT)
            
            anns_train = filter_annotations(annotations, X_train)
            anns_test=filter_annotations(annotations, X_test)

            save_coco(TRAIN, info, licenses, X_train, anns_train, categories)
            save_coco(TEST, info, licenses, X_test, anns_test, categories)

            move_images(X_train, TRAIN_DIR)
            move_images(X_test, TEST_DIR)

            print("Saved {} entries in {} and {} in {}".format(len(anns_train), TRAIN, len(anns_test), TEST))

main()

Saved 128 entries in /home/guthix/Projects/mail-detector/neuralnet/dataset/COCO/training/labels.json and 23 in /home/guthix/Projects/mail-detector/neuralnet/dataset/COCO/testing/labels.json
