In [None]:
import random
import os, sys

import numpy as np
import torch
import timm
from torch.utils.data import Subset
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

from sklearn.model_selection import StratifiedKFold

sys.path.append(os.path.abspath('..'))

# BaseLine 코드로 주어진 dataset.py model.py, loss.py를 Import 합니다.
from dataset import MaskBaseDataset, BaseAugmentation
from model import *
from loss import create_criterion

sys.path.append('../')

def seed_everything(seed):
    """
    동일한 조건으로 학습을 할 때, 동일한 결과를 얻기 위해 seed를 고정시킵니다.
    
    Args:
        seed: seed 정수값
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
seed_everything(42)

In [None]:
# -- parameters
img_root = '/opt/ml/input/data/train'

batch_size = 64
num_workers = 4
num_classes = 18

num_epochs = 5  # 학습할 epoch의 수
log_interval = 80

lr = 1e-4
lr_decay_step = 10
criterion_name = 'cross_entropy' # loss의 이름
# criterion_name = 'f1' # loss의 이름

train_log_interval = 20  # logging할 iteration의 주기
name = "02_model_results"  # 결과를 저장하는 폴더의 이름

# -- settings
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(device)


In [None]:
def getDataloader(dataset, train_idx, valid_idx, batch_size, num_workers):
    # 인자로 전달받은 dataset에서 train_idx에 해당하는 Subset 추출
    train_set = torch.utils.data.Subset(dataset,
                                        indices=train_idx)
    # 인자로 전달받은 dataset에서 valid_idx에 해당하는 Subset 추출
    val_set   = torch.utils.data.Subset(dataset,
                                        indices=valid_idx)
    
    # 추출된 Train Subset으로 DataLoader 생성
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=num_workers,
        drop_last=True,
        shuffle=True
    )
    # 추출된 Valid Subset으로 DataLoader 생성
    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=batch_size,
        num_workers=num_workers,
        drop_last=True,
        shuffle=False
    )
    
    # 생성한 DataLoader 반환
    return train_loader, val_loader

In [None]:
img_root = '/opt/ml/input/data/train/images'

dataset = MaskBaseDataset(img_root)

transform = BaseAugmentation(
    resize=[128, 96],
    mean=dataset.mean,
    std=dataset.std,
)

dataset.set_transform(transform)

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']
    
# from torchvision.models import vgg19_bn

os.makedirs(os.path.join(os.getcwd(), 'results', name), exist_ok=True)

# 5-fold Stratified KFold 5개의 fold를 형성하고 5번 Cross Validation을 진행합니다.
n_splits = 5
skf = StratifiedKFold(n_splits=n_splits)

save_dir = "ck"

counter = 0
patience = 10
accumulation_steps = 2

encode_labels = [dataset.encode_multi_class(mask, gender, age) for mask, gender, age in zip(dataset.mask_labels, dataset.gender_labels, dataset.age_labels)]
# labels = [(mask, gender, age) for mask, gender, age in zip(dataset.mask_labels, dataset.gender_labels, dataset.age_labels)]

model = None


# Stratified KFold를 사용해 Train, Valid fold의 Index를 생성합니다.
# labels 변수에 담긴 클래스를 기준으로 Stratify를 진행합니다. 
for i, (train_idx, valid_idx) in enumerate(skf.split(dataset.image_paths, encode_labels)):
    
    # 생성한 Train, Valid Index를 getDataloader 함수에 전달해 train/valid DataLoader를 생성합니다.
    # 생성한 train, valid DataLoader로 이전과 같이 모델 학습을 진행합니다. 
    train_loader, val_loader = getDataloader(dataset, train_idx, valid_idx, batch_size, num_workers)
    # -- model
    if model == None:
        model = build_model(device)
    else:
        model.load_state_dict(torch.load(f"{save_dir}/best.pth"))
    # -- loss & metric
    criterion = create_criterion(criterion_name)
    train_params = [{'params': getattr(model, 'blocks').parameters(), 'lr': lr / 10, 'weight_decay':5e-4},
                    {'params': getattr(model, 'classifier').parameters(), 'lr': lr, 'weight_decay':5e-4}]
    optimizer = Adam(train_params)
    scheduler = StepLR(optimizer, lr_decay_step, gamma=0.5)

    # -- logging
    logger = SummaryWriter(log_dir=f"results/cv{i}_{name}")

    age_best_val_acc,gender_best_val_acc,mask_best_val_acc,best_val_acc=0,0,0,0
    age_best_val_loss,gender_best_val_loss,mask_best_val_loss,best_val_loss=np.inf,np.inf,np.inf,np.inf
    
    for epoch in range(num_epochs):
        # train loop
        model.train()
        loss_value = 0
        age_loss_value,gender_loss_value,mask_loss_value,loss_value=0,0,0,0
        age_matches,gender_matches,mask_matches,matches = 0,0,0,0
        
        for idx, train_batch in enumerate(train_loader):
            inputs, labels = train_batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            age_labels = labels % 3
            age_labels = age_labels.to(device)
            gender_labels = labels // 3 % 2
            gender_labels = gender_labels.to(device)
            mask_labels = labels // 6
            mask_labels = mask_labels.to(device)

            
#             outs = model(inputs)
#             preds = torch.argmax(outs, dim=-1)
#             loss = criterion(outs, labels)
            age_outs, gender_outs, mask_outs = model(inputs)
#             outs = model(inputs)

            age_preds = torch.argmax(age_outs, dim=-1)
            age_loss = criterion(age_outs, age_labels)
            
            gender_preds = torch.argmax(gender_outs, dim=-1)
            gender_loss = criterion(gender_outs, gender_labels)
            
            mask_preds = torch.argmax(mask_outs, dim=-1)
            mask_loss = criterion(mask_outs, mask_labels)
            # loss balancing (이렇게 주는게 맞나)
            loss = 0.5*age_loss+0.25*gender_loss+0.25*mask_loss


#             print(gender_preds)
#             print(gender_labels)
            
            loss.backward()
            
             # -- Gradient Accumulation
            if (idx+1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            age_loss_value+=age_loss.item()
            gender_loss_value+=gender_loss.item()
            mask_loss_value+=mask_loss.item()
            loss_value += loss.item()
            
            age_matches += (age_preds == age_labels).sum().item()
            gender_matches += (gender_preds == gender_labels).sum().item()
            mask_matches += (mask_preds == mask_labels).sum().item()
            matches += torch.logical_and(torch.logical_and(age_preds == age_labels, gender_preds == gender_labels), mask_preds == mask_labels).sum().item()

            if (idx + 1) % log_interval == 0 or idx + 1 == len(train_loader):
                age_train_loss = age_loss_value / log_interval
                age_train_acc = age_matches / batch_size / log_interval
                gender_train_loss = gender_loss_value / log_interval
                gender_train_acc = gender_matches / batch_size / log_interval
                mask_train_loss = mask_loss_value / log_interval
                mask_train_acc = mask_matches / batch_size / log_interval
                train_loss = loss_value / log_interval
                train_acc = matches / batch_size / log_interval
                
                current_lr = get_lr(optimizer)
                print(
                    f"Epoch[{epoch}/{num_epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"age training loss {age_train_loss:4.4} || training accuracy {age_train_acc:4.2%} || lr {current_lr}"
                )
                print(
                    f"Epoch[{epoch}/{num_epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"gender training loss {gender_train_loss:4.4} || training accuracy {gender_train_acc:4.2%} || lr {current_lr}"
                )
                print(
                    f"Epoch[{epoch}/{num_epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"mask training loss {mask_train_loss:4.4} || training accuracy {mask_train_acc:4.2%} || lr {current_lr}"
                )
                print(
                    f"Epoch[{epoch}/{num_epochs}]({idx + 1}/{len(train_loader)}) || "
                    f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
                )
                age_loss_value, gender_loss_value, mask_loss_value, loss_value = 0.0, 0.0, 0.0, 0.0
                age_matches, gender_matches, mask_matches, matches= 0, 0, 0, 0
                
        scheduler.step()

        # val loop
        with torch.no_grad():
            print("Calculating validation results...")
            model.eval()
            age_val_loss_items = []
            gender_val_loss_items = []
            mask_val_loss_items = []
            val_loss_items = []
            
            age_val_acc_items = []
            gender_val_acc_items = []
            mask_val_acc_items = []
            val_acc_items = []
            
            for val_batch in val_loader:
                inputs, labels = val_batch
                inputs = inputs.to(device)
                labels = labels.to(device) #0~17 shape: (batch_size, 1)
                age_labels = labels % 3
                age_labels = age_labels.to(device)
                gender_labels = labels // 3 % 2
                gender_labels = gender_labels.to(device)
                mask_labels = labels // 6
                mask_labels = mask_labels.to(device)
                
#                 inputs, (mask_label, gender_label, age_label) = val_batch
#                 inputs = inputs.to(device)
#                 mask_label = mask_label.to(device)
#                 gender_label = gender_label.to(device)
#                 age_label = age_label.to(device)

#                 outs = model(inputs)
#                 preds = torch.argmax(outs, dim=-1)

                
                age_outs, gender_outs, mask_outs = model(inputs)
                
                age_preds = torch.argmax(age_outs, dim=-1)
                age_loss = criterion(age_outs, age_labels).item()
                
                gender_preds = torch.argmax(gender_outs, dim=-1)
                gender_loss = criterion(gender_outs, gender_labels).item()
 
                mask_preds = torch.argmax(mask_outs, dim=-1)
                mask_loss = criterion(mask_outs, mask_labels).item()
                
                loss = 0.5*age_loss+0.25*gender_loss+0.25*mask_loss
                
                age_matches = (age_preds == age_labels).sum().item()
                gender_matches = (gender_preds == gender_labels).sum().item()
                mask_matches = (mask_preds == mask_labels).sum().item()
                matches = torch.logical_and(torch.logical_and(age_preds == age_labels, gender_preds == gender_labels), mask_preds == mask_labels).sum().item()
                
                age_val_loss_items.append(age_loss)
                age_val_acc_items.append(age_matches)
                gender_val_loss_items.append(gender_loss)
                gender_val_acc_items.append(gender_matches)
                mask_val_loss_items.append(mask_loss)
                mask_val_acc_items.append(mask_matches)
                val_acc_items.append(matches)
                val_loss_items.append(loss)
                
            age_val_loss = np.sum(age_val_loss_items) / len(val_loader)
            age_val_acc = np.sum(age_val_acc_items) / len(valid_idx)
            gender_val_loss = np.sum(gender_val_loss_items) / len(val_loader)
            gender_val_acc = np.sum(gender_val_acc_items) / len(valid_idx)
            mask_val_loss = np.sum(mask_val_loss_items) / len(val_loader)
            mask_val_acc = np.sum(mask_val_acc_items) / len(valid_idx)
            val_acc = np.sum(val_acc_items) / len(valid_idx)
            val_loss = np.sum(val_loss_items) / len(val_loader)
            
            age_best_val_loss = min(age_best_val_loss, age_val_loss)
            gender_best_val_loss = min(gender_best_val_loss, gender_val_loss)
            mask_best_val_loss = min(mask_best_val_loss, mask_val_loss)
            best_val_loss = min(best_val_loss, val_loss)

#             if age_val_acc > age_best_val_acc:
#                 print(f"New best model for val accuracy : {age_val_acc:4.2%}! saving the best model..")
#                 torch.save(model.state_dict(), f"{save_dir}/best.pth")
#                 age_best_val_acc = age_val_acc
            if val_acc > best_val_acc:
                print(f"New best model for val accuracy : {val_acc:4.2%}! saving the best model..")
                torch.save(model.state_dict(), f"{save_dir}/best.pth")
                age_best_val_acc = age_val_acc
                gender_best_val_acc = gender_val_acc
                mask_best_val_acc = mask_val_acc
                best_val_acc = val_acc

            torch.save(model.state_dict(), f"{save_dir}/last.pth")
            print(
                f"[Val] age acc : {age_val_acc:4.2%}, loss: {age_val_loss:4.2} || "
                f"best acc : {age_best_val_acc:4.2%}, best loss: {age_best_val_loss:4.2}"
            )
            print(
                f"[Val] gender acc : {gender_val_acc:4.2%}, loss: {gender_val_loss:4.2} || "
                f"best acc : {gender_best_val_acc:4.2%}, best loss: {gender_best_val_loss:4.2}"
            )
            print(
                f"[Val] mask acc : {mask_val_acc:4.2%}, loss: {mask_val_loss:4.2} || "
                f"best acc : {mask_best_val_acc:4.2%}, best loss: {mask_best_val_loss:4.2}"
            )
            print(
                f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2} || "
                f"best acc : {best_val_acc:4.2%}, best loss: {best_val_loss:4.2}"
            )