In [97]:
from sklearn.model_selection import KFold, StratifiedKFold
import numpy as np
import pandas as pd
import numpy as np
import json
from tqdm import tqdm
from pycocotools.coco import COCO
import os
import json
import copy
from glob import glob
import matplotlib.pyplot as plt
import seaborn as sns
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold


In [3]:
coco = COCO('/opt/ml/detection/dataset/train.json')

img_id = coco.getImgIds()
img_info = coco.loadImgs(img_id)
fnames = [info['file_name'] for info in img_info]

loading annotations into memory...
Done (t=0.18s)
creating index...
index created!
['train/0000.jpg' 'train/0001.jpg' 'train/0002.jpg' ... 'train/4880.jpg'
 'train/4881.jpg' 'train/4882.jpg']


In [24]:
ids = np.asarray([info['id'] for info in img_info])
len_img = len(ids)

def getClsPerImg(coco, ids, len_img): 
    cls_per_img = np.zeros((len_img, 10))
    for id in ids:
        ann_id = coco.getAnnIds(imgIds=id)
        ann_list = coco.loadAnns(ann_id)
        for ann in ann_list:
            label = ann['category_id']
            cls_per_img[id][label] += 1
    return cls_per_img

cls_per_img = getClsPerImg(coco, ids, len_img)

mskf = MultilabelStratifiedKFold(n_splits=8)
for train_index, test_index in mskf.split(ids, cls_per_img):
    print("TRAIN:", len(train_index), "TEST:", len(test_index))


    

TRAIN: 4308 TEST: 575
TRAIN: 4270 TEST: 613
TRAIN: 4229 TEST: 654
TRAIN: 4305 TEST: 578
TRAIN: 4269 TEST: 614
TRAIN: 4253 TEST: 630
TRAIN: 4275 TEST: 608
TRAIN: 4272 TEST: 611
4882
[[1. 0. 0. ... 0. 0. 0.]
 [2. 0. 0. ... 2. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [1. 0. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 3. 0. 0.]
 [1. 2. 0. ... 1. 0. 0.]]


In [65]:
check_num = [0 for _ in range(10)]
for id in train_index:
    ann_id = coco.getAnnIds(imgIds=id)
    ann_list = coco.loadAnns(ann_id)
    for ann in ann_list:
        label = ann['category_id']
        check_num[label] += 1
print(check_num)

check_num_test = [0 for _ in range(10)]
for id in test_index:
    
    ann_id = coco.getAnnIds(imgIds=id)
    ann_list = coco.loadAnns(ann_id)
    for ann in ann_list:
        label = ann['category_id']
        check_num_test[label] += 1
print(check_num_test)

[3457, 5588, 785, 820, 812, 2620, 1121, 4523, 127, 423]
[509, 764, 112, 116, 170, 323, 142, 655, 32, 45]


In [None]:
# train_index, test_index

In [75]:
with open(os.path.join('/opt/ml/detection/dataset/train.json'), 'r') as read_file:
    json_data = json.load(read_file)
print(json_data.keys())

dict_keys(['info', 'licenses', 'images', 'categories', 'annotations'])


In [93]:
json_data_train = copy.deepcopy(json_data)
json_data_val = copy.deepcopy(json_data)
print(json_data['annotations'][0])


{'image_id': 0, 'category_id': 0, 'area': 257301.66, 'bbox': [197.6, 193.7, 547.8, 469.7], 'iscrowd': 0, 'id': 0}


In [95]:
add_train_idx = []
add_val_idx = []

json_data_train['images']=[]
json_data_train['annotations']=[]
json_data_val['images']=[]
json_data_val['annotations']=[]

img_id_train_cnt = 0
img_id_val_cnt = 0
img_ann_id_match_train = dict()
img_ann_id_match_val = dict()
for i, data in enumerate(json_data['images']):
    temp_info = copy.deepcopy(json_data['images'][i])
    if data['id'] in train_index:
        json_data_train['images'].append(temp_info)
        json_data_train['images'][-1]['id']=img_id_train_cnt
        img_ann_id_match_train[data['id']]=img_id_train_cnt
        img_id_train_cnt += 1
        add_train_idx.append(i)
    else:
        json_data_val['images'].append(temp_info)
        json_data_val['images'][-1]['id']=img_id_val_cnt
        img_ann_id_match_val[data['id']]=img_id_val_cnt
        img_id_val_cnt += 1
        add_val_idx.append(i)

ann_id_train_cnt = 0
ann_id_val_cnt = 0
for i, data in enumerate(json_data['annotations']):
    temp_anno = copy.deepcopy(json_data['annotations'][i])
    if data['image_id'] in add_train_idx:
        json_data_train['annotations'].append(temp_anno)
        json_data_train['annotations'][-1]['id']=ann_id_train_cnt
        json_data_train['annotations'][-1]['image_id']=img_ann_id_match_train[data['image_id']]
        ann_id_train_cnt += 1
    else:
        json_data_val['annotations'].append(temp_anno)
        json_data_val['annotations'][-1]['id']=ann_id_val_cnt
        json_data_val['annotations'][-1]['image_id']=img_ann_id_match_val[data['image_id']]
        ann_id_val_cnt += 1

print(len(json_data_train['images']))
print(len(json_data_val['images']))
print(len(json_data_train['annotations']))
print(len(json_data_val['annotations']))


4272
611
20276
2868


In [96]:
with open('/opt/ml/SK_train_annotations.json', 'w', encoding='utf-8') as make_file:
    json.dump(json_data_train, make_file, indent='\t')

with open('/opt/ml/SK_val_annotations.json', 'w', encoding='utf-8') as make_file:
    json.dump(json_data_val, make_file, indent='\t')

In [106]:
coco = COCO('/opt/ml/SK_train_annotations.json')
cats = coco.loadCats(coco.getCatIds())
nms = [cat['name'] for cat in cats]
img_id = coco.getImgIds()
img_info = coco.loadImgs(img_id)
fnames = [info['file_name'] for info in img_info]
print(nms)

num_img = len(img_info)
from ipywidgets import interact
import skimage.io as io
import matplotlib.patches as patches

@interact(idx=(0, num_img-1))
def showImg(idx):
    fig, ax = plt.subplots(1, 1, dpi=150)
    img = io.imread(os.path.join('/opt/ml/detection/dataset', fnames[idx]))
    annIds = coco.getAnnIds(imgIds=idx)
    anns = coco.loadAnns(annIds)
    
    ax.imshow(img)
    for ann in anns:
        x,y,w,h = ann['bbox']
        ax.add_patch(
            patches.Rectangle(
                (x,y), w, h,
                edgecolor='white',
                fill=False,
                ),
            )
        text_y = y-30 if y>30 else y+30 
        ax.text(x,text_y, nms[ann['category_id']], color='white', fontsize='10')
        ax.set_xticks([])
        ax.set_yticks([])


loading annotations into memory...
Done (t=0.06s)
creating index...
index created!
['General trash', 'Paper', 'Paper pack', 'Metal', 'Glass', 'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing']


interactive(children=(IntSlider(value=2135, description='idx', max=4271), Output()), _dom_classes=('widget-int…