In [5]:
import os
import sys
from glob import glob
from tqdm import tqdm
import numpy as np
import pandas as pd
import shutil
sys.path.insert(0, '../')
from dataset import LabelEncoder
from utils import save_pickle, load_json

In [55]:
TRAIN_DIR = '../input/data/train/images/'
MOVE_DIR = '../preprocessed_stratified/'
VALID_RATE = 0.15
TEST_RATE = 0.1
SEED = 42

# Stratified Split
- 사람 단위로 분배. 즉, 마스크 상태는 고려하지 않음
- 따라서 성별, 나이를 고려하여 데이터 분배

In [34]:
train_dirs = glob(os.path.join(TRAIN_DIR, '*'))
train_info = pd.read_csv('../input/data/train/train.csv')
display(train_info.head())
len(train_dirs)

Unnamed: 0,id,gender,race,age,path
0,1,female,Asian,45,000001_female_Asian_45
1,2,female,Asian,52,000002_female_Asian_52
2,4,male,Asian,54,000004_male_Asian_54
3,5,female,Asian,58,000005_female_Asian_58
4,6,female,Asian,59,000006_female_Asian_59


2700

In [37]:
def age2ageg(age):
    if age < 30:
        return 0 # young
    elif age >= 30 and age < 60:
        return 1 # middle
    else:
        return 2 # old

In [38]:
train_info['ageg'] = train_info['age'].apply(lambda x: age2ageg(x))

In [40]:
(train_info.groupby(['gender', 'ageg']).count() / train_info.shape[0])[['id']].rename({'id':'density'}, axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,density
gender,ageg,Unnamed: 2_level_1
female,0,0.271111
female,1,0.302593
female,2,0.04037
male,0,0.203333
male,1,0.151852
male,2,0.030741


In [41]:
gender = ['female', 'male']
ageg = [0, 1, 2]

interation_dist = {
    ('female', 0): 0.27,
    ('female', 1): 0.30,
    ('female', 2): 0.04,
    ('male', 0): 0.20,
    ('male', 1): 0.15,
    ('male', 2): 0.04,
    }

In [None]:
for v in tqdm(train_dirs):
    img_paths = glob(os.path.join(TRAIN_DIR, v, '*'))
    for img_path in img_paths:
        new_name = '_'.join(img_path.split('/')[-2:])
        new_path = os.path.join(MOVE_DIR, 'train', new_name)
        shutil.copy(img_path, new_path)

In [75]:
for g in gender:
    for a in ageg:
        print('Gender', g, 'Ageg', a)
        origin_dirs = train_info[(train_info['gender'] == g) & (train_info['ageg'] == a)]['path'].tolist()
        VALID_SIZE = int(len(origin_dirs) * VALID_RATE)
        TEST_SIZE = int(len(origin_dirs) * TEST_RATE)
        valid_dirs = np.random.choice(origin_dirs, size=VALID_SIZE, replace=False).tolist()
        origin_dirs = [i for i in origin_dirs if i not in valid_dirs]
        test_dirs = np.random.choice(origin_dirs, size=TEST_SIZE, replace=False).tolist()
        train_dirs = [i for i in origin_dirs if i not in test_dirs]

        for v in tqdm(valid_dirs):
            img_paths = glob(os.path.join(TRAIN_DIR, v, '*'))
            for img_path in img_paths:
                new_name = '_'.join(img_path.split('/')[-2:])
                new_path = os.path.join(MOVE_DIR, 'valid', new_name)
                shutil.copy(img_path, new_path)

        for v in tqdm(test_dirs):
            img_paths = glob(os.path.join(TRAIN_DIR, v, '*'))
            for img_path in img_paths:
                new_name = '_'.join(img_path.split('/')[-2:])
                new_path = os.path.join(MOVE_DIR, 'test', new_name)
                shutil.copy(img_path, new_path)

        for v in tqdm(train_dirs):
            img_paths = glob(os.path.join(TRAIN_DIR, v, '*'))
            for img_path in img_paths:
                new_name = '_'.join(img_path.split('/')[-2:])
                new_path = os.path.join(MOVE_DIR, 'train', new_name)
                shutil.copy(img_path, new_path)

        

100%|██████████| 109/109 [00:00<00:00, 1292.48it/s]
100%|██████████| 73/73 [00:00<00:00, 1371.77it/s]
  0%|          | 0/550 [00:00<?, ?it/s]Gender female Ageg 0
100%|██████████| 550/550 [00:00<00:00, 1144.67it/s]
100%|██████████| 122/122 [00:00<00:00, 1140.67it/s]
100%|██████████| 81/81 [00:00<00:00, 1190.85it/s]
  0%|          | 0/614 [00:00<?, ?it/s]Gender female Ageg 1
100%|██████████| 614/614 [00:00<00:00, 1162.01it/s]
100%|██████████| 16/16 [00:00<00:00, 601.40it/s]
100%|██████████| 10/10 [00:00<00:00, 1080.14it/s]
100%|██████████| 83/83 [00:00<00:00, 1132.28it/s]
100%|██████████| 82/82 [00:00<00:00, 1218.87it/s]
  0%|          | 0/54 [00:00<?, ?it/s]Gender female Ageg 2
Gender male Ageg 0
100%|██████████| 54/54 [00:00<00:00, 1213.89it/s]
100%|██████████| 413/413 [00:00<00:00, 1068.93it/s]
100%|██████████| 61/61 [00:00<00:00, 1158.24it/s]
100%|██████████| 41/41 [00:00<00:00, 1070.26it/s]
 39%|███▊      | 119/308 [00:00<00:00, 1185.79it/s]Gender male Ageg 1
100%|██████████| 308/30

확인

In [77]:
print(len(glob(os.path.join(MOVE_DIR, 'train/*'))))
print(len(glob(os.path.join(MOVE_DIR, 'valid/*'))))
print(len(glob(os.path.join(MOVE_DIR, 'test/*'))))

14217
2814
1869
