In [1]:
import os
import random
from enum import Enum
from sklearn.model_selection import StratifiedKFold

In [2]:
class MaskLabels(int, Enum):
    MASK = 0
    INCORRECT = 1
    NORMAL = 2


class GenderLabels(int, Enum):
    MALE = 0
    FEMALE = 1

    @classmethod
    def from_str(cls, value: str) -> int:
        value = value.lower()
        if value == "male":
            return cls.MALE
        elif value == "female":
            return cls.FEMALE
        else:
            raise ValueError(f"Gender value should be either 'male' or 'female', {value}")


class AgeLabels(int, Enum):
    YOUNG = 0
    MIDDLE = 1
    OLD = 2

    @classmethod
    def from_number(cls, value: str) -> int:
        try:
            value = int(value)
        except Exception:
            raise ValueError(f"Age value should be numeric, {value}")

        if value < 30:
            return cls.YOUNG
        elif value < 60:
            return cls.MIDDLE
        else:
            return cls.OLD
        
_file_names = {
    "mask1": MaskLabels.MASK,
    "mask2": MaskLabels.MASK,
    "mask3": MaskLabels.MASK,
    "mask4": MaskLabels.MASK,
    "mask5": MaskLabels.MASK,
    "incorrect_mask": MaskLabels.INCORRECT,
    "normal": MaskLabels.NORMAL
}

def encode_multi_class(mask_label, gender_label, age_label) -> int:
    return mask_label * 6 + gender_label * 3 + age_label

In [3]:
data_dir = '/opt/ml/input/data/train/images'
downsample = True
n_fold = 5

image_paths = []

mask_labels = []
gender_labels = []
age_labels = []
multi_labels = []

indices = {}
folds = []

def k_stratified_fold(profiles):
# profiles = [profile for profile in os.listdir(data_dir) if not profile.startswith('.')]
    profile_labels = []
    for profile in profiles:
        id, gender, _, age = profile.split('_')
        gender_label = GenderLabels.from_str(gender)
        age_label = AgeLabels.from_number(age)
        profile_label = encode_multi_class(0, gender_label, age_label)
        profile_labels.append(profile_label)

    skf = StratifiedKFold(n_splits=n_fold)
    kfold_profiles = []
    for train_index, val_index in skf.split(profiles, profile_labels):
        kfold_profiles.append({
            'train': train_index,
            'val': val_index
        })
    return kfold_profiles


def setup():
    profiles = [profile for profile in os.listdir(data_dir) if not profile.startswith('.')]
    kfold_profiles = k_stratified_fold(profiles)
    for kfold_profile in kfold_profiles:
        for phase, indices in kfold_profile:
            for _idx in indices:
                include_mask = True
                profile = profiles[_idx]
                img_folder = os.path.join(data_dir, profile)
                lst_dir = os.listdir(img_folder)
                random.shuffle(lst_dir)  # in-place operation (add randomness in selected images with mask label)
                for file_name in lst_dir:
                    _file_name, ext = os.path.splitext(file_name)
                    if _file_name not in _file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                        continue
                    if ext != '.jpg':
                        continue
                    if downsample and file_name.startswith('mask'):
                        if not include_mask:
                            continue
                        include_mask = False # include only 1 mask image per profile

                    img_path = os.path.join(data_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                    mask_label = _file_names[_file_name]
                    id, gender, race, age = profile.split("_")
                    gender_label = GenderLabels.from_str(gender)
                    age_label = AgeLabels.from_number(age)
                    image_paths.append(img_path)
                    mask_labels.append(mask_label)
                    gender_labels.append(gender_label)
                    age_labels.append(age_label)
                    multi_labels.append(self.encode_multi_class(mask_label, gender_label, age_label))
                    indices[phase].append(cnt)
                    cnt += 1

In [5]:
import os
import random
from collections import defaultdict
from enum import Enum
from typing import Tuple, List

import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, Subset, random_split, WeightedRandomSampler
# from torchvision import transforms
# from torchvision.transforms import *
from albumentations import *
from albumentations.pytorch import ToTensorV2

class MaskBaseDataset(Dataset):
    num_classes = 3 * 2 * 3

    _file_names = {
        "mask1": MaskLabels.MASK,
        "mask2": MaskLabels.MASK,
        "mask3": MaskLabels.MASK,
        "mask4": MaskLabels.MASK,
        "mask5": MaskLabels.MASK,
        "incorrect_mask": MaskLabels.INCORRECT,
        "normal": MaskLabels.NORMAL
    }
    
    image_paths = []
    profiles = []
    
    mask_labels = []
    gender_labels = []
    age_labels = []
    multi_labels = []  
    
    def __init__(self, data_dir, mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246), val_ratio=0.2):
        self.data_dir = data_dir
        self.mean = mean
        self.std = std
        self.downsample = True
        
        self.setup()
    
    def __getitem__(self, index):
        return self.read_image(index), self.multi_labels[index]
        
    def __len__(self):
        return len(image_paths)
    
    def setup(self):
        self.profiles = [profile for profile in os.listdir(self.data_dir) if not profile.startswith('.')]
        for profile in self.profiles:
            _, gender, _, age = profile.split("_")
            gender_label = GenderLabels.from_str(gender)
            age_label = AgeLabels.from_number(age)
            
            include_mask = True
            img_folder = os.path.join(self.data_dir, profile)
            lst_files = os.listdir(img_folder)
            random.shuffle(lst_files) # in-place operation (add randomness in selected images with mask label)
            for file_name in lst_files:
                _file_name, ext = os.path.splitext(file_name)
                if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue
                # if ext != '.jpg': # contains some png files
                #     continue
                if self.downsample and file_name.startswith('mask'):
                    if not include_mask:
                        continue
                    include_mask = False # include only 1 mask image per profile
                
                img_path = os.path.join(self.data_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                mask_label = self._file_names[_file_name]

                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)
                self.multi_labels.append(self.encode_multi_class(mask_label, gender_label, age_label))

    def get_mask_label(self, index) -> MaskLabels:
        return self.mask_labels[index]

    def get_gender_label(self, index) -> GenderLabels:
        return self.gender_labels[index]

    def get_age_label(self, index) -> AgeLabels:
        return self.age_labels[index]

    @staticmethod
    def encode_multi_class(mask_label, gender_label, age_label) -> int:
        return mask_label * 6 + gender_label * 3 + age_label

    @staticmethod
    def decode_multi_class(multi_class_label) -> Tuple[MaskLabels, GenderLabels, AgeLabels]:
        mask_label = (multi_class_label // 6) % 3
        gender_label = (multi_class_label // 3) % 2
        age_label = multi_class_label % 3
        return mask_label, gender_label, age_label

    @staticmethod
    def denormalize_image(image, mean, std):
        img_cp = image.copy()
        img_cp *= std
        img_cp += mean
        img_cp *= 255.0
        img_cp = np.clip(img_cp, 0, 255).astype(np.uint8)
        return img_cp

    def read_image(self, index):
        # read an image from directory and return it as a numpy array
        image_path = self.image_paths[index]
        return np.array(Image.open(image_path))


class MaskSplitByProfileDataset(MaskBaseDataset):
    # calc_statistics self.image_paths[:3000]:
    #   [0.5573112  0.52429302 0.50174594] [0.61373778 0.58633636 0.56743769]
    def __init__(self, data_dir, label='multi', n_fold:int=2, mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246), val_ratio=0.2):
        super().__init__(data_dir, mean, std)
        
        self.label = label
        self.set_target_label()
        
        self.n_fold = n_fold
        self.kfold_indices = []
        self.stratified_kfold()
        self.indices = self.kfold_indices[0]
        
        self.class_weights = self.compute_class_weight()
        
    def set_target_label(self):
        if self.label == 'multi':
            self.num_classes = 3 * 2 * 3
            self.target_label = self.multi_labels
        elif self.label == 'mask':
            self.num_classes = 3
            self.target_label = self.mask_labels
        elif self.label == 'gender':
            self.num_classes = 2
            self.target_label = self.gender_labels
        elif self.label == 'age':
            self.num_classes = 3
            self.target_label = self.age_labels
        else:
            raise ValueError(f"label must be 'multi', 'mask', 'gender', or 'age', {self.label}")

    def __getitem__(self, index):
        return self.read_image(index), self.target_label[index]

    def stratified_kfold(self):
        from sklearn.model_selection import StratifiedKFold
        
        profile_labels = []
        for profile in self.profiles:
            _, gender, _, age = profile.split("_")
            gender_label = GenderLabels.from_str(gender)
            age_label = AgeLabels.from_number(age)
            profile_label = self.encode_multi_class(0, gender_label, age_label)
            profile_labels.append(profile_label)
        
        skf = StratifiedKFold(n_splits=self.n_fold)
        for train_profiles, val_profiles in skf.split(self.profiles, profile_labels):
            train_index, val_index = [], []
            for profile_idx in train_profiles:
                train_index.extend(range(profile_idx*3, profile_idx*3+3))
            for profile_idx in val_profiles:
                val_index.extend(range(profile_idx*3, profile_idx*3+3))
            self.kfold_indices.append({
                'train': train_index,
                'val': val_index
            })

    def split_dataset(self) -> List[Subset]:
        return [Subset(self, indices) for phase, indices in self.indices.items()]
    
    def get_train_labels(self, label):
        # returns train data of the input label
        train_index = self.indices['train']
        return [label[idx] for idx in train_index]
    
    def get_classweight_label(self, label) -> torch.tensor:
        # returns class weight of a label within train dataset
        train_labels = self.get_train_labels(label)
        _, n_samples = np.unique(train_labels, return_counts=True)
        weights = 1. / torch.tensor(n_samples, dtype=torch.float)
        return weights
    
    def normalize_weight(self, weights):
        norm_weights = [1 - (weight / sum(weights)) for weight in weights]
        return torch.tensor(norm_weights, dtype=torch.float)

    ##################### need refactoring ##################### 
    def weight0(self):
        # v0: weights on target label
        train_index = self.indices['train'] # indices of train dataset
        train_labels = [self.target_label[idx] for idx in train_index] # target_label of train dataset
        class_counts = np.array([len(np.where(train_labels==t)[0]) for t in np.unique(train_labels)]) # get counts of each class 
        weights = 1. / torch.tensor(class_counts, dtype=torch.float) # get weights (more class count == less weight(frequent) it will be sampled)
        samples_weights = weights[train_labels] # map weights for each train dataset, len(samples_weights) == len(train dataset)
        return WeightedRandomSampler(weights=samples_weights, num_samples=len(samples_weights), replacement=True)
    
    def weight1(self):
        # v1: normalized weights on target label (better than v0)
        sample_weight = [self.class_weights[self.target_label[idx]] for idx in self.indices['train']]
        return WeightedRandomSampler(weights=sample_weight, num_samples=len(sample_weight), replacement=True)

    def weight2(self):
        # # v2: normalized weights on of specific ratio ``age=.9 : gender=.1``
        age_weight = self.get_classweight_label(self.age_labels)
        gender_weight = self.get_classweight_label(self.gender_labels)
        weights = [age_weight[self.age_labels[idx]]*.9 + gender_weight[self.gender_labels[idx]]*.1 for idx in self.indices['train']]
        return WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)

    def weight3(self):
        # v3: normalized weights on multi label
        multi_weight = self.get_classweight_label(self.multi_labels)
        multi_weight = self.normalize_weight(multi_weight)
        sample_weight = [multi_weight[self.multi_labels[idx]] for idx in self.indices['train']]
        return WeightedRandomSampler(weights=sample_weight, num_samples=len(sample_weight), replacement=True)

    def weight4(self):
        # v4: weights on multi label
        multi_weight = self.get_classweight_label(self.multi_labels)
        sample_weight = [multi_weight[self.multi_labels[idx]] for idx in self.indices['train']]
        return WeightedRandomSampler(weights=sample_weight, num_samples=len(sample_weight), replacement=True)

    def get_weighted_sampler(self, ver: int=0) -> WeightedRandomSampler:  
        """
        returns WeightedRandomSampler based on the distribution of the train label
        used to prevent overfitting due to unbalanced dataset
        """
        if ver==0: return self.weight0()
        elif ver==1: return self.weight1()
        elif ver==2: return self.weight2()
        elif ver==3: return self.weight3()
        elif ver==4: return self.weight4()
        else: raise ValueError(f'invalid version of {ver}')

    def compute_class_weight(self) -> torch.tensor:
        """
        estimate class weights for unbalanced dataset
        `` 1 - n_sample / sum(n_samples) ````
        used for loss function: weighted_cross_entropy
        """
        train_index = self.indices['train']
        train_labels = [self.target_label[idx] for idx in train_index]
        _, n_samples = np.unique(train_labels, return_counts=True)
        norm_weights = [1 - (sample / sum(n_samples)) for sample in n_samples]
        return torch.tensor(norm_weights, dtype=torch.float).to(device='cuda')

In [6]:
label = 'age'
k_folds = 5

dataset = MaskSplitByProfileDataset(
    data_dir=data_dir,
    label=label,
    n_fold=k_folds
)
num_classes = dataset.num_classes  # 18


6480
8100
image_paths   = 8100
profiles      = 2700
mask_labels   = 8100
gender_labels = 8100
age_labels    = 8100
multi_labels  = 8100


In [12]:
print(dataset.num_classes)

3


In [14]:
print(set(dataset.target_label))

{<AgeLabels.YOUNG: 0>, <AgeLabels.MIDDLE: 1>, <AgeLabels.OLD: 2>}


In [15]:
print(set(dataset.multi_labels))
print(set(dataset.mask_labels))
print(set(dataset.gender_labels))
print(set(dataset.age_labels))


{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}
{<MaskLabels.MASK: 0>, <MaskLabels.INCORRECT: 1>, <MaskLabels.NORMAL: 2>}
{<GenderLabels.MALE: 0>, <GenderLabels.FEMALE: 1>}
{<AgeLabels.YOUNG: 0>, <AgeLabels.MIDDLE: 1>, <AgeLabels.OLD: 2>}


In [71]:
for idx, indices in enumerate(dataset.kfold_indices):
    print(f'{idx}: train={len(indices["train"])}, test={len(indices["val"])}, total={len(indices["train"])+len(indices["val"])}')
    print(f'train set count={len(set(indices["train"]))}, val set count={len(set(indices["val"]))}')

0: train=6480, test=1620, total=8100
train set count=6480, val set count=1620
1: train=6480, test=1620, total=8100
train set count=6480, val set count=1620
2: train=6480, test=1620, total=8100
train set count=6480, val set count=1620
3: train=6480, test=1620, total=8100
train set count=6480, val set count=1620
4: train=6480, test=1620, total=8100
train set count=6480, val set count=1620


In [72]:
fold0_profiles = {
    'train': [dataset.image_paths[idx].split(sep='/')[7] for idx in dataset.kfold_indices[0]['train']],
    'val': [dataset.image_paths[idx].split(sep='/')[7] for idx in dataset.kfold_indices[0]['val']]
}

In [76]:
len(dataset.image_paths)

30932

In [73]:
len(fold0_profiles['train'])


6480

In [56]:
idx = dataset.kfold_indices[0]['train'][0]
dataset.image_paths[idx].split(sep='/')[7]

'001596_male_Asian_26'

In [7]:
trainset_0 = dataset.kfold_indices[0]['train']
set_train_0 = set([dataset.image_paths[idx].split(sep='/')[7] for idx in trainset_0])

In [8]:
len(set_train_0)

2160

In [9]:
len(dataset.image_paths)

8100

In [10]:
valset_0 = dataset.kfold_indices[0]['val']
set_val_0 = set([dataset.image_paths[idx].split(sep='/')[7] for idx in valset_0])

In [11]:
lst_cheating = []
for train_profile in set_train_0:
    if train_profile in set_val_0:
        lst_cheating.append(train_profile)
print(len(lst_cheating))

0


In [18]:
from torch.utils.data import DataLoader
import multiprocessing

# -- augmentation
train_set, val_set = dataset.split_dataset()
        
# -- data_loader
sampler = dataset.get_weighted_sampler(0) # WeightedRandomSampler

train_loader = DataLoader(
    train_set,
    batch_size=16,
    num_workers=multiprocessing.cpu_count()//2,
    # shuffle=True,
    pin_memory=torch.cuda.is_available(),
    drop_last=True,
    sampler=sampler,
)


In [20]:
for idx, train_batch in enumerate(train_loader):
    inputs, labels = train_batch
    print(inputs)
    print(labels)
    break

tensor([[[[203, 172, 128],
          [206, 175, 131],
          [208, 177, 133],
          ...,
          [144, 135, 154],
          [144, 135, 154],
          [144, 135, 154]],

         [[203, 172, 128],
          [206, 175, 131],
          [208, 177, 133],
          ...,
          [143, 134, 153],
          [144, 135, 154],
          [144, 135, 154]],

         [[204, 173, 129],
          [206, 175, 131],
          [208, 177, 133],
          ...,
          [143, 134, 153],
          [143, 134, 153],
          [144, 135, 154]],

         ...,

         [[157, 123,  96],
          [152, 123,  93],
          [148, 127,  96],
          ...,
          [  4,  25, 106],
          [  4,  25, 106],
          [  5,  26, 107]],

         [[154, 123,  92],
          [144, 118,  85],
          [142, 123,  91],
          ...,
          [  4,  25, 106],
          [  4,  25, 106],
          [  5,  26, 107]],

         [[152, 122,  88],
          [139, 113,  78],
          [137, 118,  86],
         

In [None]:

val_loader = DataLoader(
    val_set,
    batch_size=args.valid_batch_size,
    num_workers=multiprocessing.cpu_count()//2,
    shuffle=False,
    pin_memory=use_cuda,
    drop_last=True,
)