In [3]:
#!pip install -U scikit-learn

Collecting scikit-learn
  Downloading scikit_learn-1.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.7 MB)
[K     |████████████████████████████████| 9.7 MB 19.3 MB/s eta 0:00:01
Collecting joblib>=1.1.1
  Downloading joblib-1.2.0-py3-none-any.whl (297 kB)
[K     |████████████████████████████████| 297 kB 54.8 MB/s eta 0:00:01
Collecting threadpoolctl>=2.0.0
  Downloading threadpoolctl-3.1.0-py3-none-any.whl (14 kB)
Installing collected packages: joblib, threadpoolctl, scikit-learn
Successfully installed joblib-1.2.0 scikit-learn-1.2.0 threadpoolctl-3.1.0


In [83]:
import os
import json
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
from collections import defaultdict

In [84]:
# annotation = {train.json dataset file 경로}
annotation = '/opt/ml/input/data/train_all.json'

with open(annotation) as f: 
    data = json.load(f)

In [85]:
# output file 저장경로
output_filename = "/opt/ml/input/data/K-fold"

In [86]:
import numpy as np
category_ids = [ann['category_id'] for ann in data['annotations']]

np.unique(category_ids)

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])

In [87]:
var = [(ann['image_id'], ann['category_id']) for ann in data['annotations']]
X = np.ones((len(data['annotations']),1))
y = np.array([v[1] for v in var])      # category_id
groups = np.array([v[0] for v in var]) # image_id

cv = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=41)

for train_idx, val_idx in cv.split(X, y, groups):
    print("TRAIN:", groups[train_idx]) # image_id
    print(" ", y[train_idx])           # category_id
    print(" TEST:", groups[val_idx])
    print(" ", y[val_idx])

TRAIN: [   0    0    0 ... 3270 3270 3271]
  [8 8 6 ... 2 2 4]
 TEST: [   5    5    9 ... 3263 3263 3265]
  [2 6 3 ... 1 7 4]
TRAIN: [   0    0    0 ... 3269 3269 3271]
  [8 8 6 ... 1 1 4]
 TEST: [  15   15   15 ... 3270 3270 3270]
  [2 8 8 ... 2 2 2]
TRAIN: [   1    1    1 ... 3270 3270 3271]
  [8 8 8 ... 2 2 4]
 TEST: [   0    0    0 ... 3269 3269 3269]
  [8 8 6 ... 7 1 1]
TRAIN: [   0    0    0 ... 3270 3270 3271]
  [8 8 6 ... 2 2 4]
 TEST: [   4    4    7 ... 3261 3261 3261]
  [6 6 1 ... 1 2 2]
TRAIN: [   0    0    0 ... 3270 3270 3270]
  [8 8 6 ... 2 2 2]
 TEST: [   1    1    1 ... 3264 3264 3271]
  [8 8 8 ... 5 6 4]


In [88]:
from collections import Counter
import pandas as pd

In [89]:
def get_distribution(y):
    y_distr = Counter(y)
    y_vals_sum = sum(y_distr.values())

    return [f'{y_distr[i]/y_vals_sum:.2%}' for i in range(1, np.max(y)+1)]

In [90]:
distrs = [get_distribution(y)]
index = ['training set']

In [91]:
for fold_ind, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    train_y, val_y = y[train_idx], y[val_idx]
    train_gr, val_gr = groups[train_idx], groups[val_idx]

    assert len(set(train_gr) & set(val_gr)) == 0
    
    distrs.append(get_distribution(train_y))
    distrs.append(get_distribution(val_y))
    index.append(f'train - fold{fold_ind}')
    index.append(f'val - fold{fold_ind}')

In [92]:
categories = [d['name'] for d in data['categories']]
pd.DataFrame(distrs, index=index, columns = [categories[i] for i in range(np.max(y))])

Unnamed: 0,General trash,Paper,Paper pack,Metal,Glass,Plastic,Styrofoam,Plastic bag,Battery,Clothing
training set,10.60%,35.48%,2.51%,2.14%,2.32%,11.78%,5.12%,29.13%,0.24%,0.67%
train - fold0,10.68%,35.82%,2.41%,2.07%,2.44%,11.14%,5.21%,29.26%,0.23%,0.74%
val - fold0,10.29%,34.15%,2.93%,2.44%,1.86%,14.31%,4.74%,28.60%,0.29%,0.40%
train - fold1,10.73%,35.50%,2.55%,2.06%,2.23%,11.74%,4.95%,29.34%,0.26%,0.64%
val - fold1,10.13%,35.41%,2.38%,2.44%,2.68%,11.91%,5.72%,28.33%,0.18%,0.82%
train - fold2,10.66%,35.29%,2.58%,2.08%,2.34%,12.22%,5.01%,28.95%,0.22%,0.64%
val - fold2,10.37%,36.27%,2.22%,2.40%,2.24%,9.96%,5.57%,29.84%,0.31%,0.82%
train - fold3,10.53%,35.30%,2.52%,2.23%,2.28%,11.94%,5.19%,29.08%,0.24%,0.70%
val - fold3,10.90%,36.22%,2.48%,1.80%,2.50%,11.13%,4.84%,29.31%,0.25%,0.58%
train - fold4,10.42%,35.50%,2.50%,2.27%,2.33%,11.84%,5.23%,29.00%,0.25%,0.65%


In [98]:
for idx, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
    train_images, val_images = [], []
    train_annotations, val_annotations = [], []
    train_temp, val_temp = [], []
    index = 0 
    for i in groups[train_idx]: # image_id
        if data["images"][i]["id"] not in train_temp:
            train_temp.append(data["images"][i]["id"])
            temp = data["images"][i].copy()
            temp["id"] = index
            index += 1
            train_images.append(temp)
    index = 0
    for i in groups[val_idx]:   # image_id
        if data["images"][i]["id"] not in val_temp:
            val_temp.append(data["images"][i]["id"])
            temp = data["images"][i].copy()
            temp["id"] = index
            index += 1
            val_images.append(temp)
    train_index = dict(zip(train_temp,range(len(train_temp))))
    val_index = dict(zip(val_temp,range(len(val_temp))))
    for annotation in data["annotations"]:
        if annotation["image_id"] in groups[val_idx]:
            temp = annotation.copy()
            temp["image_id"] = val_index[temp["image_id"]]
            val_annotations.append(temp)
        else:
            temp = annotation.copy()
            temp["image_id"] = train_index[temp["image_id"]]
            train_annotations.append(temp)

    train_split = {
            "info": data.get("info", {}),
            "licenses": data.get("licenses", []),
            "images": train_images,
            "categories": data["categories"],
            "annotations": train_annotations,           
        }

    val_split = {
            "info": data.get("info", {}),
            "licenses": data.get("licenses", []),
            "images": val_images,
            "categories": data["categories"],
            "annotations": val_annotations,
        }
    
    output_files = []
    for split_type, split in zip(["train", "val"], [train_split, val_split]):
        output_files.append(output_filename + f"_{split_type}{idx+1}.json")
        with open(output_files[-1], "w") as f:
            json.dump(split, f, indent=2)
            
print("Split Done !")

Split Done !
