In [1]:
import json
import os
import funcy
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold

In [3]:
# data annotation file 저장 경로
save_annotation_path = '/data/ephemeral/home/dataset'

# annotation file
annotation = '/data/ephemeral/home/dataset/train.json'

with open(annotation) as f:
    data = json.load(f)
    info = data['info']
    licences = data['licenses']
    images = data['images']
    categories = data['categories']
    anns = data['annotations']

In [4]:
# file을 json 형태로 저장
def save_coco(file, info, licenses, images, annotations, categories):
    with open(file, 'wt', encoding='UTF-8') as coco:
        json.dump({ 'info': info, 'licenses': licenses, 'images': images, 
            'annotations': annotations, 'categories': categories}, coco, indent=2, sort_keys=False)

def filter_annotations(annotations, images):
    image_ids = funcy.lmap(lambda i: int(i['id']), images)
    return funcy.lfilter(lambda a: int(a['image_id']) in image_ids, annotations)

def filter_images(images, annotations):
    ann_ids = funcy.lmap(lambda i: int(i['image_id']), annotations)
    return funcy.lfilter(lambda a: int(a['id']) in ann_ids, images)

In [6]:
var = [(ann['image_id'],ann['category_id']) for ann in anns]
X = np.ones((len(data['annotations']), 1))  # dummy
y = np.array([v[1] for v in var])   # category_id
groups = np.array([v[0] for v in var])  # group (image_id)

cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=137)

for fold_idx, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    train_anns = []
    val_anns = []
    
    train_file_name = f'train_kfold_{fold_idx}.json'
    val_file_name = f'val_kfold_{fold_idx}.json'
    
    for id in train_idx:
        train_anns.append(anns[id])
        
    train_anns = np.array(train_anns)
        
    for id in val_idx:
        val_anns.append(anns[id])
        
    val_anns = np.array(val_anns)
    
    save_coco(os.path.join(save_annotation_path, train_file_name),info,licences,filter_images(images,train_anns),filter_annotations(train_anns, images), categories)
    print(f'{fold_idx} train annotation saved as {train_file_name}')
    save_coco(os.path.join(save_annotation_path, val_file_name),info,licences,filter_images(images,val_anns),filter_annotations(val_anns, images), categories)
    print(f'{fold_idx} val annotation saved as {val_file_name}')
    print('')

0 train annotation saved as train_kfold_0.json
0 val annotation saved as val_kfold_0.json

1 train annotation saved as train_kfold_1.json
1 val annotation saved as val_kfold_1.json

2 train annotation saved as train_kfold_2.json
2 val annotation saved as val_kfold_2.json

3 train annotation saved as train_kfold_3.json
3 val annotation saved as val_kfold_3.json

4 train annotation saved as train_kfold_4.json
4 val annotation saved as val_kfold_4.json

