In [10]:
import json
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold

annotation_path = "/home/work/MyProject/level2-objectdetection-cv-02/dataset/train.json"

with open(annotation_path) as f: 
    data = json.load(f)

var = [(ann['image_id'], ann['category_id']) for ann in data['annotations']]
X = np.ones((len(data['annotations']),1))
y = np.array([v[1] for v in var])
groups = np.array([v[0] for v in var])

cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=411)

train_json = {
    "images": [],
    "annotations": [],
    "categories": data["categories"]
}

validation_json = {
    "images": [],
    "annotations": [],
    "categories": data["categories"]
}

image_ids = set()
for fold, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    train_groups = set(groups[train_idx])
    val_groups = set(groups[val_idx])
    
    for image in data["images"]:
        if image["id"] in train_groups and image["id"] not in image_ids:
            train_json["images"].append(image)
            image_ids.add(image["id"])
        elif image["id"] in val_groups and image["id"] not in image_ids:
            validation_json["images"].append(image)
            image_ids.add(image["id"])
    
    for ann in data["annotations"]:
        if ann["image_id"] in train_groups:
            train_json["annotations"].append(ann)
        elif ann["image_id"] in val_groups:
            validation_json["annotations"].append(ann)

# JSON 형식으로 저장
with open('train_split.json', 'w') as f:
    json.dump(train_json, f, indent=2)

with open('val_split.json', 'w') as f:
    json.dump(validation_json, f, indent=2)

print(f"Train set: {len(train_json['images'])} images, {len(train_json['annotations'])} annotations")
print(f"Validation set: {len(validation_json['images'])} images, {len(validation_json['annotations'])} annotations")

Train set: 3914 images, 92576 annotations
Validation set: 969 images, 23144 annotations
