In [1]:
import os
import pandas as pd
from pandas_streaming.df import train_test_apart_stratify
import numpy as np
from PIL import Image
import easydict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torch.utils.data as data
from torch.utils.tensorboard import SummaryWriter

from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop, RandomHorizontalFlip

import sklearn
from sklearn.model_selection import StratifiedKFold

from tqdm import notebook
import gc
import random
import copy

In [3]:
# fix seed
def seed_everything(seed):
    """
    동일한 조건으로 학습을 할 때, 동일한 결과를 얻기 위해 seed를 고정시킵니다.
    
    Args:
        seed: seed 정수값
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
seed_everything(42)

# Model

In [4]:
# Base Model
from efficientnet_pytorch import EfficientNet
import timm
class EfficientNet5(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = EfficientNet.from_pretrained('efficientnet-b5',class_n)
    def forward(self,x):
        x = self.model(x)
        return x
    
class swinLargeModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('swin_large_patch4_window7_224', pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x
    
class caitBaseModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('cait_s24_224',pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x

# Data Load

In [5]:
class MaskLabels:
    mask = 0
    incorrect = 1
    normal = 2

class GenderLabels:
    male = 0
    female = 1

class AgeGroup:
    map_label = lambda x: 0 if int(x) < 30 else 1 if int(x) < 60 else 2

class MaskBaseDataset(Dataset):
    num_classes = 3 * 2 * 3

    _file_names = {
        "mask1.jpg": MaskLabels.mask,
        "mask2.jpg": MaskLabels.mask,
        "mask3.jpg": MaskLabels.mask,
        "mask4.jpg": MaskLabels.mask,
        "mask5.jpg": MaskLabels.mask,
        "incorrect_mask.jpg": MaskLabels.incorrect,
        "normal.jpg": MaskLabels.normal
    }

    image_paths = []
    mask_labels = []
    gender_labels = []
    age_labels = []

    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform

        self.setup()

    def set_transform(self, transform):
        self.transform = transform
        
    def setup(self):
        profiles = os.listdir(self.img_dir)
        for profile in profiles:
            for file_name, label in self._file_names.items():
                img_path = os.path.join(self.img_dir, profile, file_name)  
                if os.path.exists(img_path):
                    self.image_paths.append(img_path)
                    self.mask_labels.append(label)

                    id, gender, race, age = profile.split("_")
                    gender_label = getattr(GenderLabels, gender)
                    age_label = AgeGroup.map_label(age)

                    self.gender_labels.append(gender_label)
                    self.age_labels.append(age_label)

    def __getitem__(self, index):
        # 이미지를 불러옵니다.
        image_path = self.image_paths[index]
        image = Image.open(image_path)
        
        # 레이블을 불러옵니다.
        mask_label = self.mask_labels[index]
        gender_label = self.gender_labels[index]
        age_label = self.age_labels[index]
        multi_class_label = mask_label * 6 + gender_label * 3 + age_label
        
        # 이미지를 Augmentation 시킵니다.
        image_transform = self.transform(image)
            
        return image_transform, multi_class_label
    def encode_multi_class(self, mask, gender, age):
        return mask * 6 + gender * 3 + age

    def __len__(self):
        return len(self.image_paths)

In [6]:
def getDataloader(dataset, train_idx, valid_idx, batch_size, num_workers):
    # 인자로 전달받은 dataset에서 train_idx에 해당하는 Subset 추출
    train_set = torch.utils.data.Subset(dataset,
                                        indices=train_idx)
    # 인자로 전달받은 dataset에서 valid_idx에 해당하는 Subset 추출
    val_set   = torch.utils.data.Subset(dataset,
                                        indices=valid_idx)
    
    # 추출된 Train Subset으로 DataLoader 생성
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=batch_size,
        num_workers=num_workers,
        drop_last=True,
        shuffle=True
    )
    # 추출된 Valid Subset으로 DataLoader 생성
    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=batch_size,
        num_workers=num_workers,
        drop_last=True,
        shuffle=False
    )
    
    # 생성한 DataLoader 반환
    return train_loader, val_loader

In [7]:
data_dir = '../input/data/train'
img_dir = f'{data_dir}/images'
df_path = f'{data_dir}/train.csv'

dataset = MaskBaseDataset(img_dir)
dataset_labels = [dataset.encode_multi_class(mask, gender, age) for mask, gender, age in zip(dataset.mask_labels, dataset.gender_labels, dataset.age_labels)]

# train : test = 8 : 2
n_test = int(len(dataset) * 0.2)
n_train = len(dataset) - n_test
train_dataset, test_dataset = data.random_split(dataset, [n_train, n_test])
train_labels, test_labels = data.random_split(dataset_labels, [n_train, n_test])

train_transform = transforms.Compose([
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    transforms.RandomHorizontalFlip(p=1.0)
])

transform = transforms.Compose([
    CenterCrop(224),
    ToTensor(),
    Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

# 각 dataset에 augmentation 함수를 설정합니다.
train_dataset.dataset.set_transform(transform)
test_dataset.dataset.set_transform(transform)

# test data loader
test_loader = data.DataLoader(
    test_dataset,
    batch_size=32,
    num_workers=4,
    shuffle=False
)

# Train

In [8]:
def train(args, model, train_loader, optimizer, scheduler, criterion, writer):
    model.train()
    
    corrects, scores, running_loss  = 0., 0., 0.
    n_iter = 0
    for step, (images,labels) in enumerate(train_loader):
        images = images.to(args.device)
        labels = labels.to(args.device)
        outputs = model(images)
        loss = criterion(outputs,labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        if step % args.log_steps == 0:
            print(f"Training Steps: {step} Loss: {str(loss.item())}")
        
        _,preds = torch.max(outputs, 1)
        corrects += torch.sum(preds == labels.data)
        running_loss += loss.item() * images.size(0) 
        scores += sklearn.metrics.f1_score(labels.data.cpu().numpy(),preds.cpu().numpy(),average='macro')
        n_iter +=1
    train_loss = running_loss / len(train_loader.dataset)
    acc = corrects / len(train_loader.dataset) * 100
    f1_scores = scores / n_iter
    print(f'Train Loss:{train_loss} acc:{acc} F1-score:{scores}')
    return model, train_loss, acc

In [9]:
def validate(args, model, valid_loader,criterion, writer):
    model.eval()
    
    all_predictions = []
    corrects, scores, running_loss  = 0., 0., 0.
    n_iter = 0
    for step, (images,labels) in enumerate(valid_loader):
        images = images.to(args.device)
        labels = labels.to(args.device)
        
        outputs = model(images)
        loss = criterion(outputs,labels)
        _,preds = torch.max(outputs, 1)
        
        all_predictions.extend(preds.cpu().numpy())
        corrects += torch.sum(preds == labels.data)
        running_loss += loss.item() * images.size(0) 
        scores += sklearn.metrics.f1_score(labels.data.cpu().numpy(),preds.cpu().numpy(),average='macro')
        n_iter +=1
        
    valid_loss = running_loss / len(valid_loader.dataset)
    acc = corrects / len(valid_loader.dataset) * 100
    f1_scores = scores / n_iter
    print(f'Valid Loss:{valid_loss} Valid Acc:{acc} F1-score:{f1_scores}')
    return valid_loss, acc, f1_scores, all_predictions

In [10]:
n_splits = 5

def ensemble_kfold(args, train_dataset, train_labels, test_dataset, n_splits = 5):
    tensor_name = args.model_name + '0921'
    writer = SummaryWriter(tensor_name)
    
    skf = StratifiedKFold(n_splits=n_splits)
    counter = 0
    patience = 10
    accumulation_steps = 2
    best_val_acc = 0
    best_val_loss = np.inf
    train_oof_pred = None
    test_oof_pred = None
    final_oof_pred = None
    
    # K-Fold Cross Validation과 동일하게 Train, Valid Index를 생성합니다. 
    for i, (train_idx, valid_idx) in enumerate(skf.split(train_dataset,train_labels)):
        gc.collect()
        torch.cuda.empty_cache()
        print(f'k-fold:{i+1}')
        
        # make loader
        train_loader, val_loader = getDataloader(train_dataset, train_idx, valid_idx, args.batch_size, 4)
        
        # -- model
        if 'efficient5' in args.model_name:
            print('load efficient5')
            model = EfficientNet5().to(args.device)
        elif 'swin' in args.model_name: 
            print('load swin')
            model = swinLargeModel().to(args.device)
        elif 'cait' in args.model_name:
            print('load cait')
            model = caitBaseModel().to(args.device)
        
        # -- loss & metric
        criterion =nn.CrossEntropyLoss()    
        #criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=0.0)
        #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=100, T_mult=1, eta_min=1e-4)

        best_acc, best_idx, best_f1 = 0., 0., 0.
        best_model_wts = copy.deepcopy(model.state_dict())
        for epoch in notebook.tqdm(range(args.n_epochs)):
            #train
            model, train_loss, acc = train(args, model, train_loader, optimizer, scheduler, criterion, writer)
            #validate
            valid_loss, valid_acc, valid_f1, train_predictions = validate(args, model, val_loader,criterion, writer)
            #save the best model
            if valid_acc > best_acc and valid_f1 > best_f1:
                best_acc = valid_acc
                best_idx = epoch
                best_f1 = valid_f1
                best_model_prediction = train_predictions
                best_model_wts = copy.deepcopy(model.state_dict())
        
        model.load_state_dict(best_model_wts)
        torch.save(model.state_dict(), f"{args.model_name}_kfold{i+1}_ep{args.n_epochs}_batch{args.batch_size}_lr{args.lr}_{best_acc}.pt")
        
        # 각 fold에서 생성된 모델을 사용해 Test 데이터를 예측합니다. 
        all_predictions = []
        with torch.no_grad():
            model.eval()
            for images,_ in test_loader:
                images = images.to(args.device)
                # Test Time Augmentation
                pred = model(images) / 2 # 원본 이미지를 예측하고
                #pred += model(torch.flip(images)) / 2 # flip으로 뒤집어 예측합니다. 
                all_predictions.extend(pred.cpu().numpy())

            fold_pred = np.array(all_predictions)

        # 확률 값으로 앙상블을 진행하기 때문에 'k'개로 나누어줍니다.
        if test_oof_pred is None:
            test_oof_pred = np.array(fold_pred) / n_splits
        else:
            test_oof_pred += np.array(fold_pred) / n_splits
    
    return test_oof_pred

In [None]:
config = {}
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
config['n_epochs'] = 20
config['batch_size'] = 32
config['lr'] = 1e-4
config['log_steps'] = 500
config['model_name'] = 'swin_large'
args = easydict.EasyDict(config)

swin_test = ensemble_kfold(args, train_dataset, train_labels, test_loader, n_splits = 5)