In [1]:
import os, torch, copy, cv2, sys, random
# from datetime import datetime, timezone, timedelta
from PIL import Image
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

In [2]:
# 시드(seed) 설정

RANDOM_SEED = 2022

torch.manual_seed(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

In [3]:
# parameters

### 데이터 디렉토리 설정 ###
DATA_DIR= 'data'
NUM_CLS = 2

EPOCHS = 30
BATCH_SIZE = 32
LEARNING_RATE = 0.0005
EARLY_STOPPING_PATIENCE = 10
INPUT_SHAPE = 128

os.environ["CUDA_VISIBLE_DEVICES"]="0"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, mode, input_shape):
        self.data_dir = data_dir
        self.mode = mode
        self.input_shape = input_shape
        
        # Loading dataset
        self.db = self.data_loader()
        
        # Dataset split
        if self.mode == 'train':
            self.db = self.db[:int(len(self.db) * 0.9)]
        elif self.mode == 'val':
            self.db = self.db[int(len(self.db) * 0.9):]
            self.db.reset_index(inplace=True)
        else:
            print(f'!!! Invalid split {self.mode}... !!!')
            
        # Transform function
        self.transform = transforms.Compose([transforms.Resize(self.input_shape),
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    def data_loader(self):
        print('Loading ' + self.mode + ' dataset..')
        if not os.path.isdir(self.data_dir):
            print(f'!!! Cannot find {self.data_dir}... !!!')
            sys.exit()
        
        # (COVID : 1, No : 0)
        db = pd.read_csv(os.path.join(self.data_dir, 'train.csv'))
        
        return db

    def __len__(self):
        return len(self.db)

    def __getitem__(self, index):
        data = copy.deepcopy(self.db.loc[index])

        # Loading image
        cvimg = cv2.imread(os.path.join(self.data_dir,'train',data['file_name']), cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
        if not isinstance(cvimg, np.ndarray):
            raise IOError("Fail to read %s" % data['file_name'])

        # Preprocessing images
        trans_image = self.transform(Image.fromarray(cvimg))

        return trans_image, data['COVID']

In [5]:
import torch.nn.functional as F

class custom_CNN(nn.Module):
    def __init__(self, num_classes):
        super(custom_CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=25, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=25*29*29, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=num_classes)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x))) # (32, 3, 128, 128) -> (32, 8, 62, 62)
        x = self.pool(F.relu(self.conv2(x))) # (32, 8, 62, 62) -> (32, 25, 29, 29)
        
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        output = self.softmax(x)
        
        return output

In [6]:
class LossEarlyStopper():
    """Early stopper
    
    Attributes:
        patience (int): loss가 줄어들지 않아도 학습할 epoch 수
        patience_counter (int): loss 가 줄어들지 않을 때 마다 1씩 증가, 감소 시 0으로 리셋
        min_loss (float): 최소 loss
        stop (bool): True 일 때 학습 중단

    """

    def __init__(self, patience: int)-> None:
        self.patience = patience

        self.patience_counter = 0
        self.min_loss = np.Inf
        self.stop = False
        self.save_model = False

    def check_early_stopping(self, loss: float)-> None:
        """Early stopping 여부 판단"""  

        if self.min_loss == np.Inf:
            self.min_loss = loss
            return None

        elif loss > self.min_loss:
            self.patience_counter += 1
            msg = f"Early stopping counter {self.patience_counter}/{self.patience}"

            if self.patience_counter == self.patience:
                self.stop = True
                
        elif loss <= self.min_loss:
            self.patience_counter = 0
            self.save_model = True
            msg = f"Validation loss decreased {self.min_loss} -> {loss}"
            self.min_loss = loss
        
        print(msg)

In [7]:
class Trainer():
    """ epoch에 대한 학습 및 검증 절차 정의"""
    
    def __init__(self, loss_fn, model, device, metric_fn, optimizer=None, scheduler=None):
        """ 초기화
        """
        self.loss_fn = loss_fn
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.metric_fn = metric_fn

    def train_epoch(self, dataloader, epoch_index):
        """ 한 epoch에서 수행되는 학습 절차"""
        
        self.model.train()
        train_total_loss = 0
        target_lst = []
        pred_lst = []
        prob_lst = []

        for batch_index, (img, label) in enumerate(dataloader):
            img = img.to(self.device)
            label = label.to(self.device).float()
            
            pred = self.model(img)
            
            loss = self.loss_fn(pred[:,1], label)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()
            
            train_total_loss += loss.item()
            prob_lst.extend(pred[:, 1].cpu().tolist())
            target_lst.extend(label.cpu().tolist())
            pred_lst.extend(pred.argmax(dim=1).cpu().tolist())
        self.train_mean_loss = train_total_loss / batch_index
        self.train_score, f1 = self.metric_fn(y_pred=pred_lst, y_answer=target_lst)
        msg = f'Epoch {epoch_index}, Train loss: {self.train_mean_loss}, Acc: {self.train_score}, F1-Macro: {f1}'
        print(msg)

    def validate_epoch(self, dataloader, epoch_index):
        """ 한 epoch에서 수행되는 검증 절차
        """
        self.model.eval()
        val_total_loss = 0
        target_lst = []
        pred_lst = []
        prob_lst = []

        for batch_index, (img, label) in enumerate(dataloader):
            img = img.to(self.device)
            label = label.to(self.device).float()
            pred = self.model(img)
            
            loss = self.loss_fn(pred[:,1], label)
            val_total_loss += loss.item()
            prob_lst.extend(pred[:, 1].cpu().tolist())
            target_lst.extend(label.cpu().tolist())
            pred_lst.extend(pred.argmax(dim=1).cpu().tolist())
        self.val_mean_loss = val_total_loss / batch_index
        self.validation_score, f1 = self.metric_fn(y_pred=pred_lst, y_answer=target_lst)
        msg = f'Epoch {epoch_index}, Val loss: {self.val_mean_loss}, Acc: {self.validation_score}, F1-Macro: {f1}'
        print(msg)

In [8]:
from sklearn.metrics import accuracy_score, f1_score

def get_metric_fn(y_pred, y_answer):
    """ 성능을 반환하는 함수"""
    
    assert len(y_pred) == len(y_answer), 'The size of prediction and answer are not same.'
    accuracy = accuracy_score(y_answer, y_pred)
    f1 = f1_score(y_answer, y_pred, average='macro')
    return accuracy, f1

In [9]:
# Load dataset & dataloader
train_dataset = CustomDataset(data_dir=DATA_DIR, mode='train', input_shape=INPUT_SHAPE)
validation_dataset = CustomDataset(data_dir=DATA_DIR, mode='val', input_shape=INPUT_SHAPE)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True)
print('Train set samples:',len(train_dataset),  'Val set samples:', len(validation_dataset))

Loading train dataset..
Loading val dataset..
Train set samples: 581 Val set samples: 65


In [10]:
# Load Model
model = custom_CNN(NUM_CLS).to(DEVICE)

# # Save Initial Model
# torch.save(model.state_dict(), 'initial.pt')

# Set optimizer, scheduler, loss function, metric function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler =  optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=1e5, max_lr=0.0001, epochs=EPOCHS, steps_per_epoch=len(train_dataloader))
loss_fn = nn.BCELoss()
metric_fn = get_metric_fn


# Set trainer
trainer = Trainer(loss_fn, model, DEVICE, metric_fn, optimizer, scheduler)

# Set earlystopper
early_stopper = LossEarlyStopper(patience=EARLY_STOPPING_PATIENCE)

In [11]:
for epoch_index in tqdm(range(EPOCHS)):

    trainer.train_epoch(train_dataloader, epoch_index)
    trainer.validate_epoch(validation_dataloader, epoch_index)

    # early_stopping check
    early_stopper.check_early_stopping(loss=trainer.val_mean_loss)

    if early_stopper.stop:
        print('Early stopped')
        break

    if early_stopper.save_model:
        check_point = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()
        }
        torch.save(check_point, 'best.pt')


  0% 0/30 [00:00<?, ?it/s]

Epoch 0, Train loss: 0.7320066326194339, Acc: 0.5851979345955249, F1-Macro: 0.562621627615677


  3% 1/30 [01:34<45:52, 94.91s/it]

Epoch 0, Val loss: 1.0244255661964417, Acc: 0.5384615384615384, F1-Macro: 0.5125
Epoch 1, Train loss: 0.7180679142475128, Acc: 0.5886402753872634, F1-Macro: 0.5884012864405022
Epoch 1, Val loss: 0.8881474882364273, Acc: 0.6307692307692307, F1-Macro: 0.5962732919254659
Validation loss decreased 1.0244255661964417 -> 0.8881474882364273


  7% 2/30 [03:17<45:18, 97.07s/it]

Epoch 2, Train loss: 0.7049072682857513, Acc: 0.5714285714285714, F1-Macro: 0.5599087383070956
Epoch 2, Val loss: 1.1285955309867859, Acc: 0.5846153846153846, F1-Macro: 0.5190463140586462
Early stopping counter 1/10


 10% 3/30 [04:56<43:59, 97.77s/it]

Epoch 3, Train loss: 0.6896312601036496, Acc: 0.621342512908778, F1-Macro: 0.6212875394060063
Epoch 3, Val loss: 0.7613501474261284, Acc: 0.6307692307692307, F1-Macro: 0.6285714285714286
Validation loss decreased 0.8881474882364273 -> 0.7613501474261284


 13% 4/30 [06:35<42:29, 98.05s/it]

Epoch 4, Train loss: 0.6587940057118734, Acc: 0.6437177280550774, F1-Macro: 0.6437177280550775
Epoch 4, Val loss: 0.7195451259613037, Acc: 0.676923076923077, F1-Macro: 0.675694939415538
Validation loss decreased 0.7613501474261284 -> 0.7195451259613037


 17% 5/30 [08:16<41:15, 99.02s/it]

Epoch 5, Train loss: 0.6641549633608924, Acc: 0.6592082616179001, F1-Macro: 0.657502679528403
Epoch 5, Val loss: 0.906316339969635, Acc: 0.676923076923077, F1-Macro: 0.6690909090909092
Early stopping counter 1/10


 20% 6/30 [09:54<39:32, 98.86s/it]

Epoch 6, Train loss: 0.6229612777630488, Acc: 0.6815834767641996, F1-Macro: 0.6815797035759887
Epoch 6, Val loss: 0.7854321002960205, Acc: 0.6461538461538462, F1-Macro: 0.6167649320687003
Early stopping counter 2/10


 23% 7/30 [11:26<37:06, 96.81s/it]

Epoch 7, Train loss: 0.6156510992182626, Acc: 0.7108433734939759, F1-Macro: 0.7090982785751752
Epoch 7, Val loss: 0.7753697633743286, Acc: 0.6615384615384615, F1-Macro: 0.6595238095238096
Early stopping counter 3/10


 27% 8/30 [13:02<35:21, 96.44s/it]

Epoch 8, Train loss: 0.5741558753781848, Acc: 0.729776247848537, F1-Macro: 0.7295166307374339
Epoch 8, Val loss: 0.6857179999351501, Acc: 0.7846153846153846, F1-Macro: 0.7845643939393939
Validation loss decreased 0.7195451259613037 -> 0.6857179999351501


 30% 9/30 [14:37<33:34, 95.95s/it]

Epoch 9, Train loss: 0.5744568490319781, Acc: 0.7418244406196214, F1-Macro: 0.7404024878467258
Epoch 9, Val loss: 0.6752972155809402, Acc: 0.8307692307692308, F1-Macro: 0.8306088604596067
Validation loss decreased 0.6857179999351501 -> 0.6752972155809402


 33% 10/30 [16:11<31:47, 95.40s/it]

Epoch 10, Train loss: 0.5676499572065141, Acc: 0.7228915662650602, F1-Macro: 0.7211435555754296
Epoch 10, Val loss: 0.6901214644312859, Acc: 0.6307692307692307, F1-Macro: 0.5962732919254659
Early stopping counter 1/10


 37% 11/30 [17:50<30:31, 96.42s/it]

Epoch 11, Train loss: 0.5479644619756274, Acc: 0.7573149741824441, F1-Macro: 0.755647160238265
Epoch 11, Val loss: 0.6773225665092468, Acc: 0.7384615384615385, F1-Macro: 0.738213693437574
Early stopping counter 2/10


 40% 12/30 [19:24<28:43, 95.76s/it]

Epoch 12, Train loss: 0.5250816643238068, Acc: 0.76592082616179, F1-Macro: 0.764967637540453
Epoch 12, Val loss: 0.7691425681114197, Acc: 0.7230769230769231, F1-Macro: 0.7115384615384615
Early stopping counter 3/10


 43% 13/30 [21:01<27:15, 96.22s/it]

Epoch 13, Train loss: 0.5012414885891808, Acc: 0.7831325301204819, F1-Macro: 0.7824306331581825
Epoch 13, Val loss: 0.7222307473421097, Acc: 0.6615384615384615, F1-Macro: 0.6549227799227799
Early stopping counter 4/10


 47% 14/30 [22:40<25:52, 97.06s/it]

Epoch 14, Train loss: 0.5064700378312005, Acc: 0.7796901893287436, F1-Macro: 0.7792815252748295
Epoch 14, Val loss: 0.6739183664321899, Acc: 0.6923076923076923, F1-Macro: 0.6886973180076629
Validation loss decreased 0.6752972155809402 -> 0.6739183664321899


 50% 15/30 [24:16<24:11, 96.75s/it]

Epoch 15, Train loss: 0.500286739733484, Acc: 0.7796901893287436, F1-Macro: 0.7776980653801087
Epoch 15, Val loss: 0.8175498247146606, Acc: 0.7538461538461538, F1-Macro: 0.7523809523809524
Early stopping counter 1/10


 53% 16/30 [25:53<22:35, 96.81s/it]

Epoch 16, Train loss: 0.505645043320126, Acc: 0.7831325301204819, F1-Macro: 0.7824306331581825
Epoch 16, Val loss: 1.0814284086227417, Acc: 0.6153846153846154, F1-Macro: 0.583440143552935
Early stopping counter 2/10


 57% 17/30 [27:29<20:54, 96.47s/it]

Epoch 17, Train loss: 0.5317173467742072, Acc: 0.7607573149741824, F1-Macro: 0.7588254090552087
Epoch 17, Val loss: 0.6262766867876053, Acc: 0.7076923076923077, F1-Macro: 0.6973780936045086
Validation loss decreased 0.6739183664321899 -> 0.6262766867876053


 60% 18/30 [29:05<19:17, 96.49s/it]

Epoch 18, Train loss: 0.4940005871984694, Acc: 0.7762478485370051, F1-Macro: 0.7754329004329005
Epoch 18, Val loss: 0.8391310572624207, Acc: 0.676923076923077, F1-Macro: 0.6655231560891939
Early stopping counter 1/10


 63% 19/30 [30:32<17:09, 93.61s/it]

Epoch 19, Train loss: 0.4723478886816237, Acc: 0.7951807228915663, F1-Macro: 0.7941047716328615
Epoch 19, Val loss: 0.7294589132070541, Acc: 0.7384615384615385, F1-Macro: 0.7374673319078165
Early stopping counter 2/10


 67% 20/30 [31:44<14:31, 87.13s/it]

Epoch 20, Train loss: 0.4859016107188331, Acc: 0.7934595524956971, F1-Macro: 0.7924258663808503
Epoch 20, Val loss: 1.2388448119163513, Acc: 0.7076923076923077, F1-Macro: 0.7076923076923077
Early stopping counter 3/10


 70% 21/30 [33:03<12:40, 84.47s/it]

Epoch 21, Train loss: 0.4505847924285465, Acc: 0.8055077452667814, F1-Macro: 0.8043858472998138
Epoch 21, Val loss: 0.7408526688814163, Acc: 0.6923076923076923, F1-Macro: 0.6904761904761905
Early stopping counter 4/10


 73% 22/30 [34:39<11:43, 87.93s/it]

Epoch 22, Train loss: 0.4619402206606335, Acc: 0.8175559380378657, F1-Macro: 0.8167301511724795
Epoch 22, Val loss: 0.7155401855707169, Acc: 0.7384615384615385, F1-Macro: 0.7344388368180726
Early stopping counter 5/10


 77% 23/30 [36:16<10:34, 90.65s/it]

Epoch 23, Train loss: 0.4282050629456838, Acc: 0.8278829604130808, F1-Macro: 0.8263266135782099
Epoch 23, Val loss: 0.7742991894483566, Acc: 0.7846153846153846, F1-Macro: 0.7845643939393939
Early stopping counter 6/10


 80% 24/30 [37:54<09:16, 92.83s/it]

Epoch 24, Train loss: 0.43082623928785324, Acc: 0.8037865748709122, F1-Macro: 0.8034786014384391
Epoch 24, Val loss: 0.5784971863031387, Acc: 0.7692307692307693, F1-Macro: 0.7690120824449184
Validation loss decreased 0.6262766867876053 -> 0.5784971863031387


 83% 25/30 [39:24<07:39, 91.98s/it]

Epoch 25, Train loss: 0.45767996046278214, Acc: 0.8141135972461274, F1-Macro: 0.8131832797427652
Epoch 25, Val loss: 0.8614400625228882, Acc: 0.7538461538461538, F1-Macro: 0.7533206831119544
Early stopping counter 1/10


 87% 26/30 [40:32<05:39, 84.79s/it]

Epoch 26, Train loss: 0.45562876264254254, Acc: 0.8123924268502581, F1-Macro: 0.8114068916637135
Epoch 26, Val loss: 0.7129895836114883, Acc: 0.7230769230769231, F1-Macro: 0.7198275862068966
Early stopping counter 2/10


 90% 27/30 [41:38<03:57, 79.15s/it]

Epoch 27, Train loss: 0.44910454915629494, Acc: 0.8261617900172117, F1-Macro: 0.8246470139999104
Epoch 27, Val loss: 0.7172622829675674, Acc: 0.7692307692307693, F1-Macro: 0.7690120824449184
Early stopping counter 3/10


 93% 28/30 [42:46<02:31, 75.82s/it]

Epoch 28, Train loss: 0.4375220421287749, Acc: 0.8296041308089501, F1-Macro: 0.8286212290502794
Epoch 28, Val loss: 0.6607215106487274, Acc: 0.6923076923076923, F1-Macro: 0.6904761904761905
Early stopping counter 4/10


 97% 29/30 [43:48<01:11, 71.84s/it]

Epoch 29, Train loss: 0.4514554755555259, Acc: 0.8382099827882961, F1-Macro: 0.8368523563712837
Epoch 29, Val loss: 0.8458727598190308, Acc: 0.7538461538461538, F1-Macro: 0.7523809523809524
Early stopping counter 5/10


100% 30/30 [44:51<00:00, 89.73s/it]


In [None]:
TRAINED_MODEL_PATH = 'best.pt'

In [None]:
class TestDataset(Dataset):
    def __init__(self, data_dir, input_shape):
        self.data_dir = data_dir
        self.input_shape = input_shape
        
        # Loading dataset
        self.db = self.data_loader()
        
        # Transform function
        self.transform = transforms.Compose([transforms.Resize(self.input_shape),
                                             
                                             transforms.ToTensor(),
                                             transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    def data_loader(self):
        print('Loading test dataset..')
        if not os.path.isdir(self.data_dir):
            print(f'!!! Cannot find {self.data_dir}... !!!')
            sys.exit()
        
        db = pd.read_csv(os.path.join(self.data_dir, 'sample_submission.csv'))
        return db
    
    def __len__(self):
        return len(self.db)
    
    def __getitem__(self, index):
        data = copy.deepcopy(self.db.loc[index])
        
        # Loading image
        cvimg = cv2.imread(os.path.join(self.data_dir,'test',data['file_name']), cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION)
        if not isinstance(cvimg, np.ndarray):
            raise IOError("Fail to read %s" % data['file_name'])

        # Preprocessing images
        trans_image = self.transform(Image.fromarray(cvimg))

        return trans_image, data['file_name']