In [1]:
import os
import sys
import re
import argparse
import glob
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import random
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from pathlib import Path
from importlib import import_module
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from time import time
from enum import Enum
import torch.optim as optim

import torch
import torch.utils.data as data

import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
class cfg:
    data_dir = '/opt/ml/input/data/train'  
    img_dir = f'{data_dir}/images'
    df_path = f'{data_dir}/train.csv'

In [3]:
cfg = cfg()
data_dir = cfg.data_dir
img_dir = cfg.img_dir
df_path = cfg.df_path

In [4]:
mean, std = (0.55800916, 0.51224077, 0.47767341), (0.21817792, 0.23804603, 0.25183411)

In [5]:
### 마스크 여부, 성별, 나이를 mapping할 클래스를 생성합니다.

class MaskLabels(int, Enum):
    MASK = 0
    INCORRECT = 1
    NORMAL = 2


class GenderLabels(int, Enum):
    MALE = 0
    FEMALE = 1

    @classmethod
    def from_str(cls, value: str) -> int:
        value = value.lower()
        if value == "male":
            return cls.MALE
        elif value == "female":
            return cls.FEMALE
        else:
            raise ValueError(f"Gender value should be either 'male' or 'female', {value}")


class AgeLabels(int, Enum):
    YOUNG = 0
    MIDDLE = 1
    OLD = 2

    @classmethod
    def from_number(cls, value: str) -> int:
        try:
            value = int(value)
        except Exception:
            raise ValueError(f"Age value should be numeric, {value}")

        if value < 30:
            return cls.YOUNG
        elif value < 60:
            return cls.MIDDLE
        else:
            return cls.OLD

In [6]:
class MaskBaseDataset(data.Dataset):
    num_classes = 3 * 2 * 3

    _file_names = {
        "mask1": MaskLabels.MASK,
        "mask2": MaskLabels.MASK,
        "mask3": MaskLabels.MASK,
        "mask4": MaskLabels.MASK,
        "mask5": MaskLabels.MASK,
        "incorrect_mask": MaskLabels.INCORRECT,
        "normal": MaskLabels.NORMAL
    }

    image_paths = []
    mask_labels = []
    gender_labels = []
    age_labels = []

    def __init__(self, img_dir, mean, std, transform=None):
        """
        MaskBaseDataset을 initialize 합니다.

        Args:
            img_dir: 학습 이미지 폴더의 root directory 입니다.
            transform: Augmentation을 하는 함수입니다.
        """
        self.img_dir = img_dir
        self.mean = mean
        self.std = std
        self.transform = transform

        self.setup()

    def set_transform(self, transform):
        """
        transform 함수를 설정하는 함수입니다.
        """
        self.transform = transform
        
    def setup(self):
        """
        image의 경로와 각 이미지들의 label을 계산하여 저장해두는 함수입니다.
        """
        profiles = os.listdir(self.img_dir)
        for profile in profiles:
            if profile.startswith("."):  # "." 로 시작하는 파일은 무시합니다
                continue

            img_folder = os.path.join(self.img_dir, profile)
            for file_name in os.listdir(img_folder):
                _file_name, ext = os.path.splitext(file_name)
                if _file_name not in self._file_names:  # "." 로 시작하는 파일 및 invalid 한 파일들은 무시합니다
                    continue

                img_path = os.path.join(self.img_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                mask_label = self._file_names[_file_name]

                id, gender, race, age = profile.split("_")
                gender_label = GenderLabels.from_str(gender)
                age_label = AgeLabels.from_number(age)

                self.image_paths.append(img_path)
                self.mask_labels.append(mask_label)
                self.gender_labels.append(gender_label)
                self.age_labels.append(age_label)

    def __getitem__(self, index):
        """
        데이터를 불러오는 함수입니다. 
        데이터셋 class에 데이터 정보가 저장되어 있고, index를 통해 해당 위치에 있는 데이터 정보를 불러옵니다.
        
        Args:
            index: 불러올 데이터의 인덱스값입니다.
        """
        # 이미지를 불러옵니다.
        image_path = self.image_paths[index]
        image = Image.open(image_path)
        
        # 레이블을 불러옵니다.
        mask_label = self.mask_labels[index]
        gender_label = self.gender_labels[index]
        age_label = self.age_labels[index]
        multi_class_label = mask_label * 6 + gender_label * 3 + age_label
        
        # 이미지를 Augmentation 시킵니다.
        image_transform = self.transform(image=np.array(image))['image']
        return image_transform, multi_class_label

    def __len__(self):
        return len(self.image_paths)

In [7]:
from albumentations import *
from albumentations.pytorch import ToTensorV2


def get_transforms(need=('train', 'val'), img_size=(512, 384), mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246)):
    """
    train 혹은 validation의 augmentation 함수를 정의합니다. train은 데이터에 많은 변형을 주어야하지만, validation에는 최소한의 전처리만 주어져야합니다.
    
    Args:
        need: 'train', 혹은 'val' 혹은 둘 다에 대한 augmentation 함수를 얻을 건지에 대한 옵션입니다.
        img_size: Augmentation 이후 얻을 이미지 사이즈입니다.
        mean: 이미지를 Normalize할 때 사용될 RGB 평균값입니다.
        std: 이미지를 Normalize할 때 사용될 RGB 표준편차입니다.

    Returns:
        transformations: Augmentation 함수들이 저장된 dictionary 입니다. transformations['train']은 train 데이터에 대한 augmentation 함수가 있습니다.
    """
    transformations = {}
    if 'train' in need:
        transformations['train'] = Compose([
            Resize(img_size[0], img_size[1], p=1.0),
            HorizontalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            GaussNoise(p=0.5),
            Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)
    if 'val' in need:
        transformations['val'] = Compose([
            Resize(img_size[0], img_size[1]),
            Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)
    return transformations

In [8]:
# 정의한 Augmentation 함수와 Dataset 클래스 객체를 생성합니다.
transform = get_transforms(mean=mean, std=std)

dataset = MaskBaseDataset(
    img_dir=img_dir,
    mean=mean,
    std=std
)

# train dataset과 validation dataset을 8:2 비율로 나눕니다.
n_val = int(len(dataset) * 0.2)
n_train = len(dataset) - n_val
train_dataset, val_dataset = data.random_split(dataset, [n_train, n_val])

# 각 dataset에 augmentation 함수를 설정합니다.
train_dataset.dataset.set_transform(transform['train'])
val_dataset.dataset.set_transform(transform['val'])

In [9]:
!pip install timm



In [10]:
import timm

m = timm.create_model('efficientnetv2_m', num_classes=18)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/8
class FocalLoss(nn.Module):
    def __init__(self, weight=None,
                 gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            target_tensor,
            weight=self.weight,
            reduction=self.reduction
        )


class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes=3, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))


# https://gist.github.com/SuperShinyEyes/dcc68a08ff8b615442e3bc6a9b55a354
class F1Loss(nn.Module):
    def __init__(self, classes=3, epsilon=1e-7):
        super().__init__()
        self.classes = classes
        self.epsilon = epsilon

    def forward(self, y_pred, y_true):
        assert y_pred.ndim == 2
        assert y_true.ndim == 1
        y_true = F.one_hot(y_true, self.classes).to(torch.float32)
        y_pred = F.softmax(y_pred, dim=1)

        tp = (y_true * y_pred).sum(dim=0).to(torch.float32)
        tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).to(torch.float32)
        fp = ((1 - y_true) * y_pred).sum(dim=0).to(torch.float32)
        fn = (y_true * (1 - y_pred)).sum(dim=0).to(torch.float32)

        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)

        f1 = 2 * (precision * recall) / (precision + recall + self.epsilon)
        f1 = f1.clamp(min=self.epsilon, max=1 - self.epsilon)
        return 1 - f1.mean()

In [12]:
_criterion_entrypoints = {
    'cross_entropy': nn.CrossEntropyLoss,
    'focal': FocalLoss,
    'label_smoothing': LabelSmoothingLoss,
    'f1': F1Loss
}


def criterion_entrypoint(criterion_name):
    return _criterion_entrypoints[criterion_name]


def is_criterion(criterion_name):
    return criterion_name in _criterion_entrypoints


def create_criterion(criterion_name, **kwargs):
    if is_criterion(criterion_name):
        create_fn = criterion_entrypoint(criterion_name)
        criterion = create_fn(**kwargs)
    else:
        raise RuntimeError('Unknown loss (%s)' % criterion_name)
    return criterion

In [13]:
# def seed_everything(seed):
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.cuda.manual_seed_all(seed)  # if use multi-GPU
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False
#     np.random.seed(seed)
#     random.seed(seed)

# def get_lr(optimizer):
#     for param_group in optimizer.param_groups:
#         return param_group['lr']
    
# def increment_path(path, exist_ok=False):
#     """ Automatically increment path, i.e. runs/exp --> runs/exp0, runs/exp1 etc.

#     Args:
#         path (str or pathlib.Path): f"{model_dir}/{args.name}".
#         exist_ok (bool): whether increment path (increment if False).
#     """
#     path = Path(path)
#     if (path.exists() and exist_ok) or (not path.exists()):
#         return str(path)
#     else:
#         dirs = glob.glob(f"{path}*")
#         matches = [re.search(rf"%s(\d+)" % path.stem, d) for d in dirs]
#         i = [int(m.groups()[0]) for m in matches if m]
#         n = max(i) + 1 if i else 2
#         return f"{path}{n}"

    
# def train(data_dir, model_dir, args):
#     seed_everything(args.seed)

#     save_dir = increment_path(os.path.join(model_dir, args.name))
#     os.makedirs(save_dir, exist_ok=True)
#     print('model save path',save_dir)

#     # -- settings
#     use_cuda = torch.cuda.is_available()
#     device = torch.device("cuda" if use_cuda else "cpu")
    
#     # -- dataset
#     dataset_module = getattr(import_module("dataset"), args.dataset)  # default: MaskBaseDataset
#     dataset = dataset_module(
#         data_dir=data_dir, transform=My_transform('train')
#     )

#     # -- data_loader
#     train_set, val_set = dataset.split_dataset()
#     print('train data 개수:',len(train_set))
#     print('val data 개수:',len(val_set))

#     train_loader = DataLoader(
#         train_set,
#         batch_size=args.batch_size,
#         num_workers=2,
#         shuffle=True,
#         pin_memory=use_cuda,
#         drop_last=False,
#     )

#     val_loader = DataLoader(
#         val_set,
#         batch_size=args.valid_batch_size,
#         num_workers=2,
#         shuffle=False,
#         pin_memory=use_cuda,
#         drop_last=False,
#     )

#     # -- model
#     model = m.to(device)
#     model = torch.nn.DataParallel(model)
# #     model = Multi_ModelClassification()
# #     path = os.path.join('/opt/ml/code/baseline/v2/model/exp4', 'last.pth')
# #     model.load_state_dict(torch.load(path, map_location=device))
# #     model = torch.nn.DataParallel(model)

#     # -- loss & metric
#     criterion = create_criterion(args.criterion)  # default: cross_entropy
#     opt_module = getattr(import_module("torch.optim"), args.optimizer)  # default: SGD
#     optimizer = opt_module(
#         filter(lambda p: p.requires_grad, model.parameters()),
#         lr=args.lr,
#     )
#     print(optimizer)
#     scheduler = StepLR(optimizer, args.lr_decay_step, gamma=0.5)

#     age_best_val_acc,gender_best_val_acc,mask_best_val_acc=0,0,0
#     age_best_val_loss,gender_best_val_loss,mask_best_val_loss=np.inf,np.inf,np.inf
#     for epoch in range(args.epochs):
#         # train loop
#         model.train()
#         loss_value = 0
#         age_loss_value,gender_loss_value,mask_loss_value=0,0,0
#         age_matches,gender_matches,mask_matches = 0,0,0
#         for idx, train_batch in enumerate(train_loader):
#             inputs,age_label,gender_label,mask_label=train_batch
#             inputs = inputs.to(device)
#             age_label = age_label.to(device)
#             gender_label = gender_label.to(device)
#             mask_label = mask_label.to(device)
            
#             optimizer.zero_grad()

#             age_outs, gender_outs, mask_outs = model(inputs)
            
#             age_preds = torch.argmax(age_outs, dim=-1)
#             age_loss = criterion(age_outs, age_label)
            
#             gender_preds = torch.argmax(gender_outs, dim=-1)
#             gender_loss = criterion(gender_outs, gender_label)
            
#             mask_preds = torch.argmax(mask_outs, dim=-1)
#             mask_loss = criterion(mask_outs, mask_label)
#             # loss balancing (이렇게 주는게 맞나)
#             loss = 0.5*age_loss+0.25*gender_loss+0.25*mask_loss

#             loss.backward()
#             optimizer.step()

#             loss_value += loss.item()
           
#             age_loss_value+=age_loss.item()
#             gender_loss_value+=gender_loss.item()
#             mask_loss_value+=mask_loss.item()
            
#             age_matches += (age_preds == age_label).sum().item()
#             gender_matches += (gender_preds == gender_label).sum().item()
#             mask_matches += (mask_preds == mask_label).sum().item()
#             if (idx + 1) % args.log_interval == 0:
#                 age_train_loss = age_loss_value / args.log_interval
#                 age_train_acc = age_matches / args.batch_size / args.log_interval
#                 gender_train_loss = gender_loss_value / args.log_interval
#                 gender_train_acc = gender_matches / args.batch_size / args.log_interval
#                 mask_train_loss = mask_loss_value / args.log_interval
#                 mask_train_acc = mask_matches / args.batch_size / args.log_interval
#                 total_loss = loss_value / args.log_interval
#                 current_lr = get_lr(optimizer)
#                 print(
#                     f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
#                     f"age training loss {age_train_loss:4.4} || training accuracy {age_train_acc:4.2%} || lr {current_lr}"
#                 )
#                 print(
#                     f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
#                     f"gender training loss {gender_train_loss:4.4} || training accuracy {gender_train_acc:4.2%} || lr {current_lr}"
#                 )
#                 print(
#                     f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
#                     f"mask training loss {mask_train_loss:4.4} || training accuracy {mask_train_acc:4.2%} || lr {current_lr}"
#                 )

#                 loss_value = 0
#                 age_matches,gender_matches,mask_matches = 0,0,0

#         scheduler.step()
#         wandb.log({'total_loss': total_loss, 'epoch': epoch})
#         wandb.log({'age_train_accuracy': age_train_acc, 'age_train_loss': age_train_loss,'epoch': epoch})
#         wandb.log({'gender_train_accuracy': gender_train_acc, 'gender_train_loss': gender_train_loss,'epoch': epoch})
#         wandb.log({'mask_train_accuracy': mask_train_acc, 'mask_train_loss': mask_train_loss,'epoch': epoch})

#         # val loop
#         with torch.no_grad():
#             print("Calculating validation results...")
#             model.eval()
#             age_val_loss_items,gender_val_loss_items,mask_val_loss_items = [],[],[]
#             age_val_acc_items,gender_val_acc_items,mask_val_acc_items = [],[],[]
#             figure = None
#             for val_batch in val_loader:
#                 inputs,age_label,gender_label,mask_label=val_batch
#                 inputs = inputs.to(device)
#                 age_label = age_label.to(device)
#                 gender_label = gender_label.to(device)
#                 mask_label = mask_label.to(device)
                
#                 age_outs, gender_outs, mask_outs = model(inputs)
#                 age_preds = torch.argmax(age_outs, dim=-1)
#                 age_loss = criterion(age_outs, age_label).item()
#                 gender_preds = torch.argmax(gender_outs, dim=-1)
#                 gender_loss = criterion(gender_outs, gender_label).item()
#                 mask_preds = torch.argmax(mask_outs, dim=-1)
#                 mask_loss = criterion(mask_outs, mask_label).item()
                
#                 age_matches = (age_preds == age_label).sum().item()
#                 gender_matches = (gender_preds == gender_label).sum().item()
#                 mask_matches = (mask_preds == mask_label).sum().item()


#                 age_val_loss_items.append(age_loss)
#                 age_val_acc_items.append(age_matches)
#                 gender_val_loss_items.append(gender_loss)
#                 gender_val_acc_items.append(gender_matches)
#                 mask_val_loss_items.append(mask_loss)
#                 mask_val_acc_items.append(mask_matches)

#             age_val_loss = np.sum(age_val_loss_items) / len(val_loader)
#             age_val_acc = np.sum(age_val_acc_items) / len(val_set)
#             gender_val_loss = np.sum(gender_val_loss_items) / len(val_loader)
#             gender_val_acc = np.sum(gender_val_acc_items) / len(val_set)
#             mask_val_loss = np.sum(mask_val_loss_items) / len(val_loader)
#             mask_val_acc = np.sum(mask_val_acc_items) / len(val_set)
            
#             age_best_val_loss = min(age_best_val_loss, age_val_loss)
#             gender_best_val_loss = min(gender_best_val_loss, gender_val_loss)
#             mask_best_val_loss = min(mask_best_val_loss, mask_val_loss)

#             if age_val_acc > age_best_val_acc:
#                 print(f"New best model for val accuracy : {age_val_acc:4.2%}! saving the best model..")
#                 torch.save(model.module.state_dict(), f"{save_dir}/best.pth")
#                 age_best_val_acc = age_val_acc
#             torch.save(model.module.state_dict(), f"{save_dir}/last.pth")
#             print(
#                 f"[Val] age acc : {age_val_acc:4.2%}, loss: {age_val_loss:4.2} || "
#                 f"best acc : {age_best_val_acc:4.2%}, best loss: {age_best_val_loss:4.2}"
#             )
#             print(
#                 f"[Val] gender acc : {gender_val_acc:4.2%}, loss: {gender_val_loss:4.2} || "
#                 f"best acc : {gender_best_val_acc:4.2%}, best loss: {gender_best_val_loss:4.2}"
#             )
#             print(
#                 f"[Val] mask acc : {mask_val_acc:4.2%}, loss: {mask_val_loss:4.2} || "
#                 f"best acc : {mask_best_val_acc:4.2%}, best loss: {mask_best_val_loss:4.2}"
#             )

#             wandb.log({'total_loss': total_loss, 'epoch': epoch})
#             wandb.log({'age_val_accuracy': age_val_acc, 'age_val_loss': age_val_loss,'epoch': epoch})
#             wandb.log({'gender_val_accuracy': gender_val_acc, 'gender_val_loss': gender_val_loss,'epoch': epoch})
#             wandb.log({'mask_val_accuracy': mask_val_acc, 'mask_val_loss': mask_val_loss,'epoch': epoch})
            
#             print()


# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()

#     # Data and model checkpoints directories
#     parser.add_argument('--seed', type=int, default=42, help='random seed (default: 42)')
#     parser.add_argument('--epochs', type=int, default=30, help='number of epochs to train (default: 1)')
#     parser.add_argument('--dataset', type=str, default='MaskSplitByProfileDataset', help='dataset augmentation type (default: MaskBaseDataset)')
#     parser.add_argument('--augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)')
#     parser.add_argument("--resize", nargs="+", type=list, default=[128, 96], help='resize size for image when training')
#     parser.add_argument('--batch_size', type=int, default=128, help='input batch size for training (default: 64)')
#     parser.add_argument('--valid_batch_size', type=int, default=1000, help='input batch size for validing (default: 1000)')
#     parser.add_argument('--model', type=str, default='BaseModel', help='model type (default: BaseModel)')
#     parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer type (default: SGD)')
#     parser.add_argument('--lr', type=float, default=1e-4, help='learning rate (default: 1e-3)')
#     parser.add_argument('--val_ratio', type=float, default=0.1, help='ratio for validaton (default: 0.2)')
#     parser.add_argument('--criterion', type=str, default='focal', help='criterion type (default: cross_entropy)')
#     parser.add_argument('--lr_decay_step', type=int, default=10, help='learning rate scheduler deacy step (default: 20)')
#     parser.add_argument('--log_interval', type=int, default=20, help='how many batches to wait before logging training status')
#     parser.add_argument('--name', default='exp', help='model save at {SM_MODEL_DIR}/{name}')

#     # Container environment
#     parser.add_argument('--data_dir', type=str, default=os.environ.get('SM_CHANNEL_TRAIN', '/opt/ml/input/data/train/images'))
#     parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR', '/opt/ml/code/baseline/v2/model'))
#     parser.add_argument("-f", "--fff", help="a dummy argument to fool ipython", default="1")



#     args = parser.parse_args()
#     print(args)

#     data_dir = args.data_dir
#     model_dir = args.model_dir

#     train(data_dir, model_dir, args)

In [None]:
os.makedirs(os.path.join(os.getcwd(), 'results_0412', 'test'), exist_ok=True)

counter = 0
best_val_acc = 0
num_epochs = 10
accumulation_steps = 2
train_log_interval = 20
batch_size = 24
patience = 1
name = 'test'

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = m.to(device)
best_val_loss = np.inf
criterion = create_criterion('cross_entropy')
optimizer = optim.Adam(model.parameters(), lr=0.0001)
scheduler = StepLR(optimizer, 10, gamma=0.5)

train_loader = DataLoader(
    train_dataset,
    batch_size=24,
    num_workers=2,
    shuffle=True,
    pin_memory=use_cuda,
    drop_last=False,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=100,
    num_workers=2,
    shuffle=False,
    pin_memory=use_cuda,
    drop_last=False,
)

for epoch in range(num_epochs):
    # train loop
    model.train()
    loss_value = 0
    matches = 0
    for idx, train_batch in enumerate(train_loader):
        inputs, labels = train_batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        outs = model(inputs)
        preds = torch.argmax(outs, dim=-1)
        loss = criterion(outs, labels)

        loss.backward()
        
        # -- Gradient Accumulation
        if (idx+1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        loss_value += loss.item()
        matches += (preds == labels).sum().item()
        if (idx + 1) % train_log_interval == 0:
            train_loss = loss_value / train_log_interval
            train_acc = matches / batch_size / train_log_interval
            current_lr = scheduler.get_last_lr()
            print(
                f"Epoch[{epoch}/{num_epochs}]({idx + 1}/{len(train_loader)}) || "
                f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
            )

            loss_value = 0
            matches = 0

    scheduler.step()

    # val loop
    with torch.no_grad():
        print("Calculating validation results...")
        model.eval()
        val_loss_items = []
        val_acc_items = []
        for val_batch in val_loader:
            inputs, labels = val_batch
            inputs = inputs.to(device)
            labels = labels.to(device)

            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)

            loss_item = criterion(outs, labels).item()
            acc_item = (labels == preds).sum().item()
            val_loss_items.append(loss_item)
            val_acc_items.append(acc_item)

        val_loss = np.sum(val_loss_items) / len(val_loader)
        val_acc = np.sum(val_acc_items) / len(val_dataset)
        
        # Callback1: validation accuracy가 향상될수록 모델을 저장합니다.
        if val_loss < best_val_loss:
            best_val_loss = val_loss
        if val_acc > best_val_acc:
            print("New best model for val accuracy! saving the model..")
            torch.save(model.state_dict(), f"results/{name}/{epoch:03}_accuracy_{val_acc:4.2%}.ckpt")
            best_val_acc = val_acc
            counter = 0
        else:
            counter += 1
        # Callback2: patience 횟수 동안 성능 향상이 없을 경우 학습을 종료시킵니다.
        if counter > patience:
            print("Early Stopping...")
            break
        
        print(
            f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2} || "
            f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
        )

Epoch[0/10](20/630) || training loss 3.09 || training accuracy 12.29% || lr [0.0001]
Epoch[0/10](40/630) || training loss 2.823 || training accuracy 19.79% || lr [0.0001]
Epoch[0/10](60/630) || training loss 2.667 || training accuracy 19.79% || lr [0.0001]
Epoch[0/10](80/630) || training loss 2.606 || training accuracy 24.17% || lr [0.0001]
Epoch[0/10](100/630) || training loss 2.563 || training accuracy 26.25% || lr [0.0001]
Epoch[0/10](120/630) || training loss 2.565 || training accuracy 25.42% || lr [0.0001]
Epoch[0/10](140/630) || training loss 2.455 || training accuracy 27.92% || lr [0.0001]
Epoch[0/10](160/630) || training loss 2.315 || training accuracy 29.58% || lr [0.0001]
Epoch[0/10](180/630) || training loss 2.369 || training accuracy 30.00% || lr [0.0001]
Epoch[0/10](200/630) || training loss 2.318 || training accuracy 32.50% || lr [0.0001]
Epoch[0/10](220/630) || training loss 2.301 || training accuracy 31.25% || lr [0.0001]
Epoch[0/10](240/630) || training loss 2.316 || t