In [1]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [2]:
import GPUtil
GPUtil.showUtilization()

| ID | GPU | MEM |
------------------
|  0 |  0% | 12% |


In [3]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random
import os
from torchmetrics import F1Score

In [4]:
os.chdir('../input/data')
os.getcwd()

'/opt/ml/input/data'

In [5]:
!find . -regex ".*\.\_[a-zA-Z0-9._]+" -delete

In [6]:
random_seed = 12
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

In [7]:
train_dir_path = '/opt/ml/input/data/train/'
train_image_path = '/opt/ml/input/data/train/images/'

dt_train = pd.read_csv(train_dir_path+'train.csv')
dt_train

Unnamed: 0,id,gender,race,age,path
0,000001,female,Asian,45,000001_female_Asian_45
1,000002,female,Asian,52,000002_female_Asian_52
2,000004,male,Asian,54,000004_male_Asian_54
3,000005,female,Asian,58,000005_female_Asian_58
4,000006,female,Asian,59,000006_female_Asian_59
...,...,...,...,...,...
2695,006954,male,Asian,19,006954_male_Asian_19
2696,006955,male,Asian,19,006955_male_Asian_19
2697,006956,male,Asian,19,006956_male_Asian_19
2698,006957,male,Asian,20,006957_male_Asian_20


In [8]:
def get_age_range(age):
    if age < 30:
        return 0
    elif 30 <= age < 60:
        return 1
    else:
        return 2

In [9]:
dt_train['age_range'] = dt_train['age'].apply(lambda x : get_age_range(x))

In [10]:
dt_train

Unnamed: 0,id,gender,race,age,path,age_range
0,000001,female,Asian,45,000001_female_Asian_45,1
1,000002,female,Asian,52,000002_female_Asian_52,1
2,000004,male,Asian,54,000004_male_Asian_54,1
3,000005,female,Asian,58,000005_female_Asian_58,1
4,000006,female,Asian,59,000006_female_Asian_59,1
...,...,...,...,...,...,...
2695,006954,male,Asian,19,006954_male_Asian_19,0
2696,006955,male,Asian,19,006955_male_Asian_19,0
2697,006956,male,Asian,19,006956_male_Asian_19,0
2698,006957,male,Asian,20,006957_male_Asian_20,0


In [11]:
train_idx, valid_idx = train_test_split(np.arange(len(dt_train)),
                                       test_size=0.2,
                                       shuffle=True,
                                       stratify=dt_train['age_range'])

In [12]:
train_image = []
train_label = []

for idx in train_idx:
    path = dt_train.iloc[idx]['path']
    for file_name in [i for i in os.listdir(train_image_path+path) if i[0] != '.']:
        train_image.append(train_image_path+path+'/'+file_name)
        train_label.append((path.split('_')[1], path.split('_')[3], file_name.split('.')[0]))                                 

In [13]:
valid_image = []
valid_label = []

for idx in valid_idx:
    path = dt_train.iloc[idx]['path']
    for file_name in [i for i in os.listdir(train_image_path+path) if i[0] != '.']:
        valid_image.append(train_image_path+path+'/'+file_name)
        valid_label.append((path.split('_')[1], path.split('_')[3], file_name.split('.')[0]))                                 

In [14]:
def onehot_enc(x):
    def gender(i):
        if i == 'male':
            return 0
        elif i == 'female':
            return 3
    def age(j):
        j = int(j)
        if j < 30:
            return 0
        elif j >= 30 and j < 60:
            return 1
        elif j >= 60:
            return 2
    def mask(k):
        if k == 'normal':
            return 12
        elif 'incorrect' in k:
            return 6
        else:
            return 0
    return gender(x[0]) + age(x[1]) + mask(x[2])

In [15]:
train_data = pd.Series(train_image)
train_label = pd.Series(train_label)

valid_data = pd.Series(valid_image)
valid_label = pd.Series(valid_label)

In [16]:
class Dataset_Mask(Dataset):
    def __init__(self, data, label, encoding=True, midcrop=True, transform=None):
        self.encoding = encoding
        self.midcrop = midcrop
        self.data = data
        self.label = label
        self.transform = transform
        
        if encoding:
            self.label = self.label.apply(onehot_enc)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        X = cv2.cvtColor(cv2.imread(self.data[idx]), cv2.COLOR_BGR2RGB)
        y = self.label[idx]
        
        if self.midcrop:
            X = X[70:420, 17:367]
        
        if self.transform:
            return self.transform(X), y
        return X, y

In [17]:
mask_train_set = Dataset_Mask(data=train_data, label=train_label, transform = transforms.Compose([
                                transforms.ToTensor()
                            ]))

In [18]:
mask_val_set = Dataset_Mask(data=valid_data, label=valid_label, transform = transforms.Compose([
                                transforms.ToTensor()
                            ]))

In [19]:
# t_image = [mask_train_set[i][1] for i in range(len(mask_train_set))]
# v_image = [mask_val_set[i][1] for i in range(len(mask_val_set))]

In [20]:
# t_df = pd.DataFrame(t_image, columns=['counts'])
# v_df = pd.DataFrame(v_image, columns=['counts'])

In [21]:
# import seaborn as sns

# fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# sns.countplot(x='counts', data=t_df, ax=axes[0])
# axes[0].set_xlabel("train set labels")
# sns.countplot(x='counts', data=v_df, ax=axes[1])
# axes[1].set_xlabel("valid set labels")

In [22]:
print(f'training data size : {len(mask_train_set)}')
print(f'validation data size : {len(mask_val_set)}')

training data size : 15120
validation data size : 3780


In [23]:
batch_size = 512

train_dataloader_mask = DataLoader(dataset = mask_train_set, batch_size=batch_size, num_workers=2)
val_dataloader_mask = DataLoader(dataset = mask_val_set, batch_size=batch_size, num_workers=2)

In [24]:
model = torchvision.models.efficientnet_v2_s(weights='IMAGENET1K_V1')
for param in model.parameters():
    param.requires_grad = False
    
# print('필요 입력 채널 개수', model.conv1.weight.shape[1])
# print('네트워크 출력 채널 개수', model.fc.weight.shape[0])
print(model)

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  

In [25]:
import math
class_num = 18

model.classifier[1] = nn.Linear(in_features=1280, out_features=class_num, bias=True)

nn.init.xavier_uniform_(model.classifier[1].weight)
stdv = 1. / math.sqrt(model.classifier[1].weight.size(1))
model.classifier[1].bias.data.uniform_(-stdv, stdv)

# print('필요 입력 채널 개수', model.conv1.weight.shape[1])
print('네트워크 출력 채널 개수', model.classifier[1].weight.shape[0])

네트워크 출력 채널 개수 18


In [26]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"using {device}")

model.to(device)

LEARNING_RATE = 0.0001
NUM_EPOCH = 50

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

using cuda:0


In [27]:
np.set_printoptions(precision=3)
n_param = 0
for p_idx, (param_name, param) in enumerate(model.named_parameters()):
    if param.requires_grad:
        param_numpy = param.detach().cpu().numpy()
        n_param += len(param_numpy.reshape(-1))
        print ("[%d] name:[%s] shape:[%s]."%(p_idx,param_name,param_numpy.shape))
        print ("    val:%s"%(param_numpy.reshape(-1)[:5]))
print ("Total number of parameters:[%s]."%(format(n_param,',d')))

[450] name:[classifier.1.weight] shape:[(18, 1280)].
    val:[ 0.067 -0.059  0.026 -0.015 -0.002]
[451] name:[classifier.1.bias] shape:[(18,)].
    val:[-0.007  0.028 -0.011  0.014  0.017]
Total number of parameters:[23,058].


In [28]:
best_val_acc = 0
best_val_loss = np.inf
patience = 10
cur_count = 0

f1 = F1Score(num_classes=class_num, average='macro').to(device)
best_f1_score = 0

for epoch in range(NUM_EPOCH):
    model.train()
    loss_value = 0
    matches = 0
    for train_batch in tqdm(train_dataloader_mask):
        inputs, labels = train_batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outs = model(inputs)
        preds = torch.argmax(outs, dim=-1)
        loss = criterion(outs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % 5 == 0:
            torch.save(model, '/opt/ml/checkpoint/efficientnet/checkpoint_ep_%d.pt'% epoch)
        
        loss_value += loss.item()
        matches += (preds == labels).sum().item()
        
        train_loss = loss_value / batch_size
        train_acc = matches / batch_size
        
        loss_value = 0
        matches = 0
    print(f"epoch[{epoch}/{NUM_EPOCH}] training loss {train_loss:.3f}, training accuracy {train_acc:.3f}")
        
    with torch.no_grad():
        model.eval()
        val_loss_items = []
        val_acc_items = []
        f1_score = 0
        for val_batch in val_dataloader_mask:
            inputs, labels = val_batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)
            
            loss_item = criterion(outs, labels).item()
            acc_item = (labels==preds).sum().item()
            val_loss_items.append(loss_item)
            val_acc_items.append(acc_item)
            f1_score += f1(outs, labels)
            
        val_loss = np.sum(val_loss_items) / len(val_dataloader_mask)
        val_acc = np.sum(val_acc_items) / len(mask_val_set)
        f1_score /= len(val_dataloader_mask)
        
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            
        if f1_score > best_f1_score:
            best_f1_score = f1_score
            cur_count = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch' : epoch,
                'val_acc' : best_val_acc,
                'val_loss' : best_val_loss,
                'f1_score' : best_f1_score}, '/opt/ml/checkpoint/efficientnet/checkpoint.tar')
#             torch.save(model, '/opt/ml/checkpoint/efficientnet/checkpoint_best.pt')
            print("Update checkpoint!!!")
        else:
            cur_count += 1
            if cur_count >= patience:
                print("Early Stopping!")
                break
            
            
        print(f"[val] acc : {val_acc:.3f}, loss : {val_loss:.3f}, f1 score: {f1_score:.3f}")
        print(f"best acc : {best_val_acc:.3f}, best loss : {best_val_loss:.3f}, best f1 : {best_f1_score:.3f}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[0/50] training loss 0.005, training accuracy 0.074
Update checkpoint!!!
[val] acc : 0.199, loss : 2.646, f1 score: 0.044
best acc : 0.199, best loss : 2.646, best f1 : 0.044


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[1/50] training loss 0.005, training accuracy 0.109
Update checkpoint!!!
[val] acc : 0.276, loss : 2.413, f1 score: 0.066
best acc : 0.276, best loss : 2.413, best f1 : 0.066


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[2/50] training loss 0.005, training accuracy 0.141
Update checkpoint!!!
[val] acc : 0.333, loss : 2.268, f1 score: 0.094
best acc : 0.333, best loss : 2.268, best f1 : 0.094


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[3/50] training loss 0.004, training accuracy 0.178
Update checkpoint!!!
[val] acc : 0.387, loss : 2.148, f1 score: 0.126
best acc : 0.387, best loss : 2.148, best f1 : 0.126


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[4/50] training loss 0.004, training accuracy 0.195
Update checkpoint!!!
[val] acc : 0.436, loss : 2.045, f1 score: 0.189
best acc : 0.436, best loss : 2.045, best f1 : 0.189


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[5/50] training loss 0.004, training accuracy 0.223
Update checkpoint!!!
[val] acc : 0.479, loss : 1.950, f1 score: 0.221
best acc : 0.479, best loss : 1.950, best f1 : 0.221


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[6/50] training loss 0.004, training accuracy 0.250
Update checkpoint!!!
[val] acc : 0.516, loss : 1.869, f1 score: 0.236
best acc : 0.516, best loss : 1.869, best f1 : 0.236


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[7/50] training loss 0.004, training accuracy 0.250
[val] acc : 0.544, loss : 1.793, f1 score: 0.234
best acc : 0.544, best loss : 1.793, best f1 : 0.236


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[8/50] training loss 0.004, training accuracy 0.266
Update checkpoint!!!
[val] acc : 0.569, loss : 1.723, f1 score: 0.254
best acc : 0.569, best loss : 1.723, best f1 : 0.254


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[9/50] training loss 0.003, training accuracy 0.281
Update checkpoint!!!
[val] acc : 0.596, loss : 1.660, f1 score: 0.271
best acc : 0.596, best loss : 1.660, best f1 : 0.271


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[10/50] training loss 0.003, training accuracy 0.289
Update checkpoint!!!
[val] acc : 0.615, loss : 1.604, f1 score: 0.277
best acc : 0.615, best loss : 1.604, best f1 : 0.277


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[11/50] training loss 0.003, training accuracy 0.293
Update checkpoint!!!
[val] acc : 0.626, loss : 1.553, f1 score: 0.282
best acc : 0.626, best loss : 1.553, best f1 : 0.282


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[12/50] training loss 0.003, training accuracy 0.312
Update checkpoint!!!
[val] acc : 0.637, loss : 1.506, f1 score: 0.286
best acc : 0.637, best loss : 1.506, best f1 : 0.286


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[13/50] training loss 0.003, training accuracy 0.297
[val] acc : 0.647, loss : 1.462, f1 score: 0.283
best acc : 0.647, best loss : 1.462, best f1 : 0.286


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[14/50] training loss 0.003, training accuracy 0.330
Update checkpoint!!!
[val] acc : 0.655, loss : 1.419, f1 score: 0.315
best acc : 0.655, best loss : 1.419, best f1 : 0.315


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[15/50] training loss 0.003, training accuracy 0.338
Update checkpoint!!!
[val] acc : 0.667, loss : 1.385, f1 score: 0.328
best acc : 0.667, best loss : 1.385, best f1 : 0.328


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[16/50] training loss 0.003, training accuracy 0.328
Update checkpoint!!!
[val] acc : 0.672, loss : 1.349, f1 score: 0.329
best acc : 0.672, best loss : 1.349, best f1 : 0.329


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[17/50] training loss 0.003, training accuracy 0.346
Update checkpoint!!!
[val] acc : 0.678, loss : 1.317, f1 score: 0.342
best acc : 0.678, best loss : 1.317, best f1 : 0.342


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[18/50] training loss 0.003, training accuracy 0.369
Update checkpoint!!!
[val] acc : 0.681, loss : 1.286, f1 score: 0.343
best acc : 0.681, best loss : 1.286, best f1 : 0.343


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[19/50] training loss 0.003, training accuracy 0.336
Update checkpoint!!!
[val] acc : 0.688, loss : 1.260, f1 score: 0.365
best acc : 0.688, best loss : 1.260, best f1 : 0.365


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[20/50] training loss 0.003, training accuracy 0.363
Update checkpoint!!!
[val] acc : 0.692, loss : 1.232, f1 score: 0.395
best acc : 0.692, best loss : 1.232, best f1 : 0.395


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[21/50] training loss 0.002, training accuracy 0.367
Update checkpoint!!!
[val] acc : 0.699, loss : 1.208, f1 score: 0.403
best acc : 0.699, best loss : 1.208, best f1 : 0.403


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[22/50] training loss 0.002, training accuracy 0.363
Update checkpoint!!!
[val] acc : 0.703, loss : 1.188, f1 score: 0.415
best acc : 0.703, best loss : 1.188, best f1 : 0.415


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[23/50] training loss 0.002, training accuracy 0.361
Update checkpoint!!!
[val] acc : 0.707, loss : 1.164, f1 score: 0.423
best acc : 0.707, best loss : 1.164, best f1 : 0.423


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[24/50] training loss 0.002, training accuracy 0.363
[val] acc : 0.711, loss : 1.144, f1 score: 0.423
best acc : 0.711, best loss : 1.144, best f1 : 0.423


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[25/50] training loss 0.002, training accuracy 0.377
Update checkpoint!!!
[val] acc : 0.713, loss : 1.125, f1 score: 0.445
best acc : 0.713, best loss : 1.125, best f1 : 0.445


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[26/50] training loss 0.002, training accuracy 0.367
Update checkpoint!!!
[val] acc : 0.717, loss : 1.105, f1 score: 0.457
best acc : 0.717, best loss : 1.105, best f1 : 0.457


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[27/50] training loss 0.002, training accuracy 0.367
Update checkpoint!!!
[val] acc : 0.720, loss : 1.088, f1 score: 0.464
best acc : 0.720, best loss : 1.088, best f1 : 0.464


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[28/50] training loss 0.002, training accuracy 0.373
[val] acc : 0.720, loss : 1.072, f1 score: 0.464
best acc : 0.720, best loss : 1.072, best f1 : 0.464


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[29/50] training loss 0.002, training accuracy 0.387
[val] acc : 0.724, loss : 1.056, f1 score: 0.464
best acc : 0.724, best loss : 1.056, best f1 : 0.464


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[30/50] training loss 0.002, training accuracy 0.375
[val] acc : 0.728, loss : 1.042, f1 score: 0.464
best acc : 0.728, best loss : 1.042, best f1 : 0.464


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[31/50] training loss 0.002, training accuracy 0.385
Update checkpoint!!!
[val] acc : 0.730, loss : 1.029, f1 score: 0.465
best acc : 0.730, best loss : 1.029, best f1 : 0.465


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[32/50] training loss 0.002, training accuracy 0.389
[val] acc : 0.731, loss : 1.017, f1 score: 0.462
best acc : 0.731, best loss : 1.017, best f1 : 0.465


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[33/50] training loss 0.002, training accuracy 0.391
Update checkpoint!!!
[val] acc : 0.733, loss : 1.003, f1 score: 0.465
best acc : 0.733, best loss : 1.003, best f1 : 0.465


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[34/50] training loss 0.002, training accuracy 0.369
Update checkpoint!!!
[val] acc : 0.736, loss : 0.989, f1 score: 0.468
best acc : 0.736, best loss : 0.989, best f1 : 0.468


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[35/50] training loss 0.002, training accuracy 0.393
Update checkpoint!!!
[val] acc : 0.738, loss : 0.978, f1 score: 0.478
best acc : 0.738, best loss : 0.978, best f1 : 0.478


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[36/50] training loss 0.002, training accuracy 0.383
[val] acc : 0.741, loss : 0.966, f1 score: 0.468
best acc : 0.741, best loss : 0.966, best f1 : 0.478


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[37/50] training loss 0.002, training accuracy 0.396
Update checkpoint!!!
[val] acc : 0.743, loss : 0.958, f1 score: 0.488
best acc : 0.743, best loss : 0.958, best f1 : 0.488


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[38/50] training loss 0.002, training accuracy 0.393
Update checkpoint!!!
[val] acc : 0.745, loss : 0.947, f1 score: 0.502
best acc : 0.745, best loss : 0.947, best f1 : 0.502


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[39/50] training loss 0.002, training accuracy 0.402
Update checkpoint!!!
[val] acc : 0.746, loss : 0.938, f1 score: 0.516
best acc : 0.746, best loss : 0.938, best f1 : 0.516


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[40/50] training loss 0.002, training accuracy 0.398
[val] acc : 0.747, loss : 0.930, f1 score: 0.506
best acc : 0.747, best loss : 0.930, best f1 : 0.516


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[41/50] training loss 0.002, training accuracy 0.398
Update checkpoint!!!
[val] acc : 0.749, loss : 0.921, f1 score: 0.522
best acc : 0.749, best loss : 0.921, best f1 : 0.522


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[42/50] training loss 0.002, training accuracy 0.391
Update checkpoint!!!
[val] acc : 0.750, loss : 0.912, f1 score: 0.536
best acc : 0.750, best loss : 0.912, best f1 : 0.536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[43/50] training loss 0.002, training accuracy 0.402
[val] acc : 0.750, loss : 0.902, f1 score: 0.536
best acc : 0.750, best loss : 0.902, best f1 : 0.536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[44/50] training loss 0.002, training accuracy 0.396
[val] acc : 0.752, loss : 0.896, f1 score: 0.526
best acc : 0.752, best loss : 0.896, best f1 : 0.536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[45/50] training loss 0.002, training accuracy 0.385
[val] acc : 0.752, loss : 0.889, f1 score: 0.536
best acc : 0.752, best loss : 0.889, best f1 : 0.536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[46/50] training loss 0.002, training accuracy 0.398
[val] acc : 0.753, loss : 0.879, f1 score: 0.536
best acc : 0.753, best loss : 0.879, best f1 : 0.536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[47/50] training loss 0.002, training accuracy 0.402
[val] acc : 0.753, loss : 0.872, f1 score: 0.536
best acc : 0.753, best loss : 0.872, best f1 : 0.536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[48/50] training loss 0.002, training accuracy 0.402
[val] acc : 0.754, loss : 0.866, f1 score: 0.536
best acc : 0.754, best loss : 0.866, best f1 : 0.536


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=30.0), HTML(value='')))


epoch[49/50] training loss 0.002, training accuracy 0.393
Update checkpoint!!!
[val] acc : 0.755, loss : 0.860, f1 score: 0.544
best acc : 0.755, best loss : 0.860, best f1 : 0.544


In [31]:
print(f'Best f1 score:{best_f1_score}')

Best f1 score:0.5443476438522339


In [32]:
print(f"best acc : {best_val_acc:.3f}, best loss : {best_val_loss:.3f}, best f1 : {best_f1_score:.3f}")

best acc : 0.755, best loss : 0.860, best f1 : 0.544


### fine tuning

In [33]:
batch_size = 32

train_dataloader_mask = DataLoader(dataset = mask_train_set, batch_size=batch_size, num_workers=2)
val_dataloader_mask = DataLoader(dataset = mask_val_set, batch_size=batch_size, num_workers=2)

In [35]:
checkpoint = torch.load('/opt/ml/checkpoint/efficientnet/checkpoint.tar')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

checkpoint_epoch = checkpoint['epoch']
best_val_acc = checkpoint['val_acc']
best_val_loss = checkpoint['val_loss']
best_f1_score = checkpoint['f1_score']
patience = 10
cur_count = 0

In [36]:
for epoch in range(NUM_EPOCH):
    model.train()
    loss_value = 0
    matches = 0
    for train_batch in tqdm(train_dataloader_mask):
        inputs, labels = train_batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outs = model(inputs)
        preds = torch.argmax(outs, dim=-1)
        loss = criterion(outs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % 5 == 0:
            torch.save(model, '/opt/ml/checkpoint/efficientnet/fineTuning/checkpoint_ep_%d.pt'% epoch)
        
        loss_value += loss.item()
        matches += (preds == labels).sum().item()
        
        train_loss = loss_value / batch_size
        train_acc = matches / batch_size
        
        loss_value = 0
        matches = 0
    print(f"epoch[{epoch}/{NUM_EPOCH}] training loss {train_loss:.3f}, training accuracy {train_acc:.3f}")
        
    with torch.no_grad():
        model.eval()
        val_loss_items = []
        val_acc_items = []
        for val_batch in val_dataloader_mask:
            inputs, labels = val_batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)
            
            loss_item = criterion(outs, labels).item()
            acc_item = (labels==preds).sum().item()
            val_loss_items.append(loss_item)
            val_acc_items.append(acc_item)
            
        val_loss = np.sum(val_loss_items) / len(val_dataloader_mask)
        val_acc = np.sum(val_acc_items) / len(mask_val_set)

        f1_score = f1(outs, labels)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            cur_count = 0
            torch.save(model, '/opt/ml/checkpoint/efficientnet/fineTuning/checkpoint_best.pt')
            print("Update checkpoint!!!")
        else: 
            if val_acc == best_val_acc and val_loss < best_val_loss:
                best_val_loss = val_loss
                cur_count = 0
                torch.save(model, '/opt/ml/checkpoint/efficientnet/fineTuning/checkpoint_best.pt')
                print("Update checkpoint!!!")
            else:
                cur_count += 1
                if cur_count >= patience:
                    print("Early Stopping!")
                    break
                
        if f1_score > best_f1_score:
            best_f1_score = f1_score
            
        print(f"[val] acc : {val_acc:.3f}, loss : {val_loss:.3f}, f1 score: {f1_score:.3f}")
        print(f"best acc : {best_val_acc:.3f}, best loss : {best_val_loss:.3f}, best f1 : {best_f1_score:.3f}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[0/50] training loss 0.047, training accuracy 0.219
[val] acc : 0.755, loss : 1.292, f1 score: 0.250
best acc : 0.755, best loss : 0.860, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[1/50] training loss 0.050, training accuracy 0.188
Update checkpoint!!!
[val] acc : 0.756, loss : 1.431, f1 score: 0.250
best acc : 0.756, best loss : 0.860, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[2/50] training loss 0.042, training accuracy 0.188
Update checkpoint!!!
[val] acc : 0.759, loss : 1.029, f1 score: 0.250
best acc : 0.759, best loss : 0.860, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[3/50] training loss 0.037, training accuracy 0.312
Update checkpoint!!!
[val] acc : 0.761, loss : 1.213, f1 score: 0.250
best acc : 0.761, best loss : 0.860, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[4/50] training loss 0.039, training accuracy 0.250
Update checkpoint!!!
[val] acc : 0.764, loss : 0.822, f1 score: 0.250
best acc : 0.764, best loss : 0.822, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[5/50] training loss 0.038, training accuracy 0.281
Update checkpoint!!!
[val] acc : 0.766, loss : 1.033, f1 score: 0.250
best acc : 0.766, best loss : 0.822, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[6/50] training loss 0.037, training accuracy 0.281
Update checkpoint!!!
[val] acc : 0.768, loss : 0.874, f1 score: 0.250
best acc : 0.768, best loss : 0.822, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[7/50] training loss 0.035, training accuracy 0.219
[val] acc : 0.767, loss : 0.797, f1 score: 0.250
best acc : 0.768, best loss : 0.797, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[8/50] training loss 0.040, training accuracy 0.281
Update checkpoint!!!
[val] acc : 0.774, loss : 0.858, f1 score: 0.250
best acc : 0.774, best loss : 0.797, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[9/50] training loss 0.038, training accuracy 0.188
[val] acc : 0.770, loss : 1.016, f1 score: 0.250
best acc : 0.774, best loss : 0.797, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[10/50] training loss 0.035, training accuracy 0.344
[val] acc : 0.771, loss : 1.233, f1 score: 0.417
best acc : 0.774, best loss : 0.797, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[11/50] training loss 0.037, training accuracy 0.281
Update checkpoint!!!
[val] acc : 0.775, loss : 0.751, f1 score: 0.250
best acc : 0.775, best loss : 0.751, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[12/50] training loss 0.027, training accuracy 0.375
[val] acc : 0.774, loss : 0.819, f1 score: 0.500
best acc : 0.775, best loss : 0.751, best f1 : 0.544


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[13/50] training loss 0.037, training accuracy 0.281
Update checkpoint!!!
[val] acc : 0.775, loss : 1.209, f1 score: 0.778
best acc : 0.775, best loss : 0.751, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[14/50] training loss 0.032, training accuracy 0.344
Update checkpoint!!!
[val] acc : 0.776, loss : 1.111, f1 score: 0.500
best acc : 0.776, best loss : 0.751, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[15/50] training loss 0.032, training accuracy 0.312
Update checkpoint!!!
[val] acc : 0.780, loss : 1.586, f1 score: 0.778
best acc : 0.780, best loss : 0.751, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[16/50] training loss 0.031, training accuracy 0.312
[val] acc : 0.779, loss : 2.105, f1 score: 0.778
best acc : 0.780, best loss : 0.751, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[17/50] training loss 0.033, training accuracy 0.344
[val] acc : 0.780, loss : 0.790, f1 score: 0.778
best acc : 0.780, best loss : 0.751, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[18/50] training loss 0.026, training accuracy 0.406
Update checkpoint!!!
[val] acc : 0.784, loss : 1.747, f1 score: 0.778
best acc : 0.784, best loss : 0.751, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[19/50] training loss 0.031, training accuracy 0.312
[val] acc : 0.783, loss : 0.792, f1 score: 0.778
best acc : 0.784, best loss : 0.751, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[20/50] training loss 0.026, training accuracy 0.406
[val] acc : 0.783, loss : 1.425, f1 score: 0.778
best acc : 0.784, best loss : 0.751, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[21/50] training loss 0.027, training accuracy 0.375
Update checkpoint!!!
[val] acc : 0.785, loss : 0.704, f1 score: 0.778
best acc : 0.785, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[22/50] training loss 0.029, training accuracy 0.344
Update checkpoint!!!
[val] acc : 0.786, loss : 0.821, f1 score: 0.778
best acc : 0.786, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[23/50] training loss 0.024, training accuracy 0.406
Update checkpoint!!!
[val] acc : 0.788, loss : 1.166, f1 score: 0.500
best acc : 0.788, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[24/50] training loss 0.031, training accuracy 0.312
Update checkpoint!!!
[val] acc : 0.788, loss : 1.439, f1 score: 0.778
best acc : 0.788, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[25/50] training loss 0.027, training accuracy 0.438
[val] acc : 0.787, loss : 1.633, f1 score: 0.778
best acc : 0.788, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[26/50] training loss 0.034, training accuracy 0.250
[val] acc : 0.787, loss : 1.088, f1 score: 0.778
best acc : 0.788, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[27/50] training loss 0.022, training accuracy 0.406
Update checkpoint!!!
[val] acc : 0.790, loss : 1.995, f1 score: 0.778
best acc : 0.790, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[28/50] training loss 0.023, training accuracy 0.406
[val] acc : 0.788, loss : 1.979, f1 score: 0.778
best acc : 0.790, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[29/50] training loss 0.029, training accuracy 0.312
[val] acc : 0.789, loss : 0.933, f1 score: 0.500
best acc : 0.790, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[30/50] training loss 0.024, training accuracy 0.375
[val] acc : 0.789, loss : 1.958, f1 score: 0.500
best acc : 0.790, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[31/50] training loss 0.027, training accuracy 0.406
[val] acc : 0.790, loss : 2.523, f1 score: 0.778
best acc : 0.790, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[32/50] training loss 0.025, training accuracy 0.406
Update checkpoint!!!
[val] acc : 0.793, loss : 1.102, f1 score: 0.778
best acc : 0.793, best loss : 0.704, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[33/50] training loss 0.026, training accuracy 0.375
[val] acc : 0.792, loss : 0.681, f1 score: 0.778
best acc : 0.793, best loss : 0.681, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[34/50] training loss 0.026, training accuracy 0.344
Update checkpoint!!!
[val] acc : 0.793, loss : 0.681, f1 score: 0.778
best acc : 0.793, best loss : 0.681, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[35/50] training loss 0.023, training accuracy 0.344
Update checkpoint!!!
[val] acc : 0.794, loss : 0.672, f1 score: 0.778
best acc : 0.794, best loss : 0.672, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[36/50] training loss 0.020, training accuracy 0.438
[val] acc : 0.793, loss : 1.480, f1 score: 0.778
best acc : 0.794, best loss : 0.672, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[37/50] training loss 0.023, training accuracy 0.344
Update checkpoint!!!
[val] acc : 0.794, loss : 1.054, f1 score: 0.778
best acc : 0.794, best loss : 0.672, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[38/50] training loss 0.024, training accuracy 0.406
[val] acc : 0.792, loss : 1.609, f1 score: 0.778
best acc : 0.794, best loss : 0.672, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[39/50] training loss 0.026, training accuracy 0.312
[val] acc : 0.792, loss : 2.284, f1 score: 0.778
best acc : 0.794, best loss : 0.672, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[40/50] training loss 0.027, training accuracy 0.375
[val] acc : 0.792, loss : 0.716, f1 score: 0.778
best acc : 0.794, best loss : 0.672, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[41/50] training loss 0.026, training accuracy 0.375
[val] acc : 0.793, loss : 0.668, f1 score: 0.778
best acc : 0.794, best loss : 0.668, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[42/50] training loss 0.026, training accuracy 0.375
Update checkpoint!!!
[val] acc : 0.796, loss : 1.540, f1 score: 0.778
best acc : 0.796, best loss : 0.668, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[43/50] training loss 0.024, training accuracy 0.344
[val] acc : 0.794, loss : 0.693, f1 score: 0.778
best acc : 0.796, best loss : 0.668, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[44/50] training loss 0.026, training accuracy 0.344
[val] acc : 0.793, loss : 0.733, f1 score: 0.500
best acc : 0.796, best loss : 0.668, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[45/50] training loss 0.031, training accuracy 0.375
[val] acc : 0.794, loss : 0.660, f1 score: 0.778
best acc : 0.796, best loss : 0.660, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[46/50] training loss 0.024, training accuracy 0.344
[val] acc : 0.795, loss : 3.457, f1 score: 0.778
best acc : 0.796, best loss : 0.660, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[47/50] training loss 0.025, training accuracy 0.344
[val] acc : 0.796, loss : 1.679, f1 score: 0.778
best acc : 0.796, best loss : 0.660, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[48/50] training loss 0.025, training accuracy 0.375
Update checkpoint!!!
[val] acc : 0.796, loss : 0.652, f1 score: 0.778
best acc : 0.796, best loss : 0.652, best f1 : 0.778


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=473.0), HTML(value='')))


epoch[49/50] training loss 0.023, training accuracy 0.406
[val] acc : 0.796, loss : 1.524, f1 score: 0.778
best acc : 0.796, best loss : 0.652, best f1 : 0.778


In [37]:
print(f"best acc : {best_val_acc:.3f}, best loss : {best_val_loss:.3f}, best f1 : {best_f1_score:.3f}")

best acc : 0.796, best loss : 0.652, best f1 : 0.778


In [80]:
# meta 데이터와 이미지 경로를 불러옵니다.
test_dir_path = '/opt/ml/input/data/eval/'
test_image_path = '/opt/ml/input/data/eval/images/'

model = torch.load('/opt/ml/checkpoint/efficientnet/fineTuning/checkpoint_best.pt')
submission = pd.read_csv(test_dir_path+'info.csv')
submission.head()

Unnamed: 0,ImageID,ans
0,cbc5c6e168e63498590db46022617123f1fe1268.jpg,0
1,0e72482bf56b3581c081f7da2a6180b8792c7089.jpg,0
2,b549040c49190cedc41327748aeb197c1670f14d.jpg,0
3,4f9cb2a045c6d5b9e50ad3459ea7b791eb6e18bc.jpg,0
4,248428d9a4a5b6229a7081c32851b90cb8d38d0c.jpg,0


In [81]:
image_paths = [os.path.join(test_image_path, img_id) for img_id in submission.ImageID]
test_image = pd.Series(image_paths)

In [82]:
class Test_Dataset(Dataset):
    def __init__(self, midcrop=True, transform=None):
        self.midcrop = midcrop
        self.data = test_image
        self.transform = transform
        
    def __len__(self):
        return len(test_image)
    
    def __getitem__(self, idx):
        img = cv2.cvtColor(cv2.imread(self.data[idx]), cv2.COLOR_BGR2RGB)
        
        if self.midcrop:
            img = img[70:420, 17:367]
            
        if self.transform:
            img = self.transform(img)
            
        return img

In [83]:
dataset = Test_Dataset(transform = transforms.Compose([
                            transforms.ToTensor()
                        ]))

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

# 모델을 정의합니다. (학습한 모델이 있다면 torch.load로 모델을 불러주세요!)
device = torch.device('cuda')
model = model.to(device)
model.eval()

# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in loader:
    with torch.no_grad():
        images = images.to(device)
        pred = model(images)
        pred = pred.argmax(dim=-1)
        all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions

# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(test_dir_path, 'submission.csv'), index=False)
print('test inference is done!')

test inference is done!
