In [14]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, random_split, SubsetRandomSampler, DataLoader
from typing import Callable, Optional, Any

import torchvision.models as models
from torchvision import transforms

from torch.utils.tensorboard import SummaryWriter

from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import KFold, StratifiedKFold

from PIL import Image
from tqdm import tqdm

import pandas as pd


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

device

device(type='cuda', index=0)

In [3]:
# tensorboard logs, model.state_dict() saved PATH & DIRECTORY

tensorboard_saved_dir = 'logs'
param_saved_dir = 'saved'

os.makedirs(tensorboard_saved_dir, exist_ok=True)
os.makedirs(param_saved_dir, exist_ok=True)

In [4]:
# train data path
train_data_path = '/opt/ml/code/dataset'

train_csv = pd.read_csv(os.path.join(train_data_path, 'image_label.csv'))

train_csv.head()

Unnamed: 0,id,path,file_name,absolute_path,label
0,1,000001_female_Asian_45,incorrect_mask.jpg,/opt/ml/input/data/train/images/000001_female_...,10
1,1,000001_female_Asian_45,mask4.jpg,/opt/ml/input/data/train/images/000001_female_...,4
2,1,000001_female_Asian_45,mask2.jpg,/opt/ml/input/data/train/images/000001_female_...,4
3,1,000001_female_Asian_45,mask1.jpg,/opt/ml/input/data/train/images/000001_female_...,4
4,1,000001_female_Asian_45,mask3.jpg,/opt/ml/input/data/train/images/000001_female_...,4


In [5]:
class MaskDataset(Dataset):

    def __init__(self,table,transform:Optional[callable]=None):
        self.table = table
        self.transform = transform

        self.x, self.y = self.load_data()

    def load_data(self):
        x,y = [],[]
        for i in tqdm(range(len(self.table))):
            im = Image.open(self.table['absolute_path'][i])

            if self.transform:
                im = self.transform(im)
            
            x.append(im)
            y.append(torch.tensor(self.table['label'][i]))
        return x, y

    def __len__(self):
        return len(self.y)
    
    def __getitem__(self,idx):
        return self.x[idx], self.y[idx]

In [6]:
image_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.CenterCrop(150),
    transforms.ToTensor()
])

In [7]:
# make dataset

mask_dataset = MaskDataset(train_csv, image_transform)

100%|██████████| 18900/18900 [02:07<00:00, 147.68it/s]


In [8]:
# Stratified KFold 설정

k = 5
skfold = StratifiedKFold(n_splits=k, shuffle=True)

In [116]:
# train 함수
def model_train(model,device,dataloader,criterion, optimizer):
    model.train()
    model.to(device)
    acc, f1,losses = 0,0,0
    for x,y in dataloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)

        loss = criterion(y_pred,y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        y = y.flatten().to('cpu').detach().numpy()
        y_pred = torch.argmax(y_pred,dim=-1).to('cpu').detach().numpy()
        acc += accuracy_score(y,y_pred)
        f1 += f1_score(y,y_pred,average='macro')
        losses += loss
    return acc / len(dataloader), f1 / len(dataloader), loss / len(dataloader)

In [117]:
# validate 함수
def model_val(model,device,dataloader,criterion):
    model.eval()
    model.to(device)
    acc, f1,losses = 0,0,0
    with torch.no_grad():
        for x,y in dataloader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)

            loss = criterion(y_pred,y)

            y = y.flatten().to('cpu').detach().numpy()
            y_pred = torch.argmax(y_pred,dim=-1).to('cpu').detach().numpy()
            acc += accuracy_score(y,y_pred)
            f1 += f1_score(y,y_pred,average='macro')
            losses += loss
    return acc / len(dataloader), f1 / len(dataloader), loss / len(dataloader)

In [112]:
class MaskResNet18(nn.Module):
    def __init__(self):
        super(MaskResNet18,self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        self.resnet18.fc = nn.Linear(in_features=128,out_features=18)

    def frozen(self):
        for param in self.resnet18.parameters():
            param.requires_grad = False
        for param in self.resnet18.fc.parameters():
            param.requires_grad = True

    def unfrozen(self):
        for param in self.resnet18.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = self.resnet18.conv1(x)
        x = self.resnet18.bn1(x)
        x = self.resnet18.relu(x)
        x = self.resnet18.maxpool(x)
        x = self.resnet18.layer1(x)
        x = self.resnet18.layer2(x)
        x = self.resnet18.avgpool(x)

        x = x.flatten(1)
        x = self.resnet18.fc(x)
        return x

In [118]:
mymodel = MaskResNet18()

In [49]:
writer = SummaryWriter(tensorboard_saved_dir)

In [119]:
BATCH = 30
EPOCH = 10

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mymodel.parameters(),lr=0.0001)


pre_train_f1, pre_train_acc = 0,0
pre_val_f1, pre_val_acc = 0,0

for fold, (train_idx, val_idx) in enumerate(skfold.split(torch.arange(len(mask_dataset)),mask_dataset.y)):
    print('Fold {}'.format(fold+1))

    train_sampler = SubsetRandomSampler(train_idx)
    val_sampler = SubsetRandomSampler(val_idx)

    train_loader = DataLoader(mask_dataset,batch_size=BATCH,sampler=train_sampler,shuffle=False)
    val_loader = DataLoader(mask_dataset,batch_size=BATCH,sampler=val_sampler,shuffle=False)
    for epoch in range(EPOCH):
        train_acc, train_f1, train_loss = model_train(mymodel, device, train_loader,criterion,optimizer)
        val_acc, val_f1, val_loss = model_val(mymodel, device, train_loader,criterion)

        if pre_train_f1 < train_f1:
            pre_train_acc, pre_train_f1 = train_acc, train_f1
            torch.save({
                'epoch': epoch,
                'model_state_dict':mymodel.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'train_acc': train_acc,
                'train_f1': train_f1
            },f"saved/checkpoint_model_train_{fold}_{epoch}_{train_acc}_{train_f1}.pt") 
        if pre_val_f1 < val_f1:
            pre_val_f1, pre_val_acc = val_f1, val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict':mymodel.state_dict(),
                'optimizer_state_dict':optimizer.state_dict(),
                'train_acc': val_acc,
                'train_f1': val_f1
            },f"saved/checkpoint_model_val_{fold}_{epoch}_{val_acc}_{val_f1}.pt") 
        
        writer.add_scalar(f'Acc/fold{fold}/train',train_acc,epoch)
        writer.add_scalar(f'f1/fold{fold}/train',train_f1,epoch)
        writer.add_scalar(f'loss/fold{fold}/train',train_loss,epoch)

        writer.add_scalar(f'Acc/fold{fold}/val',val_acc,epoch)
        writer.add_scalar(f'f1/fold{fold}/val',val_f1,epoch)
        writer.add_scalar(f'loss/fold{fold}/train',val_loss,epoch)
        
        print(f"EPOCH: {epoch} train_acc: {train_acc}, train_f1: {train_f1}, loss: {train_loss}")
        print(f"EPOCH: {epoch} val_acc:{val_acc}, val_f1:{val_f1}, loss: {val_loss}")

Fold 1
EPOCH: 0 train_acc: 0.43974867724867706, train_f1: 0.1703040245563125, loss: 0.003077421337366104
EPOCH: 0 val_acc:0.5276455026455024, val_f1:0.2251404051101525, loss: 0.0022668696474283934
EPOCH: 1 train_acc: 0.6146164021164007, train_f1: 0.3202116809145975, loss: 0.0016764879692345858
EPOCH: 1 val_acc:0.698544973544973, val_f1:0.43257243520383065, loss: 0.002298537874594331
EPOCH: 2 train_acc: 0.7306878306878306, train_f1: 0.4931482978913052, loss: 0.0011985249584540725
EPOCH: 2 val_acc:0.7545634920634925, val_f1:0.5542364661206166, loss: 0.0018170926487073302
EPOCH: 3 train_acc: 0.8179894179894175, train_f1: 0.63740022514339, loss: 0.0011417614296078682
EPOCH: 3 val_acc:0.8660052910052901, val_f1:0.7188184541647106, loss: 0.001402572263032198
EPOCH: 4 train_acc: 0.8534391534391527, train_f1: 0.7038131449170523, loss: 0.0014214685652405024
EPOCH: 4 val_acc:0.8608465608465595, val_f1:0.7265722890957965, loss: 0.0014630734222009778
EPOCH: 5 train_acc: 0.8769841269841255, train_f