In [1]:
import argparse
import glob
import json
import multiprocessing
import os
import random
import re
from pathlib import Path
from enum import Enum
from typing import Tuple, List
from collections import defaultdict
from sklearn.metrics import f1_score

import cv2
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.optim import SGD, Adam, RMSprop
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR, ExponentialLR
from torch.utils.data import Dataset, Subset, random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, models
from albumentations import *
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import StratifiedKFold
from torchsampler import ImbalancedDatasetSampler

# config

In [2]:
# config
class Config():
    seed = 42
    
    # 데이터
#     data_dir = './face_input/train' #
    data_dir = './input/data/train'
    resize = [224, 224]
    val_ratio = 0.2
    
    # 학습 설정
    epochs = 50
    batch_size = 64
#     batch_size = 16
    valid_batch_size = 1000
    lr = 1e-4
    lr_decay_step = 5
    log_interval = 50
    patience = 10 # early stop
    n_splits = 5 # k - fold
    
    # 세이브 경로
    save_dir = './kfoldEnsemble_exp'
    

config = Config()

In [3]:
# 시드 고정 함수
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

In [4]:
# 시드 고정
seed_everything(config.seed)

In [5]:
 # -- settings
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# Dataset

In [6]:
# base setting
IMG_EXTENSIONS = [
    ".jpg", ".JPG", ".jpeg", ".JPEG", ".png",
    ".PNG", ".ppm", ".PPM", ".bmp", ".BMP",
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

class MaskLabels(int, Enum):
    MASK = 0
    INCORRECT = 1
    NORMAL = 2

class GenderLabels(int, Enum):
    MALE = 0
    FEMALE = 1

    @classmethod
    def from_str(cls, value: str) -> int:
        value = value.lower()
        if value == "male":
            return cls.MALE
        elif value == "female":
            return cls.FEMALE
        else:
            raise ValueError(f"Gender value should be either 'male' or 'female', {value}")

class AgeLabels(int, Enum):
    YOUNG = 0
    MIDDLE = 1
    OLD = 2

    @classmethod
    def from_number(cls, value: str) -> int:
        try:
            value = int(value)
        except Exception:
            raise ValueError(f"Age value should be numeric, {value}")

        if value < 30:
            return cls.YOUNG
        elif value < 60:
            return cls.MIDDLE
        else:
            return cls.OLD

### define transform (Augmentation)

In [7]:
mean=(0.548, 0.504, 0.479)
std=(0.237, 0.247, 0.246)
train_transform = Compose([
    CenterCrop(height=480, width=320),
    HorizontalFlip(p=0.5),
#     ShiftScaleRotate(p=0.5),
    HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
    RandomScale(scale_limit=0.1, interpolation=cv2.INTER_LINEAR, p=0.5),
    RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=0.5),
    GaussNoise(p=0.5),
    Resize(config.resize[0], config.resize[1], p=1.0, interpolation=cv2.INTER_LINEAR),
    Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)
val_transform = Compose([
    CenterCrop(height=480, width=320),
    Resize(config.resize[0], config.resize[1], p=1.0),
    Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
    ToTensorV2(p=1.0),
], p=1.0)

### define datasets

In [8]:
# base dataset class
class newBaseDataset(Dataset):

    def __init__(self, img_dir, indices, transform=None):
        """
        MaskBaseDataset을 initialize 합니다.

        Args:
            img_dir: 학습 이미지 폴더의 root directory 입니다.
            transform: Augmentation을 하는 함수입니다.
        """
        self.img_dir = img_dir
        self.transform = transform
        self.indices = indices
        self._file_names = {
            "mask1": MaskLabels.MASK,
            "mask2": MaskLabels.MASK,
            "mask3": MaskLabels.MASK,
            "mask4": MaskLabels.MASK,
            "mask5": MaskLabels.MASK,
            "incorrect_mask": MaskLabels.INCORRECT,
            "normal": MaskLabels.NORMAL
        }

        self.image_paths = []
        self.mask_labels = []
        self.gender_labels = []
        self.age_labels = []
        self.setup()

    def set_transform(self, transform):
        """
        transform 함수를 설정하는 함수입니다.
        """
        self.transform = transform

    def setup(self):
        """
        image의 경로와 각 이미지들의 label을 계산하여 저장해두는 함수입니다.
        """
        csv_path = os.path.join(self.img_dir, 'train.csv')
        csv_data = pd.read_csv(csv_path)
        
        csv_data = csv_data.loc[self.indices]
        csv_data_list = list(csv_data.values)
        cnt = 0
        for id, gender, _, age, path in csv_data_list:
            img_folder = os.path.join(self.img_dir, 'images', path)
            for file_name in os.listdir(img_folder):
                _file_name, ext = os.path.splitext(file_name)
                if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue
                    
                img_path = os.path.join(img_folder, file_name)  # (data_path, 000004_male_Asian_54, mask1.jpg)
                mask_label = self._file_names[_file_name]
                gender_label = GenderLabels.from_str(gender)
                age_label = AgeLabels.from_number(age)
                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)
                
    def __getitem__(self, index):
        """
        데이터를 불러오는 함수입니다. 
        데이터셋 class에 데이터 정보가 저장되어 있고, index를 통해 해당 위치에 있는 데이터 정보를 불러옵니다.
        
        Args:
            index: 불러올 데이터의 인덱스값입니다.
        """
        # 이미지를 불러옵니다.
        image_path = self.image_paths[index]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         image = Image.open(image_path)
        
        # 레이블을 불러옵니다.
        mask_label = self.mask_labels[index]
        gender_label = self.gender_labels[index]
        age_label = self.age_labels[index]
        multi_class_label = mask_label * 6 + gender_label * 3 + age_label
        
        # 이미지를 Augmentation 시킵니다.
        image_transform = self.transform(image=np.array(image))['image']
        return image_transform, (mask_label, gender_label, age_label, multi_class_label) # multi_class_label

    def __len__(self):
        return len(self.image_paths)
    
    
    # 이외 utils
#     @staticmethod
#     def split_data(val_ratio):
#         return val
#     @staticmethod
#     def encode_multi_class(mask_label, gender_label, age_label) -> int:
#         return mask_label * 6 + gender_label * 3 + age_label
    
#     @staticmethod
#     def decode_multi_class(multi_class_label) -> Tuple[MaskLabels, GenderLabels, AgeLabels]:
#         mask_label = (multi_class_label // 6) % 3
#         gender_label = (multi_class_label // 3) % 2
#         age_label = multi_class_label % 3
#         return mask_label, gender_label, age_label

In [9]:
# class Sampler(ImbalancedDatasetSampler):
#     def _get_labels(self, dataset):
#         return dataset.age_labels

In [10]:
def getDataloader(train_idx, valid_idx):
    train_dataset = newBaseDataset(img_dir = config.data_dir, indices = train_idx)
    val_dataset = newBaseDataset(img_dir = config.data_dir, indices = valid_idx)
    train_dataset.set_transform(train_transform)
    val_dataset.set_transform(val_transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        num_workers=multiprocessing.cpu_count() // 2,
#         sampler=Sampler(train_dataset),
        shuffle=True,
        pin_memory=use_cuda
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.valid_batch_size,
        num_workers=multiprocessing.cpu_count() // 2,
        shuffle=False,
        pin_memory=use_cuda
    )
    return train_loader, val_loader

# define model

In [11]:
# class BaseModel(nn.Module):
#     def __init__(self, num_classes):
#         super().__init__()
#         self.base_model = models.vit_b_16(pretrained=True)
#         self.base_model.heads.head = nn.Linear(in_features=768, out_features=num_classes, bias=True)

#     def forward(self, x):
#         return self.base_model(x)

In [12]:
# class ResnextModel(nn.Module):
#     def __init__(self, num_classes):
#         super().__init__()
#         self.base_model = models.resnext50_32x4d(pretrained=True)
#         self.base_model.fc = nn.Sequential(
#             nn.Linear(in_features=2048, out_features=1000, bias=True),
#             nn.ELU(True),
#             nn.Dropout(),
#             nn.Linear(in_features=1000, out_features=num_classes, bias=True)
#         )
#     def forward(self, x):
#         return self.base_model(x)

In [13]:
# class Resnext101Model(nn.Module):
#     def __init__(self, num_classes):
#         super().__init__()
#         self.base_model = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.DEFAULT)
#         self.base_model.fc = nn.Sequential(
#             nn.Linear(in_features=2048, out_features=1000, bias=True),
#             nn.ELU(True),
#             nn.Dropout(),
#             nn.Linear(in_features=1000, out_features=num_classes, bias=True)
#         )
#     def forward(self, x):
#         return self.base_model(x)

In [14]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x
    
class ResnextMultiheadModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.base_model = models.resnext50_32x4d(pretrained=True)
        self.base_model.fc = Identity()

        self.fc_mask_classifier = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1000, bias=True),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=3, bias=True)
        )
        self.fc_age_classifier = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1000, bias=True),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=1000, bias=True),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=3, bias=True)
        )
        self.fc_gender_classifier = nn.Sequential(
            nn.Linear(in_features=2048, out_features=1000, bias=True),
            nn.ELU(True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(in_features=1000, out_features=2, bias=True)
        )
    def forward(self, x):
        x = self.base_model(x)
        mask = self.fc_mask_classifier(x)
        age = self.fc_age_classifier(x)
        gender = self.fc_gender_classifier(x)
        return mask, age, gender

# train

In [15]:
# model = ResnextMultiheadModel().to(device)
# model = torch.nn.DataParallel(model)

In [16]:
# https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/8
class FocalLoss(nn.Module):
    def __init__(self, weight=None,
                 gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            target_tensor,
            weight=self.weight,
            reduction=self.reduction
        )

class F1Loss(nn.Module):
    def __init__(self, classes=3, epsilon=1e-7):
        super().__init__()
        self.classes = classes
        self.epsilon = epsilon

    def forward(self, y_pred, y_true):
        assert y_pred.ndim == 2
        assert y_true.ndim == 1
        y_true = F.one_hot(y_true, self.classes).to(torch.float32)
        y_pred = F.softmax(y_pred, dim=1)

        tp = (y_true * y_pred).sum(dim=0).to(torch.float32)
        tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).to(torch.float32)
        fp = ((1 - y_true) * y_pred).sum(dim=0).to(torch.float32)
        fn = (y_true * (1 - y_pred)).sum(dim=0).to(torch.float32)

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        f1 = 2 * (precision * recall) / (precision + recall + self.epsilon)
        f1 = f1.clamp(min=self.epsilon, max=1 - self.epsilon)
        return 1 - f1.mean()
    
class LADELoss(nn.Module):
    def __init__(self, num_classes=10, img_num_per_cls=None, remine_lambda=0.1):
        super().__init__()
        if img_num_per_cls is not None:
#             self.img_num_per_cls = calculate_prior(num_classes, img_max, prior, prior_txt, return_num=True).float().cuda()
            self.img_num_per_cls = img_num_per_cls
            self.prior = self.img_num_per_cls / self.img_num_per_cls.sum()
        else:
            self.prior = None
        self.prior = self.prior.to(device)
        self.balanced_prior = torch.tensor(1. / num_classes).float().cuda()
        self.remine_lambda = remine_lambda

        self.num_classes = num_classes
        self.cls_weight = (self.img_num_per_cls.float() / torch.sum(self.img_num_per_cls.float())).cuda()

    def mine_lower_bound(self, x_p, x_q, num_samples_per_cls):
        N = x_p.size(-1)
        first_term = torch.sum(x_p, -1) / (num_samples_per_cls + 1e-8)
        second_term = torch.logsumexp(x_q, -1) - np.log(N)

        return first_term - second_term, first_term, second_term

    def remine_lower_bound(self, x_p, x_q, num_samples_per_cls):
        loss, first_term, second_term = self.mine_lower_bound(x_p, x_q, num_samples_per_cls)
        reg = (second_term ** 2) * self.remine_lambda
        return loss - reg, first_term, second_term

    def forward(self, y_pred, target, q_pred=None):
        """
        y_pred: N x C
        target: N
        """
        per_cls_pred_spread = y_pred.T * (target == torch.arange(0, self.num_classes).view(-1, 1).type_as(target))  # C x N
        pred_spread = (y_pred - torch.log(self.prior + 1e-9) + torch.log(self.balanced_prior + 1e-9)).T  # C x N

        num_samples_per_cls = torch.sum(target == torch.arange(0, self.num_classes).view(-1, 1).type_as(target), -1).float()  # C
        estim_loss, first_term, second_term = self.remine_lower_bound(per_cls_pred_spread, pred_spread, num_samples_per_cls)

        loss = -torch.sum(estim_loss * self.cls_weight)
        return loss

### loss , metric , optimizer , scheduler

In [17]:
# criterion = FocalLoss()
# optimizer = Adam(
#     filter(lambda p: p.requires_grad, model.parameters()),
#     lr=config.lr,
#     weight_decay=5e-4
# )
# optimizer = SGD(
#     filter(lambda p: p.requires_grad, model.parameters()),
#     lr=config.lr,
#     momentum=0.9,
#     nesterov=True
# )
# optimizer = RMSprop(
#     filter(lambda p: p.requires_grad, model.parameters()),
#     lr=config.lr,
#     weight_decay=5e-4,
#     momentum=0.9
# )
# scheduler = StepLR(optimizer, config.lr_decay_step, gamma=0.5)
# -- scheduler: ReduceLROnPlateau
# 성능이 향상되지 않을 때 learning rate를 줄입니다. patience=10은 10회 동안 성능 향상이 없을 경우입니다.
# scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=10)
# -- scheduler: CosineAnnealingLR
# CosineAnnealing은 learning rate를 cosine 그래프처럼 변화시킵니다.
# scheduler = CosineAnnealingLR(optimizer, T_max=2, eta_min=0.)
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [18]:
# 세이브 경로
path = Path(config.save_dir)
if (not path.exists()):
    save_dir = str(path)
else:
    dirs = glob.glob(f"{path}*")
    matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
    i = [int(m.groups()[0]) for m in matches if m]
    n = max(i) + 1 if i else 2
    save_dir = f"{path}{n}"

print("save path : " + save_dir)
os.makedirs(save_dir)

save path : kfoldEnsemble_exp6


### start training

In [19]:
# logger = SummaryWriter(log_dir=save_dir)
# with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
#     json.dump(vars(config), f, ensure_ascii=False, indent=4)

In [20]:
# k-fold setting
skf = StratifiedKFold(n_splits=config.n_splits)
csv_path = os.path.join(config.data_dir,'train.csv')
data = pd.read_csv(csv_path)
age_img_num_per_cls =  torch.Tensor([data['age'][lambda x : x <30].count(), data['age'][lambda x : (x >= 30) & (x<60)].count(), data['age'][lambda x : x >= 60].count()])
labels = [ AgeLabels.from_number(d) for d in list(data.age.values)]
data_list = list(data.values)
# for i, (train_idx, valid_idx) in enumerate(skf.split(data_list, labels)):
#     print(i)
#     print(len(train_idx), len(valid_idx))
#     print(data.loc[valid_idx])


for i, (train_idx, valid_idx) in enumerate(skf.split(data_list, labels)):
    counter = 0
    best_val_acc = 0
    best_val_loss = np.inf
    best_val_f1 = 0
    train_loader, val_loader = getDataloader(train_idx, valid_idx)
    model = ResnextMultiheadModel().to(device)
    model = torch.nn.DataParallel(model)
#     criterion = FocalLoss()
    criterion_mask = FocalLoss() #F1Loss(classes=3)
    criterion_age = LADELoss(num_classes=3, img_num_per_cls=age_img_num_per_cls) #F1Loss(classes=3)
    criterion_gender = FocalLoss() #F1Loss(classes=2)
    optimizer = Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config.lr,
        weight_decay=5e-4
    )
    scheduler = ExponentialLR(optimizer, gamma=0.85)
#     scheduler = StepLR(optimizer, config.lr_decay_step, gamma=0.5)
    print(f"kfold[{i+1}/{config.n_splits}]")
    for epoch in range(config.epochs):
        # train loop
        model.train()
        loss_value = 0
        matches = 0
        label_list=[]
        pred_list=[]
        mask_label_list=[]
        mask_pred_list=[]
        age_label_list=[]
        age_pred_list=[]
        gender_label_list=[]
        gender_pred_list=[]
        for idx, train_batch in enumerate(train_loader):
            inputs, (mask_labels, gender_labels, age_labels, multi_class_labels) = train_batch
            inputs = inputs.to(device)
            mask_labels = mask_labels.to(device)
            gender_labels = gender_labels.to(device)
            age_labels = age_labels.to(device)
            multi_class_labels = multi_class_labels.to(device)

            optimizer.zero_grad()

            mask_outs, age_outs, gender_outs = model(inputs)
            mask_preds = torch.argmax(mask_outs, dim=-1)
            age_preds = torch.argmax(age_outs, dim=-1)
            gender_preds = torch.argmax(gender_outs, dim=-1)

            multi_preds = mask_preds * 6 + gender_preds * 3 + age_preds
            
            
#             mask_loss = criterion(mask_outs, mask_labels)
#             age_loss = criterion(age_outs, age_labels)
#             gender_loss = criterion(gender_outs, gender_labels)
            mask_loss = criterion_mask(mask_outs, mask_labels)
            age_loss = criterion_age(age_outs, age_labels)
            gender_loss = criterion_gender(gender_outs, gender_labels)

            loss = mask_loss + gender_loss + age_loss

            loss.backward()
            optimizer.step()

            label_list+=multi_class_labels.detach().cpu()
            pred_list+=multi_preds.detach().cpu()

            mask_label_list+=mask_labels.detach().cpu()
            mask_pred_list+=mask_preds.detach().cpu()
            age_label_list+=age_labels.detach().cpu()
            age_pred_list+=age_preds.detach().cpu()
            gender_label_list+=gender_labels.detach().cpu()
            gender_pred_list+=gender_preds.detach().cpu()

            loss_value += loss.item()
            matches += (multi_preds == multi_class_labels).sum().item()
            if (idx + 1) % config.log_interval == 0:
                train_loss = loss_value / config.log_interval
                train_acc = matches / config.batch_size / config.log_interval
                train_f1 = f1_score(label_list, pred_list, average='macro')
                mask_f1 = f1_score(mask_label_list, mask_pred_list, average='macro')
                age_f1 = f1_score(age_label_list, age_pred_list, average='macro')
                gender_f1 = f1_score(gender_label_list, gender_pred_list, average='macro')
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch}/{config.epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || training f1_score {train_f1:4.4} || lr {current_lr}"
                )
                print(f"          training each f1 || mask f1 {mask_f1:4.4} || age f1 {age_f1:4.4} || gender f1 {gender_f1:4.4}")
    #             logger.add_scalar("Train/loss", train_loss, epoch * len(train_loader) + idx)
    #             logger.add_scalar("Train/accuracy", train_acc, epoch * len(train_loader) + idx)
    #             logger.add_scalar("Train/f1_score", train_f1, epoch * len(train_loader) + idx)

                loss_value = 0
                matches = 0
                label_list=[]
                pred_list=[]

        scheduler.step()

        # val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            val_loss_items = []
            val_acc_items = []
            label_list=[]
            pred_list=[]
            mask_label_list=[]
            mask_pred_list=[]
            age_label_list=[]
            age_pred_list=[]
            gender_label_list=[]
            gender_pred_list=[]
            data_len = 0
            for val_batch in val_loader:
                inputs, (mask_labels, gender_labels, age_labels, multi_class_labels) = val_batch
                data_len += len(multi_class_labels)
                inputs = inputs.to(device)
                mask_labels = mask_labels.to(device)
                gender_labels = gender_labels.to(device)
                age_labels = age_labels.to(device)
                multi_class_labels = multi_class_labels.to(device)

                mask_outs, age_outs, gender_outs = model(inputs)
                mask_preds = torch.argmax(mask_outs, dim=-1)
                age_preds = torch.argmax(age_outs, dim=-1)
                gender_preds = torch.argmax(gender_outs, dim=-1)
                multi_preds = mask_preds * 6 + gender_preds * 3 + age_preds
                
#                 mask_loss = criterion(mask_outs, mask_labels)
#                 age_loss = criterion(age_outs, age_labels)
#                 gender_loss = criterion(gender_outs, gender_labels)
                mask_loss = criterion_mask(mask_outs, mask_labels)
                age_loss = criterion_age(age_outs, age_labels)
                gender_loss = criterion_gender(gender_outs, gender_labels)
                loss = mask_loss + age_loss + gender_loss

                label_list+=multi_class_labels.detach().cpu()
                pred_list+=multi_preds.detach().cpu()
                mask_label_list+=mask_labels.detach().cpu()
                mask_pred_list+=mask_preds.detach().cpu()
                age_label_list+=age_labels.detach().cpu()
                age_pred_list+=age_preds.detach().cpu()
                gender_label_list+=gender_labels.detach().cpu()
                gender_pred_list+=gender_preds.detach().cpu()
                


                acc_item = (multi_class_labels == multi_preds).sum().item()
                val_loss_items.append(loss.item())
                val_acc_items.append(acc_item)

            val_loss = float(np.sum(val_loss_items) / len(val_loader))
            val_acc = float(np.sum(val_acc_items) / data_len)
            val_f1= float(f1_score(label_list,pred_list,average="macro"))
            val_mask_f1 = f1_score(mask_label_list, mask_pred_list, average='macro')
            val_age_f1 = f1_score(age_label_list, age_pred_list, average='macro')
            val_gender_f1 = f1_score(gender_label_list, gender_pred_list, average='macro')

            if val_f1 > best_val_f1:
                print(f"New best model for val f1 : {val_f1:4.4}! saving the best model..")
                torch.save(model.module.state_dict(), f"{save_dir}/best_{i}.pth")
                best_val_f1 = val_f1
                best_val_acc = val_acc
                best_val_loss = val_loss
                counter = 0
            else:
                counter += 1
            torch.save(model.module.state_dict(), f"{save_dir}/last_{i}.pth")
            print(
                f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.4}, f1: {val_f1:4.4} || "
                f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.4}, best f1: {best_val_f1:4.4}"
            )
            print(f"[Val] each f1 || mask f1 {val_mask_f1:4.4} || age f1 {val_age_f1:4.4} || gender f1 {val_gender_f1:4.4}")
            if counter > config.patience:
                print("Early Stopping!")
                counter = 0
                break
#             logger.add_scalar("Val/loss", val_loss, epoch)
#             logger.add_scalar("Val/accuracy", val_acc, epoch)
#             logger.add_scalar("Val/f1", val_f1, epoch)
            print()



kfold[1/5]
Epoch[0/50](50/237) || training loss -0.486 || training accuracy 74.25% || training f1_score 0.5796 || lr 0.0001
          training each f1 || mask f1 0.8991 || age f1 0.6683 || gender f1 0.9047
Epoch[0/50](100/237) || training loss -0.7937 || training accuracy 87.91% || training f1_score 0.7535 || lr 0.0001
          training each f1 || mask f1 0.9414 || age f1 0.7436 || gender f1 0.9375
Epoch[0/50](150/237) || training loss -0.8551 || training accuracy 90.09% || training f1_score 0.8162 || lr 0.0001
          training each f1 || mask f1 0.958 || age f1 0.7861 || gender f1 0.9478
Epoch[0/50](200/237) || training loss -0.8884 || training accuracy 92.22% || training f1_score 0.8453 || lr 0.0001
          training each f1 || mask f1 0.9659 || age f1 0.8152 || gender f1 0.9553
Calculating validation results...
New best model for val f1 : 0.6732! saving the best model..
[Val] acc : 84.50%, loss: -0.6852, f1: 0.6732 || best acc : 84.50%, best loss: -0.6852, best f1: 0.6732
[Val] 



kfold[2/5]
Epoch[0/50](50/237) || training loss -0.4784 || training accuracy 72.25% || training f1_score 0.5338 || lr 0.0001
          training each f1 || mask f1 0.8946 || age f1 0.6635 || gender f1 0.8909
Epoch[0/50](100/237) || training loss -0.805 || training accuracy 88.66% || training f1_score 0.7957 || lr 0.0001
          training each f1 || mask f1 0.9414 || age f1 0.7572 || gender f1 0.9241
Epoch[0/50](150/237) || training loss -0.8579 || training accuracy 90.97% || training f1_score 0.8089 || lr 0.0001
          training each f1 || mask f1 0.9549 || age f1 0.7986 || gender f1 0.9403
Epoch[0/50](200/237) || training loss -0.909 || training accuracy 92.91% || training f1_score 0.8591 || lr 0.0001
          training each f1 || mask f1 0.9644 || age f1 0.8248 || gender f1 0.9501
Calculating validation results...
New best model for val f1 : 0.6586! saving the best model..
[Val] acc : 80.32%, loss: -0.7333, f1: 0.6586 || best acc : 80.32%, best loss: -0.7333, best f1: 0.6586
[Val] 



kfold[3/5]
Epoch[0/50](50/237) || training loss -0.5301 || training accuracy 75.34% || training f1_score 0.5625 || lr 0.0001
          training each f1 || mask f1 0.9088 || age f1 0.6662 || gender f1 0.9092
Epoch[0/50](100/237) || training loss -0.8266 || training accuracy 88.91% || training f1_score 0.7953 || lr 0.0001
          training each f1 || mask f1 0.9487 || age f1 0.7675 || gender f1 0.9359
Epoch[0/50](150/237) || training loss -0.8723 || training accuracy 91.34% || training f1_score 0.8295 || lr 0.0001
          training each f1 || mask f1 0.9623 || age f1 0.7987 || gender f1 0.9493
Epoch[0/50](200/237) || training loss -0.9122 || training accuracy 93.75% || training f1_score 0.8857 || lr 0.0001
          training each f1 || mask f1 0.9692 || age f1 0.826 || gender f1 0.9582
Calculating validation results...
New best model for val f1 : 0.6112! saving the best model..
[Val] acc : 76.83%, loss: -0.4977, f1: 0.6112 || best acc : 76.83%, best loss: -0.4977, best f1: 0.6112
[Val]



kfold[4/5]
Epoch[0/50](50/237) || training loss -0.5013 || training accuracy 73.88% || training f1_score 0.5489 || lr 0.0001
          training each f1 || mask f1 0.9097 || age f1 0.6798 || gender f1 0.8892
Epoch[0/50](100/237) || training loss -0.8065 || training accuracy 88.09% || training f1_score 0.7297 || lr 0.0001
          training each f1 || mask f1 0.9488 || age f1 0.7633 || gender f1 0.9221
Epoch[0/50](150/237) || training loss -0.8738 || training accuracy 91.50% || training f1_score 0.8507 || lr 0.0001
          training each f1 || mask f1 0.9612 || age f1 0.8105 || gender f1 0.9399
Epoch[0/50](200/237) || training loss -0.9069 || training accuracy 92.72% || training f1_score 0.872 || lr 0.0001
          training each f1 || mask f1 0.9669 || age f1 0.8404 || gender f1 0.9492
Calculating validation results...
New best model for val f1 : 0.554! saving the best model..
[Val] acc : 70.63%, loss: -0.9466, f1: 0.554 || best acc : 70.63%, best loss: -0.9466, best f1: 0.554
[Val] ea



kfold[5/5]
Epoch[0/50](50/237) || training loss -0.5103 || training accuracy 74.16% || training f1_score 0.5331 || lr 0.0001
          training each f1 || mask f1 0.8937 || age f1 0.6604 || gender f1 0.8855
Epoch[0/50](100/237) || training loss -0.8184 || training accuracy 89.03% || training f1_score 0.7481 || lr 0.0001
          training each f1 || mask f1 0.9389 || age f1 0.7503 || gender f1 0.9291
Epoch[0/50](150/237) || training loss -0.8673 || training accuracy 90.78% || training f1_score 0.8159 || lr 0.0001
          training each f1 || mask f1 0.9551 || age f1 0.7984 || gender f1 0.942
Epoch[0/50](200/237) || training loss -0.8919 || training accuracy 92.34% || training f1_score 0.8377 || lr 0.0001
          training each f1 || mask f1 0.9629 || age f1 0.8238 || gender f1 0.9504
Calculating validation results...
New best model for val f1 : 0.6491! saving the best model..
[Val] acc : 73.62%, loss: -0.7345, f1: 0.6491 || best acc : 73.62%, best loss: -0.7345, best f1: 0.6491
[Val]