In [1]:
# Cell 1: Import required libraries
import argparse
import os.path as osp
import json
import mmengine
import numpy as np
import os

from sklearn.model_selection import StratifiedGroupKFold

In [None]:
# Cell 2: Function to parse arguments (for notebook, we don't need argparse, so we will use default values)
def parse_args():
    class Args:
        data_root = '/data/ephemeral/home/level2-objectdetection-cv-12/dataset'
        out_dir = '/data/ephemeral/home/dataset/k-fold-final/'
        fold = 10
    return Args()

In [None]:
# Cell 3: Function to extract labels and groups from the dataset
def extract_labels_and_groups(data):
    labels = []
    groups = []
    for img in data['images']:
        img_annotations = [ann for ann in data['annotations'] if ann['image_id'] == img['id']]
        if img_annotations:
            # Use the category ID of the first annotation as the representative label
            labels.append(img_annotations[0]['category_id'])
        else:
            labels.append(0)  # Set label to 0 if there are no annotations
        groups.append(img['id'])
    return np.array(labels), np.array(groups)

In [None]:
# Cell 4: Function to save annotation files
def save_anns(name, images, annotations, original_data, out_dir):
    sub_anns = {'images': images, 'annotations': annotations, 'licenses': original_data['licenses'], 'categories': original_data['categories'], 'info': original_data['info']}
    mmengine.mkdir_or_exist(out_dir)
    mmengine.dump(sub_anns, os.path.join(out_dir, name))

In [None]:
# Cell 5: Function to perform Stratified Group K-Fold split
def stratified_group_kfold_split(data, out_dir, fold):
    labels, groups = extract_labels_and_groups(data)
    sgkf = StratifiedGroupKFold(n_splits=fold, shuffle=True, random_state=2024)
    for f, (train_idx, val_idx) in enumerate(sgkf.split(groups, labels, groups), 1):
        train_images = [data['images'][i] for i in train_idx]
        val_images = [data['images'][i] for i in val_idx]
        train_annotations = [ann for ann in data['annotations'] if ann['image_id'] in train_idx]
        val_annotations = [ann for ann in data['annotations'] if ann['image_id'] in val_idx]
        save_anns(f'train_fold_{f}.json', train_images, train_annotations, data, out_dir)
        save_anns(f'val_fold_{f}.json', val_images, val_annotations, data, out_dir)

In [None]:
# Cell 6: Load the dataset and execute the K-fold split
# Instead of argparse, we will manually set the arguments here
args = parse_args()
with open(os.path.join(args.data_root, 'clean-final.json')) as f:
    data = json.load(f)
    
stratified_group_kfold_split(data, args.out_dir, args.fold)