In [1]:
import argparse
import glob
import json
import multiprocessing
import os
import random
import re
from pathlib import Path
from enum import Enum
from typing import Tuple, List
from collections import defaultdict
from sklearn.metrics import f1_score

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR
from torch.utils.data import Dataset, Subset, random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from albumentations import *
from albumentations.pytorch import ToTensorV2


# config

In [2]:
# config
class Config():
    seed = 42
    
    # 데이터
#     data_dir = './face_input/train' #
    data_dir = './input/data/train'
    resize = [224, 224]
    val_ratio = 0.2
    
    # 학습 설정
    epochs = 70
    batch_size = 64
#     batch_size = 16
    valid_batch_size = 1000
    lr = 1e-4
    lr_decay_step = 10
    log_interval = 50
    
    # 세이브 경로
    save_dir = './resnextmultihead_exp'
    

config = Config()

In [3]:
# 시드 고정 함수
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

In [4]:
# 시드 고정
seed_everything(config.seed)

In [5]:
 # -- settings
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Dataset

In [6]:
# base setting
IMG_EXTENSIONS = [
    ".jpg", ".JPG", ".jpeg", ".JPEG", ".png",
    ".PNG", ".ppm", ".PPM", ".bmp", ".BMP",
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

class MaskLabels(int, Enum):
    MASK = 0
    INCORRECT = 1
    NORMAL = 2

class GenderLabels(int, Enum):
    MALE = 0
    FEMALE = 1

    @classmethod
    def from_str(cls, value: str) -> int:
        value = value.lower()
        if value == "male":
            return cls.MALE
        elif value == "female":
            return cls.FEMALE
        else:
            raise ValueError(f"Gender value should be either 'male' or 'female', {value}")

class AgeLabels(int, Enum):
    YOUNG = 0
    MIDDLE = 1
    OLD = 2

    @classmethod
    def from_number(cls, value: str) -> int:
        try:
            value = int(value)
        except Exception:
            raise ValueError(f"Age value should be numeric, {value}")

        if value < 30:
            return cls.YOUNG
        elif value < 60:
            return cls.MIDDLE
        else:
            return cls.OLD

### define transform (Augmentation)

In [7]:
mean=(0.548, 0.504, 0.479)
std=(0.237, 0.247, 0.246)
train_transform = Compose([
    CenterCrop(height=440, width=320),
    Resize(config.resize[0], config.resize[1], p=1.0, interpolation=cv2.INTER_LINEAR),
    HorizontalFlip(p=0.5),
#     ShiftScaleRotate(p=0.5),
#     HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
    RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
    GaussNoise(p=0.5),
    Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)
val_transform = Compose([
    CenterCrop(height=440, width=320),
    Resize(config.resize[0], config.resize[1], p=1.0),
    Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)

### define datasets

In [8]:
# base dataset class
class BaseDataset(Dataset):
    num_classes = 3*2*3
    
    _file_names = {
        "mask1": MaskLabels.MASK,
        "mask2": MaskLabels.MASK,
        "mask3": MaskLabels.MASK,
        "mask4": MaskLabels.MASK,
        "mask5": MaskLabels.MASK,
        "incorrect_mask": MaskLabels.INCORRECT,
        "normal": MaskLabels.NORMAL
    }
    
    image_paths = []
    mask_labels = []
    gender_labels = []
    age_labels = []

    def __init__(self, img_dir, val_ratio=0.2, transform=None):
        """
        MaskBaseDataset을 initialize 합니다.

        Args:
            img_dir: 학습 이미지 폴더의 root directory 입니다.
            transform: Augmentation을 하는 함수입니다.
        """
        self.img_dir = img_dir
        self.transform = transform
        self.indices = defaultdict(list)
        
        self.val_ratio = val_ratio
        
        
        self.setup()

    def set_transform(self, transform):
        """
        transform 함수를 설정하는 함수입니다.
        """
        self.transform = transform
    
    def setup(self):
        """
        image의 경로와 각 이미지들의 label을 계산하여 저장해두는 함수입니다.
        """
        profiles = os.listdir(self.img_dir)
        profiles = [profile for profile in profiles if not profile.startswith(".")]
        split_profiles = self._split_profile(profiles, self.val_ratio)

        cnt = 0
        for phase, indices in split_profiles.items():
            for _idx in indices:
                profile = profiles[_idx]
                img_folder = os.path.join(self.img_dir, profile)
                for file_name in os.listdir(img_folder):
                    _file_name, ext = os.path.splitext(file_name)
                    if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                        continue

                    img_path = os.path.join(self.img_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                    mask_label = self._file_names[_file_name]

                    id, gender, race, age = profile.split("_")
                    gender_label = GenderLabels.from_str(gender)
                    age_label = AgeLabels.from_number(age)

                    self.image_paths.append(img_path)
                    self.mask_labels.append(mask_label)
                    self.gender_labels.append(gender_label)
                    self.age_labels.append(age_label)

                    self.indices[phase].append(cnt)
                    cnt += 1
                    
    def split_dataset(self) -> List[Subset]:
        return [Subset(self, indices) for phase, indices in self.indices.items()]
                
    def __getitem__(self, index):
        """
        데이터를 불러오는 함수입니다. 
        데이터셋 class에 데이터 정보가 저장되어 있고, index를 통해 해당 위치에 있는 데이터 정보를 불러옵니다.
        
        Args:
            index: 불러올 데이터의 인덱스값입니다.
        """
        # 이미지를 불러옵니다.
        image_path = self.image_paths[index]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)


#         image = Image.open(image_path)
        
        # 레이블을 불러옵니다.
        mask_label = self.mask_labels[index]
        gender_label = self.gender_labels[index]
        age_label = self.age_labels[index]
        multi_class_label = mask_label * 6 + gender_label * 3 + age_label
        
        # 이미지를 Augmentation 시킵니다.
        image_transform = self.transform(image=np.array(image))['image']
        return image_transform, (mask_label, gender_label, age_label, multi_class_label) # multi_class_label

    def __len__(self):
        return len(self.image_paths)
    
    
    # 이외 utils
    @staticmethod
    def _split_profile(profiles, val_ratio):
        length = len(profiles)
        n_val = int(length * val_ratio)

        val_indices = set(random.choices(range(length), k=n_val))
        train_indices = set(range(length)) - val_indices
        return {
            "train": train_indices,
            "val": val_indices
        }
    
    @staticmethod
    def encode_multi_class(mask_label, gender_label, age_label) -> int:
        return mask_label * 6 + gender_label * 3 + age_label
    
    @staticmethod
    def decode_multi_class(multi_class_label) -> Tuple[MaskLabels, GenderLabels, AgeLabels]:
        mask_label = (multi_class_label // 6) % 3
        gender_label = (multi_class_label // 3) % 2
        age_label = multi_class_label % 3
        return mask_label, gender_label, age_label

In [9]:
dataset = BaseDataset(img_dir = f'{config.data_dir}/images', val_ratio = config.val_ratio)

train_dataset, val_dataset = dataset.split_dataset()

train_dataset.dataset.set_transform(train_transform)
val_dataset.dataset.set_transform(val_transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    num_workers=multiprocessing.cpu_count() // 2,
    shuffle=True,
    pin_memory=use_cuda
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.valid_batch_size,
    num_workers=multiprocessing.cpu_count() // 2,
    shuffle=False,
    pin_memory=use_cuda
)

# define model

In [10]:
# class BaseModel(nn.Module):
#     def __init__(self, num_classes):
#         super().__init__()
#         self.base_model = models.vit_b_16(pretrained=True)
#         self.base_model.heads.head = nn.Linear(in_features=768, out_features=num_classes, bias=True)

#     def forward(self, x):
#         return self.base_model(x)

In [11]:
# class ResnextModel(nn.Module):
#     def __init__(self, num_classes):
#         super().__init__()
#         self.base_model = models.resnext50_32x4d(pretrained=True)
#         self.base_model.fc = nn.Sequential(
#             nn.Linear(in_features=2048, out_features=1000, bias=True),
#             nn.ELU(True),
#             nn.Dropout(),
#             nn.Linear(in_features=1000, out_features=num_classes, bias=True)
#         )
#     def forward(self, x):
#         return self.base_model(x)

In [12]:
# class Resnext101Model(nn.Module):
#     def __init__(self, num_classes):
#         super().__init__()
#         self.base_model = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.DEFAULT)
#         self.base_model.fc = nn.Sequential(
#             nn.Linear(in_features=2048, out_features=1000, bias=True),
#             nn.ELU(True),
#             nn.Dropout(),
#             nn.Linear(in_features=1000, out_features=num_classes, bias=True)
#         )
#     def forward(self, x):
#         return self.base_model(x)

In [13]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x
    
class ResnextMultiheadModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.base_model = models.resnext50_32x4d(pretrained=True)
        self.base_model.fc = Identity()
#         self.base_model = nn.Sequential(*list(self.base_model.children())[:-2])
#         self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.fc_mask_classifier = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1000, bias=True),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=3, bias=True)
#             nn.Softmax(dim=1)
        )
        self.fc_age_classifier = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1000, bias=True),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=1000, bias=True),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=3, bias=True)
#             nn.Softmax(dim=1)
        )
        self.fc_gender_classifier = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1000, bias=True),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=2, bias=True)
#             nn.Sigmoid()
        )
    def forward(self, x):
        x = self.base_model(x)
#         x = self.avgpool(x)
#         print(x.shape)
        mask = self.fc_mask_classifier(x)
        age = self.fc_age_classifier(x)
        gender = self.fc_gender_classifier(x)
        return mask, age, gender
#         return x

# train

In [14]:
# model
# num_classes = train_dataset.dataset.num_classes
# print(num_classes)
# model = BaseModel(num_classes=num_classes).to(device)
# model = Resnext101Model(num_classes=num_classes).to(device)
# a = torch.tensor(np.ones([64, 3, 224, 224]), dtype=torch.float32).to(device)
model = ResnextMultiheadModel().to(device)
model = torch.nn.DataParallel(model)



In [15]:
# https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/8
class FocalLoss(nn.Module):
    def __init__(self, weight=None,
                 gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            target_tensor,
            weight=self.weight,
            reduction=self.reduction
        )

### loss , metric , optimizer , scheduler

In [16]:
criterion = FocalLoss()
optimizer = Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=config.lr,
    weight_decay=5e-4
)
# scheduler = StepLR(optimizer, config.lr_decay_step, gamma=0.5)
# -- scheduler: ReduceLROnPlateau
# 성능이 향상되지 않을 때 learning rate를 줄입니다. patience=10은 10회 동안 성능 향상이 없을 경우입니다.
# scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=10)
# -- scheduler: CosineAnnealingLR
# CosineAnnealing은 learning rate를 cosine 그래프처럼 변화시킵니다.
scheduler = CosineAnnealingLR(optimizer, T_max=2, eta_min=0.)
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [17]:
# 세이브 경로
path = Path(config.save_dir)
if (not path.exists()):
    save_dir = str(path)
else:
    dirs = glob.glob(f"{path}*")
    matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
    i = [int(m.groups()[0]) for m in matches if m]
    n = max(i) + 1 if i else 2
    save_dir = f"{path}{n}"

print("save path : " + save_dir)

save path : resnextmultihead_exp4


### start training

In [18]:
logger = SummaryWriter(log_dir=save_dir)
with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
    json.dump(vars(config), f, ensure_ascii=False, indent=4)

In [19]:
best_val_acc = 0
best_val_loss = np.inf
best_val_f1 = 0
for epoch in range(config.epochs):
    # train loop
    model.train()
    loss_value = 0
    matches = 0
#     f1_sum = 0
    label_list=[]
    pred_list=[]
    mask_label_list=[]
    mask_pred_list=[]
    age_label_list=[]
    age_pred_list=[]
    gender_label_list=[]
    gender_pred_list=[]
    for idx, train_batch in enumerate(train_loader):
        inputs, (mask_labels, gender_labels, age_labels, multi_class_labels) = train_batch
        inputs = inputs.to(device)
        mask_labels = mask_labels.to(device)
        gender_labels = gender_labels.to(device)
        age_labels = age_labels.to(device)
        multi_class_labels = multi_class_labels.to(device)

        optimizer.zero_grad()

        mask_outs, age_outs, gender_outs = model(inputs)
        mask_preds = torch.argmax(mask_outs, dim=-1)
        age_preds = torch.argmax(age_outs, dim=-1)
        gender_preds = torch.argmax(gender_outs, dim=-1)
        
        multi_preds = mask_preds * 6 + gender_preds * 3 + age_preds
        
        mask_loss = criterion(mask_outs, mask_labels)
        age_loss = criterion(age_outs, age_labels)
        gender_loss = criterion(gender_outs, gender_labels)
        
        loss = mask_loss + gender_loss + 2*age_loss
        
        loss.backward()
#         mask_loss.backward()
#         age_loss.backward()
#         gender_loss.backward()
        optimizer.step()
        
        label_list+=multi_class_labels.detach().cpu()
        pred_list+=multi_preds.detach().cpu()
        
        mask_label_list+=mask_labels.detach().cpu()
        mask_pred_list+=mask_preds.detach().cpu()
        age_label_list+=age_labels.detach().cpu()
        age_pred_list+=age_preds.detach().cpu()
        gender_label_list+=gender_labels.detach().cpu()
        gender_pred_list+=gender_preds.detach().cpu()

        loss_value += loss.item()
        matches += (multi_preds == multi_class_labels).sum().item()
#         f1_sum += f1_score(labels.cpu(), preds.cpu(), average='macro')
        if (idx + 1) % config.log_interval == 0:
            train_loss = loss_value / config.log_interval
            train_acc = matches / config.batch_size / config.log_interval
#             train_f1 = f1_sum / config.log_interval
            train_f1 = f1_score(label_list, pred_list, average='macro')
            mask_f1 = f1_score(mask_label_list, mask_pred_list, average='macro')
            age_f1 = f1_score(age_label_list, age_pred_list, average='macro')
            gender_f1 = f1_score(gender_label_list, gender_pred_list, average='macro')
            current_lr = get_lr(optimizer)
            print(
                f"Epoch[{epoch}/{config.epochs}]({idx + 1}/{len(train_loader)}) || "
                f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || training f1_score {train_f1:4.4} || lr {current_lr}"
            )
            print(f"          training each f1 || mask f1 {mask_f1:4.4} || age f1 {age_f1:4.4} || gender f1 {gender_f1:4.4}")
            logger.add_scalar("Train/loss", train_loss, epoch * len(train_loader) + idx)
            logger.add_scalar("Train/accuracy", train_acc, epoch * len(train_loader) + idx)
            logger.add_scalar("Train/f1_score", train_f1, epoch * len(train_loader) + idx)

            loss_value = 0
            matches = 0
#             f1_sum = 0
            label_list=[]
            pred_list=[]

    scheduler.step()

    # val loop
    with torch.no_grad():
        print("Calculating validation results...")
        model.eval()
        val_loss_items = []
        val_acc_items = []
#         val_f1_items = []
        label_list=[]
        pred_list=[]
        mask_label_list=[]
        mask_pred_list=[]
        age_label_list=[]
        age_pred_list=[]
        gender_label_list=[]
        gender_pred_list=[]
        for val_batch in val_loader:
            inputs, (mask_labels, gender_labels, age_labels, multi_class_labels) = val_batch
            inputs = inputs.to(device)
            mask_labels = mask_labels.to(device)
            gender_labels = gender_labels.to(device)
            age_labels = age_labels.to(device)
            multi_class_labels = multi_class_labels.to(device)

            mask_outs, age_outs, gender_outs = model(inputs)
            mask_preds = torch.argmax(mask_outs, dim=-1)
            age_preds = torch.argmax(age_outs, dim=-1)
            gender_preds = torch.argmax(gender_outs, dim=-1)
            multi_preds = mask_preds * 6 + gender_preds * 3 + age_preds
            
            label_list+=multi_class_labels.detach().cpu()
            pred_list+=multi_preds.detach().cpu()
            mask_label_list+=mask_labels.detach().cpu()
            mask_pred_list+=mask_preds.detach().cpu()
            age_label_list+=age_labels.detach().cpu()
            age_pred_list+=age_preds.detach().cpu()
            gender_label_list+=gender_labels.detach().cpu()
            gender_pred_list+=gender_preds.detach().cpu()

#             loss_item = criterion(outs, labels).item() 
            mask_loss = criterion(mask_outs, mask_labels)
            age_loss = criterion(age_outs, age_labels)
            gender_loss = criterion(gender_outs, gender_labels)
            loss = mask_loss + age_loss + gender_loss
            
            acc_item = (multi_class_labels == multi_preds).sum().item()
#             f1_item = f1_score(labels.cpu(), preds.cpu(), average='macro').item()
            val_loss_items.append(loss.item())
            val_acc_items.append(acc_item)
#             val_f1_items.append(f1_item)

        val_loss = float(np.sum(val_loss_items) / len(val_loader))
        val_acc = float(np.sum(val_acc_items) / len(val_dataset))
        val_f1= float(f1_score(label_list,pred_list,average="macro"))
        val_mask_f1 = f1_score(mask_label_list, mask_pred_list, average='macro')
        val_age_f1 = f1_score(age_label_list, age_pred_list, average='macro')
        val_gender_f1 = f1_score(gender_label_list, gender_pred_list, average='macro')
#         val_f1 = np.sum(val_f1_items) / len(val_dataset)
        
#         best_val_loss = min(best_val_loss, val_loss)
#         if val_acc > best_val_acc:
        if val_f1 > best_val_f1:
            print(f"New best model for val f1 : {val_f1:4.4}! saving the best model..")
            torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
            best_val_f1 = val_f1
            best_val_acc = val_acc
            best_val_loss = val_loss
        torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
        print(
            f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.4}, f1: {val_f1:4.4} || "
            f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.4}, best f1: {best_val_f1:4.4}"
        )
        print(f"[Val] each f1 || mask f1 {val_mask_f1:4.4} || age f1 {val_age_f1:4.4} || gender f1 {val_gender_f1:4.4}")
        logger.add_scalar("Val/loss", val_loss, epoch)
        logger.add_scalar("Val/accuracy", val_acc, epoch)
        logger.add_scalar("Val/f1", val_f1, epoch)
        print()

Epoch[0/70](50/242) || training loss 0.4622 || training accuracy 72.28% || training f1_score 0.5562 || lr 0.0001
          training each f1 || mask f1 0.8939 || age f1 0.6773 || gender f1 0.8964
Epoch[0/70](100/242) || training loss 0.1736 || training accuracy 87.47% || training f1_score 0.7621 || lr 0.0001
          training each f1 || mask f1 0.9408 || age f1 0.7516 || gender f1 0.9309
Epoch[0/70](150/242) || training loss 0.1076 || training accuracy 91.81% || training f1_score 0.8482 || lr 0.0001
          training each f1 || mask f1 0.9573 || age f1 0.7987 || gender f1 0.947
Epoch[0/70](200/242) || training loss 0.09275 || training accuracy 93.62% || training f1_score 0.8787 || lr 0.0001
          training each f1 || mask f1 0.9655 || age f1 0.8286 || gender f1 0.9558
Calculating validation results...
New best model for val f1 : 0.783! saving the best model..
[Val] acc : 87.91%, loss: 0.1226, f1: 0.783 || best acc : 87.91%, best loss: 0.1226, best f1: 0.783
[Val] each f1 || mask f1