In [16]:
import argparse
import glob
import json
import multiprocessing
import os
import random
import re
from pathlib import Path
from enum import Enum

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, Subset, random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from albumentations import *
from albumentations.pytorch import ToTensorV2


# config

In [17]:
# config
class Config():
    seed = 42
    
    # 데이터
    data_dir = './input/data/train' #
    resize = [224, 224]
    val_ratio = 0.2
    
    # 학습 설정
    epochs = 100 
    batch_size = 64
    valid_batch_size = 1000
    lr = 1e-3
    lr_decay_step = 20
    log_interval = 50
    
    # 세이브 경로
    save_dir = './exp'
    

config = Config()

In [18]:
# 시드 고정 함수
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

In [19]:
# 시드 고정
seed_everything(config.seed)

In [20]:
 # -- settings
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Dataset

In [21]:
# base setting
IMG_EXTENSIONS = [
    ".jpg", ".JPG", ".jpeg", ".JPEG", ".png",
    ".PNG", ".ppm", ".PPM", ".bmp", ".BMP",
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

class MaskLabels(int, Enum):
    MASK = 0
    INCORRECT = 1
    NORMAL = 2

class GenderLabels(int, Enum):
    MALE = 0
    FEMALE = 1

    @classmethod
    def from_str(cls, value: str) -> int:
        value = value.lower()
        if value == "male":
            return cls.MALE
        elif value == "female":
            return cls.FEMALE
        else:
            raise ValueError(f"Gender value should be either 'male' or 'female', {value}")

class AgeLabels(int, Enum):
    YOUNG = 0
    MIDDLE = 1
    OLD = 2

    @classmethod
    def from_number(cls, value: str) -> int:
        try:
            value = int(value)
        except Exception:
            raise ValueError(f"Age value should be numeric, {value}")

        if value < 30:
            return cls.YOUNG
        elif value < 60:
            return cls.MIDDLE
        else:
            return cls.OLD

### define transform (Augmentation)

In [22]:
mean=(0.548, 0.504, 0.479)
std=(0.237, 0.247, 0.246)
train_transform = Compose([
    Resize(config.resize[0], config.resize[1], p=1.0),
#     HorizontalFlip(p=0.5),
#     ShiftScaleRotate(p=0.5),
#     HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
#     RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
#     GaussNoise(p=0.5),
    Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)
val_transform = Compose([
    Resize(config.resize[0], config.resize[1], p=1.0),
    Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)

### define datasets

In [23]:
# base dataset class
class BaseDataset(Dataset):
    num_classes = 3*2*3
    
    _file_names = {
        "mask1": MaskLabels.MASK,
        "mask2": MaskLabels.MASK,
        "mask3": MaskLabels.MASK,
        "mask4": MaskLabels.MASK,
        "mask5": MaskLabels.MASK,
        "incorrect_mask": MaskLabels.INCORRECT,
        "normal": MaskLabels.NORMAL
    }
    
    image_paths = []
    mask_labels = []
    gender_labels = []
    age_labels = []

    def __init__(self, img_dir, transform=None):
        """
        MaskBaseDataset을 initialize 합니다.

        Args:
            img_dir: 학습 이미지 폴더의 root directory 입니다.
            transform: Augmentation을 하는 함수입니다.
        """
        self.img_dir = img_dir
        self.transform = transform
        
        self.setup()

    def set_transform(self, transform):
        """
        transform 함수를 설정하는 함수입니다.
        """
        self.transform = transform
    
    def setup(self):
        """
        image의 경로와 각 이미지들의 label을 계산하여 저장해두는 함수입니다.
        """
        profiles = os.listdir(self.img_dir)
        for profile in profiles:
            if profile.startswith("."):  # "." 로 시작하는 파일은 무시합니다
                continue

            img_folder = os.path.join(self.img_dir, profile)
            for file_name in os.listdir(img_folder):
                _file_name, ext = os.path.splitext(file_name)
                if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue

                img_path = os.path.join(self.img_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                mask_label = self._file_names[_file_name]

                id, gender, race, age = profile.split("_")
                gender_label = GenderLabels.from_str(gender)
                age_label = AgeLabels.from_number(age)

                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)
                
    def __getitem__(self, index):
        """
        데이터를 불러오는 함수입니다. 
        데이터셋 class에 데이터 정보가 저장되어 있고, index를 통해 해당 위치에 있는 데이터 정보를 불러옵니다.
        
        Args:
            index: 불러올 데이터의 인덱스값입니다.
        """
        # 이미지를 불러옵니다.
        image_path = self.image_paths[index]
        image = Image.open(image_path)
        
        # 레이블을 불러옵니다.
        mask_label = self.mask_labels[index]
        gender_label = self.gender_labels[index]
        age_label = self.age_labels[index]
        multi_class_label = mask_label * 6 + gender_label * 3 + age_label
        
        # 이미지를 Augmentation 시킵니다.
        image_transform = self.transform(image=np.array(image))['image']
        return image_transform, multi_class_label

    def __len__(self):
        return len(self.image_paths)

In [24]:
dataset = BaseDataset(img_dir = f'{config.data_dir}/images')

n_val = int(len(dataset) * config.val_ratio)
n_train = len(dataset) - n_val
train_dataset, val_dataset = random_split(dataset, [n_train, n_val])

train_dataset.dataset.set_transform(train_transform)
val_dataset.dataset.set_transform(val_transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    num_workers=multiprocessing.cpu_count() // 2,
    shuffle=True,
    pin_memory=use_cuda
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.valid_batch_size,
    num_workers=multiprocessing.cpu_count() // 24,
    shuffle=False,
    pin_memory=use_cuda
)

# define model

In [25]:
class BaseModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.base_model = models.vit_b_16(pretrained=True)
        self.base_model.heads.head = nn.Linear(in_features=768, out_features=num_classes, bias=True)

    def forward(self, x):
        return self.base_model(x)

# train

In [26]:
# model
num_classes = train_dataset.dataset.num_classes
print(num_classes)
model = BaseModel(num_classes=num_classes).to(device)
model = torch.nn.DataParallel(model)

18


### loss , metric , optimizer , scheduler

In [27]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=config.lr,
    weight_decay=5e-4
)
scheduler = StepLR(optimizer, config.lr_decay_step, gamma=0.5)

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [28]:
# 세이브 경로
path = Path(config.save_dir)
if (not path.exists()):
    save_dir = str(path)
else:
    dirs = glob.glob(f"{path}*")
    matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
    i = [int(m.groups()[0]) for m in matches if m]
    n = max(i) + 1 if i else 2
    save_dir = f"{path}{n}"

print("save path : " + save_dir)

save path : exp


### start training

In [29]:
logger = SummaryWriter(log_dir=save_dir)
with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
    json.dump(vars(config), f, ensure_ascii=False, indent=4)

In [30]:
best_val_acc = 0
best_val_loss = np.inf
for epoch in range(config.epochs):
    # train loop
    model.train()
    loss_value = 0
    matches = 0
    for idx, train_batch in enumerate(train_loader):
        inputs, labels = train_batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outs = model(inputs)
        preds = torch.argmax(outs, dim=-1)
        loss = criterion(outs, labels)

        loss.backward()
        optimizer.step()

        loss_value += loss.item()
        matches += (preds == labels).sum().item()
        if (idx + 1) % config.log_interval == 0:
            train_loss = loss_value / config.log_interval
            train_acc = matches / config.batch_size / config.log_interval
            current_lr = get_lr(optimizer)
            print(
                f"Epoch[{epoch}/{config.epochs}]({idx + 1}/{len(train_loader)}) || "
                f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
            )
            logger.add_scalar("Train/loss", train_loss, epoch * len(train_loader) + idx)
            logger.add_scalar("Train/accuracy", train_acc, epoch * len(train_loader) + idx)

            loss_value = 0
            matches = 0

    scheduler.step()

    # val loop
    with torch.no_grad():
        print("Calculating validation results...")
        model.eval()
        val_loss_items = []
        val_acc_items = []
        for val_batch in val_loader:
            inputs, labels = val_batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)

            loss_item = criterion(outs, labels).item()
            acc_item = (labels == preds).sum().item()
            val_loss_items.append(loss_item)
            val_acc_items.append(acc_item)

        val_loss = np.sum(val_loss_items) / len(val_loader)
        val_acc = np.sum(val_acc_items) / len(val_dataset)
        best_val_loss = min(best_val_loss, val_loss)
        if val_acc > best_val_acc:
            print(f"New best model for val accuracy : {val_acc:4.2%}! saving the best model..")
            torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
            best_val_acc = val_acc
        torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
        print(
            f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2} || "
            f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
        )
        logger.add_scalar("Val/loss", val_loss, epoch)
        logger.add_scalar("Val/accuracy", val_acc, epoch)
        print()

Epoch[0/100](50/237) || training loss 2.52 || training accuracy 19.59% || lr 0.001
Epoch[0/100](100/237) || training loss 2.404 || training accuracy 19.66% || lr 0.001
Epoch[0/100](150/237) || training loss 2.398 || training accuracy 20.28% || lr 0.001
Epoch[0/100](200/237) || training loss 2.344 || training accuracy 22.75% || lr 0.001
Calculating validation results...
New best model for val accuracy : 23.70%! saving the best model..
[Val] acc : 23.70%, loss:  2.3 || best acc : 23.70%, best loss:  2.3

Epoch[1/100](50/237) || training loss 2.278 || training accuracy 25.69% || lr 0.001
Epoch[1/100](100/237) || training loss 2.249 || training accuracy 28.28% || lr 0.001
Epoch[1/100](150/237) || training loss 2.221 || training accuracy 26.78% || lr 0.001
Epoch[1/100](200/237) || training loss 2.176 || training accuracy 28.34% || lr 0.001
Calculating validation results...
New best model for val accuracy : 29.31%! saving the best model..
[Val] acc : 29.31%, loss:  2.3 || best acc : 29.31%, 

In [31]:
# from torchvision.models import vit_b_16
# model = models.vit_b_16(pretrained=True)
# model.heads.head = nn.Linear(in_features=768, out_features=18, bias=True)
# model