In [6]:
import os
import sys
from glob import glob
from tqdm import tqdm
import numpy as np
import pandas as pd
import pandas as pd
import shutil

from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn

sys.path.insert(0, '../')
from dataset import LabelEncoder
from utils import save_pickle, load_json, save_json
from config import Task

In [7]:

def cutmix(batch, alpha):
    data, targets = batch

    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]

    lam = np.random.beta(alpha, alpha)

    image_h, image_w = data.shape[2:]
    cx = np.random.uniform(0, image_w)
    cy = np.random.uniform(0, image_h)
    w = image_w * np.sqrt(1 - lam)
    h = image_h * np.sqrt(1 - lam)
    x0 = int(np.round(max(cx - w / 2, 0)))
    x1 = int(np.round(min(cx + w / 2, image_w)))
    y0 = int(np.round(max(cy - h / 2, 0)))
    y1 = int(np.round(min(cy + h / 2, image_h)))

    data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1]
    targets = (targets, shuffled_targets, lam)

    return data, targets


class CutMixCollator:
    def __init__(self, alpha):
        self.alpha = alpha

    def __call__(self, batch):
        batch = torch.utils.data.dataloader.default_collate(batch)
        batch = cutmix(batch, self.alpha)
        return batch


class CutMixCriterion:
    def __init__(self, reduction):
        self.criterion = nn.CrossEntropyLoss(reduction=reduction)

    def __call__(self, preds, targets):
        targets1, targets2, lam = targets
        return lam * self.criterion(
            preds, targets1) + (1 - lam) * self.criterion(preds, targets2)

In [8]:
import cv2
from PIL import Image
def mix_up(img1_path, img2_path, alpha = 0.5):
    """
    INPUT
    img1_path, img2_path : Mixup 하고자하는 image path
    alpha : img2_path 이미지의 label(영향도)

    OUTPUT
    im_pil : MixUp img(PIL.Image 타입)
    beta : img1_path 이미지의 label(영향도)
    alpha : img2_path 이미지의 label(영향도)
    (순서 유의)
    """
    img1 = cv2.imread(img1_path)
    img2 = cv2.imread(img2_path)
    beta = 1.0 - alpha
    
    dst = cv2.addWeighted(img1, beta, img2, alpha, 0)
    img = cv2.cvtColor(dst, cv2.COLOR_BGR2RGB)
    
    im_pil = Image.fromarray(img)
    return im_pil, beta, alpha

In [25]:
TRAIN_DIR = '../preprocessed_stratified/train'
TEST_DIR = '../preprocessed_stratified/test'
META = '../preprocessed_stratified/metadata.json'
SEED = 42

metadata = load_json(META)

train_size = len(glob(TRAIN_DIR + '/*'))
test_size = len(glob(TEST_DIR + '/*'))

print(train_size, test_size)

17031 1869


In [26]:
decoder = {
        'mask': {0: 'incorrect', 1:'wear', 2:'not_wear'},
        'gender': {0:'male', 1:'female'},
        'ageg': {0: 'young', 1: 'middle', 2: 'old'}
        }

In [27]:
mask_states = ['wear','incorrect', 'not_wear']
gender_states = ['male', 'female']
ageg_states = ['young', 'middle', 'old']

classes = []
for m in mask_states:
    for g in gender_states:
        for a in ageg_states:
            classes.append((m, g, a))
classes

[('wear', 'male', 'young'),
 ('wear', 'male', 'middle'),
 ('wear', 'male', 'old'),
 ('wear', 'female', 'young'),
 ('wear', 'female', 'middle'),
 ('wear', 'female', 'old'),
 ('incorrect', 'male', 'young'),
 ('incorrect', 'male', 'middle'),
 ('incorrect', 'male', 'old'),
 ('incorrect', 'female', 'young'),
 ('incorrect', 'female', 'middle'),
 ('incorrect', 'female', 'old'),
 ('not_wear', 'male', 'young'),
 ('not_wear', 'male', 'middle'),
 ('not_wear', 'male', 'old'),
 ('not_wear', 'female', 'young'),
 ('not_wear', 'female', 'middle'),
 ('not_wear', 'female', 'old')]

In [28]:
data_per_class = {i:[] for i in range(18)}

train_images = glob(os.path.join(TRAIN_DIR, '*'))

for train_image in train_images:
    name = os.path.basename(train_image)
    mask_state = decoder['mask'][metadata[name]['mask']]
    gender_state = decoder['gender'][metadata[name]['gender']]
    ageg_state = decoder['ageg'][metadata[name]['ageg']]
    main_state = classes.index((mask_state, gender_state, ageg_state))

    data_per_class[main_state].append(name)

In [31]:
counts = pd.Series(data_per_class).apply(lambda x: len(x))
counts

0     2475
1     1845
2      375
3     3295
4     3680
5      495
6      495
7      369
8       75
9      659
10     736
11      99
12     495
13     369
14      75
15     659
16     736
17      99
dtype: int64

In [32]:
(75 * 74) / 2

2775.0

In [34]:
NUM_AUGS = 3700
SEED = 42
num_augs = {i: j for i, j in enumerate((NUM_AUGS - counts).tolist())}
num_augs

{0: 1225,
 1: 1855,
 2: 3325,
 3: 405,
 4: 20,
 5: 3205,
 6: 3205,
 7: 3331,
 8: 3625,
 9: 3041,
 10: 2964,
 11: 3601,
 12: 3205,
 13: 3331,
 14: 3625,
 15: 3041,
 16: 2964,
 17: 3601}

In [35]:
np.random.seed(SEED)
num_augs_pairs = {i:[] for i in range(18)}
choosen = list()

for class_ in range(18):
    for _ in tqdm(range(num_augs[class_])):
        pair = np.random.choice(data_per_class[class_], size=2, replace=False).tolist()
        if pair not in choosen and pair[::-1] not in choosen:
            num_augs_pairs[class_].append(pair)
            choosen.extend([pair, pair[::-1]])
            

100%|██████████| 1225/1225 [00:00<00:00, 2790.99it/s]
100%|██████████| 1855/1855 [00:00<00:00, 1968.92it/s]
100%|██████████| 3325/3325 [00:02<00:00, 1389.81it/s]
100%|██████████| 405/405 [00:00<00:00, 858.70it/s]
100%|██████████| 20/20 [00:00<00:00, 781.29it/s]
100%|██████████| 3205/3205 [00:03<00:00, 912.12it/s]
100%|██████████| 3205/3205 [00:05<00:00, 597.62it/s]
100%|██████████| 3331/3331 [00:06<00:00, 485.43it/s]
100%|██████████| 3625/3625 [00:06<00:00, 546.24it/s]
100%|██████████| 3041/3041 [00:08<00:00, 338.76it/s]
100%|██████████| 2964/2964 [00:10<00:00, 285.33it/s]
100%|██████████| 3601/3601 [00:10<00:00, 327.43it/s]
100%|██████████| 3205/3205 [00:12<00:00, 256.67it/s]
100%|██████████| 3331/3331 [00:15<00:00, 210.88it/s]
100%|██████████| 3625/3625 [00:15<00:00, 238.85it/s]
100%|██████████| 3041/3041 [00:17<00:00, 171.72it/s]
100%|██████████| 2964/2964 [00:19<00:00, 153.39it/s]
100%|██████████| 3601/3601 [00:21<00:00, 171.44it/s]


In [36]:
for i in range(18):
    print(len(num_augs_pairs[i]))

1225
1855
3254
405
20
3170
3171
3274
2018
3018
2952
2544
3171
3241
1993
3019
2946
2533


In [37]:
for class_ in range(18):
    for p1, p2 in tqdm(num_augs_pairs[class_]):
        name = 'MIXUP_' + p1.split('.')[0] + '_' + p2.split('.')[0] + '.png'
        metainfo = metadata[p1]

        p1_path, p2_path = map(lambda x: os.path.join(TRAIN_DIR, x), [p1, p2])
        img_mixed = mix_up(p1_path, p2_path, alpha=0.5)[0]
        img_mixed.save(os.path.join(TRAIN_DIR, name), 'png')
        metadata[name] = metainfo

100%|██████████| 1225/1225 [01:44<00:00, 11.71it/s]
100%|██████████| 1855/1855 [02:53<00:00, 10.71it/s]
100%|██████████| 3254/3254 [05:07<00:00, 10.58it/s]
100%|██████████| 405/405 [00:34<00:00, 11.76it/s]
100%|██████████| 20/20 [00:01<00:00, 10.12it/s]
100%|██████████| 3170/3170 [04:37<00:00, 11.43it/s]
100%|██████████| 3171/3171 [04:29<00:00, 11.77it/s]
100%|██████████| 3274/3274 [05:06<00:00, 10.68it/s]
100%|██████████| 2018/2018 [03:12<00:00, 10.50it/s]
100%|██████████| 3018/3018 [04:18<00:00, 11.69it/s]
100%|██████████| 2952/2952 [04:39<00:00, 10.58it/s]
100%|██████████| 2544/2544 [03:40<00:00, 11.55it/s]
100%|██████████| 3171/3171 [04:28<00:00, 11.83it/s]
100%|██████████| 3241/3241 [05:04<00:00, 10.64it/s]
100%|██████████| 1993/1993 [03:10<00:00, 10.47it/s]
100%|██████████| 3019/3019 [04:18<00:00, 11.70it/s]
100%|██████████| 2946/2946 [04:37<00:00, 10.64it/s]
100%|██████████| 2533/2533 [03:39<00:00, 11.52it/s]


In [41]:
save_json('../preprocessed_stratified/metadata.json', metadata)

In [42]:
aug_data_per_class = {i:[] for i in range(18)}

train_images = glob(os.path.join(TRAIN_DIR, '*'))

for train_image in train_images:
    name = os.path.basename(train_image)
    mask_state = decoder['mask'][metadata[name]['mask']]
    gender_state = decoder['gender'][metadata[name]['gender']]
    ageg_state = decoder['ageg'][metadata[name]['ageg']]
    main_state = classes.index((mask_state, gender_state, ageg_state))

    aug_data_per_class[main_state].append(name)

In [43]:
for i in range(18):
    print(len(aug_data_per_class[i]))

3700
3700
3629
3700
3700
3665
3666
3643
2093
3677
3688
2643
3666
3610
2068
3678
3682
2632
