In [1]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [2]:
import GPUtil
GPUtil.showUtilization()

| ID | GPU | MEM |
------------------
|  0 |  0% |  0% |


In [3]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, sampler
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random
import os
from torchmetrics import F1Score

In [4]:
random_seed = 12
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)

In [5]:
train_dir_path = '/opt/ml/input/data/train/'
train_image_path = '/opt/ml/input/data/train/images/'

dt_train = pd.read_csv(train_dir_path+'train.csv')
dt_train

Unnamed: 0,id,gender,race,age,path
0,1,female,Asian,45,000001_female_Asian_45
1,2,female,Asian,52,000002_female_Asian_52
2,4,male,Asian,54,000004_male_Asian_54
3,5,female,Asian,58,000005_female_Asian_58
4,6,female,Asian,59,000006_female_Asian_59
...,...,...,...,...,...
2695,6954,male,Asian,19,006954_male_Asian_19
2696,6955,male,Asian,19,006955_male_Asian_19
2697,6956,male,Asian,19,006956_male_Asian_19
2698,6957,male,Asian,20,006957_male_Asian_20


In [6]:
def get_age_range(age):
    if age < 30:
        return 0
    elif 30 <= age < 60:
        return 1
    else:
        return 2

In [7]:
dt_train['age_range'] = dt_train['age'].apply(lambda x : get_age_range(x))

In [8]:
over_sixty = dt_train.loc[dt_train['age_range']==2,:]
path1 = []
path2 = []
for i in range(len(over_sixty)):
    path1.append(over_sixty.iloc[i]['path'] + '_1')
    path2.append(over_sixty.iloc[i]['path'] + '_2')
over_sixty_1 = pd.DataFrame({'id':over_sixty['id'], 'gender':over_sixty['gender'], 'race':over_sixty['race'], 'age':over_sixty['age'], 'path':path1, 'age_range':over_sixty['age_range']})
over_sixty_2 = pd.DataFrame({'id':over_sixty['id'], 'gender':over_sixty['gender'], 'race':over_sixty['race'], 'age':over_sixty['age'], 'path':path2, 'age_range':over_sixty['age_range']})
dt_train = dt_train.append(over_sixty_1, ignore_index=True)
dt_train = dt_train.append(over_sixty_2, ignore_index=True)
dt_train

Unnamed: 0,id,gender,race,age,path,age_range
0,1,female,Asian,45,000001_female_Asian_45,1
1,2,female,Asian,52,000002_female_Asian_52,1
2,4,male,Asian,54,000004_male_Asian_54,1
3,5,female,Asian,58,000005_female_Asian_58,1
4,6,female,Asian,59,000006_female_Asian_59,1
...,...,...,...,...,...,...
3113,5453,female,Asian,60,005453_female_Asian_60_2,2
3114,5459,male,Asian,60,005459_male_Asian_60_2,2
3115,5461,female,Asian,60,005461_female_Asian_60_2,2
3116,5504,female,Asian,60,005504_female_Asian_60_2,2


In [9]:
train_idx, valid_idx = train_test_split(np.arange(len(dt_train)),
                                       test_size=0.2,
                                       shuffle=True,
                                       stratify=dt_train['age_range'])

In [10]:
train_image = []
train_label = []

for idx in train_idx:
    path = dt_train.iloc[idx]['path']
    for file_name in [i for i in os.listdir(train_image_path+path) if i[0] != '.']:
        train_image.append(train_image_path+path+'/'+file_name)
        train_label.append((path.split('_')[1], path.split('_')[3], file_name.split('.')[0]))            

In [11]:
valid_image = []
valid_label = []

for idx in valid_idx:
    path = dt_train.iloc[idx]['path']
    for file_name in [i for i in os.listdir(train_image_path+path) if i[0] != '.']:
        valid_image.append(train_image_path+path+'/'+file_name)
        valid_label.append((path.split('_')[1], path.split('_')[3], file_name.split('.')[0]))                

In [12]:
def onehot_enc(x):
    def gender(i):
        if i == 'male':
            return 0
        elif i == 'female':
            return 3
    def age(j):
        j = int(j)
        if j < 30:
            return 0
        elif j >= 30 and j < 60:
            return 1
        elif j >= 60:
            return 2
    def mask(k):
        if k == 'normal':
            return 12
        elif 'incorrect' in k:
            return 6
        else:
            return 0
    return gender(x[0]) + age(x[1]) + mask(x[2])

In [13]:
train_data = pd.Series(train_image)
train_label = pd.Series(train_label)

valid_data = pd.Series(valid_image)
valid_label = pd.Series(valid_label)

In [14]:
class Dataset_Mask(Dataset):
    def __init__(self, data, label, encoding=True, transform=None):
        self.encoding = encoding
        self.data = data
        self.label = label
        self.transform = transform
        
        if encoding:
            self.label = self.label.apply(onehot_enc)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        X = Image.open(self.data[idx])
        y = self.label[idx]
        
        if self.transform:
            return self.transform(X), y
        return X, y

In [15]:
mask_train_set = Dataset_Mask(data=train_data, label=train_label, transform = transforms.Compose([
                                transforms.CenterCrop(350),
                                transforms.Resize(224, Image.BILINEAR),
                                transforms.RandomHorizontalFlip(p=0.5), 
                                transforms.RandomRotation(5),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                            ]))
mask_val_set = Dataset_Mask(data=valid_data, label=valid_label, transform = transforms.Compose([
                                transforms.CenterCrop(350),
                                transforms.Resize(224, Image.BILINEAR),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                            ]))



In [16]:
def make_class_weights(labels, class_num):
    print("Labels shape:\n", labels.shape)
    print("Given labels:\n", labels)

    labels = labels.apply(lambda x : x % class_num)

    labels = np.array(labels)
    class_weights = np.zeros_like(labels) 
    
    _, counts = np.unique(labels, return_counts=True)
    # 각 class가 몇 번 등장하는지 count

    print("Labels:\n", labels)
    print("Label count:\n", counts)

    for cls in range(class_num):
        class_weights = np.where(labels == cls, 1/counts[cls], class_weights)
        # label이 class에 해당하면 count의 역수 적용
    return class_weights

class_num = 18
age_class_num = 3

In [17]:
# class_weights = make_class_weights(mask_train_set.label, class_num)
class_weights = make_class_weights(mask_train_set.label, age_class_num)

print("Class weights: ", class_weights)
print("Length: ", len(class_weights))

Labels shape:
 (17458,)
Given labels:
 0         5
1         5
2         5
3        11
4         5
         ..
17453     3
17454     9
17455     3
17456     3
17457    15
Length: 17458, dtype: int64
Labels:
 [2 2 2 ... 0 0 0]
Label count:
 [7175 6769 3514]
Class weights:  [0.00028458 0.00028458 0.00028458 ... 0.00013937 0.00013937 0.00013937]
Length:  17458


In [18]:
sampler = sampler.WeightedRandomSampler(weights=class_weights, num_samples=len(class_weights))

In [19]:
print(f'training data size : {len(mask_train_set)}')
print(f'validation data size : {len(mask_val_set)}')

training data size : 17458
validation data size : 4368


In [20]:
batch_size = 64

train_dataloader_mask = DataLoader(dataset = mask_train_set, batch_size=batch_size, sampler=sampler, num_workers=2)
val_dataloader_mask = DataLoader(dataset = mask_val_set, batch_size=batch_size, num_workers=2)

In [21]:
model = torchvision.models.resnet50(pretrained=True)
print('필요 입력 채널 개수', model.conv1.weight.shape[1])
print('네트워크 출력 채널 개수', model.fc.weight.shape[0])
print(model)



필요 입력 채널 개수 3
네트워크 출력 채널 개수 1000
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256

In [22]:
import math
class_num = 18
model.fc = nn.Sequential(
                nn.Linear(in_features=2048, out_features=1000, bias=True),
                nn.Linear(in_features=1000, out_features=class_num, bias=True)
)
nn.init.xavier_uniform_(model.fc[0].weight)
stdv = 1. / math.sqrt(model.fc[0].weight.size(1))
model.fc[0].bias.data.uniform_(-stdv, stdv)
nn.init.xavier_uniform_(model.fc[1].weight)
stdv = 1. / math.sqrt(model.fc[1].weight.size(1))
model.fc[1].bias.data.uniform_(-stdv, stdv)

print('필요 입력 채널 개수', model.conv1.weight.shape[1])
print('네트워크 출력 채널 개수', model.fc[1].weight.shape[0])

필요 입력 채널 개수 3
네트워크 출력 채널 개수 18


In [24]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"using {device}")

model.to(device)

LEARNING_RATE = 0.0001
NUM_EPOCH = 100

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

using cuda:0


In [26]:
checkpoint = torch.load('/opt/ml/checkpoint/res50_aug_2/best_checkpoint.tar')
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [27]:
best_val_acc = 0.94048
best_val_loss = 0.24948
patience = 10
cur_count = 0
accumulation_steps = 4

f1 = F1Score(num_classes=class_num, average='macro').to(device)
best_f1_score = 0

for epoch in range(NUM_EPOCH):
    model.train()
    loss_value = 0
    matches = 0
    for idx, train_batch in enumerate(tqdm(train_dataloader_mask)):
        inputs, labels = train_batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        outs = model(inputs)
        preds = torch.argmax(outs, dim=-1)
        loss = criterion(outs, labels)
        
        loss.backward()
        
        if (idx+1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            
        if epoch % 10 == 2:
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch' : epoch
                }, f'/opt/ml/checkpoint/res50_aug_2/checkpoint_ep_{epoch+8}.tar')
        
        loss_value += loss.item()
        matches += (preds == labels).sum().item()
        
    train_loss = loss_value / len(train_dataloader_mask)
    train_acc = matches / len(mask_train_set)
        
    print(f"epoch[{epoch+8}/{NUM_EPOCH}] training loss {train_loss:.5f}, training accuracy {train_acc:.5f}")
        
    with torch.no_grad():
        model.eval()
        val_loss_items = []
        val_acc_items = []
        f1_score = 0
        for val_batch in tqdm(val_dataloader_mask):
            inputs, labels = val_batch
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outs = model(inputs)
            preds = torch.argmax(outs, dim=-1)
            
            loss_item = criterion(outs, labels).item()
            acc_item = (labels==preds).sum().item()
            val_loss_items.append(loss_item)
            val_acc_items.append(acc_item)
            f1_score += f1(outs, labels)
            
        val_loss = np.sum(val_loss_items) / len(val_dataloader_mask)
        val_acc = np.sum(val_acc_items) / len(mask_val_set)

        f1_score /= len(val_dataloader_mask)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epoch' : epoch
                }, f'/opt/ml/checkpoint/res50_aug_2/best_checkpoint.tar')
            print('checkpoint saved!')
            cur_count = 0
        else:
            cur_count += 1
            if cur_count >= patience:
                print("Early Stopping!")
                break
            
        if f1_score > best_f1_score:
            best_f1_score = f1_score
            
            
        print(f"[val] acc : {val_acc:.5f}, loss : {val_loss:.5f}, f1 score: {f1_score:.5f}")
        print(f"best acc : {best_val_acc:.5f}, best loss : {best_val_loss:.5f}, best f1 : {best_f1_score:.5f}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[8/100] training loss 0.01877, training accuracy 0.99347


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.93567, loss : 0.31947, f1 score: 0.89328
best acc : 0.94048, best loss : 0.24948, best f1 : 0.89328


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[9/100] training loss 0.01789, training accuracy 0.99341


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.92193, loss : 0.41351, f1 score: 0.85596
best acc : 0.94048, best loss : 0.24948, best f1 : 0.89328


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[10/100] training loss 0.02329, training accuracy 0.99301


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.91598, loss : 0.39205, f1 score: 0.86421
best acc : 0.94048, best loss : 0.24948, best f1 : 0.89328


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[11/100] training loss 0.01394, training accuracy 0.99547


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.92766, loss : 0.36562, f1 score: 0.86638
best acc : 0.94048, best loss : 0.24948, best f1 : 0.89328


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[12/100] training loss 0.01090, training accuracy 0.99622


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


checkpoint saved!
[val] acc : 0.94254, loss : 0.30098, f1 score: 0.89594
best acc : 0.94254, best loss : 0.24948, best f1 : 0.89594


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[13/100] training loss 0.00911, training accuracy 0.99719


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


checkpoint saved!
[val] acc : 0.94322, loss : 0.30397, f1 score: 0.89807
best acc : 0.94322, best loss : 0.24948, best f1 : 0.89807


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[14/100] training loss 0.00738, training accuracy 0.99765


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.94025, loss : 0.32028, f1 score: 0.89492
best acc : 0.94322, best loss : 0.24948, best f1 : 0.89807


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[15/100] training loss 0.00944, training accuracy 0.99702


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.93384, loss : 0.36356, f1 score: 0.88607
best acc : 0.94322, best loss : 0.24948, best f1 : 0.89807


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[16/100] training loss 0.01429, training accuracy 0.99570


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.93132, loss : 0.38816, f1 score: 0.88150
best acc : 0.94322, best loss : 0.24948, best f1 : 0.89807


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[17/100] training loss 0.01766, training accuracy 0.99393


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.92537, loss : 0.44915, f1 score: 0.87258
best acc : 0.94322, best loss : 0.24948, best f1 : 0.89807


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[18/100] training loss 0.01586, training accuracy 0.99456


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.93544, loss : 0.32051, f1 score: 0.88304
best acc : 0.94322, best loss : 0.24948, best f1 : 0.89807


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[19/100] training loss 0.00879, training accuracy 0.99719


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


checkpoint saved!
[val] acc : 0.95238, loss : 0.25093, f1 score: 0.91433
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[20/100] training loss 0.01079, training accuracy 0.99696


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.93567, loss : 0.33890, f1 score: 0.89299
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[21/100] training loss 0.00985, training accuracy 0.99696


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.93704, loss : 0.35043, f1 score: 0.89869
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[22/100] training loss 0.01071, training accuracy 0.99754


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.93956, loss : 0.37026, f1 score: 0.89957
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[23/100] training loss 0.01023, training accuracy 0.99702


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.93636, loss : 0.33864, f1 score: 0.89121
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[24/100] training loss 0.01043, training accuracy 0.99645


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.94048, loss : 0.37360, f1 score: 0.89606
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[25/100] training loss 0.00356, training accuracy 0.99908


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.94368, loss : 0.35562, f1 score: 0.90490
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[26/100] training loss 0.00574, training accuracy 0.99800


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.93796, loss : 0.34876, f1 score: 0.89417
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[27/100] training loss 0.00785, training accuracy 0.99805


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.94093, loss : 0.37973, f1 score: 0.90148
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[28/100] training loss 0.00914, training accuracy 0.99719


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


[val] acc : 0.94414, loss : 0.35830, f1 score: 0.91286
best acc : 0.95238, best loss : 0.24948, best f1 : 0.91433


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=273.0), HTML(value='')))


epoch[29/100] training loss 0.01351, training accuracy 0.99519


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=69.0), HTML(value='')))


Early Stopping!


In [28]:
# meta 데이터와 이미지 경로를 불러옵니다.
test_dir_path = '/opt/ml/input/data/eval/'
test_image_path = '/opt/ml/input/data/eval/images/'

checkpoint = torch.load('/opt/ml/checkpoint/res50_aug_2/best_checkpoint.tar')
model.load_state_dict(checkpoint['model_state_dict'])

submission = pd.read_csv(test_dir_path+'info.csv')
submission.head()

Unnamed: 0,ImageID,ans
0,cbc5c6e168e63498590db46022617123f1fe1268.jpg,0
1,0e72482bf56b3581c081f7da2a6180b8792c7089.jpg,0
2,b549040c49190cedc41327748aeb197c1670f14d.jpg,0
3,4f9cb2a045c6d5b9e50ad3459ea7b791eb6e18bc.jpg,0
4,248428d9a4a5b6229a7081c32851b90cb8d38d0c.jpg,0


In [29]:
image_paths = [os.path.join(test_image_path, img_id) for img_id in submission.ImageID]
test_image = pd.Series(image_paths)

In [30]:
class Test_Dataset(Dataset):
    def __init__(self, transform=None):
        self.data = test_image
        self.transform = transform
        
    def __len__(self):
        return len(test_image)
    
    def __getitem__(self, idx):
        img = Image.open(self.data[idx])
            
        if self.transform:
            img = self.transform(img)
            
        return img

In [31]:
dataset = Test_Dataset(transform = transforms.Compose([
                            transforms.CenterCrop(350),
                            transforms.Resize(224, Image.BILINEAR),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
                        ]))

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)

# 모델을 정의합니다. (학습한 모델이 있다면 torch.load로 모델을 불러주세요!)
device = torch.device('cuda')
model = model.to(device)
model.eval()

# 모델이 테스트 데이터셋을 예측하고 결과를 저장합니다.
all_predictions = []
for images in loader:
    with torch.no_grad():
        images = images.to(device)
        pred = model(images)
        pred = pred.argmax(dim=-1)
        all_predictions.extend(pred.cpu().numpy())
submission['ans'] = all_predictions

# 제출할 파일을 저장합니다.
submission.to_csv(os.path.join(test_dir_path, 'submission_res50_aug_2.csv'), index=False)
print('test inference is done!')



test inference is done!
