In [1]:
import json
import os
import random
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold

In [2]:
def split_dataset(input_json, output_dir, random_seed):
    random.seed(random_seed)

    with open(input_json) as json_reader:
        dataset = json.load(json_reader)

    images = dataset['images']
    annotations = dataset['annotations']
    categories = dataset['categories']

    k_var = [(ann['image_id'], ann['category_id']) for ann in annotations]
    X = np.ones((len(annotations),1))
    y = np.array([v[1] for v in k_var])
    groups = np.array([v[0] for v in k_var])
    
    #file_name에 prefix 디렉토리까지 포함 (CocoDataset 클래스를 사용하는 경우)
    #for image in images:
    #    image['file_name'] = '{}/{}'.format(image['file_name'][0], image['file_name'])
        
    cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=411)
    
    for k, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
        train_ids = groups[train_idx]
        val_ids = groups[val_idx]
        
        image_ids_val, image_ids_train = set(val_ids), set(train_ids)

        train_images = [x for x in images if x.get('id') in image_ids_train]
        val_images = [x for x in images if x.get('id') in image_ids_val]
        train_annotations = [x for x in annotations if x.get('image_id') in image_ids_train]
        val_annotations = [x for x in annotations if x.get('image_id') in image_ids_val]

        train_data = {
            'images': train_images,
            'annotations': train_annotations,
            'categories': categories,
        }

        val_data = {
            'images': val_images,
            'annotations': val_annotations,
            'categories': categories,
        }

        output_seed_dir = os.path.join(output_dir, f'seed{random_seed}')
        os.makedirs(output_seed_dir, exist_ok=True)
        output_train_json = os.path.join(output_seed_dir, f'train_{k}.json')
        output_val_json = os.path.join(output_seed_dir, f'val_{k}.json')

        print(f'write {output_train_json}')
        with open(output_train_json, 'w') as train_writer:
            json.dump(train_data, train_writer)

        print(f'write {output_val_json}')
        with open(output_val_json, 'w') as val_writer:
            json.dump(val_data, val_writer)
        

In [3]:
split_dataset("../../dataset/train.json", "../../dataset", 2022 )

write ../../dataset/seed2022/train_0.json
write ../../dataset/seed2022/val_0.json
write ../../dataset/seed2022/train_1.json
write ../../dataset/seed2022/val_1.json
write ../../dataset/seed2022/train_2.json
write ../../dataset/seed2022/val_2.json
write ../../dataset/seed2022/train_3.json
write ../../dataset/seed2022/val_3.json
write ../../dataset/seed2022/train_4.json
write ../../dataset/seed2022/val_4.json


In [4]:
train_data = json.load(open('../../dataset/seed2022/train_0.json'))

print('training data')
print(f'images: {len(train_data["images"])}')
print(f'annotations: {len(train_data["annotations"])}')
print(f'categories: {len(train_data["categories"])}')

training data
images: 3914
annotations: 18633
categories: 10


In [5]:
train_data = json.load(open('../../dataset/seed2022/train_1.json'))

print('training data')
print(f'images: {len(train_data["images"])}')
print(f'annotations: {len(train_data["annotations"])}')
print(f'categories: {len(train_data["categories"])}')

training data
images: 3906
annotations: 18075
categories: 10


In [6]:
val_data = json.load(open('../../dataset/seed2022/val_0.json'))

print('validation data')
print(f'images: {len(val_data["images"])}')
print(f'annotations: {len(val_data["annotations"])}')
print(f'categories: {len(val_data["categories"])}')

validation data
images: 969
annotations: 4511
categories: 10


In [7]:
val_data = json.load(open('../../dataset/seed2022/val_1.json'))

print('validation data')
print(f'images: {len(val_data["images"])}')
print(f'annotations: {len(val_data["annotations"])}')
print(f'categories: {len(val_data["categories"])}')

validation data
images: 977
annotations: 5069
categories: 10
