0826_(1) 에서 valid patience를 추가하고 train:val 을 85:15 로 함

## 0. Libarary 불러오기 및 경로설정

In [1]:
import os
import pandas as pd
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from torchvision import models
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop

In [2]:
train_dir = '/opt/ml/input/data/train'

## 2. Train Dataset 정의

In [3]:
class TrainDataset(Dataset):
    def __init__(self,img_paths,labels,trans_dict,val=False):
        self.img_paths = img_paths
        self.labels = labels
        self.trans_dict = trans_dict
        self.val = val
        
    def __getitem__(self,index):
        image = Image.open(self.img_paths[index])
        label = self.labels[index]
        if self.trans_dict:
            if self.val:
                image = self.trans_dict['val'](image)
            elif label in [0,1,3,4]:
                image = self.trans_dict['train1'](image)
            else:
                image = self.trans_dict['train2'](image)
        
        return image,label

    def __len__(self):
        return len(self.img_paths)

sample 이미지 살펴보고 transform 결정

In [12]:
#hyperparam 정의
num_epoch = 20
num_classes = 18
batch_size = 64
val_split = 0.2
random_seed= 48
shuffle_dataset = True

In [7]:
# meta 데이터와 이미지 경로를 불러옵니다.
train_info = pd.read_csv(os.path.join(train_dir, 'train_15.csv'))
train_imgpaths = train_info['path']
train_labels = train_info['category']

val_info = pd.read_csv(os.path.join(train_dir, 'valid_15.csv'))
val_imgpaths = val_info['path']
val_labels = val_info['category']

#transform 정의
data_transform = {
    'train1': transforms.Compose([
                CenterCrop(300),  
                Resize((224, 224)),
                transforms.RandomHorizontalFlip(p=0.5),
                ToTensor(),
                Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
            ]),
    'train2': transforms.Compose([
                    CenterCrop(300),  
                    Resize((224, 224)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply(transforms=[transforms.ColorJitter(brightness=0.5)],p=0.5),
                    #transforms.RandomApply(transforms=[transforms.Pad(padding=5, fill=0, padding_mode='constant')],p=0.5),
                    transforms.RandomApply(transforms=[transforms.RandomPerspective(distortion_scale=0.2, p=0.5)],p=0.5),
                    ToTensor(),
                    Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2))
                ]),
    'val': transforms.Compose([
                Resize((224, 224), Image.BILINEAR),
                ToTensor(),
                Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
            ])
}

#train val split
'''
dataset_size = len(labels)
indices = list(range(dataset_size))
split = int(np.floor(val_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
'''
train_dataset = TrainDataset(train_imgpaths,train_labels,data_transform)
val_dataset = TrainDataset(val_imgpaths,val_labels,data_transform,val=True)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    drop_last=True
)
print(len(train_dataset),len(val_dataset))

16065 2835


## 4. Model 정의

In [8]:
from custom_models import ResnetModel as MyModel

## 5. Train

In [17]:
from sklearn.metrics import f1_score
import numpy as np
from loss import f1_loss

# 모델을 정의
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyModel(num_classes=num_classes).to(device)

#loss function, optimizer 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

#모델 학습
from tqdm import tqdm
right_count = [0 for _ in range(num_classes)]
wrong_count = [0 for _ in range(num_classes)]

temp_valid_f1=0.0; patience=0

for epoch in tqdm(range(num_epoch)):
    #train
    model.train()
    cur_loss,cur_acc,cur_f1 = 0.0,0.0,0.0
    for i,data in enumerate(train_dataloader):
        images,labels = data
        images,labels = images.to(device),labels.to(device)
        
        optimizer.zero_grad()
        model_output = model(images)
        loss = criterion(model_output,labels)
        loss.backward()
        optimizer.step()
        
        cur_loss += loss.item()
        
        predict = model_output.argmax(dim=-1)
        cur_acc += torch.sum(labels==predict)
        
        f1 = f1_loss(labels,model_output)
        cur_f1 += f1
                
        #for k in range(batch_size):
        #    if predict[k]==labels[k]: right_count[labels[k]]+=1
        #    else: wrong_count[labels[k]]+=1
        
        if i%100==99:
            print('[%d,%5d] train_loss: %.3f, train_acc:%.3f, train_f1:%.3f'% (epoch+1,i+1,cur_loss/100,cur_acc/(batch_size*100),cur_f1/100))
            cur_loss,cur_acc,cur_f1 = 0.0,0.0,0.0
        
    #eval
    model.eval()
    valid_loss,valid_acc,valid_f1=0.0,0.0,0.0
    for i,data in enumerate(val_dataloader):
        images,labels = data
        images,labels = images.to(device),labels.to(device)
        
        model_output = model(images)
        #print(model_output,labels)
        loss = criterion(model_output,labels)
        valid_loss += loss.item()
        
        predict = model_output.argmax(dim=-1)
        valid_acc += torch.sum(labels==predict)
        
        f1 = f1_loss(labels,model_output)
        valid_f1 += f1
    print('epoch [%d] valid_loss: %.3f, valid_acc:%.3f, valid_f1:%.3f'% (epoch+1,valid_loss/(i+1),valid_acc/(16*(i+1)),valid_f1/(i+1)))
    patience = patience+1 if valid_f1 < temp_valid_f1 else 0
    if patience>=3: 
        torch.save(model,'..models/0826_4_model.pth')
        break
    temp_valid_f1 = valid_f1
    
print('training finished')

  0%|          | 0/20 [00:00<?, ?it/s]

[1,  100] train_loss: 2.057, train_acc:0.353, train_f1:5.807
[1,  200] train_loss: 1.496, train_acc:0.529, train_f1:8.129


  5%|▌         | 1/20 [02:12<41:50, 132.11s/it]

epoch [1] valid_loss: 1.350, valid_acc:0.147, valid_f1:7.132
[2,  100] train_loss: 1.213, train_acc:0.627, train_f1:8.678
[2,  200] train_loss: 1.132, train_acc:0.647, train_f1:8.774


 10%|█         | 2/20 [04:27<39:55, 133.06s/it]

epoch [2] valid_loss: 1.147, valid_acc:0.160, valid_f1:7.367
[3,  100] train_loss: 1.055, train_acc:0.665, train_f1:8.823
[3,  200] train_loss: 1.012, train_acc:0.677, train_f1:8.900


 15%|█▌        | 3/20 [06:40<37:44, 133.20s/it]

epoch [3] valid_loss: 1.051, valid_acc:0.166, valid_f1:7.443
[4,  100] train_loss: 0.948, train_acc:0.700, train_f1:8.896
[4,  200] train_loss: 0.940, train_acc:0.702, train_f1:8.946


 20%|██        | 4/20 [08:50<35:15, 132.25s/it]

epoch [4] valid_loss: 0.999, valid_acc:0.171, valid_f1:7.498
[5,  100] train_loss: 0.892, train_acc:0.711, train_f1:8.925
[5,  200] train_loss: 0.882, train_acc:0.713, train_f1:8.967


 25%|██▌       | 5/20 [11:01<32:56, 131.76s/it]

epoch [5] valid_loss: 0.965, valid_acc:0.172, valid_f1:7.515
[6,  100] train_loss: 0.861, train_acc:0.720, train_f1:8.950
[6,  200] train_loss: 0.865, train_acc:0.719, train_f1:8.989


 30%|███       | 6/20 [12:59<29:46, 127.63s/it]

epoch [6] valid_loss: 0.943, valid_acc:0.173, valid_f1:7.509
[7,  100] train_loss: 0.831, train_acc:0.726, train_f1:8.958
[7,  200] train_loss: 0.834, train_acc:0.729, train_f1:8.998


 35%|███▌      | 7/20 [14:53<26:44, 123.46s/it]

epoch [7] valid_loss: 0.929, valid_acc:0.173, valid_f1:7.566
[8,  100] train_loss: 0.807, train_acc:0.729, train_f1:8.979
[8,  200] train_loss: 0.808, train_acc:0.742, train_f1:9.030


 40%|████      | 8/20 [17:02<25:03, 125.27s/it]

epoch [8] valid_loss: 0.910, valid_acc:0.174, valid_f1:7.559
[9,  100] train_loss: 0.789, train_acc:0.735, train_f1:8.999
[9,  200] train_loss: 0.798, train_acc:0.741, train_f1:9.029


 45%|████▌     | 9/20 [19:07<22:55, 125.07s/it]

epoch [9] valid_loss: 0.906, valid_acc:0.174, valid_f1:7.561
[10,  100] train_loss: 0.779, train_acc:0.730, train_f1:8.993
[10,  200] train_loss: 0.776, train_acc:0.745, train_f1:9.051


 50%|█████     | 10/20 [21:20<21:15, 127.51s/it]

epoch [10] valid_loss: 0.895, valid_acc:0.174, valid_f1:7.560
[11,  100] train_loss: 0.765, train_acc:0.745, train_f1:9.018
[11,  200] train_loss: 0.765, train_acc:0.744, train_f1:9.041


 55%|█████▌    | 11/20 [23:33<19:22, 129.13s/it]

epoch [11] valid_loss: 0.885, valid_acc:0.175, valid_f1:7.576
[12,  100] train_loss: 0.749, train_acc:0.745, train_f1:9.013
[12,  200] train_loss: 0.762, train_acc:0.749, train_f1:9.047


 60%|██████    | 12/20 [25:47<17:24, 130.57s/it]

epoch [12] valid_loss: 0.886, valid_acc:0.176, valid_f1:7.602
[13,  100] train_loss: 0.747, train_acc:0.743, train_f1:9.019
[13,  200] train_loss: 0.754, train_acc:0.746, train_f1:9.034


 65%|██████▌   | 13/20 [27:52<15:01, 128.86s/it]

epoch [13] valid_loss: 0.879, valid_acc:0.175, valid_f1:7.578
[14,  100] train_loss: 0.739, train_acc:0.749, train_f1:9.025
[14,  200] train_loss: 0.742, train_acc:0.751, train_f1:9.054


 70%|███████   | 14/20 [29:56<12:45, 127.51s/it]

epoch [14] valid_loss: 0.871, valid_acc:0.176, valid_f1:7.621
[15,  100] train_loss: 0.722, train_acc:0.752, train_f1:9.032
[15,  200] train_loss: 0.736, train_acc:0.751, train_f1:9.043


 75%|███████▌  | 15/20 [32:05<10:40, 128.03s/it]

epoch [15] valid_loss: 0.864, valid_acc:0.176, valid_f1:7.586
[16,  100] train_loss: 0.725, train_acc:0.750, train_f1:9.047
[16,  200] train_loss: 0.730, train_acc:0.758, train_f1:9.070


 80%|████████  | 16/20 [34:12<08:30, 127.70s/it]

epoch [16] valid_loss: 0.864, valid_acc:0.176, valid_f1:7.622
[17,  100] train_loss: 0.719, train_acc:0.755, train_f1:9.044
[17,  200] train_loss: 0.732, train_acc:0.750, train_f1:9.050


 85%|████████▌ | 17/20 [36:25<06:27, 129.13s/it]

epoch [17] valid_loss: 0.858, valid_acc:0.177, valid_f1:7.605
[18,  100] train_loss: 0.708, train_acc:0.755, train_f1:9.044
[18,  200] train_loss: 0.726, train_acc:0.755, train_f1:9.081


 90%|█████████ | 18/20 [38:38<04:20, 130.48s/it]

epoch [18] valid_loss: 0.854, valid_acc:0.178, valid_f1:7.633
[19,  100] train_loss: 0.706, train_acc:0.751, train_f1:9.053
[19,  200] train_loss: 0.732, train_acc:0.749, train_f1:9.051


 95%|█████████▌| 19/20 [40:55<02:12, 132.27s/it]

epoch [19] valid_loss: 0.855, valid_acc:0.176, valid_f1:7.626
[20,  100] train_loss: 0.701, train_acc:0.757, train_f1:9.032
[20,  200] train_loss: 0.727, train_acc:0.753, train_f1:9.080


100%|██████████| 20/20 [42:56<00:00, 128.83s/it]

epoch [20] valid_loss: 0.854, valid_acc:0.177, valid_f1:7.614
training finished





In [20]:
torch.save(model,'../models/0826_4_model.pth')

In [18]:
print(right_count)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [19]:
print(wrong_count)

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
