In [1]:
import argparse
import glob
import json
import multiprocessing
import os
import random
import re
from pathlib import Path
from enum import Enum
from typing import Tuple, List
from collections import defaultdict
from sklearn.metrics import f1_score

import cv2
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.optim import SGD, Adam, RMSprop, AdamW
from adamp import AdamP
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR, ExponentialLR
from torch.utils.data import Dataset, Subset, random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from albumentations import *
from albumentations.pytorch import ToTensorV2
from torchsampler import ImbalancedDatasetSampler

# config

In [2]:
# config
class Config():
    seed = 42
    
    # 데이터
#     data_dir = './face_input/train' #
    data_dir = './input/data/train'
#     resize = [256, 256]
    resize = [224, 224]
    val_ratio = 0.1
    
    # 학습 설정
    epochs = 50
    batch_size = 64
#     batch_size = 16
    valid_batch_size = 1000
    lr = 0.0001
    lr_decay_step = 5
    log_interval = 50
    
    # 세이브 경로
    save_dir = './newAgeModel_exp'
    multi_dir = './newModel_exp2'
    

config = Config()

In [3]:
# 시드 고정 함수
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

In [4]:
# 시드 고정
seed_everything(config.seed)

In [5]:
 # -- settings
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Dataset

In [6]:
# base setting
IMG_EXTENSIONS = [
    ".jpg", ".JPG", ".jpeg", ".JPEG", ".png",
    ".PNG", ".ppm", ".PPM", ".bmp", ".BMP",
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

class MaskLabels(int, Enum):
    MASK = 0
    INCORRECT = 1
    NORMAL = 2

class GenderLabels(int, Enum):
    MALE = 0
    FEMALE = 1

    @classmethod
    def from_str(cls, value: str) -> int:
        value = value.lower()
        if value == "male":
            return cls.MALE
        elif value == "female":
            return cls.FEMALE
        else:
            raise ValueError(f"Gender value should be either 'male' or 'female', {value}")

class AgeLabels(int, Enum):
    YOUNG = 0
    MIDDLE = 1
    OLD = 2
    
    @classmethod
    def class2number(cls, value: int) -> int:
        if value <= 2:
            return cls.YOUNG
        elif value <= 8:
            return cls.MIDDLE
        else:
            return cls.OLD
        
    @classmethod
    def from_cls(cls, value: str) -> int:
        try:
            value = int(value)
        except Exception:
            raise ValueError(f"Age value should be numeric, {value}")

        if value < 20:
            return 0
        elif value < 25:
            return 1
        elif value < 30:
            return 2
        elif value < 35:
            return 3
        elif value < 40:
            return 4
        elif value < 45:
            return 5
        elif value < 50:
            return 6
        elif value < 55:
            return 7
        elif value < 60:
            return 8
        else:
            return 9

### define transform (Augmentation)

In [7]:
mean=(0.548, 0.504, 0.479)
std=(0.237, 0.247, 0.246)
train_transform = Compose([
    RandomCrop(height=460, width=320),
    HorizontalFlip(p=0.5),
#     ShiftScaleRotate(p=0.5),
#     HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
    CLAHE(p=0.1),
#     RandomScale(scale_limit=0.1, interpolation=cv2.INTER_LINEAR, p=0.5),
    RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
#     GaussNoise(p=0.5),
    Resize(config.resize[0], config.resize[1], p=1.0, interpolation=cv2.INTER_LINEAR),
    Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)
val_transform = Compose([
    CenterCrop(height=460, width=320),
    Resize(config.resize[0], config.resize[1], p=1.0),
    Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)

### define datasets

In [8]:
class newBaseDataset(Dataset):

    def __init__(self, img_dir, indices, transform=None):
        """
        MaskBaseDataset을 initialize 합니다.

        Args:
            img_dir: 학습 이미지 폴더의 root directory 입니다.
            transform: Augmentation을 하는 함수입니다.
        """
        self.img_dir = img_dir
        self.transform = transform
        self.indices = indices
        self._file_names = {
            "mask1": MaskLabels.MASK,
            "mask2": MaskLabels.MASK,
            "mask3": MaskLabels.MASK,
            "mask4": MaskLabels.MASK,
            "mask5": MaskLabels.MASK,
            "incorrect_mask": MaskLabels.INCORRECT,
            "normal": MaskLabels.NORMAL
        }

        self.image_paths = []
        self.mask_labels = []
        self.gender_labels = []
        self.age_labels = []
        self.setup()

    def set_transform(self, transform):
        """
        transform 함수를 설정하는 함수입니다.
        """
        self.transform = transform

    def setup(self):
        """
        image의 경로와 각 이미지들의 label을 계산하여 저장해두는 함수입니다.
        """
        csv_path = os.path.join(self.img_dir, 'train.csv')
        csv_data = pd.read_csv(csv_path)
        
        csv_data = csv_data.loc[self.indices]
        csv_data_list = list(csv_data.values)
        cnt = 0
        for id, gender, _, age, path in csv_data_list:
            img_folder = os.path.join(self.img_dir, 'images', path)
            for file_name in os.listdir(img_folder):
                _file_name, ext = os.path.splitext(file_name)
                if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue
                    
                img_path = os.path.join(img_folder, file_name)  # (data_path, 000004_male_Asian_54, mask1.jpg)
                mask_label = self._file_names[_file_name]
                gender_label = GenderLabels.from_str(gender)
                age_label = AgeLabels.from_cls(age)
                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)
                
    def __getitem__(self, index):
        """
        데이터를 불러오는 함수입니다. 
        데이터셋 class에 데이터 정보가 저장되어 있고, index를 통해 해당 위치에 있는 데이터 정보를 불러옵니다.
        
        Args:
            index: 불러올 데이터의 인덱스값입니다.
        """
        # 이미지를 불러옵니다.
        image_path = self.image_paths[index]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         image = Image.open(image_path)
        
        # 레이블을 불러옵니다.
        mask_label = self.mask_labels[index]
        gender_label = self.gender_labels[index]
        age_label = self.age_labels[index]
        multi_class_label = mask_label * 6 + gender_label * 3 + AgeLabels.class2number(age_label)
        
        # 이미지를 Augmentation 시킵니다.
        image_transform = self.transform(image=np.array(image))['image']
        return image_transform, (mask_label, gender_label, age_label, multi_class_label) # multi_class_label

    def __len__(self):
        return len(self.image_paths)

In [9]:
# class Sampler(ImbalancedDatasetSampler):
#     def _get_labels(self, dataset):
#         return [dataset.dataset.class_list[i] for i in dataset.indices]

class Sampler(ImbalancedDatasetSampler):
    def _get_labels(self, dataset):
        return dataset.age_labels

In [10]:
csv_path = os.path.join(config.data_dir,'train.csv')
data = pd.read_csv(csv_path)
age_0 = data['age'][lambda x : (x < 20)].count()
age_1 = data['age'][lambda x : (x >= 20) & (x<25)].count()
age_2 = data['age'][lambda x : (x >= 25) & (x<30)].count()
age_3 = data['age'][lambda x : (x >= 30) & (x<35)].count()
age_4 = data['age'][lambda x : (x >= 35) & (x<40)].count()
age_5 = data['age'][lambda x : (x >= 40) & (x<45)].count()
age_6 = data['age'][lambda x : (x >= 45) & (x<50)].count()
age_7 = data['age'][lambda x : (x >= 50) & (x<55)].count()
age_8 = data['age'][lambda x : (x >= 55) & (x<60)].count()
age_9 = data['age'][lambda x : x >= 60].count()
age_img_num_per_cls =  torch.Tensor([age_0, age_1, age_2, age_3, age_4, age_5, age_6, age_7, age_8, age_9])

length = len(data)
n_val = int(length * config.val_ratio)
val_indices = set(random.choices(range(length), k=n_val))
train_indices = set(range(length)) - val_indices

In [11]:
train_dataset = newBaseDataset(img_dir = config.data_dir, indices = train_indices)
val_dataset = newBaseDataset(img_dir = config.data_dir, indices = val_indices)
train_dataset.set_transform(train_transform)
val_dataset.set_transform(val_transform)
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    num_workers=multiprocessing.cpu_count() // 2,
    sampler=Sampler(train_dataset),
#     shuffle=True,
    drop_last=True,
    pin_memory=use_cuda
    
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config.valid_batch_size,
    num_workers=multiprocessing.cpu_count() // 2,
    shuffle=False,
    pin_memory=use_cuda
)


# define model

In [12]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x

In [13]:
# class SwinMultiheadModel(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.base_model = models.swin_b(weights='DEFAULT')
#         self.base_model.head = Identity()

#         self.fc_mask_head = nn.Sequential(
#             nn.Linear(in_features=1024, out_features=1000, bias=True),
#             nn.BatchNorm1d(1000),
#             nn.ELU(True),
#             nn.Dropout(0.5, inplace=True),
#             nn.Linear(in_features=1000, out_features=3, bias=True)
#         )
#         self.fc_gender_head = nn.Sequential(
#             nn.Linear(in_features=1024, out_features=1000, bias=True),
#             nn.BatchNorm1d(1000),
#             nn.ELU(True),
#             nn.Dropout(0.5, inplace=True),
#             nn.Linear(in_features=1000, out_features=2, bias=True)
#         )
#     def forward(self, x):
#         x = self.base_model(x)
#         mask = self.fc_mask_head(x)
#         gender = self.fc_gender_head(x)
#         return mask, gender

In [14]:
# class SwinAgeModel(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.base_model = models.swin_b(weights='DEFAULT')
#         self.base_model.head = nn.Sequential(
#             nn.Linear(in_features=1024, out_features=1000, bias=True),
#             nn.BatchNorm1d(1000),
#             nn.ELU(True),
#             nn.Dropout(0.5, inplace=True),
#             nn.Linear(in_features=1000, out_features=1000, bias=True),
#             nn.BatchNorm1d(1000),
#             nn.ELU(True),
#             nn.Dropout(0.5, inplace=True),
#             nn.Linear(in_features=1000, out_features=10, bias=True)
#         )
#     def forward(self, x):
#         x = self.base_model(x)
#         return x

In [15]:
class ResMultiheadModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.base_model = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.DEFAULT)
        self.base_model.fc = Identity()

        self.fc_mask = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1000, bias=True),
            nn.BatchNorm1d(1000),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=3, bias=True)
        )
        self.fc_gender = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1000, bias=True),
            nn.BatchNorm1d(1000),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=2, bias=True)
        )
    def forward(self, x):
        x = self.base_model(x)
        mask = self.fc_mask(x)
        gender = self.fc_gender(x)
        return mask, gender


In [16]:
class ResAgeModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.base_model = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.DEFAULT)
        self.base_model.fc = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1000, bias=True),
            nn.BatchNorm1d(1000),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=1000, bias=True),
            nn.BatchNorm1d(1000),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=10, bias=True)
        )
    def forward(self, x):
        x = self.base_model(x)
        return x

In [17]:
model_multi = ResMultiheadModel()
model_age = ResAgeModel()

model_multi_path = os.path.join(config.multi_dir, f'best_total_multi.pth') #best_age.pth
model_multi.load_state_dict(torch.load(model_multi_path, map_location=device))

model_multi = model_multi.to(device)
model_age = model_age.to(device)

# train (Age only)

In [24]:
# https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/8
class FocalLoss(nn.Module):
    def __init__(self, weight=None,
                 gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            target_tensor,
            weight=self.weight,
            reduction=self.reduction
        )

class F1Loss(nn.Module):
    def __init__(self, classes=3, epsilon=1e-7):
        super().__init__()
        self.classes = classes
        self.epsilon = epsilon

    def forward(self, y_pred, y_true):
        assert y_pred.ndim == 2
        assert y_true.ndim == 1
        y_true = F.one_hot(y_true, self.classes).to(torch.float32)
        y_pred = F.softmax(y_pred, dim=1)

        tp = (y_true * y_pred).sum(dim=0).to(torch.float32)
        tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).to(torch.float32)
        fp = ((1 - y_true) * y_pred).sum(dim=0).to(torch.float32)
        fn = (y_true * (1 - y_pred)).sum(dim=0).to(torch.float32)

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        f1 = 2 * (precision * recall) / (precision + recall + self.epsilon)
        f1 = f1.clamp(min=self.epsilon, max=1 - self.epsilon)
        return 1 - f1.mean()
    
class LADELoss(nn.Module):
    def __init__(self, num_classes=10, img_num_per_cls=None, remine_lambda=0.1):
        super().__init__()
        if img_num_per_cls is not None:
#             self.img_num_per_cls = calculate_prior(num_classes, img_max, prior, prior_txt, return_num=True).float().cuda()
            self.img_num_per_cls = img_num_per_cls
            self.prior = self.img_num_per_cls / self.img_num_per_cls.sum()
        else:
            self.prior = None
        self.prior = self.prior.to(device)
        self.balanced_prior = torch.tensor(1. / num_classes).float().cuda()
        self.remine_lambda = remine_lambda

        self.num_classes = num_classes
        self.cls_weight = (self.img_num_per_cls.float() / torch.sum(self.img_num_per_cls.float())).cuda()

    def mine_lower_bound(self, x_p, x_q, num_samples_per_cls):
        N = x_p.size(-1)
        first_term = torch.sum(x_p, -1) / (num_samples_per_cls + 1e-8)
        second_term = torch.logsumexp(x_q, -1) - np.log(N)

        return first_term - second_term, first_term, second_term

    def remine_lower_bound(self, x_p, x_q, num_samples_per_cls):
        loss, first_term, second_term = self.mine_lower_bound(x_p, x_q, num_samples_per_cls)
        reg = (second_term ** 2) * self.remine_lambda
        return loss - reg, first_term, second_term

    def forward(self, y_pred, target, q_pred=None):
        """
        y_pred: N x C
        target: N
        """
        per_cls_pred_spread = y_pred.T * (target == torch.arange(0, self.num_classes).view(-1, 1).type_as(target))  # C x N
        pred_spread = (y_pred - torch.log(self.prior + 1e-9) + torch.log(self.balanced_prior + 1e-9)).T  # C x N

        num_samples_per_cls = torch.sum(target == torch.arange(0, self.num_classes).view(-1, 1).type_as(target), -1).float()  # C
        estim_loss, first_term, second_term = self.remine_lower_bound(per_cls_pred_spread, pred_spread, num_samples_per_cls)

        loss = -torch.sum(estim_loss * self.cls_weight)
        return loss

### loss , metric , optimizer , scheduler

In [25]:
# criterion = FocalLoss()

# criterion_mask = FocalLoss() # F1Loss(classes=3)
# criterion_age = LADELoss(num_classes=10, img_num_per_cls=age_img_num_per_cls) #F1Loss(classes=3)
# criterion_gender = FocalLoss() #F1Loss(classes=2)

# criterion_age = F1Loss(classes=10)
criterion_age = FocalLoss()
# optimizer_multi = AdamP(
#     filter(lambda p: p.requires_grad, model_multi.parameters()),
#     lr=config.lr,
#     weight_decay=5e-4
# )
optimizer_age = Adam(
    filter(lambda p: p.requires_grad, model_age.parameters()),
    lr=config.lr,
    weight_decay=5e-4
)
scheduler_age = StepLR(optimizer_age, config.lr_decay_step, gamma=0.5)
# scheduler_multi = CosineAnnealingLR(optimizer_multi, T_max=2, eta_min=0.)
# scheduler_age = CosineAnnealingLR(optimizer_age, T_max=2, eta_min=0.)
# -- scheduler: ReduceLROnPlateau
# 성능이 향상되지 않을 때 learning rate를 줄입니다. patience=10은 10회 동안 성능 향상이 없을 경우입니다.
# scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=10)
# -- scheduler: CosineAnnealingLR
# CosineAnnealing은 learning rate를 cosine 그래프처럼 변화시킵니다.
# scheduler = CosineAnnealingLR(optimizer, T_max=2, eta_min=0.)
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [26]:
# 세이브 경로
path = Path(config.save_dir)
if (not path.exists()):
    save_dir = str(path)
else:
    dirs = glob.glob(f"{path}*")
    matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
    i = [int(m.groups()[0]) for m in matches if m]
    n = max(i) + 1 if i else 2
    save_dir = f"{path}{n}"

print("save path : " + save_dir)

save path : newAgeModel_exp2


### start training

In [27]:
logger = SummaryWriter(log_dir=save_dir)
with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
    json.dump(vars(config), f, ensure_ascii=False, indent=4)

In [28]:
best_val_acc = 0
# best_val_multi_loss = np.inf
# best_val_multi_f1 = 0
best_val_age_loss = np.inf
best_val_age_f1 = 0
best_val_total_f1 = 0
for epoch in range(config.epochs):
    # train loop
    model_multi.eval()
    model_age.train()
#     multi_loss_value = 0
    age_loss_value = 0
    matches = 0
    label_list=[]
    pred_list=[]
    mask_label_list=[]
    mask_pred_list=[]
    age_label_list=[]
    age_pred_list=[]
    gender_label_list=[]
    gender_pred_list=[]
    for idx, train_batch in enumerate(train_loader):
        inputs, (mask_labels, gender_labels, age_labels, multi_class_labels) = train_batch
        inputs = inputs.to(device)
        
        mask_labels = mask_labels.to(device)
        gender_labels = gender_labels.to(device)
        age_labels = age_labels.to(device)
        multi_class_labels = multi_class_labels.to(device)

#         optimizer_multi.zero_grad()
        optimizer_age.zero_grad()
    
        with torch.no_grad():
            mask_outs, gender_outs = model_multi(inputs)
            mask_preds = torch.argmax(mask_outs, dim=-1)
            mask_preds = mask_preds.detach().cpu()
            gender_preds = torch.argmax(gender_outs, dim=-1)
            gender_preds = gender_preds.detach().cpu()
        
        age_outs = model_age(inputs)
        age_preds = torch.argmax(age_outs, dim=-1)
        age_pred = age_preds.detach().cpu()
        age_preds = age_preds.detach().cpu()
        age_pred_list+=age_pred
        age_preds = age_preds.apply_(AgeLabels.class2number)
        
        multi_preds = mask_preds * 6 + gender_preds * 3 + age_preds
        multi_preds = multi_preds.to(device)
        
#         mask_loss = criterion_mask(mask_outs, mask_labels)
#         gender_loss = criterion_gender(gender_outs, gender_labels)
        
#         multi_loss = mask_loss + gender_loss 
        age_loss = criterion_age(age_outs, age_labels)
        
#         multi_loss.backward()
        age_loss.backward()

#         optimizer_multi.step()
        optimizer_age.step()
        
        label_list+=multi_class_labels.detach().cpu()
        pred_list+=multi_preds.detach().cpu()
        mask_label_list+=mask_labels.detach().cpu()
        mask_pred_list+=mask_preds
        age_label_list+=age_labels.detach().cpu()
        gender_label_list+=gender_labels.detach().cpu()
        gender_pred_list+=gender_preds

#         multi_loss_value += multi_loss.item()
        age_loss_value += age_loss.item()
        matches += (multi_preds == multi_class_labels).sum().item()
        if (idx + 1) % config.log_interval == 0:
#             multi_train_loss = multi_loss_value / config.log_interval
            age_train_loss = age_loss_value / config.log_interval
            
            total_train_acc = matches / config.batch_size / config.log_interval
            total_train_f1 = f1_score(label_list, pred_list, average='macro')
            
            mask_f1 = f1_score(mask_label_list, mask_pred_list, average='macro')
            age_f1 = f1_score(age_label_list, age_pred_list, average='macro')
            gender_f1 = f1_score(gender_label_list, gender_pred_list, average='macro')
            
#             multi_current_lr = get_lr(optimizer_multi)
            age_current_lr = get_lr(optimizer_age)
            
            print(f"Epoch[{epoch}/{config.epochs}]({idx + 1}/{len(train_loader)})")
#             print(f"     [Multi] training loss {multi_train_loss:4.4} || mask f1 {mask_f1:4.4} || gender f1 {gender_f1:4.4} || lr {multi_current_lr}")
            print(f"     [ Age ] training loss {age_train_loss:4.4} || age f1 {age_f1:4.4} || lr {age_current_lr}")
            print(f"     [Total] Acc {total_train_acc:4.2%} || f1-score {total_train_f1:4.4}")
#             logger.add_scalar("Train/multi_loss", multi_train_loss, epoch * len(train_loader) + idx)
            logger.add_scalar("Train/age_loss", age_train_loss, epoch * len(train_loader) + idx)
            logger.add_scalar("Train/total_accuracy", total_train_acc, epoch * len(train_loader) + idx)
            logger.add_scalar("Train/total_f1_score", total_train_f1, epoch * len(train_loader) + idx)

#             multi_loss_value = 0
            age_loss_value = 0
            matches = 0
            label_list=[]
            pred_list=[]
            mask_label_list=[]
            mask_pred_list=[]
            age_label_list=[]
            age_pred_list=[]
            gender_label_list=[]
            gender_pred_list=[]

#     scheduler_multi.step()
    scheduler_age.step()

    # val loop
    with torch.no_grad():
        print("\nCalculating validation results...")
        model_multi.eval()
        model_age.eval()
#         val_multi_loss_items = []
        val_age_loss_items = []
        val_acc_items = []
        label_list=[]
        pred_list=[]
        mask_label_list=[]
        mask_pred_list=[]
        age_label_list=[]
        age_pred_list=[]
        gender_label_list=[]
        gender_pred_list=[]
        for val_batch in val_loader:
            inputs, (mask_labels, gender_labels, age_labels, multi_class_labels) = val_batch
            inputs = inputs.to(device)
            mask_labels = mask_labels.to(device)
            gender_labels = gender_labels.to(device)
            age_labels = age_labels.to(device)
            multi_class_labels = multi_class_labels.to(device)

            mask_outs, gender_outs = model_multi(inputs)
            mask_preds = torch.argmax(mask_outs, dim=-1)
            mask_preds = mask_preds.detach().cpu()
            gender_preds = torch.argmax(gender_outs, dim=-1)
            gender_preds = gender_preds.detach().cpu()

            age_outs = model_age(inputs)
            age_preds = torch.argmax(age_outs, dim=-1)
            age_pred = age_preds.detach().cpu()
            age_preds = age_preds.detach().cpu()
            age_pred_list+=age_pred
            age_preds = age_preds.apply_(AgeLabels.class2number)

            multi_preds = mask_preds * 6 + gender_preds * 3 + age_preds
            multi_preds = multi_preds.to(device)

#             mask_loss = criterion_mask(mask_outs, mask_labels)
#             gender_loss = criterion_gender(gender_outs, gender_labels)

#             multi_loss = mask_loss + gender_loss               
            age_loss = criterion_age(age_outs, age_labels)      
            
            label_list+=multi_class_labels.detach().cpu()
            pred_list+=multi_preds.detach().cpu()
            mask_label_list+=mask_labels.detach().cpu()
            mask_pred_list+=mask_preds
            age_label_list+=age_labels.detach().cpu()
            gender_label_list+=gender_labels.detach().cpu()
            gender_pred_list+=gender_preds
            
            acc_item = (multi_class_labels == multi_preds).sum().item()
#             val_multi_loss_items.append(multi_loss.item())
            val_age_loss_items.append(age_loss.item())
            val_acc_items.append(acc_item)

#         val_multi_loss = float(np.sum(val_multi_loss_items) / len(val_loader))
        val_age_loss = float(np.sum(val_age_loss_items) / len(val_loader))
        
        val_total_acc = float(np.sum(val_acc_items) / len(val_dataset))
        val_total_f1= float(f1_score(label_list,pred_list,average="macro"))
        
        val_mask_f1 = f1_score(mask_label_list, mask_pred_list, average='macro')
        val_age_f1 = f1_score(age_label_list, age_pred_list, average='macro')
        val_gender_f1 = f1_score(gender_label_list, gender_pred_list, average='macro')
        val_multi_f1 = (val_mask_f1 + val_gender_f1) / 2  # mask, gender mean

#         if val_multi_f1 > best_val_multi_f1:
#             print(f"New best model for Multi || f1 : {val_multi_f1:4.4}! saving the best multi model..")
#             torch.save(model_multi.state_dict(), f"{save_dir}/best_multi.pth")
#             best_val_multi_f1 = val_multi_f1
#             best_val_multi_loss = val_multi_loss
        if val_age_f1 > best_val_age_f1:
            print(f"New best model for Age || f1 : {val_age_f1:4.4}! saving the best multi model..")
            torch.save(model_age.state_dict(), f"{save_dir}/best_age.pth")
            best_val_age_f1 = val_age_f1
            best_val_age_loss = val_age_loss
        if val_total_f1 > best_val_total_f1:
            print(f"New best model for Total || f1 : {val_total_f1:4.4}! saving the best multi model..")
#             torch.save(model_multi.state_dict(), f"{save_dir}/best_total_multi.pth")
            torch.save(model_age.state_dict(), f"{save_dir}/best_total_age.pth")
            best_val_total_f1 = val_total_f1
#         torch.save(model_multi.state_dict(), f"{save_dir}/last_multi.pth")
        torch.save(model_age.state_dict(), f"{save_dir}/last_age.pth")
        
#         print(f"[Val_Multi] loss {val_multi_loss:4.4} || mask f1 {val_mask_f1:4.4} || gender f1 {val_gender_f1:4.4}")
        print(f"[ Val_Age ] loss {val_age_loss:4.4} || age f1 {val_age_f1:4.4}")
        print(f"[Val_Total] Acc {val_total_acc:4.2%} || f1-score {val_total_f1:4.4}")
        print(f"[  Best   ] age_f1 {best_val_age_f1:4.4} || total_f1 {best_val_total_f1:4.4}")
#         logger.add_scalar("Val/multi_loss", val_multi_loss, epoch)
        logger.add_scalar("Val/age_loss", val_age_loss, epoch)
        logger.add_scalar("Val/total_accuracy", val_total_acc, epoch)
        logger.add_scalar("Val/total_f1_score", val_total_f1, epoch)
        print()

Epoch[0/50](50/267)
     [ Age ] training loss 0.1808 || age f1 0.839 || lr 0.0001
     [Total] Acc 96.62% || f1-score 0.9446
Epoch[0/50](100/267)
     [ Age ] training loss 0.1423 || age f1 0.8588 || lr 0.0001
     [Total] Acc 97.94% || f1-score 0.9633
Epoch[0/50](150/267)
     [ Age ] training loss 0.1115 || age f1 0.8866 || lr 0.0001
     [Total] Acc 98.09% || f1-score 0.9664
Epoch[0/50](200/267)
     [ Age ] training loss 0.09219 || age f1 0.909 || lr 0.0001
     [Total] Acc 98.34% || f1-score 0.9771
Epoch[0/50](250/267)
     [ Age ] training loss 0.08731 || age f1 0.9149 || lr 0.0001
     [Total] Acc 98.56% || f1-score 0.9793

Calculating validation results...
New best model for Age || f1 : 0.4067! saving the best multi model..
New best model for Total || f1 : 0.7796! saving the best multi model..
[ Val_Age ] loss 0.8614 || age f1 0.4067
[Val_Total] Acc 87.54% || f1-score 0.7796
[  Best   ] age_f1 0.4067 || total_f1 0.7796

Epoch[1/50](50/267)
     [ Age ] training loss 0.06536 ||

KeyboardInterrupt: 