In [1]:
import os
import glob

import pandas as pd

from sklearn.model_selection import StratifiedGroupKFold

In [2]:
# workding directory 세팅
working_dir = 'synthesis-car-od'
dir_len = len(working_dir)

path = os.getcwd().replace('\\', '/')
index = path.find(working_dir)
working_dir = path[:index + dir_len + 1]
working_dir

'/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/'

In [3]:
# 데이터 경로 세팅
DATA_PATH = os.path.join(working_dir, 'data/')
TRAIN_PATH = os.path.join(DATA_PATH, 'train/')

In [4]:
# 이미지 경로
imgs = sorted(glob.glob(os.path.join(TRAIN_PATH, '*.png')))
imgs[:5]

['/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00000.png',
 '/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00001.png',
 '/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00002.png',
 '/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00003.png',
 '/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00004.png']

In [5]:
# label txt 경로
txts = sorted(glob.glob(os.path.join(TRAIN_PATH, '*.txt')))
txts[:5]

['/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00000.txt',
 '/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00001.txt',
 '/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00002.txt',
 '/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00003.txt',
 '/Users/a16/Desktop/JJ/self_study/projects/synthesis-car-od/data/train/syn_00004.txt']

In [6]:
img_names = []
txt_names = []
labels = []
bboxes = []

for img, txt in zip(imgs, txts):
    img_name = img.replace('\\', '/').split('/')[-1]
    txt_name = txt.replace('\\', '/').split('/')[-1]

    with open(txt, 'r') as t:
        lines = t.readlines()

        for line in lines:
            line = line.strip()
            label = int(float(line.split(' ')[0]))
            bbox = ' '.join(line.split(' ')[1:])

            img_names.append(img_name)
            txt_names.append(txt_name)
            labels.append(label)
            bboxes.append(bbox)

print(img_names[:5])
print(txt_names[:5])
print(labels[:5])
print(bboxes[:5])

['syn_00000.png', 'syn_00000.png', 'syn_00000.png', 'syn_00001.png', 'syn_00001.png']
['syn_00000.txt', 'syn_00000.txt', 'syn_00000.txt', 'syn_00001.txt', 'syn_00001.txt']
[9, 25, 12, 16, 14]
['1037 209 1312 209 1312 448 1037 448', '804 425 1127 425 1127 783 804 783', '330 250 583 250 583 511 330 511', '1000 98 1295 98 1295 405 1000 405', '678 175 926 175 926 421 678 421']


In [7]:
train_df = pd.DataFrame({
    'img' : img_names,
    'txt' : txt_names,
    'label' : labels,
    'bbox' : bboxes
})
train_df.head()

Unnamed: 0,img,txt,label,bbox
0,syn_00000.png,syn_00000.txt,9,1037 209 1312 209 1312 448 1037 448
1,syn_00000.png,syn_00000.txt,25,804 425 1127 425 1127 783 804 783
2,syn_00000.png,syn_00000.txt,12,330 250 583 250 583 511 330 511
3,syn_00001.png,syn_00001.txt,16,1000 98 1295 98 1295 405 1000 405
4,syn_00001.png,syn_00001.txt,14,678 175 926 175 926 421 678 421


In [8]:
SEED = 41

In [9]:
sgkf = StratifiedGroupKFold(
    n_splits=5,
    shuffle=True,
    random_state=SEED
)

In [10]:
train_indices = []
val_indices = []

for train_idx, val_idx in sgkf.split(train_df['img'], train_df['label'], train_df['img']):
    train_indices.append(train_idx)
    val_indices.append(val_idx)

print(train_indices)
print(val_indices)

[array([    3,     4,     5, ..., 16993, 16997, 16998]), array([    0,     1,     2, ..., 16995, 16996, 16999]), array([    0,     1,     2, ..., 16997, 16998, 16999]), array([    0,     1,     2, ..., 16997, 16998, 16999]), array([    0,     1,     2, ..., 16997, 16998, 16999])]
[array([    0,     1,     2, ..., 16995, 16996, 16999]), array([    6,     7,    10, ..., 16982, 16997, 16998]), array([   19,    20,    34, ..., 16986, 16987, 16988]), array([    3,     4,     5, ..., 16977, 16978, 16979]), array([   16,    17,    18, ..., 16973, 16992, 16993])]


In [11]:
for i, (train_idx, val_idx) in enumerate(zip(train_indices, val_indices)):
    print(f'fold{i}')
    print(f'train fold{i}')
    display(train_df.iloc[train_idx, 2].value_counts().sort_index())
    print(f'val_fold{i}')
    display(train_df.iloc[val_idx, 2].value_counts().sort_index())

fold0
train fold0


label
0     403
1     404
2     415
3     390
4     409
5     395
6     387
7     399
8     389
9     404
10    411
11    404
12    404
13    402
14    392
15    402
16    389
17    411
18    414
19    409
20    410
21    395
22    398
23    397
24    382
25    405
26    412
27    408
28    393
29    391
30    399
31    413
32    390
33    399
Name: count, dtype: int64

val_fold0


label
0      97
1      96
2      85
3     110
4      91
5     105
6     113
7     101
8     111
9      96
10     89
11     96
12     96
13     98
14    108
15     98
16    111
17     89
18     86
19     91
20     90
21    105
22    102
23    103
24    118
25     95
26     88
27     92
28    107
29    109
30    101
31     87
32    110
33    101
Name: count, dtype: int64

fold1
train fold1


label
0     420
1     405
2     396
3     397
4     404
5     411
6     405
7     409
8     397
9     403
10    395
11    403
12    398
13    404
14    395
15    403
16    404
17    390
18    393
19    400
20    404
21    384
22    386
23    405
24    403
25    412
26    382
27    397
28    402
29    405
30    395
31    392
32    402
33    397
Name: count, dtype: int64

val_fold1


label
0      80
1      95
2     104
3     103
4      96
5      89
6      95
7      91
8     103
9      97
10    105
11     97
12    102
13     96
14    105
15     97
16     96
17    110
18    107
19    100
20     96
21    116
22    114
23     95
24     97
25     88
26    118
27    103
28     98
29     95
30    105
31    108
32     98
33    103
Name: count, dtype: int64

fold2
train fold2


label
0     395
1     400
2     402
3     416
4     402
5     387
6     398
7     398
8     414
9     401
10    402
11    400
12    396
13    406
14    412
15    393
16    397
17    398
18    397
19    397
20    410
21    406
22    404
23    406
24    407
25    394
26    398
27    409
28    382
29    387
30    408
31    397
32    407
33    387
Name: count, dtype: int64

val_fold2


label
0     105
1     100
2      98
3      84
4      98
5     113
6     102
7     102
8      86
9      99
10     98
11    100
12    104
13     94
14     88
15    107
16    103
17    102
18    103
19    103
20     90
21     94
22     96
23     94
24     93
25    106
26    102
27     91
28    118
29    113
30     92
31    103
32     93
33    113
Name: count, dtype: int64

fold3
train fold3


label
0     392
1     398
2     400
3     385
4     392
5     402
6     396
7     406
8     401
9     401
10    393
11    386
12    398
13    389
14    401
15    405
16    407
17    406
18    406
19    396
20    372
21    404
22    412
23    407
24    396
25    391
26    392
27    402
28    413
29    404
30    399
31    412
32    403
33    422
Name: count, dtype: int64

val_fold3


label
0     108
1     102
2     100
3     115
4     108
5      98
6     104
7      94
8      99
9      99
10    107
11    114
12    102
13    111
14     99
15     95
16     93
17     94
18     94
19    104
20    128
21     96
22     88
23     93
24    104
25    109
26    108
27     98
28     87
29     96
30    101
31     88
32     97
33     78
Name: count, dtype: int64

fold4
train fold4


label
0     390
1     393
2     387
3     412
4     393
5     405
6     414
7     388
8     399
9     391
10    399
11    407
12    404
13    399
14    400
15    397
16    403
17    395
18    390
19    398
20    404
21    411
22    400
23    385
24    412
25    398
26    416
27    384
28    410
29    413
30    399
31    386
32    398
33    395
Name: count, dtype: int64

val_fold4


label
0     110
1     107
2     113
3      88
4     107
5      95
6      86
7     112
8     101
9     109
10    101
11     93
12     96
13    101
14    100
15    103
16     97
17    105
18    110
19    102
20     96
21     89
22    100
23    115
24     88
25    102
26     84
27    116
28     90
29     87
30    101
31    114
32    102
33    105
Name: count, dtype: int64