In [1]:
import random
import numpy as np
import pandas as pd
from collections import Counter, defaultdict

In [2]:
def stratified_group_k_fold(X, y, groups, k, seed=None):
    labels_num = np.max(y) + 1

    # y_counts_per_group = {image_id : [0, 0, ..](11개), ...}
    # y_distr = 각 클래스 별 개수
    y_counts_per_group = defaultdict(lambda: np.zeros(labels_num))
    y_distr = Counter()
    for label, g in zip(y, groups):
        y_counts_per_group[g][label] += 1
        y_distr[label] += 1

    y_counts_per_fold = defaultdict(lambda: np.zeros(labels_num))
    groups_per_fold = defaultdict(set)

    def eval_y_counts_per_fold(y_counts, fold):
        y_counts_per_fold[fold] += y_counts
        std_per_label = []
        for label in range(labels_num):
            label_std = np.std([y_counts_per_fold[i][label] / y_distr[label] for i in range(k)])
            std_per_label.append(label_std)
        y_counts_per_fold[fold] -= y_counts
        return np.mean(std_per_label)
    
    groups_and_y_counts = list(y_counts_per_group.items())
    random.Random(seed).shuffle(groups_and_y_counts)

    for g, y_counts in sorted(groups_and_y_counts, key=lambda x: -np.std(x[1])): # 각 이미지 별 label 벡터 전체의 std 큰 것부터 내림차순 정렬
        best_fold = None
        min_eval = None
        for i in range(k):
            fold_eval = eval_y_counts_per_fold(y_counts, i)
            if min_eval is None or fold_eval < min_eval:
                min_eval = fold_eval
                best_fold = i
        y_counts_per_fold[best_fold] += y_counts
        groups_per_fold[best_fold].add(g)

    all_groups = set(groups)
    for i in range(k):
        train_groups = all_groups - groups_per_fold[i]
        test_groups = groups_per_fold[i]

        train_indices = [i for i, g in enumerate(groups) if g in train_groups]
        test_indices = [i for i, g in enumerate(groups) if g in test_groups]

        # Yield는 함수가 제너레이터를 반환한다는 것을 제외하고 return과 비슷하게 사용되는 키워드
        yield train_indices, test_indices

In [3]:
train_x = pd.read_csv('../input/data/train_all.csv')
train_y = train_x.class_id - 1 # class id 1씩 빼줌 (나중에 다시 원상복귀 필요)
train_y = train_y.values # 각 데이터 클래스 id
groups = np.array(train_x.image_id.values) # 각 데이터 이미지 id

def get_distribution(y_vals):
        y_distr = Counter(y_vals) # 각 label 몇개 있는지
        y_vals_sum = sum(y_distr.values()) # 전체 label 개수
        return [f'{y_distr[i] / y_vals_sum:.2%}' for i in range(np.max(y_vals) + 1)]

In [4]:
distrs = [get_distribution(train_y)]
index = ['training set']
developement_set = []
validation_set = []
dev_group = [] # list형태 (안에 k개의 array 들어있음 - 각 array는 train dataset에 해당하는 annotation들의 image_id 나열)
val_group = [] # list형태 (안에 k개의 array 들어있음 - 각 array는 valid dataset에 해당하는 annotation들의 image_id 나열)
for fold_ind, (dev_ind, val_ind) in enumerate(stratified_group_k_fold(train_x, train_y, groups, k=5, seed=42)):
    dev_y, val_y = train_y[dev_ind], train_y[val_ind]
    dev_groups, val_groups = groups[dev_ind], groups[val_ind]
    
    assert len(set(dev_groups) & set(val_groups)) == 0
    
    distrs.append(get_distribution(dev_y))
    index.append(f'development set - fold {fold_ind}')
    distrs.append(get_distribution(val_y))
    index.append(f'validation set - fold {fold_ind}')
    developement_set.append(dev_ind)
    validation_set.append(val_ind)
    dev_group.append(dev_groups)
    val_group.append(val_groups)

display('Distribution per class:')
pd.DataFrame(distrs, index=index, columns=[f'Label {l+1}' for l in range(np.max(train_y) + 1)]) # train_y label 하나씩 올림

'Distribution per class:'

Unnamed: 0,Label 1,Label 2,Label 3,Label 4,Label 5,Label 6,Label 7,Label 8,Label 9,Label 10
training set,10.60%,35.48%,2.51%,2.14%,2.32%,11.78%,5.12%,29.13%,0.24%,0.67%
development set - fold 0,10.60%,35.49%,2.51%,2.14%,2.33%,11.78%,5.12%,29.13%,0.24%,0.67%
validation set - fold 0,10.61%,35.47%,2.51%,2.15%,2.32%,11.77%,5.12%,29.11%,0.25%,0.69%
development set - fold 1,10.60%,35.49%,2.51%,2.14%,2.33%,11.78%,5.12%,29.13%,0.24%,0.67%
validation set - fold 1,10.61%,35.46%,2.51%,2.15%,2.32%,11.77%,5.12%,29.12%,0.25%,0.69%
development set - fold 2,10.60%,35.48%,2.51%,2.14%,2.32%,11.78%,5.12%,29.13%,0.24%,0.68%
validation set - fold 2,10.59%,35.48%,2.52%,2.13%,2.32%,11.78%,5.13%,29.13%,0.25%,0.67%
development set - fold 3,10.60%,35.48%,2.51%,2.14%,2.32%,11.77%,5.12%,29.13%,0.24%,0.68%
validation set - fold 3,10.60%,35.50%,2.52%,2.14%,2.33%,11.78%,5.11%,29.13%,0.23%,0.67%
development set - fold 4,10.60%,35.48%,2.51%,2.14%,2.32%,11.77%,5.12%,29.12%,0.24%,0.68%


In [16]:
for i in range(5):
    devgroup_num = len(set(dev_group[i]))
    valgroup_num = len(set(val_group[i]))
    print(f"k-{i}",devgroup_num, valgroup_num, devgroup_num + valgroup_num)

k-0 2607 664 3271
k-1 2613 658 3271
k-2 2622 649 3271
k-3 2621 650 3271
k-4 2621 650 3271


In [21]:
for i in range(5):
    devset_num = len(developement_set[i])
    valset_num = len(validation_set[i])
    print(f"k-{i}",devset_num, valset_num, devset_num + valset_num)

k-0 20988 5252 26240
k-1 20989 5251 26240
k-2 20992 5248 26240
k-3 20995 5245 26240
k-4 20996 5244 26240


In [22]:
import json

train_json = "/opt/ml/segmentation/input/data/train_all.json"
with open(train_json, "r", encoding="utf8") as outfile:
    json_data = json.load(outfile)
base_annotations = json_data['annotations']
base_images = json_data['images']

In [None]:
annotations = [base_annotations[index] for index in developement_set[0]]
json_data['annotations'] = annotations
print(len(json_data['annotations']))

In [None]:
# split train json 파일 만들기
for i in range(5):
    annotations = [base_annotations[index] for index in developement_set[i]]
    images = [base_images[index] for index in sorted(list(set(dev_group[i])))]
    json_data['annotations'] = annotations
    json_data['images'] = images
    with open(f'train_{i}.json', 'w', encoding='utf-8') as make_file:
        json.dump(json_data, make_file, indent="\t")

In [None]:
# split valid json 파일 만들기
for i in range(5):
    annotations = [base_annotations[index] for index in validation_set[i]]
    images = [base_images[index] for index in sorted(list(set(val_group[i])))]
    json_data['annotations'] = annotations
    json_data['images'] = images
    with open(f'valid_{i}.json', 'w', encoding='utf-8') as make_file:
        json.dump(json_data, make_file, indent="\t")