In [1]:
import os
import sys
from glob import glob
import numpy as np
import pandas as pd
import cv2
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from time import time

import torch
import torch.utils.data as data

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
data_dir = '/opt/ml/input/data/train'
img_dir = os.path.join(data_dir,'images')
df_path = os.path.join(data_dir,'train.csv')

In [3]:
df = pd.read_csv(df_path)
df.head()

Unnamed: 0,id,gender,race,age,path
0,1,female,Asian,45,000001_female_Asian_45
1,2,female,Asian,52,000002_female_Asian_52
2,4,male,Asian,54,000004_male_Asian_54
3,5,female,Asian,58,000005_female_Asian_58
4,6,female,Asian,59,000006_female_Asian_59


In [4]:
def get_ext(img_dir,img_id):
    file_name = os.listdir(os.path.join(img_dir,img_id))[0]
    ext = os.path.splitext(file_name)[-1].lower()
    return ext

In [5]:
def get_img_stats(img_dir,img_ids):
    img_info = dict(heights=[], widths=[], means=[], stds=[])
    for img_id in tqdm(img_ids):
        for path in glob(os.path.join(img_dir,img_id,'*')):
            img = np.array(Image.open(path))
            h,w,_ = img.shape
            img_info['heights'].append(h)
            img_info['widths'].append(w)
            img_info['means'].append(img.mean(axis=(0,1)))
            img_info['stds'].append(img.std(axis=(0,1)))
    return img_info
            

In [6]:
img_info = get_img_stats(img_dir,df['path'])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2700.0), HTML(value='')))




KeyboardInterrupt: 

In [None]:
print(f"{np.mean(img_info['means'],axis=0)/255.}")
print(f'RGB Standard Deviation: {np.mean(img_info["stds"], axis=0) / 255.}')

# Augmentation Function

In [None]:
mean, std = (0.5, 0.5, 0.5), (0.2, 0.2, 0.2)

In [None]:
import albumentations
from albumentations.pytorch import ToTensorV2

In [None]:
def get_transforms(need=('train', 'val'), img_size=(512, 384), mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246)):
    transformations={}
    if 'train' in need:
        transformations['train']= albumentations.Compose([
            albumentations.Resize(img_size[0],img_size[1],p=1.),
            albumentations.HorizontalFlip(0.5),
            albumentations.ShiftScaleRotate(0.5),
            albumentations.HueSaturationValue(hue_shift_limit=0.2,sat_shift_limit=0.2,val_shift_limit=0.2,p=0.5),
            albumentations.RandomBrightnessContrast(brightness_limit=0.1,contrast_limit=0.1),
            albumentations.GaussNoise(p=0.5),
            albumentations.Normalize(mean=mean,std=std,max_pixel_value=255.0,p=1.0),
            ToTensorV2(p=1.0)],p=1.0)
    if 'val' in need:
        transformations['val']= albumentations.Compose([
            albumentations.Resize(img_size[0],img_size[1],p=1.),
            albumentations.Normalize(mean=mean,std=std,max_pixel_value=255.0,p=1.0),
            ToTensorV2(p=1.0)
        ],p=1.)
    return transformations

In [None]:
### 마스크 여부, 성별, 나이를 mapping할 클래스를 생성합니다.

class MaskLabels:
    mask = 0
    incorrect = 1
    normal = 2

class GenderLabels:
    male = 0
    female = 1

class AgeGroup:
    map_label = lambda x: 0 if int(x) < 30 else 1 if int(x) < 60 else 2

In [None]:
class MaskBaseDataset(data.Dataset):
    num_classes = 3*2*3
    
    _file_names = {
        "mask1.jpg": MaskLabels.mask,
        "mask2.jpg": MaskLabels.mask,
        "mask3.jpg": MaskLabels.mask,
        "mask4.jpg": MaskLabels.mask,
        "mask5.jpg": MaskLabels.mask,
        "incorrect_mask.jpg": MaskLabels.incorrect,
        "normal.jpg": MaskLabels.normal
    }
    
    image_paths=[]
    mask_labels=[]
    gender_labels=[]
    age_labels=[]
    
    def __init__(self,img_dir,transform=None):
        self.img_dir=img_dir
        self.transform = transform
        
        self.setup()
        
    def set_transform(self,transform):
        self.transform=transform
        
    def setup(self):
        profiles = os.listdir(self.img_dir)
        for profile in profiles:
            for file_name,label in self._file_names.items():
                img_path = os.path.join(self.img_dir,profile,file_name)
                if os.path.exists(img_path):
                    self.image_paths.append(img_path)
                    self.mask_labels.append(label)
                    
                    id, gender, race, age = profile.split('_')
                    gender_label = getattr(GenderLabels,gender)
                    age_label = AgeGroup.map_label(age)
                    
                    self.gender_labels.append(gender_label)
                    self.age_labels.append(age_label)
                    
    def __getitem__(self,index):
        image_path = self.image_paths[index]
        image = Image.open(image_path)
        
        mask_label = self.mask_labels[index]
        gender_label = self.gender_labels[index]
        age_label = self.age_labels[index]
        multi_class_label = mask_label*6 + gender_label*3 + age_label
        
        image_transform = self.transform(image=np.array(image))['image']
        return image_transform, multi_class_label
    
    def __len__(self):
        return len(self.image_paths)

In [None]:
transform = get_transforms(mean=mean,std=std)

dataset = MaskBaseDataset(img_dir=img_dir)

n_val = int(len(dataset)*0.2)
n_train = len(dataset)-n_val
train_dataset,val_dataset = data.random_split(dataset,[n_train,n_val])

In [None]:
train_dataset.dataset.set_transform(transform['train'])
val_dataset.dataset.set_transform(transform['val'])

# DataLoader

In [None]:
train_loader = data.DataLoader(
    train_dataset,
    batch_size=12,
    num_workers=4,
    shuffle=True)
val_loader = data.DataLoader(
    val_dataset,
    batch_size=12,
    num_workers=4,
    shuffle=False)

In [None]:
images, labels = next(iter(train_loader))
print(f'images shape: {images.shape}')
print(f'labels shape: {labels.shape}')

In [None]:
from torchvision import transforms

# Augmentation으로 이미지를 Normalize했기 때문에, 역으로 다시 Normalize 해주어야합니다.
inv_normalize = transforms.Normalize(
    mean=[-m / s for m, s in zip(mean, std)],
    std=[1 / s for s in std]
)

n_rows, n_cols = 4, 3

fig, axes = plt.subplots(n_rows, n_cols, sharex=True, sharey=True, figsize=(16, 24))
for i in range(n_rows*n_cols):
    axes[i%n_rows][i//(n_cols+1)].imshow(inv_normalize(images[i]).permute(1, 2, 0))
    axes[i%n_rows][i//(n_cols+1)].set_title(f'Label: {labels[i]}', color='r')
plt.tight_layout()