In [None]:
#coco split
import json
import random

def split_coco(annotation_path, train_out, val_out, val_ratio=0.2, seed=42):
    random.seed(seed)
    with open(annotation_path, 'r', encoding='utf-8') as f:
        coco = json.load(f)

    images = coco['images']
    anns = coco['annotations']
    cats = coco['categories']

    img_ids = [img['id'] for img in images]
    random.shuffle(img_ids)

    n_val = int(len(img_ids) * val_ratio)
    val_ids = set(img_ids[:n_val])
    train_ids = set(img_ids[n_val:])

    def filter_imgs(ids):
        return [img for img in images if img['id'] in ids]

    def filter_anns(ids):
        return [ann for ann in anns if ann['image_id'] in ids]

    train_coco = {
        'images': filter_imgs(train_ids),
        'annotations': filter_anns(train_ids),
        'categories': cats
    }
    val_coco = {
        'images': filter_imgs(val_ids),
        'annotations': filter_anns(val_ids),
        'categories': cats
    }

    with open(train_out, 'w', encoding='utf-8') as f:
        json.dump(train_coco, f, ensure_ascii=False)
    with open(val_out, 'w', encoding='utf-8') as f:
        json.dump(val_coco, f, ensure_ascii=False)


In [None]:
if __name__ == "__main__":
    split_coco(
        annotation_path="../../dataset/train.json",
        train_out="../../dataset/train_split.json",
        val_out="../../dataset/val_split.json",
        val_ratio=0.2,
        seed=42,
    )