In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import easydict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import torch
import torch.utils.data as data

from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop

import sklearn

from tqdm import notebook
import gc
import random
import timm

## Fix Seed

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

# Load Train Data

In [None]:
data_dir = '../input/data/train'
img_dir = f'{data_dir}/images'
df_path = f'{data_dir}/train.csv'

df = pd.read_csv(df_path)

def get_transforms(need=('train', 'val'), mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
    transformations = {}
    if 'train' in need:
        transformations['train'] = transforms.Compose([
            CenterCrop(224),
            ToTensor(),
            Normalize(mean=mean, std=std),
        ])
    if 'val' in need:
         transformations['val'] = transforms.Compose([
            CenterCrop(224),
            ToTensor(),
            Normalize(mean=mean, std=std),
        ])
    return transformations

### 마스크 여부, 성별, 나이를 mapping할 클래스를 생성합니다.

class MaskLabels:
    mask = 0
    incorrect = 1
    normal = 2

class GenderLabels:
    male = 0
    female = 1

class AgeGroup:
    map_label = lambda x: 0 if int(x) < 30 else 1 if int(x) < 60 else 2
    
class MaskBaseDataset(Dataset):
    num_classes = 3 * 2 * 3

    _file_names = {
        "mask1.jpg": MaskLabels.mask,
        "mask2.jpg": MaskLabels.mask,
        "mask3.jpg": MaskLabels.mask,
        "mask4.jpg": MaskLabels.mask,
        "mask5.jpg": MaskLabels.mask,
        "incorrect_mask.jpg": MaskLabels.incorrect,
        "normal.jpg": MaskLabels.normal
    }

    image_paths = []
    mask_labels = []
    gender_labels = []
    age_labels = []

    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.transform = transform

        self.setup()

    def set_transform(self, transform):
        self.transform = transform
        
    def setup(self):
        profiles = os.listdir(self.img_dir)
        for profile in profiles:
            for file_name, label in self._file_names.items():
                img_path = os.path.join(self.img_dir, profile, file_name)  # (resized_data, 000004_male_Asian_54, mask1.jpg)
                if os.path.exists(img_path):
                    self.image_paths.append(img_path)
                    self.mask_labels.append(label)

                    id, gender, race, age = profile.split("_")
                    gender_label = getattr(GenderLabels, gender)
                    age_label = AgeGroup.map_label(age)

                    self.gender_labels.append(gender_label)
                    self.age_labels.append(age_label)

    def __getitem__(self, index):
        # 이미지를 불러옵니다.
        image_path = self.image_paths[index]
        image = Image.open(image_path)
        
        # 레이블을 불러옵니다.
        mask_label = self.mask_labels[index]
        gender_label = self.gender_labels[index]
        age_label = self.age_labels[index]
        multi_class_label = mask_label * 6 + gender_label * 3 + age_label
        
        # 이미지를 Augmentation 시킵니다.
        image_transform = self.transform(image)
            
        return image_transform, multi_class_label

    def __len__(self):
        return len(self.image_paths)

In [None]:
# 정의한 Augmentation 함수와 Dataset 클래스 객체를 생성합니다.
transform = get_transforms()

dataset = MaskBaseDataset(
    img_dir=img_dir
)

# train dataset과 validation dataset을 8:2 비율로 나눕니다.
n_val = int(len(dataset) * 0.2)
n_train = len(dataset) - n_val
train_dataset, val_dataset = data.random_split(dataset, [n_train, n_val])

# 각 dataset에 augmentation 함수를 설정합니다.
train_dataset.dataset.set_transform(transform['train'])
val_dataset.dataset.set_transform(transform['val'])

train_loader = data.DataLoader(
    train_dataset,
    batch_size=16,
    num_workers=4,
    shuffle=True
)

val_loader = data.DataLoader(
    val_dataset,
    batch_size=16,
    num_workers=4,
    shuffle=False
)

# Model

In [None]:
class swinBaseModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x
    
class swinTinyModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('swin_tiny_patch4_window7_224',pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x

class swinLargeModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('swin_large_patch4_window7_224', pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x

class EfficientNet7(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = EfficientNet.from_pretrained('efficientnet-b7',class_n)
    def forward(self,x):
        x = self.model(x)
        return x
    
class EfficientNet5(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = EfficientNet.from_pretrained('efficientnet-b5',class_n)
    def forward(self,x):
        x = self.model(x)
        return x

class caitBaseModel(nn.Module):
    def __init__(self, class_n=18):
        super().__init__()
        self.model = timm.create_model('cait_s24_224',pretrained=True)
        self.classify = torch.nn.Linear(in_features=1000,out_features=class_n)        
    
    def forward(self,x):
        x = self.model(x)
        x = self.classify(x)
        return x

Train & Valid

In [None]:
def train(args,model, train_loader, optimizer, scheduler):
    model.train()
    
    corrects, scores, running_loss  = 0., 0., 0.
    n_iter = 0
    for step, (images,labels) in enumerate(train_loader):
        images = images.to(args.device)
        labels = labels.to(args.device)
        outputs = model(images)
        loss = criterion(outputs,labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        if step % args.log_steps == 0:
            print(f"Training Steps: {step} Loss: {str(loss.item())}")
        
        _,preds = torch.max(outputs, 1)
        corrects += torch.sum(preds == labels.data)
        running_loss += loss.item() * images.size(0) 
        scores += sklearn.metrics.f1_score(labels.data.cpu().numpy(),preds.cpu().numpy(),average='macro')
        n_iter +=1
    train_loss = running_loss / len(train_loader.dataset)
    acc = corrects / len(train_loader.dataset) * 100
    f1_scores = scores / n_iter
    print(f'Train Loss:{train_loss} acc:{acc} F1-score:{scores}')
    return model, train_loss, acc

In [None]:
def validate(args, model, valid_loader,criterion):
    model.eval()
    
    all_predictions = []
    corrects, scores, running_loss  = 0., 0., 0.
    n_iter = 0
    for step, (images,labels) in enumerate(valid_loader):
        images = images.to(args.device)
        labels = labels.to(args.device)
        
        outputs = model(images)
        loss = criterion(outputs,labels)
        
        _,preds = torch.max(outputs, 1)
        
        all_predictions.extend(preds.cpu().numpy())
        corrects += torch.sum(preds == labels.data)
        running_loss += loss.item() * images.size(0) 
        scores += sklearn.metrics.f1_score(labels.data.cpu().numpy(),preds.cpu().numpy(),average='macro')
        n_iter +=1
        
    valid_loss = running_loss / len(valid_loader.dataset)
    acc = corrects / len(valid_loader.dataset) * 100
    f1_scores = scores / n_iter
    print(f'Valid Loss:{valid_loss} Valid Acc:{acc} F1-score:{f1_scores}')
    return valid_loss, acc, f1_scores, all_predictions

## Main Run

In [None]:
def run(args, train_data, valid_data):
    gc.collect()
    torch.cuda.empty_cache()

    if 'efficient5' in args.model_name:
        model = EfficientNet5().to(args.device)
    elif 'efficient' in args.model_name:
        model = EfficientNet7().to(args.device)
    elif 'swin_base' in args.model_name:
        model = swinBaseModel().to(args.device)
    elif 'swin_tiny' in args.model_name:
        model = swinTinyModel().to(args.device)
    elif 'swin_large' in args.model_name:
        model = swinLargeModel().to(args.device)
    elif 'cait' in args.model_name:
        model = caitBaseModel().to(args.device)
    
    optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=0.0)
    optimizer.zero_grad()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
    criterion =nn.CrossEntropyLoss()
    # criterion = nn.BCEWithLogitsLoss()
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = -1
    best_f1 = 0.0
    
    for epoch in notebook.tqdm(range(args.n_epochs)):
        print(f'Epoch {epoch + 1}/{args.n_epochs}')
        train_model, train_loss, train_acc = train(args,model,train_data,optimizer,scheduler)
        # model = train_model
        valid_loss, valid_acc, epoch_f1, outputs = validate(args, model, valid_data)
        print(f'train acc: {train_acc}\tand valid acc:{valid_acc}\n')
        if valid_acc>best_acc:
            best_acc = valid_acc
            best_idx = epoch
            best_f1 = epoch_f1
            best_model_wts = copy.deepcopy(model.state_dict())
    
    print('Best valid Acc: %d - %.1f' %(best_idx, best_acc))
    print('Best valid f1: %d - %.1f' %(best_idx, best_f1))
    model.load_state_dict(best_model_wts)
    save_name =  f"{args.model_name}_ep{args.n_epochs}_batch{args.batch_size}_lr{args.lr}_{best_acc}.pt"
    torch.save(model, save_name)

In [None]:
config = {}
config['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
config['n_epochs'] = 20
config['batch_size'] = 16
config['lr'] = 1e-4
config['log_steps'] = 300
config['model_name'] = 'swin_base'
args = easydict.EasyDict(config)

run(args,train_loader,val_loader)