In [1]:
import os
import sys
from glob import glob
from tqdm import tqdm
import numpy as np
import pandas as pd
import pandas as pd
import shutil

from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn

sys.path.insert(0, '../')
from dataset import LabelEncoder
from utils import save_pickle, load_json
from config import Task

In [2]:

def cutmix(batch, alpha):
    data, targets = batch

    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]

    lam = np.random.beta(alpha, alpha)

    image_h, image_w = data.shape[2:]
    cx = np.random.uniform(0, image_w)
    cy = np.random.uniform(0, image_h)
    w = image_w * np.sqrt(1 - lam)
    h = image_h * np.sqrt(1 - lam)
    x0 = int(np.round(max(cx - w / 2, 0)))
    x1 = int(np.round(min(cx + w / 2, image_w)))
    y0 = int(np.round(max(cy - h / 2, 0)))
    y1 = int(np.round(min(cy + h / 2, image_h)))

    data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1]
    targets = (targets, shuffled_targets, lam)

    return data, targets


class CutMixCollator:
    def __init__(self, alpha):
        self.alpha = alpha

    def __call__(self, batch):
        batch = torch.utils.data.dataloader.default_collate(batch)
        batch = cutmix(batch, self.alpha)
        return batch


class CutMixCriterion:
    def __init__(self, reduction):
        self.criterion = nn.CrossEntropyLoss(reduction=reduction)

    def __call__(self, preds, targets):
        targets1, targets2, lam = targets
        return lam * self.criterion(
            preds, targets1) + (1 - lam) * self.criterion(preds, targets2)

In [3]:
import cv2
from PIL import Image
def mix_up(img1_path, img2_path, alpha = 0.5):
    """
    INPUT
    img1_path, img2_path : Mixup 하고자하는 image path
    alpha : img2_path 이미지의 label(영향도)

    OUTPUT
    im_pil : MixUp img(PIL.Image 타입)
    beta : img1_path 이미지의 label(영향도)
    alpha : img2_path 이미지의 label(영향도)
    (순서 유의)
    """
    img1 = cv2.imread(img1_path)
    img2 = cv2.imread(img2_path)
    beta = 1.0 - alpha
    
    dst = cv2.addWeighted(img1, beta, img2, alpha, 0)
    img = cv2.cvtColor(dst, cv2.COLOR_BGR2RGB)
    
    im_pil = Image.fromarray(img)
    return im_pil, beta, alpha

In [4]:
TRAIN_DIR = '../preprocessed_stratified/train'
META = '../preprocessed_stratified/metadata.json'
SEED = 42

In [81]:
class TrainDataset(Dataset):
    def __init__(self, root: str, transform=None, task: str=Task.All, meta_path:str=META):
        """마스크 상태, 나이, 나이대, 클래스(0~17)의 4가지 태스크에 따른 레이블링을 지원하는 데이터셋
        """
        self.img_paths = glob(os.path.join(root, '*'))
        self.metadata = load_json(meta_path)
        self.task = task
        self.transform = transform
    
    def __getitem__(self, index):
        name = os.path.basename(self.img_paths[index])
        img = Image.open(self.img_paths[index])
        label = self.metadata[name]

        if self.task != Task.All:
            label = label[self.task]

        if self.transform is not None:
            img = self.transform(img)

        return img, label
        
    def __len__(self):
        return len(self.img_paths)

In [8]:
metadata = load_json(META)

In [9]:
decoder = {
        'mask': {0: 'incorrect', 1:'wear', 2:'not_wear'},
        'gender': {0:'male', 1:'female'},
        'ageg': {0: 'young', 1: 'middle', 2: 'old'}
        }

In [10]:
mask_states = ['wear','incorrect', 'not_wear']
gender_states = ['male', 'female']
ageg_states = ['young', 'middle', 'old']

classes = []
for m in mask_states:
    for g in gender_states:
        for a in ageg_states:
            classes.append((m, g, a))
classes

[('wear', 'male', 'young'),
 ('wear', 'male', 'middle'),
 ('wear', 'male', 'old'),
 ('wear', 'female', 'young'),
 ('wear', 'female', 'middle'),
 ('wear', 'female', 'old'),
 ('incorrect', 'male', 'young'),
 ('incorrect', 'male', 'middle'),
 ('incorrect', 'male', 'old'),
 ('incorrect', 'female', 'young'),
 ('incorrect', 'female', 'middle'),
 ('incorrect', 'female', 'old'),
 ('not_wear', 'male', 'young'),
 ('not_wear', 'male', 'middle'),
 ('not_wear', 'male', 'old'),
 ('not_wear', 'female', 'young'),
 ('not_wear', 'female', 'middle'),
 ('not_wear', 'female', 'old')]

In [11]:
data_per_class = {i:[] for i in range(18)}

train_images = glob(os.path.join(TRAIN_DIR, '*'))

for train_image in train_images:
    name = os.path.basename(train_image)
    mask_state = decoder['mask'][metadata[name]['mask']]
    gender_state = decoder['gender'][metadata[name]['gender']]
    ageg_state = decoder['ageg'][metadata[name]['ageg']]
    main_state = classes.index((mask_state, gender_state, ageg_state))

    data_per_class[main_state].append(name)

In [12]:
counts = pd.Series(data_per_class).apply(lambda x: len(x))

In [None]:
counts

각 클래스별 3100개로 증강

In [13]:
NUM_AUGS = 3100
SEED = 42
num_augs = {i: j for i, j in enumerate((NUM_AUGS - counts).tolist())}
num_augs

{0: 1035,
 1: 1560,
 2: 2785,
 3: 350,
 4: 30,
 5: 2685,
 6: 2687,
 7: 2792,
 8: 3037,
 9: 2550,
 10: 2486,
 11: 3017,
 12: 2687,
 13: 2792,
 14: 3037,
 15: 2550,
 16: 2486,
 17: 3017}

In [14]:
np.random.seed(SEED)
num_augs_pairs = {i:[] for i in range(18)}

for class_ in range(18):
    for _ in range(num_augs[class_]):
        num_augs_pairs[class_].append(np.random.choice(data_per_class[class_], size=2, replace=False).tolist())

In [15]:
class_ = 0
idx = 0
for class_ in range(18):
    for p1, p2 in tqdm(num_augs_pairs[class_]):
        name = 'MIXUP_' + p1.split('.')[0] + '_' + p2.split('.')[0] + '.png'
        metainfo = metadata[p1]

        p1_path, p2_path = map(lambda x: os.path.join(TRAIN_DIR, x), [p1, p2])
        img_mixed = mix_up(p1_path, p2_path, alpha=0.5)[0]
        img_mixed.save(os.path.join(TRAIN_DIR, name), 'png')
        metadata[name] = metainfo

In [None]:
aug_data_per_class = {i:[] for i in range(18)}

train_images = glob(os.path.join(TRAIN_DIR, '*'))

for train_image in train_images:
    name = os.path.basename(train_image)
    mask_state = decoder['mask'][metadata[name]['mask']]
    gender_state = decoder['gender'][metadata[name]['gender']]
    ageg_state = decoder['ageg'][metadata[name]['ageg']]
    main_state = classes.index((mask_state, gender_state, ageg_state))

    aug_data_per_class[main_state].append(name)