In [48]:
import os
import pandas as pd
from PIL import Image

import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torch.utils.data as data
import pandas as pd
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize

In [49]:
### Configurations
data_dir = '/opt/ml/input/data/train'
img_dir = f'{data_dir}/images'
df_path = f'{data_dir}/train.csv'

In [50]:
df = pd.read_csv(df_path)
df.head()

Unnamed: 0,id,gender,race,age,path
0,1,female,Asian,45,000001_female_Asian_45
1,2,female,Asian,52,000002_female_Asian_52
2,4,male,Asian,54,000004_male_Asian_54
3,5,female,Asian,58,000005_female_Asian_58
4,6,female,Asian,59,000006_female_Asian_59


In [51]:
mean, std = (0.5, 0.5, 0.5), (0.2, 0.2, 0.2)

In [52]:
from albumentations import *
from albumentations.pytorch import ToTensorV2


def get_transforms(need=('train', 'val'), img_size=(512, 384), mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246)):
    """
    train 혹은 validation의 augmentation 함수를 정의합니다. train은 데이터에 많은 변형을 주어야하지만, validation에는 최소한의 전처리만 주어져야합니다.
    
    Args:
        need: 'train', 혹은 'val' 혹은 둘 다에 대한 augmentation 함수를 얻을 건지에 대한 옵션입니다.
        img_size: Augmentation 이후 얻을 이미지 사이즈입니다.
        mean: 이미지를 Normalize할 때 사용될 RGB 평균값입니다.
        std: 이미지를 Normalize할 때 사용될 RGB 표준편차입니다.

    Returns:
        transformations: Augmentation 함수들이 저장된 dictionary 입니다. transformations['train']은 train 데이터에 대한 augmentation 함수가 있습니다.
    """
    transformations = {}
    if 'train' in need:
        transformations['train'] = Compose([
            Resize(img_size[0], img_size[1], p=1.0),
            HorizontalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            GaussNoise(p=0.5),
            Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)
    if 'val' in need:
        transformations['val'] = Compose([
            Resize(img_size[0], img_size[1]),
            Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)
    return transformations

## Define Dataset

In [53]:
# 마스크 여부, 성별 ,나이를 ampping 할 클래스 생성

class MaskLabels:
    mask = 0
    incorrect = 1
    normal = 2

class GenderLabels:
    male = 0
    female = 1

class AgeGroup:
    map_label = lambda x:0 if int(x) < 30 else 1 if int(x) < 60 else 2


In [65]:
class MaskBaseDataset(data.Dataset):
    num_classes = 3 * 2 * 3

    _file_names = {
        "mask1.jpg" : MaskLabels.mask,
        "mask2.jpg" : MaskLabels.mask,
        "mask3.jpg" : MaskLabels.mask,
        "mask4.jpg" : MaskLabels.mask,
        "mask5.jpg" : MaskLabels.mask,
        "incorrect_mask.jpg" : MaskLabels.incorrect,
        "normal.jpg" : MaskLabels.normal,
    }

    image_paths = []
    mask_labels = []
    gender_labels = []
    age_labels = []

    def __init__(self, img_dir, transform=None):

        self.img_dir = img_dir
        self.mean = mean
        self.std = std
        self.transform = transform

        self.setup()


    def set_transform(self, transform):

        self.transform


    def setup(self):
        """
        image의 경로와 각 이미지들의 label을 계산하여 저장하는 함수
        """
        profiles = os.listdir(self.img_dir)
        for profile in profiles:
            for file_name, label in self._file_names.items():
                img_path = os.path.join(self.img_dir, profile, file_name)
                if os.path.exists(img_path):
                    self.image_paths.append(img_path)
                    self.mask_labels.append(label)

                    id, gender, race, age = profile.split("_")
                    gender_label = getattr(GenderLabels, gender)
                    age_label = AgeGroup.map_label(age)

                    self.gender_labels.append(gender_label)
                    self.age_labels.append(age_label)

    
    def __getitem__(self, index):
        """
        데이터를 불러오는 함수입니다.
        데이터셋 class에 데이터 정보가 저장되어 있고, index를 통해 해당 위치에 있는 
        데이터 정보를 불러옵니다

        Args:
            index: 불러올 데이터 인덱스값
        """

        #이미지 불러오기
        image_path = self.image_paths[index]
        image = Image.open(image_path)


        #레이블 불러오기
        mask_label = self.mask_labels[index]
        gender_label = self.gender_labels[index]
        age_label = self.age_labels[index]
        multi_class_label = mask_label * 6 + gender_label * 3 + age_label

        #이미지를 Augmentation 시키기
        img_transform = self.transform(image=np.array(image))['image']
        return image_transform, multi_class_label

    def __len__(self):
        return len(self.image_paths)

In [66]:
# 정의한 Augmentation 함수와 Dataset 클래스 객체 생성
transform = get_transforms(mean=mean, std=std)

dataset = MaskBaseDataset(
    img_dir=img_dir
)

In [67]:
n_val = int(len(dataset) * 0.2)
n_train = len(dataset) - n_val

train_dataset, val_dataset = data.random_split(dataset, [n_train, n_val])

train_dataset.dataset.set_transform(transform['train'])
val_dataset.dataset.set_transform(transform['val'])


In [68]:
train_loader = data.DataLoader(
    train_dataset,
    batch_size = 12,
    num_workers = 4,
    shuffle = True
)


val_loader = data.DataLoader(
    val_dataset,
    batch_size = 12,
    num_workers = 4,
    shuffle=False
)

In [69]:
images, labels = next(iter(train_loader))
print(f'images shape: {images.shape}')
print(f'labels shape: {labels.shape}')

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb4092ff5e0>Exception ignored in: 
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb4092ff5e0>Traceback (most recent call last):

  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
Traceback (most recent call last):
      File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1203, in __del__
self._shutdown_workers()    
self._shutdown_workers()  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers

  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
        Exception ignored in: w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)<function _MultiProcessingDataLoaderIter.__del__ at 0x7fb4092ff5e0>
  File "/opt/conda/lib/python3.8/multiprocessing/

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataset.py", line 272, in __getitem__
    return self.dataset[self.indices[idx]]
  File "<ipython-input-65-f74cc28d77f3>", line 76, in __getitem__
    img_transform = self.transform(image=np.array(image))['image']
TypeError: 'NoneType' object is not callable
