In [94]:
import json
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold

In [109]:
# annotation = {train.json dataset file 경로}
annotation = './dataset/train.json'

with open(annotation) as f: data = json.load(f)

var = [(ann['image_id'], ann['category_id']) for ann in data['annotations']]
X = np.ones((len(data['annotations']),1))
y = np.array([v[1] for v in var])
groups = np.array([v[0] for v in var])

cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=41)

for train_idx, val_idx in cv.split(X, y, groups):
    print("TRAIN:", groups[train_idx]) # image_id
    print(" ", y[train_idx])           # category_id
    print(" TEST:", groups[val_idx])
    print(" ", y[val_idx])

TRAIN: [   0    1    1 ... 4882 4882 4882]
  [0 3 7 ... 0 1 1]
 TEST: [   3    3    5 ... 4870 4877 4877]
  [2 6 7 ... 1 7 7]
TRAIN: [   1    1    1 ... 4882 4882 4882]
  [3 7 4 ... 0 1 1]
 TEST: [   0    4    4 ... 4876 4876 4880]
  [0 1 1 ... 0 2 0]
TRAIN: [   0    1    1 ... 4879 4879 4880]
  [0 3 7 ... 7 7 0]
 TEST: [   7    7   16 ... 4882 4882 4882]
  [9 9 6 ... 0 1 1]
TRAIN: [   0    2    3 ... 4882 4882 4882]
  [0 3 2 ... 0 1 1]
 TEST: [   1    1    1 ... 4879 4879 4879]
  [3 7 4 ... 0 7 7]
TRAIN: [   0    1    1 ... 4882 4882 4882]
  [0 3 7 ... 0 1 1]
 TEST: [   2    6   15 ... 4867 4867 4873]
  [3 1 6 ... 8 8 0]


In [110]:
from collections import Counter
import pandas as pd

In [111]:
def get_distribution(y):
    y_distr = Counter(y)
    y_vals_sum = sum(y_distr.values())

    return [f'{y_distr[i]/y_vals_sum:.2%}' for i in range(np.max(y) +1)]

In [112]:
distrs = [get_distribution(y)]
index = ['training set']

In [113]:
for fold_ind, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    train_y, val_y = y[train_idx], y[val_idx]
    train_gr, val_gr = groups[train_idx], groups[val_idx]

    assert len(set(train_gr) & set(val_gr)) == 0
    
    distrs.append(get_distribution(train_y))
    distrs.append(get_distribution(val_y))
    index.append(f'train - fold{fold_ind}')
    index.append(f'val - fold{fold_ind}')

In [64]:
categories = [d['name'] for d in data['categories']]
pd.DataFrame(distrs, index=index, columns = [categories[i] for i in range(np.max(y) + 1)])

Unnamed: 0,General trash,Paper,Paper pack,Metal,Glass,Plastic,Styrofoam,Plastic bag,Battery,Clothing
training set,17.14%,27.45%,3.88%,4.04%,4.24%,12.72%,5.46%,22.37%,0.69%,2.02%
train - fold0,17.36%,28.05%,3.88%,3.71%,4.11%,12.63%,5.27%,22.36%,0.69%,1.94%
val - fold0,16.29%,25.11%,3.88%,5.34%,4.74%,13.06%,6.16%,22.43%,0.67%,2.33%
train - fold1,16.96%,27.43%,3.92%,4.09%,4.21%,12.79%,5.49%,22.44%,0.63%,2.03%
val - fold1,17.87%,27.50%,3.70%,3.85%,4.38%,12.40%,5.34%,22.08%,0.91%,1.97%
train - fold2,17.12%,26.88%,3.76%,4.30%,4.64%,12.68%,5.58%,22.52%,0.61%,1.91%
val - fold2,17.19%,29.76%,4.33%,2.99%,2.64%,12.87%,4.96%,21.78%,1.01%,2.47%
train - fold3,17.10%,27.39%,3.79%,3.97%,4.32%,12.82%,5.45%,22.26%,0.76%,2.14%
val - fold3,17.26%,27.67%,4.21%,4.34%,3.96%,12.30%,5.49%,22.81%,0.38%,1.57%
train - fold4,17.14%,27.48%,4.03%,4.15%,3.94%,12.66%,5.49%,22.28%,0.74%,2.08%


In [82]:
# annotation = {dataset 경로/K-fold}
output_filename = "./dataset/K-fold"

In [86]:
for idx, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    train_images, val_images = [], []
    train_annotations, val_annotations = [], []
    for i in groups[train_idx]: # image_id
        train_images.append(data["images"][i].copy())
    for i in groups[val_idx]:   # image_id
        val_images.append(data["images"][i].copy())
    for annotation in data["annotations"]:
        if annotation["image_id"] in groups[val_idx]:
            val_annotations.append(annotation.copy())
        else:
            train_annotations.append(annotation.copy())

    train_split = {
            "images": train_images,
            "annotations": train_annotations,
            "info": data.get("info", {}),
            "licenses": data.get("licenses", []),
            "categories": data["categories"],
        }

    val_split = {
            "images": val_images,
            "annotations": val_annotations,
            "info": data.get("info", {}),
            "licenses": data.get("licenses", []),
            "categories": data["categories"],
        }
    
    output_files = []
    for split_type, split in zip(["train", "val"], [train_split, val_split]):
        output_files.append(output_filename + f"_{split_type}{idx+1}.json")
        with open(output_files[-1], "w") as f:
            json.dump(split, f, indent=2)