In [6]:
import os
import random
import copy
import cv2
import timm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statistics import mean
from PIL import Image
from glob import glob
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import models, transforms
from torchvision.transforms import Resize, CenterCrop, RandomRotation, ToTensor, Normalize, Grayscale
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import f1_score

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from efficientnet_pytorch import EfficientNet

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [None]:
SEED = 43

# Hyper parameters for dataloader
BATCH_SIZE = 64
TRAIN_RATIO = 0.8
NUM_WORKERS = 4

# hyper parameter for training
EPOCHS = {
    'mask': 10,
    'gender': 10,
    'age': 10
}

LEARNING_RATE = {
    'mask': 0.0004,
    'gender': 0.0004,
    'age': 0.0001
}

K_FOLD = 5 # 1이면 single model train

In [7]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [8]:
seed_everything(SEED)

In [9]:
# train, test 데이터셋 폴더 경로를 지정.
train_root = '/opt/ml/input/data/train'
test_root = '/opt/ml/input/data/eval'

In [10]:
class MaskDataset(Dataset):
    def __init__(self, img_paths, img_labels, transform=None, phase='train'):
        super(MaskDataset, self).__init__()
        self.img_paths = np.array(img_paths)
        self.img_labels = np.array(img_labels)
        
        self.transform = transform
        self.phase = phase
        
    def __getitem__(self, idx):
        image = cv2.cvtColor(cv2.imread(self.img_paths[idx]), cv2.COLOR_BGR2RGB)
        label = self.img_labels[idx]

        if self.transform:
            if self.phase == 'train':
                if label == 0:
                    image = self.transform['mask'](image=image)['image']
                else:
                    image = self.transform['not_mask'](image=image)['image']
            elif self.phase == 'valid':
                image = self.transform(image=image)['image']
            
        return image, label
        
    def __len__(self):
        return len(self.img_paths)
    
    
class GenderDataset(Dataset):
    def __init__(self, img_paths, img_labels, transform=None, phase='train'):
        super(GenderDataset, self).__init__()
        self.img_paths = np.array(img_paths)
        self.img_labels = np.array(img_labels)
        
        self.transform = transform
        self.phase = phase
        
    def __getitem__(self, idx):
        image = cv2.cvtColor(cv2.imread(self.img_paths[idx]), cv2.COLOR_BGR2RGB)
        label = self.img_labels[idx]
        
        if self.transform:
            if self.phase == 'train':
                if label == 0:
                    image = self.transform['male'](image=image)['image']
                elif label == 1:
                    image = self.transform['female'](image=image)['image']
            elif self.phase == 'valid':
                image = self.transform(image=image)['image']
            
        return image, label

    def __len__(self):
        return len(self.img_paths)
    
    
class AgeDataset(Dataset):
    def __init__(self, img_paths, img_labels, transform=None, phase='train'):
        super(AgeDataset, self).__init__()
        self.img_paths = np.array(img_paths)
        self.img_labels = np.array(img_labels)
        
        self.transform = transform
        self.phase = phase
        
    def __getitem__(self, idx):
        image = cv2.cvtColor(cv2.imread(self.img_paths[idx]), cv2.COLOR_BGR2RGB)
        label = self.img_labels[idx]
        
        if self.transform:
            if self.phase == 'train':
                if label == 2:
                    image = self.transform['old'](image=image)['image']
                else:
                    image = self.transform['not_old'](image=image)['image']
            elif self.phase == 'valid':
                image = self.transform(image=image)['image']
            
        return image, label
        
    def __len__(self):
        return len(self.img_paths)
    

class FaceTestDataset(Dataset):
    def __init__(self, img_paths, transform=None):
        super(FaceTestDataset, self).__init__()
        self.img_paths = img_paths
        self.transform = transform

    def __getitem__(self, idx):
        image = cv2.cvtColor(cv2.imread(self.img_paths[idx]), cv2.COLOR_BGR2RGB)

        if self.transform:
            image = self.transform(image=image)['image']

        return image
        

    def __len__(self):
        return len(self.img_paths)


In [11]:
# Augmentations
mask_train_transform = {
    'not_mask': A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0, saturation=0, hue=0, p=0.5),
    A.Rotate(limit=10, p=0.5),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()]),
    
    'mask': A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0, saturation=0, hue=0, p=0.5),
    A.Rotate(limit=10, p=0.5),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()])
}


gender_train_transform = {
    'male': A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0, saturation=0, hue=0, p=0.5),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()]),
    
    'female': A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0, saturation=0, hue=0, p=0.5),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()])   
}

age_train_transform = {
    'not_old': A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0, saturation=0, hue=0, p=0.5),
    A.Rotate(limit=10, p=0.5),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()]),
    
    'old': A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0, saturation=0, hue=0, p=0.5),
    A.Rotate(limit=10, p=0.5),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()]),
}

test_transform = A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()])

In [12]:
# age를 0, 1, 2 카테고리로 설정.
def convert_age(age):
    if age < 30:
        return 0
    elif 30 <= age < 60:
        return 1
    else: 
        return 2
    
img_info_df = pd.read_csv(os.path.join(train_root, 'train.csv'))
img_info_df['age'] = img_info_df['age'].apply(convert_age)

In [1]:
def make_mask_img_paths_labels(paths, phase='train'):
    # mask 쓴 사진 : mask1 ~ mask5, mask 잘못 쓴 사진 : incorrect_maks, mask 안 쓴 사진 : normal
    img_name2label = {
        'mask1': 0,
        'mask2': 0,
        'mask3': 0,
        'mask4': 0,
        'mask5': 0,
        'incorrect_mask': 1,
        'normal': 2,
    }
    
    if phase == 'train':
        label2multiply = {
            0: 1,
            1: 5,
            2: 5
        }
    elif phase == 'valid':
        label2multiply = {
            0: 1,
            1: 1,
            2: 1
        }
        
    img_paths = []
    img_labels = []
    # mask img의 확장자로 jpg, png, jpeg가 모두 있음.
    for p in paths:
        for img in glob(os.path.join(train_root, f'images/{p}/*.jpg')):
            img_name = img.split('/')[-1].split('.')[0]
            label = img_name2label[img_name]
            
            img_paths.extend([img] * label2multiply[label])
            img_labels.extend([label] * label2multiply[label])
        
        for img in glob(os.path.join(train_root, f'images/{p}/*.png')):
            img_name = img.split('/')[-1].split('.')[0]
            label = img_name2label[img_name]
            
            img_paths.extend([img] * label2multiply[label])
            img_labels.extend([label] * label2multiply[label])
            
        for img in glob(os.path.join(train_root, f'images/{p}/*.jpeg')):
            img_name = img.split('/')[-1].split('.')[0]
            label = img_name2label[img_name]
            
            img_paths.extend([img] * label2multiply[label])
            img_labels.extend([label] * label2multiply[label])
            
    return img_paths, img_labels


def make_img_paths_labels(paths, label_value, multiply=1):
    img_paths = [img for p in paths for img in glob(os.path.join(train_root, f'images/{p}/*.jpg'))] * multiply
    img_paths.extend([img for p in paths for img in glob(os.path.join(train_root, f'images/{p}/*.png'))] * multiply)
    img_paths.extend([img for p in paths for img in glob(os.path.join(train_root, f'images/{p}/*.jpeg'))] * multiply)
    img_labels = [label_value] * len(img_paths)

    return img_paths, img_labels


def count_dataset_label(train_labels, val_labels):
    train_label_count = defaultdict(int)
    for label in train_labels:
        train_label_count[label] += 1

    val_label_count = defaultdict(int)
    for label in val_labels:
        val_label_count[label] += 1

    train_label_count = sorted(train_label_count.items(), key=lambda x: x[0])
    val_label_count = sorted(val_label_count.items(), key=lambda x: x[0])

    for label, count in train_label_count:
        print(f'{label} : {count}')

    for label, count in val_label_count:
        print(f'{label} : {count}')

SyntaxError: invalid syntax (<ipython-input-1-f5c49c07ce3e>, line 1)

In [None]:
if K_FOLD == 1:
    # mask image path
    mask_paths = img_info_df['path'].values.tolist()

    random.Random(42).shuffle(mask_paths) # 한 번 shuffle해서 mask, gender, age 모두 이용 가능할 듯. ==> 그렇게는 못하고 gender, age는 비율 생각해서 해야됨.\

    train_size = int(TRAIN_RATIO * len(mask_paths))
    val_size = len(mask_paths) - train_size

    train_mask_paths, val_mask_paths = mask_paths[:train_size], mask_paths[train_size:]

    train_mask_img_paths, train_mask_img_labels = make_mask_img_paths_labels(train_mask_paths, phase='train')

    del train_mask_paths

    val_mask_img_paths, val_mask_img_labels = make_mask_img_paths_labels(val_mask_paths, phase='valid')

    del val_mask_paths

    # mask Dataset & DataLoader 생성
    train_mask_dataset = MaskDataset(train_mask_img_paths, train_mask_img_labels, transform=mask_train_transform, phase='train')
    val_mask_dataset = MaskDataset(val_mask_img_paths, val_mask_img_labels, transform=test_transform1, phase='valid')
    train_mask_loader = DataLoader(train_mask_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_mask_loader = DataLoader(val_mask_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    
elif K_FOLD > 1:
    # mask image path(k-fold)
    # loader를 한 번에 load하면 메모리가 부족할 수 있으므로 하나씩 yield 해줌.
    mask_paths = img_info_df['path'].values.tolist()

    random.Random(42).shuffle(mask_paths)

    mask_train_size = int(TRAIN_RATIO * len(mask_paths))
    mask_val_size = len(mask_paths) - mask_train_size

    def mask_k_fold(k_fold=5):
        for i in range(k_fold):
            val_mask_paths = mask_paths[mask_val_size*i:mask_val_size*(i+1)]
            train_mask_paths = mask_paths[:mask_val_size*i] + mask_paths[mask_val_size*(i+1):]

            train_mask_img_paths, train_mask_img_labels = make_mask_img_paths_labels(train_mask_paths, phase='train')

            val_mask_img_paths, val_mask_img_labels = make_mask_img_paths_labels(val_mask_paths, phase='valid')

            # # mask Dataset & DataLoader 생성
            train_mask_dataset = MaskDataset(train_mask_img_paths, train_mask_img_labels, transform=mask_train_transform, phase='train')
            val_mask_dataset = MaskDataset(val_mask_img_paths, val_mask_img_labels, transform=test_transform1, phase='valid')
            train_mask_loader = DataLoader(train_mask_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
            val_mask_loader = DataLoader(val_mask_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

            yield train_mask_loader, val_mask_loader

In [None]:
if K_FOLD == 1:
    # gender image path
    male_paths = img_info_df[img_info_df['gender'] == 'male']['path'].values.tolist()
    female_paths = img_info_df[img_info_df['gender'] == 'female']['path'].values.tolist()

    random.Random(42).shuffle(male_paths)
    random.Random(42).shuffle(female_paths)

    male_train_size = int(TRAIN_RATIO * len(male_paths))
    female_train_size = int(TRAIN_RATIO * len(female_paths))

    train_male_paths, val_male_paths = male_paths[:male_train_size], male_paths[male_train_size:]
    train_female_paths, val_female_paths = female_paths[:female_train_size], female_paths[female_train_size:]

    train_male_img_paths, train_male_img_labels = make_img_paths_labels(train_male_paths, 0, 3)
    train_female_img_paths, train_female_img_labels = make_img_paths_labels(train_female_paths, 1, 2)

    train_gender_img_paths = train_male_img_paths + train_female_img_paths
    train_gender_img_labels = train_male_img_labels + train_female_img_labels

    del train_male_img_paths
    del train_male_img_labels
    del train_female_img_paths
    del train_female_img_labels

    val_male_img_paths, val_male_img_labels = make_img_paths_labels(val_male_paths, 0, 1)
    val_female_img_paths, val_female_img_labels = make_img_paths_labels(val_female_paths, 1, 1)

    val_gender_img_paths = val_male_img_paths + val_female_img_paths
    val_gender_img_labels = val_male_img_labels + val_female_img_labels

    del val_male_img_paths
    del val_male_img_labels
    del val_female_img_paths
    del val_female_img_labels

    # gender Dataset & DataLoader 생성
    train_gender_dataset = GenderDataset(train_gender_img_paths, train_gender_img_labels, transform=gender_train_transform, phase='train')
    val_gender_dataset = GenderDataset(val_gender_img_paths, val_gender_img_labels, transform=test_transform1, phase='valid')
    train_gender_loader = DataLoader(train_gender_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_gender_loader = DataLoader(val_gender_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    
elif K_FOLD > 1:
    # gender image path(k-fold)
    male_paths = img_info_df[img_info_df['gender'] == 'male']['path'].values.tolist()
    female_paths = img_info_df[img_info_df['gender'] == 'female']['path'].values.tolist()

    random.Random(42).shuffle(male_paths)
    random.Random(42).shuffle(female_paths)

    male_train_size = int(TRAIN_RATIO * len(male_paths))
    male_val_size = len(male_paths) - male_train_size
    female_train_size = int(TRAIN_RATIO * len(female_paths))
    female_val_size = len(female_paths) - female_train_size

    def gender_k_fold(k_fold=5):
        for i in range(k_fold):
            val_male_paths = male_paths[male_val_size*i:male_val_size*(i+1)]
            train_male_paths = male_paths[:male_val_size*i] + male_paths[male_val_size*(i+1):]
            val_female_paths = female_paths[female_val_size*i:female_val_size*(i+1)]
            train_female_paths = female_paths[:female_val_size*i] + female_paths[female_val_size*(i+1):]

            train_male_img_paths, train_male_img_labels = make_img_paths_labels(train_male_paths, 0, 3)
            train_female_img_paths, train_female_img_labels = make_img_paths_labels(train_female_paths, 1, 2)

            train_gender_img_paths = train_male_img_paths + train_female_img_paths
            train_gender_img_labels = train_male_img_labels + train_female_img_labels

            val_male_img_paths, val_male_img_labels = make_img_paths_labels(val_male_paths, 0, 1)
            val_female_img_paths, val_female_img_labels = make_img_paths_labels(val_female_paths, 1, 1)

            val_gender_img_paths = val_male_img_paths + val_female_img_paths
            val_gender_img_labels = val_male_img_labels + val_female_img_labels

            # gender Dataset & DataLoader 생성
            train_gender_dataset = GenderDataset(train_gender_img_paths, train_gender_img_labels, transform=gender_train_transform, phase='train')
            val_gender_dataset = GenderDataset(val_gender_img_paths, val_gender_img_labels, transform=test_transform1, phase='valid')
            train_gender_loader = DataLoader(train_gender_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
            val_gender_loader = DataLoader(val_gender_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

            yield train_gender_loader, val_gender_loader

In [17]:
if K_FOLD == 1:
    # age image path
    age0_paths = img_info_df[img_info_df['age'] == 0]['path'].values.tolist()
    age1_paths = img_info_df[img_info_df['age'] == 1]['path'].values.tolist()
    age2_paths = img_info_df[img_info_df['age'] == 2]['path'].values.tolist()

    random.Random(42).shuffle(age0_paths)
    random.Random(42).shuffle(age1_paths)
    random.Random(42).shuffle(age2_paths)

    age0_train_size = int(TRAIN_RATIO * len(age0_paths))
    age0_val_size = len(age0_paths) - age0_train_size
    age1_train_size = int(TRAIN_RATIO * len(age1_paths))
    age1_val_size = len(age1_paths) - age1_train_size
    age2_train_size = int(TRAIN_RATIO * len(age2_paths))
    age2_val_size = len(age2_paths) - age2_train_size

    train_age0_paths, val_age0_paths = age0_paths[:age0_train_size], age0_paths[age0_train_size:]
    train_age1_paths, val_age1_paths = age1_paths[:age1_train_size], age1_paths[age1_train_size:]
    train_age2_paths, val_age2_paths = age2_paths[:age2_train_size], age2_paths[age2_train_size:]

    train_age0_img_paths, train_age0_img_labels = make_img_paths_labels(train_age0_paths, 0, 1)
    train_age1_img_paths, train_age1_img_labels = make_img_paths_labels(train_age1_paths, 1, 1)
    train_age2_img_paths, train_age2_img_labels = make_img_paths_labels(train_age2_paths, 2, 6)

    train_age_img_paths = train_age0_img_paths + train_age1_img_paths + train_age2_img_paths
    train_age_img_labels = train_age0_img_labels + train_age1_img_labels + train_age2_img_labels

    del train_age0_img_paths
    del train_age0_img_labels
    del train_age1_img_paths
    del train_age1_img_labels
    del train_age2_img_paths
    del train_age2_img_labels

    val_age0_img_paths, val_age0_img_labels = make_img_paths_labels(val_age0_paths, 0, 1)
    val_age1_img_paths, val_age1_img_labels = make_img_paths_labels(val_age1_paths, 1, 1)
    val_age2_img_paths, val_age2_img_labels = make_img_paths_labels(val_age2_paths, 2, 1)

    val_age_img_paths = val_age0_img_paths + val_age1_img_paths + val_age2_img_paths
    val_age_img_labels = val_age0_img_labels + val_age1_img_labels + val_age2_img_labels

    del val_age0_img_paths
    del val_age0_img_labels
    del val_age1_img_paths
    del val_age1_img_labels
    del val_age2_img_paths
    del val_age2_img_labels

    # age Dataset & DataLoader 생성
    train_age_dataset = AgeDataset(train_age_img_paths, train_age_img_labels, transform=age_train_transform, phase='train')
    val_age_dataset = AgeDataset(val_age_img_paths, val_age_img_labels, transform=test_transform1, phase='valid')
    train_age_loader = DataLoader(train_age_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_age_loader = DataLoader(val_age_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    
elif K_FOLD > 1:
    # age image path k-fold
    age0_paths = img_info_df[img_info_df['age'] == 0]['path'].values.tolist()
    age1_paths = img_info_df[img_info_df['age'] == 1]['path'].values.tolist()
    age2_paths = img_info_df[img_info_df['age'] == 2]['path'].values.tolist()

    random.Random(42).shuffle(age0_paths)
    random.Random(42).shuffle(age1_paths)
    random.Random(42).shuffle(age2_paths)

    age0_train_size = int(TRAIN_RATIO * len(age0_paths))
    age0_val_size = len(age0_paths) - age0_train_size
    age1_train_size = int(TRAIN_RATIO * len(age1_paths))
    age1_val_size = len(age1_paths) - age1_train_size
    age2_train_size = int(TRAIN_RATIO * len(age2_paths))
    age2_val_size = len(age2_paths) - age2_train_size

    # 여기에서 cross validation 나누면 될 듯.
    def age_k_fold(k_fold=5):
        for i in range(k_fold):
            val_age0_paths = age0_paths[age0_val_size*i:age0_val_size*(i+1)]
            train_age0_paths = age0_paths[:age0_val_size*i] + age0_paths[age0_val_size*(i+1):]

            val_age1_paths = age1_paths[age1_val_size*i:age1_val_size*(i+1)]
            train_age1_paths = age1_paths[:age1_val_size*i] + age1_paths[age1_val_size*(i+1):]

            val_age2_paths = age2_paths[age2_val_size*i:age2_val_size*(i+1)]
            train_age2_paths = age2_paths[:age2_val_size*i] + age2_paths[age2_val_size*(i+1):]

            train_age0_img_paths, train_age0_img_labels = make_img_paths_labels(train_age0_paths, 0, 1)
            train_age1_img_paths, train_age1_img_labels = make_img_paths_labels(train_age1_paths, 1, 1)
            train_age2_img_paths, train_age2_img_labels = make_img_paths_labels(train_age2_paths, 2, 6)

            train_age_img_paths = train_age0_img_paths + train_age1_img_paths + train_age2_img_paths
            train_age_img_labels = train_age0_img_labels + train_age1_img_labels + train_age2_img_labels

            val_age0_img_paths, val_age0_img_labels = make_img_paths_labels(val_age0_paths, 0, 1)
            val_age1_img_paths, val_age1_img_labels = make_img_paths_labels(val_age1_paths, 1, 1)
            val_age2_img_paths, val_age2_img_labels = make_img_paths_labels(val_age2_paths, 2, 1)

            val_age_img_paths = val_age0_img_paths + val_age1_img_paths + val_age2_img_paths
            val_age_img_labels = val_age0_img_labels + val_age1_img_labels + val_age2_img_labels

            # age Dataset & DataLoader 생성
            train_age_dataset = AgeDataset(train_age_img_paths, train_age_img_labels, transform=age_train_transform, phase='train')
            val_age_dataset = AgeDataset(val_age_img_paths, val_age_img_labels, transform=test_transform1, phase='valid')

            train_age_loader = DataLoader(train_age_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
            val_age_loader = DataLoader(val_age_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

            yield train_age_loader, val_age_loader

In [18]:
def plot_dataset_images(dataset):
    # row=2 ,col=3인 grid 형태로 train image 출력
    n_rows, n_cols = 2, 3
    fig, axes = plt.subplots(n_rows, n_cols, sharex=True, sharey=True, figsize=(15, 15))
    
    for i in range(n_rows*n_cols):
        rand_idx = random.randint(0,len(dataset)-1)
        img, label = dataset[rand_idx]
        
        axes[i//(n_rows+1)][i%n_cols].imshow(img.permute(1, 2, 0))
        axes[i//(n_rows+1)][i%n_cols].set_title(f'{label}')
        
    plt.tight_layout()

In [None]:
plot_dataset_images(train_mask_dataset)

In [19]:
class ResNet50(nn.Module):
    def __init__(self, num_classes=3, pretrained=True, freeze=False):
        super().__init__()
        self.resnet50 = models.resnet50(pretrained=pretrained) # (B, 512, 7, 7)
        
        if freeze:
            for param in self.resnet50.parameters():
                param.requires_grad = False
        
        self.n_features = self.resnet50.fc.in_features
        
        self.resnet50.fc = nn.Sequential(
            nn.Linear(self.n_features, 512),                   
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(512, num_classes) 
        )
        
        self._init_weight_()
        
    def _init_weight_(self):
        nn.init.xavier_uniform_(self.resnet50.fc[0].weight)
        nn.init.xavier_uniform_(self.resnet50.fc[3].weight)
        

    def forward(self, x):
        x = self.resnet50(x)
        
        return x
    
    
class EfficientNetB0(nn.Module):
    def __init__(self, num_classes=3, pretrained=True, freeze=False):
        super().__init__()
        self.efficientnet_b0 = timm.create_model('efficientnet_b0', pretrained=pretrained)
        
        if freeze:
            for param in self.efficientnet_b0.parameters():
                param.requires_grad = False
        
        
        self.n_features = self.efficientnet_b0.classifier.in_features
        
        self.efficientnet_b0.classifier = nn.Sequential(
            nn.Linear(self.n_features, 512),                   
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(512, num_classes)   
        )
        
        self._init_weight_()

    def _init_weight_(self):
        nn.init.xavier_uniform_(self.efficientnet_b0.classifier[0].weight)
        nn.init.xavier_uniform_(self.efficientnet_b0.classifier[3].weight)
    

    def forward(self, x):
        x = self.efficientnet_b0(x)
        
        return x
    

class EfficientNetB1(nn.Module):
    def __init__(self, num_classes=3, pretrained=True, freeze=False):
        super().__init__()
        self.efficientnet_b1 = timm.create_model('efficientnet_b1', pretrained=pretrained)
        
        if freeze:
            for param in self.efficientnet_b1.parameters():
                param.requires_grad = False
        
        
        self.n_features = self.efficientnet_b1.classifier.in_features
        
        self.efficientnet_b1.classifier = nn.Sequential(
            nn.Linear(self.n_features, 512),                   
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(512, num_classes)   
        )
        
        self._init_weight_()

    def _init_weight_(self):
        nn.init.xavier_uniform_(self.efficientnet_b1.classifier[0].weight)
        nn.init.xavier_uniform_(self.efficientnet_b1.classifier[3].weight)
    

    def forward(self, x):
        x = self.efficientnet_b1(x)
        
        return x

In [20]:
class AverageMeter(object):
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [22]:
class FocalLoss(nn.Module):
    def __init__(self, weight=None,
                 gamma=2., reduction='mean'):
        nn.Module.__init__(self)
        self.weight = weight
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input_tensor, target_tensor):
        log_prob = F.log_softmax(input_tensor, dim=-1)
        prob = torch.exp(log_prob)
        
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            target_tensor,
            weight=self.weight,
            reduction=self.reduction
        )
    
    
# Label Smoothing Loss
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes=3, smoothing=0.1, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
            
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

In [23]:
# Early Stopping : 기준 : validation f1 score
class EarlyStopping():
    def __init__(self, patience=0, verbose=0):
        self._step = 0
        self._val_f1 = 0.0
        self.patience  = patience
        self.verbose = verbose

    def validate(self, val_f1, epoch):
        if self._val_f1 > val_f1:
            self._step += 1
            if self._step > self.patience:
                if self.verbose:
                    print(f'Training process is stopped early at Epoch {epoch}!')
                return True
        else:
            self._step = 0
            self._val_f1 = val_f1

        return False

# mask 모델 학습

In [None]:
if K_FOLD == 1:
    # mask class 학습(No K-fold)
    
    # 모델 설정
    mask_model = ResNet50(num_classes=3, pretrained=True).to(device)
#     mask_model = EfficientNetB0(num_classes=3, pretrained=True).to(device)
#     mask_model = EfficientNetB1(num_classes=3, pretrained=True).to(device)

    # optimizer 설정
    mask_optimizer = optim.Adam(filter(lambda p: p.requires_grad, mask_model.parameters()), lr=LEARNING_RATE['mask'])

    # scheduler 설정
    mask_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(mask_optimizer, T_0=10, eta_min=0.05, last_epoch=-1)
    # mask_scheduler = optim.lr_scheduler.StepLR(mask_optimizer, step_size=2, gamma=0.2)

    # loss function 설정
    mask_criterion = nn.CrossEntropyLoss()
    # mask_criterion = LabelSmoothingLoss()

    best_val_mask_f1 = 0.0
    best_mask_model = None

    mask_early_stopping = EarlyStopping(patience=4, verbose=1)
    scaler = GradScaler()

    for epoch in range(EPOCHS['mask']):
        train_mask_loss = AverageMeter() # mask loss
        train_mask_acc = AverageMeter() # mask acc
        train_mask_f1 = AverageMeter()

        loop = tqdm(enumerate(train_mask_loader), total=len(train_mask_loader))
        for i, (img, label) in loop:
            loop.set_description(f"Epoch {epoch+1}/{EPOCHS['mask']}")

            img = img.to(device)
            mask = label.to(device)
            
            with autocast():
                mask_outputs = mask_model(img)

                mask_loss = mask_criterion(mask_outputs, mask)

            scaler.scale(mask_loss).backward()
            scaler.step(mask_optimizer)
            scaler.update()

            # Accuracy 계산
            pred_mask = torch.argmax(mask_outputs, 1)

            mask_acc = ((pred_mask == mask).sum().item() / len(img))

            train_mask_loss.update(mask_loss.item(), len(img)) # mask loss update 
            train_mask_acc.update(mask_acc, len(img)) # mask accuracy update 
            train_mask_f1.update(f1_score(mask.cpu(), pred_mask.cpu(), average='macro'), len(img))

        # validation
        with torch.no_grad():
            val_mask_loss = AverageMeter()
            val_mask_acc = AverageMeter()
            val_mask_f1 = AverageMeter()

            for img, label in val_mask_loader:
                img = img.to(device)
                mask = label.to(device)

                mask_outputs = mask_model(img)

                mask_loss = mask_criterion(mask_outputs, mask)

                # Accuracy 계산
                pred_mask = torch.argmax(mask_outputs, 1)

                mask_acc = ((pred_mask == mask).sum().item() / len(img))

                val_mask_loss.update(mask_loss.item(), len(img)) # mask loss update 
                val_mask_acc.update(mask_acc, len(img)) # mask accuracy update 
                val_mask_f1.update(f1_score(mask.cpu(), pred_mask.cpu(), average='macro'), len(img))

        print(f'train_mask_loss: {train_mask_loss.avg:.4f}, train_mask_acc: {train_mask_acc.avg:.4f}, train_mask_f1: {train_mask_f1.avg:.4f}')
        print(f'val_mask_loss: {val_mask_loss.avg:.4f}, val_mask_acc: {val_mask_acc.avg:.4f}, val_mask_f1 : {val_mask_f1.avg:.4f}')

        cur_val_mask_f1 = val_mask_f1.avg
        if cur_val_mask_f1 > best_val_mask_f1:
            best_mask_model = copy.deepcopy(mask_model)

            best_val_mask_f1 = cur_val_mask_f1
            print(f'mask model saved! at Epoch {epoch+1}')

        if mask_early_stopping.validate(val_mask_f1.avg, epoch):
            break
        
        print('=' * 100)
        
elif K_FOLD > 1:
    # mask class 학습(K-fold)
    best_mask_model = [None] * K_FOLD

    for i, (train_mask_loader, val_mask_loader) in enumerate(mask_k_fold(k_fold=K_FOLD)):
        best_val_mask_f1 = 0.0    
        print(f'< {i+1}번째 Fold 진행 중... >')
        print('=' * 100)
        
        # 모델 설정
        mask_model = EfficientNetB0(num_classes=3, pretrained=True).to(device)
#         mask_model = EfficientNetB1(num_classes=3, pretrained=True).to(device)

        # optimizer 설정
        mask_optimizer = optim.Adam(filter(lambda p: p.requires_grad, mask_model.parameters()), lr=LEARNING_RATE['mask'])

        # scheduler 설정
        mask_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(mask_optimizer, T_0=10, eta_min=0.05, last_epoch=-1)

        # loss function 설정
        mask_criterion = nn.CrossEntropyLoss()

        mask_early_stopping = EarlyStopping(patience=4, verbose=1)
        scaler = GradScaler() 

        for epoch in range(EPOCHS['mask']):
            train_mask_loss = AverageMeter() # mask loss
            train_mask_acc = AverageMeter() # mask acc
            train_mask_f1 = AverageMeter()

            loop = tqdm(train_mask_loader, total=len(train_mask_loader))
            for img, label in loop:
                loop.set_description(f"Epoch {epoch+1}/{EPOCHS['mask']}")

                img = img.to(device)
                mask = label.to(device)

                with autocast():
                    mask_outputs = mask_model(img)

                    mask_loss = mask_criterion(mask_outputs, mask)

                scaler.scale(mask_loss).backward()
                scaler.step(mask_optimizer)
                scaler.update()

                # Accuracy 계산
                pred_mask = torch.argmax(mask_outputs, 1)

                mask_acc = ((pred_mask == mask).sum().item() / len(img))

                train_mask_loss.update(mask_loss.item(), len(img)) # mask loss update 
                train_mask_acc.update(mask_acc, len(img)) # mask accuracy update 
                train_mask_f1.update(f1_score(mask.cpu(), pred_mask.cpu(), average='macro'), len(img))

            # validation
            with torch.no_grad():
                val_mask_loss = AverageMeter()
                val_mask_acc = AverageMeter()
                val_mask_f1 = AverageMeter()

                for img, label in val_mask_loader:
                    img = img.to(device)
                    mask = label.to(device)

                    mask_outputs = mask_model(img)

                    mask_loss = mask_criterion(mask_outputs, mask)

                    # Accuracy 계산
                    pred_mask = torch.argmax(mask_outputs, 1)

                    mask_acc = ((pred_mask == mask).sum().item() / len(img))

                    val_mask_loss.update(mask_loss.item(), len(img)) # mask loss update 
                    val_mask_acc.update(mask_acc, len(img)) # mask accuracy update 
                    val_mask_f1.update(f1_score(mask.cpu(), pred_mask.cpu(), average='macro'), len(img))


            print(f'train_mask_loss: {train_mask_loss.avg:.4f}, train_mask_acc: {train_mask_acc.avg:.4f}, train_mask_f1: {train_mask_f1.avg:.4f}')
            print(f'val_mask_loss: {val_mask_loss.avg:.4f}, val_mask_acc: {val_mask_acc.avg:.4f}, val_mask_f1 : {val_mask_f1.avg:.4f}')

            cur_val_mask_f1 = val_mask_f1.avg
            if cur_val_mask_f1 > best_val_mask_f1:
                best_mask_model[i] = copy.deepcopy(mask_model)

                best_val_mask_f1 = cur_val_mask_f1
                print(f'mask model saved! at Epoch {epoch+1}')

            if mask_early_stopping.validate(val_mask_f1.avg, epoch):
                break
            
            print('=' * 100)

# Gender 모델 학습

In [None]:
if K_FOLD == 1:
    # gender class 학습(No K-fold)
    # 모델 설정
    gender_model = ResNet50(num_classes=1, pretrained=True).to(device)
#     gender_model = EfficientNetB0(num_classes=1, pretrained=True).to(device)
#     gender_model = EfficientNetB1(num_classes=1, pretrained=True).to(device)

    # optimizer 설정
    gender_optimizer = optim.Adam(filter(lambda p: p.requires_grad, gender_model.parameters()), lr=LEARNING_RATE['gender'])

    # scheduler 설정
    gender_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(gender_optimizer, T_0=10, eta_min=0.05, last_epoch=-1)
    # gender_scheduler = optim.lr_scheduler.StepLR(gender_optimizer, step_size=2, gamma=0.2)

    # loss function 설정
    gender_criterion = nn.BCEWithLogitsLoss() # sigmoid + BCELoss

    best_val_gender_f1 = 0.0
    best_gender_model = None

    gender_early_stopping = EarlyStopping(patience=4, verbose=1)
    scaler = GradScaler()
    for epoch in range(EPOCHS['gender']):
        train_gender_loss = AverageMeter() # mask loss
        train_gender_acc = AverageMeter() # mask acc
        train_gender_f1 = AverageMeter()

        loop = tqdm(enumerate(train_gender_loader), total=len(train_gender_loader))
        for i, (img, label) in loop:
            loop.set_description(f"Epoch {epoch+1}/{EPOCHS['gender']}")

            img = img.to(device)
            gender = label.float().to(device)

            with autocast():
                gender_outputs = gender_model(img)

                gender_loss = gender_criterion(gender_outputs.squeeze(), gender)

            scaler.scale(gender_loss).backward()
            scaler.step(gender_optimizer)
            scaler.update()

            # Accuracy 계산
            pred_gender = torch.round(torch.sigmoid(gender_outputs)).squeeze()

            gender_acc = ((pred_gender == gender).sum().item() / len(img))

            train_gender_loss.update(gender_loss.item(), len(img)) # gender loss update 
            train_gender_acc.update(gender_acc, len(img)) # gender accuracy update 
            train_gender_f1.update(f1_score(gender.detach().cpu().numpy(), pred_gender.detach().cpu().numpy(), average='macro'), len(img))

        # validation
        with torch.no_grad():
            val_gender_loss = AverageMeter()
            val_gender_acc = AverageMeter()
            val_gender_f1 = AverageMeter()

            for img, label in val_gender_loader:
                img = img.to(device)
                gender = label.float().to(device)

                gender_outputs = gender_model(img)

                gender_loss = gender_criterion(gender_outputs.squeeze(), gender)

                # Accuracy 계산
                pred_gender = torch.round(torch.sigmoid(gender_outputs)).squeeze()

                gender_acc = ((pred_gender == gender).sum().item() / len(img))

                val_gender_loss.update(gender_loss.item(), len(img)) # gender loss update 
                val_gender_acc.update(gender_acc, len(img)) # gender accuracy update     
                val_gender_f1.update(f1_score(gender.cpu(), pred_gender.cpu(), average='macro'), len(img))


        print(f'train_gender_loss: {train_gender_loss.avg:.4f}, train_gender_acc: {train_gender_acc.avg:.4f}, train_gender_f1: {train_gender_f1.avg:.4f}')
        print(f'val_gender_loss: {val_gender_loss.avg:.4f}, val_gender_acc: {val_gender_acc.avg:.4f}, val_gender_f1 : {val_gender_f1.avg:.4f}')

        cur_val_gender_f1 = val_gender_f1.avg
        if cur_val_gender_f1 > best_val_gender_f1:
            best_gender_model = copy.deepcopy(gender_model)

            best_val_gender_f1 = cur_val_gender_f1
            print(f'gender model saved! at Epoch {epoch+1}')

        if gender_early_stopping.validate(val_gender_f1.avg, epoch):
            break
        
        print('=' * 100)

elif K_FOLD > 1:
    # gender class 학습(K-fold)
    best_gender_model = [None] * K_FOLD

    for i, (train_gender_loader, val_gender_loader) in enumerate(gender_k_fold(k_fold=K_FOLD)):
        best_val_gender_f1 = 0.0    
        print(f'< {i+1}번째 Fold 진행 중... >')
        print('=' * 100)

        # 모델 설정
        gender_model = EfficientNetB0(num_classes=1, pretrained=True).to(device)
#         gender_model = EfficientNetB1(num_classes=1, pretrained=True).to(device)
        
        # optimizer 설정
        gender_optimizer = optim.Adam(filter(lambda p: p.requires_grad, gender_model.parameters()), lr=LEARNING_RATE['gender'])

        # scheduler 설정
        gender_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(gender_optimizer, T_0=10, eta_min=0.05, last_epoch=-1)

        # loss function 설정
        gender_criterion = nn.BCEWithLogitsLoss() # sigmoid + BCELoss


        gender_early_stopping = EarlyStopping(patience=4, verbose=1)
        scaler = GradScaler() 

        for epoch in range(EPOCHS['gender']):
            train_gender_loss = AverageMeter() # mask loss
            train_gender_acc = AverageMeter() # mask acc
            train_gender_f1 = AverageMeter()

            loop = tqdm(train_gender_loader, total=len(train_gender_loader))
            for img, label in loop:
                loop.set_description(f"Epoch {epoch+1}/{EPOCHS['gender']}")

                img = img.to(device)
                gender = label.float().to(device)

                with autocast():
                    gender_outputs = gender_model(img)

                    gender_loss = gender_criterion(gender_outputs.squeeze(), gender)

                scaler.scale(gender_loss).backward()
                scaler.step(gender_optimizer)
                scaler.update()

                # Accuracy 계산
                pred_gender = torch.round(torch.sigmoid(gender_outputs)).squeeze()

                gender_acc = ((pred_gender == gender).sum().item() / len(img))

                train_gender_loss.update(gender_loss.item(), len(img)) # gender loss update 
                train_gender_acc.update(gender_acc, len(img)) # gender accuracy update 
                train_gender_f1.update(f1_score(gender.detach().cpu().numpy(), pred_gender.detach().cpu().numpy(), average='macro'), len(img))

            # validation
            with torch.no_grad():
                val_gender_loss = AverageMeter()
                val_gender_acc = AverageMeter()
                val_gender_f1 = AverageMeter()

                for img, label in val_gender_loader:
                    img = img.to(device)
                    gender = label.float().to(device)

                    gender_outputs = gender_model(img)

                    gender_loss = gender_criterion(gender_outputs.squeeze(), gender)

                    # Accuracy 계산
                    pred_gender = torch.round(torch.sigmoid(gender_outputs)).squeeze()

                    gender_acc = ((pred_gender == gender).sum().item() / len(img))

                    val_gender_loss.update(gender_loss.item(), len(img)) # gender loss update 
                    val_gender_acc.update(gender_acc, len(img)) # gender accuracy update     
                    val_gender_f1.update(f1_score(gender.cpu(), pred_gender.cpu(), average='macro'), len(img))

            print(f'train_gender_loss: {train_gender_loss.avg:.4f}, train_gender_acc: {train_gender_acc.avg:.4f}, train_gender_f1: {train_gender_f1.avg:.4f}')
            print(f'val_gender_loss: {val_gender_loss.avg:.4f}, val_gender_acc: {val_gender_acc.avg:.4f}, val_gender_f1 : {val_gender_f1.avg:.4f}')

            cur_val_gender_f1 = val_gender_f1.avg
            if cur_val_gender_f1 > best_val_gender_f1:
                best_gender_model[i] = copy.deepcopy(gender_model)

                best_val_gender_f1 = cur_val_gender_f1
                print(f'gender model saved! at Epoch {epoch+1}')

            if gender_early_stopping.validate(val_gender_f1.avg, epoch):
                break
            
            print('=' * 100)

# age 모델 학습

In [None]:
if K_FOLD == 1:
    # age class 학습(No K-Fold)
    best_val_age_f1 = 0.0
    best_age_model = None
    
    # model 설정
    age_model = ResNet50(num_classes=3, pretrained=True).to(device)
#     age_model = EfficientNetB0(num_classes=3, pretrained=True).to(device)
#     age_model = EfficientNetB1(num_classes=3, pretrained=True).to(device)

    # optimizer 설정
    age_optimizer = optim.Adam(filter(lambda p: p.requires_grad, age_model.parameters()), lr=LEARNING_RATE['age'])

    # scheduler 설정
    age_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(age_optimizer, T_0=10, eta_min=0.05, last_epoch=-1)

    # loss function 설정
    age_criterion = nn.CrossEntropyLoss()

    age_early_stopping = EarlyStopping(patience=4, verbose=1)
        
    scaler = GradScaler() 
    for epoch in range(EPOCHS['age']):
        train_age_loss = AverageMeter()
        train_age_acc = AverageMeter()
        train_age_f1 = AverageMeter()

        loop = tqdm(enumerate(train_age_loader), total=len(train_age_loader))
        for i, (img, label) in loop:
            loop.set_description(f"Epoch {epoch+1}/{EPOCHS['age']}")

            img = img.to(device)
            age = label.to(device)

            with autocast():
                age_outputs = age_model(img)

                age_loss = age_criterion(age_outputs, age)

            scaler.scale(age_loss).backward()
            scaler.step(age_optimizer)
            scaler.update()

            # Accuracy 계산
            pred_age = torch.argmax(age_outputs, 1)

            age_acc = ((pred_age == age).sum().item() / len(img))

            train_age_loss.update(age_loss.item(), len(img)) # gender loss update 
            train_age_acc.update(age_acc, len(img)) # gender accuracy update 
            train_age_f1.update(f1_score(age.cpu(), pred_age.cpu(), average='macro'), len(img))

        # validation
        with torch.no_grad():
            val_age_loss = AverageMeter()
            val_age_acc = AverageMeter()
            val_age_f1 = AverageMeter()

            for img, label in val_age_loader:
                img = img.to(device)
                age = label.to(device)

                age_outputs = age_model(img)

                age_loss = age_criterion(age_outputs, age)

                # Accuracy 계산
                pred_age = torch.argmax(age_outputs, 1)

                age_acc = ((pred_age == age).sum().item() / len(img))

                val_age_loss.update(age_loss.item(), len(img)) # gender loss update 
                val_age_acc.update(age_acc, len(img)) # gender accuracy update 
                val_age_f1.update(f1_score(age.cpu(), pred_age.cpu(), average='macro'), len(img))

        print(f'train_age_loss: {train_age_loss.avg:.4f}, train_age_acc: {train_age_acc.avg:.4f}, train_age_f1: {train_age_f1.avg:.4f}')
        print(f'val_age_loss: {val_age_loss.avg:.4f}, val_age_acc: {val_age_acc.avg:.4f}, val_age_f1 : {val_age_f1.avg:.4f}')

        cur_val_age_f1 = val_age_f1.avg
        if cur_val_age_f1 > best_val_age_f1:
            best_age_model = copy.deepcopy(age_model)

            best_val_age_f1 = cur_val_age_f1
            print(f'age model saved! at Epoch {epoch+1}')

        if age_early_stopping.validate(val_age_f1.avg, epoch):
            break
        
        print('=' * 100)
        
elif K_FOLD > 1:
    # age class 학습(K-Fold)
    best_age_model = [None] * K_FOLD

    for i, (train_age_loader, val_age_loader) in enumerate(age_k_fold(k_fold=K_FOLD)):
        best_val_age_f1 = 0.0    
        print(f'< {i+1}번째 Fold 진행 중... >')
        print('=' * 100)

        # model 설정
        age_model = ResNet50(num_classes=3, pretrained=True).to(device)
#         age_model = EfficientNetB0(num_classes=3, pretrained=True).to(device)
#         age_model = EfficientNetB1(num_classes=3, pretrained=True).to(device)

        # optimizer 설정
        age_optimizer = optim.Adam(filter(lambda p: p.requires_grad, age_model.parameters()), lr=LEARNING_RATE['age'])

        # scheduler 설정
        age_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(age_optimizer, T_0=10, eta_min=0.05, last_epoch=-1)

        # loss function 설정
        age_criterion = nn.CrossEntropyLoss()

        age_early_stopping = EarlyStopping(patience=4, verbose=1)
        scaler = GradScaler()

        for epoch in range(EPOCHS['age']):
            train_age_loss = AverageMeter()
            train_age_acc = AverageMeter()
            train_age_f1 = AverageMeter()

            loop = tqdm(train_age_loader, total=len(train_age_loader))
            for img, label in loop:
                loop.set_description(f"Epoch {epoch+1}/{EPOCHS['age']}")

                img = img.to(device)
                age = label.to(device)

                with autocast():
                    age_outputs = age_model(img)

                    age_loss = age_criterion(age_outputs, age)

                scaler.scale(age_loss).backward()
                scaler.step(age_optimizer)
                scaler.update()

                # Accuracy 계산
                pred_age = torch.argmax(age_outputs, 1)

                age_acc = ((pred_age == age).sum().item() / len(img))

                train_age_loss.update(age_loss.item(), len(img)) # gender loss update 
                train_age_acc.update(age_acc, len(img)) # gender accuracy update 
                train_age_f1.update(f1_score(age.cpu(), pred_age.cpu(), average='macro'), len(img))

            # validation
            with torch.no_grad():
                val_age_loss = AverageMeter()
                val_age_acc = AverageMeter()
                val_age_f1 = AverageMeter()

                for img, label in val_age_loader:
                    img = img.to(device)
                    age = label.to(device)

                    age_outputs = age_model(img)

                    age_loss = age_criterion(age_outputs, age)

                    # Accuracy 계산
                    pred_age = torch.argmax(age_outputs, 1)

                    age_acc = ((pred_age == age).sum().item() / len(img))

                    val_age_loss.update(age_loss.item(), len(img)) # gender loss update 
                    val_age_acc.update(age_acc, len(img)) # gender accuracy update 
                    val_age_f1.update(f1_score(age.cpu(), pred_age.cpu(), average='macro'), len(img))

            print(f'train_age_loss: {train_age_loss.avg:.4f}, train_age_acc: {train_age_acc.avg:.4f}, train_age_f1: {train_age_f1.avg:.4f}')
            print(f'val_age_loss: {val_age_loss.avg:.4f}, val_age_acc: {val_age_acc.avg:.4f}, val_age_f1 : {val_age_f1.avg:.4f}')

            cur_val_age_f1 = val_age_f1.avg
            if cur_val_age_f1 > best_val_age_f1:
                best_age_model[i] = copy.deepcopy(age_model)

                best_val_age_f1 = cur_val_age_f1
                print(f'age model saved! at Epoch {epoch+1}')

            if age_early_stopping.validate(val_age_f1.avg, epoch):
                break
            
            print('=' * 100)

# 각 best model 저장

In [29]:
n_submissions = len(os.listdir('./submissions'))

# submission 파일 이름 : submission_{달}_{일}_{submission 순서}
SUBMISSION_PATH = f'submission_4_8_{n_submissions}'

In [31]:
# 각 모델별 저장 경로
MASK_SAVE_PATH = f'./model_save/mask/mask_{SUBMISSION_PATH}.pth'
GENDER_SAVE_PATH = f'./model_save/gender/gender_{SUBMISSION_PATH}.pth'
AGE_SAVE_PATH = f'./model_save/age/age_{SUBMISSION_PATH}.pth'

if K_FOLD == 1:
    torch.save(best_mask_model.state_dict(), MASK_SAVE_PATH)
    torch.save(best_gender_model.state_dict(), GENDER_SAVE_PATH)
    torch.save(best_age_model.state_dict(), AGE_SAVE_PATH)
elif K_FOLD > 1:
    torch.save({'best_mask_model_1_state_dict': best_mask_model[0].state_dict(),
                'best_mask_model_2_state_dict': best_mask_model[1].state_dict(),
                'best_mask_model_3_state_dict': best_mask_model[2].state_dict(),
                'best_mask_model_4_state_dict': best_mask_model[3].state_dict(),
                'best_mask_model_5_state_dict': best_mask_model[4].state_dict()}, MASK_SAVE_PATH)
    
    torch.save({'best_gender_model_1_state_dict': best_gender_model[0].state_dict(),
                'best_gender_model_2_state_dict': best_gender_model[1].state_dict(),
                'best_gender_model_3_state_dict': best_gender_model[2].state_dict(),
                'best_gender_model_4_state_dict': best_gender_model[3].state_dict(),
                'best_gender_model_5_state_dict': best_gender_model[4].state_dict()}, GENDER_SAVE_PATH)
    
    torch.save({'best_age_model_1_state_dict': best_age_model[0].state_dict(),
                'best_age_model_2_state_dict': best_age_model[1].state_dict(),
                'best_age_model_3_state_dict': best_age_model[2].state_dict(),
                'best_age_model_4_state_dict': best_age_model[3].state_dict(),
                'best_age_model_5_state_dict': best_age_model[4].state_dict()}, AGE_SAVE_PATH)

# 저장한 모델 load

In [None]:
# mask 모델 load
MASK_LOAD_PATH = './model_save/mask/mask_submission_4_6_7.pth'
best_mask_model = EfficientNetB0(num_classes=3, pretrained=True).to(device)
best_mask_model.load_state_dict(torch.load(MASK_LOAD_PATH))

# gender 모델 load
GENDER_LOAD_PATH = './model_save/gender/gender_submission_4_6_7.pth'
best_gender_model = EfficientNetB0(num_classes=1, pretrained=True).to(device)
best_gender_model.load_state_dict(torch.load(GENDER_LOAD_PATH))

# age 모델 load
AGE_LOAD_PATH = './model_save/age/age_submission_4_6_4.pth'
best_gender_model = EfficientNetB0(num_classes=3, pretrained=True).to(device)
best_gender_model.load_state_dict(torch.load(AGE_LOAD_PATH))

# Inference

In [None]:
# age에 대해서만 TTA & age에 대해서 K-fold validation
def tta_inference(transforms):
    best_mask_model.eval()
    best_gender_model.eval()
    for i in range(len(best_age_model)):
        best_age_model[i].eval()
    
    submission = pd.read_csv(os.path.join(test_root, 'info.csv'))

    img_dir = os.path.join(test_root, 'images')
    img_paths = [os.path.join(img_dir, img_id) for img_id in submission['ImageID']]

    mask_outputs = None
    gender_outputs = None
    age_outputs = None
    
    loop = tqdm(enumerate(transforms), total=len(transforms))
    for i, transform in loop:
        loop.set_description(f"Transform {i+1}/{len(transforms)}")
        
        test_dataset = FaceTestDataset(img_paths, transform=transform)
        test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        
        if i == 0:
            mask_outputs = torch.zeros(len(test_dataset), 3)
            gender_outputs = torch.zeros(len(test_dataset), 1)
            age_outputs = torch.zeros(len(test_dataset), 3)
            
        tmp_age_outputs = torch.zeros(len(test_dataset), 3) # age에 대해서만 tta를 진행하므로 temp 변수 선언.
        
        # 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
        with torch.no_grad():
            for j, img in enumerate(test_loader):
                img = img.to(device)
                
                if i == 0:
                    # 첫 번째 transform에 대해서만 3개 label 모두 예측 진행.
                    mask_outputs[j * BATCH_SIZE: (j+1) * BATCH_SIZE] = best_mask_model(img)
                    gender_outputs[j * BATCH_SIZE: (j+1) * BATCH_SIZE] = best_gender_model(img)
                    
                tmp_age_output = None
                for k in range(len(best_age_model)):
                    if k == 0:
                        tmp_age_output = best_age_model[k](img)
                    else:
                        tmp_age_output += best_age_model[k](img)
                        
                tmp_age_output /= len(best_age_model)
                tmp_age_outputs[j * BATCH_SIZE: (j+1) * BATCH_SIZE] = tmp_age_output
                    
            age_outputs = age_outputs + tmp_age_outputs
            
    age_outputs /= len(transforms)

    all_predictions = []
    for mask_output, gender_output, age_output in zip(mask_outputs, gender_outputs, age_outputs):
        pred_mask = torch.argmax(mask_output)
        pred_gender = torch.round(torch.sigmoid(gender_output)).squeeze()
        pred_age = torch.argmax(age_output)

        pred_overall_class = 6 * pred_mask + 3 * pred_gender + pred_age

        all_predictions.append(pred_overall_class.cpu().numpy())
    
    
    submission['ans'] = all_predictions

    submission.to_csv(f'./{SUBMISSION_PATH}.csv', index=False)
    print('test inference is done!')

In [94]:
# resolution 늘림
test_transform2 = A.Compose([
    A.Resize(500, 500, p=1),
    A.CenterCrop(300, 300, p=1),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()]) 

# rotate
test_transform3 = A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.Rotate(limit=10, p=1),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()]) 

# horizontal flip
test_transform4 = A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.HorizontalFlip(p=1),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()]) 

# color jitter flip
test_transform5 = A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.ColorJitter(brightness=0.1, contrast=0, saturation=0, hue=0, p=1),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()]) 

# 위의 augmentation들을 종합
test_transform6 = A.Compose([
    A.Resize(400, 400, p=1),
    A.CenterCrop(224, 224, p=1),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.1, contrast=0, saturation=0, hue=0, p=0.5),
    A.Rotate(limit=10, p=0.5),
    A.Normalize(mean=(0.561, 0.524, 0.501), std=(0.223, 0.243, 0.246)),
    ToTensorV2()]) 

test_transforms = [test_transform, test_transform2, test_transform3, test_transform4, test_transform5, test_transform6]

tta_inference(test_transforms)

Transform 1/4:   0%|          | 0/4 [00:00<?, ?it/s]

6 6 10


Transform 4/4: 100%|██████████| 4/4 [21:03<00:00, 315.78s/it]


test inference is done!


##### 