In [2]:
import os
import json
import random
import numpy as np
import torch
from sklearn.model_selection import StratifiedGroupKFold

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


In [32]:
seed = 21
input_json = '/opt/ml/input/data/train.json'
output_dir = '/opt/ml/input/data'

seed_everything(21)

In [33]:
with open(input_json) as f:
    data = json.load(f)

images = data['images']
annotations = data['annotations']
categories = data['categories']

print('\nN images\t', len(images), '\nN annotations\t', len(annotations), '\nN categories\t', len(categories), '\n')
print('data keys\t', data.keys())
print('image keys\t', images[0].keys())
print('annotation keys\t', annotations[0].keys())


N images	 2617 
N annotations	 20988 
N categories	 10 

data keys	 dict_keys(['info', 'licenses', 'images', 'categories', 'annotations'])
image keys	 dict_keys(['license', 'url', 'file_name', 'height', 'width', 'date_captured', 'id'])
annotation keys	 dict_keys(['id', 'image_id', 'category_id', 'segmentation', 'area', 'bbox', 'iscrowd'])


In [34]:
print(annotations[10]['id'])
print(annotations[10]['image_id'])
print(annotations[10]['category_id'])
print(annotations[10]['iscrowd'])

10
0
8
0


In [35]:
annotation_infos = [(ann['image_id'], ann['category_id']) for ann in annotations]
X = np.ones((len(annotations), 1))          # (20988, 1)
y = np.array([annotation_info[1] for annotation_info in annotation_infos])          # category_ids (20988)
groups = np.array([annotation_info[0] for annotation_info in annotation_infos])     # image_ids (20988)

cv = StratifiedGroupKFold(n_splits=3, shuffle=True, random_state=seed)

for idx, (train_ids, val_ids) in enumerate(cv.split(X, y, groups)):
    train_images = [x for x in images if x.get('id') in groups[train_ids]]
    val_images = [x for x in images if x.get('id') in groups[val_ids]]
    train_annotations = [x for x in annotations if x.get('image_id') in groups[train_ids]]
    val_annotations = [x for x in annotations if x.get('image_id') in groups[val_ids]]

    train_data = {
        'images': train_images,
        'annotations': train_annotations,
        'categories': categories,
    }

    val_data = {
        'images': val_images,
        'annotations': val_annotations,
        'categories': categories,
    }
    
    output_seed_dir = os.path.join(output_dir, f'seed{seed}')
    os.makedirs(output_seed_dir, exist_ok=True)
    output_train_json = os.path.join(output_seed_dir, f'train_{idx}.json')
    output_val_json = os.path.join(output_seed_dir, f'val_{idx}.json')

    
    with open(output_train_json, 'w') as train_writer:
        json.dump(train_data, train_writer)
    print(f'done. {output_train_json}')
    
    with open(output_val_json, 'w') as val_writer:
        json.dump(val_data, val_writer)
    print(f'done. {output_val_json}')

done. /opt/ml/input/data/seed21/train_0.json
done. /opt/ml/input/data/seed21/val_0.json
done. /opt/ml/input/data/seed21/train_1.json
done. /opt/ml/input/data/seed21/val_1.json
done. /opt/ml/input/data/seed21/train_2.json
done. /opt/ml/input/data/seed21/val_2.json
