## 0. Libarary 불러오기 및 경로설정

In [6]:
import os
import pandas as pd
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from torchvision import models
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop

In [7]:
train_dir = '/opt/ml/input/data/train'

## 2. Train Dataset 정의

In [8]:
class TrainDataset(Dataset):
    def __init__(self,img_paths,labels,trans_dict,val=False):
        self.img_paths = img_paths
        self.labels = labels
        self.trans_dict = trans_dict
        self.val = val
        
    def __getitem__(self,index):
        image = Image.open(self.img_paths[index])
        label = self.labels[index]
        if self.trans_dict:
            if self.val:
                image = self.trans_dict['val'](image)
            elif label in [0,1,3,4]:
                image = self.trans_dict['train1'](image)
            else:
                image = self.trans_dict['train2'](image)
        
        return image,label

    def __len__(self):
        return len(self.img_paths)

sample 이미지 살펴보고 transform 결정

In [9]:
#hyperparam 정의
num_epoch = 15
num_classes = 18
batch_size = 64
val_split = 0.2
random_seed= 48
shuffle_dataset = True

In [10]:
#hyperparam 정의
num_epoch = 15
num_classes = 18
batch_size = 16
val_split = 0.2
random_seed= 48
shuffle_dataset = True

In [11]:
# meta 데이터와 이미지 경로를 불러옵니다.
train_info = pd.read_csv(os.path.join(train_dir, 'train3.csv'))
image_paths = train_info['path']
labels = train_info['category']

#transform 정의
data_transform = {
    'train1': transforms.Compose([
                CenterCrop(300),  
                Resize((224, 224)),
                transforms.RandomHorizontalFlip(p=0.5),
                ToTensor(),
                Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
            ]),
    'train2': transforms.Compose([
                    CenterCrop(300),  
                    Resize((224, 224)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply(transforms=[transforms.ColorJitter(brightness=0.5)],p=0.5),
                    #transforms.RandomApply(transforms=[transforms.Pad(padding=5, fill=0, padding_mode='constant')],p=0.5),
                    transforms.RandomApply(transforms=[transforms.RandomPerspective(distortion_scale=0.2, p=0.5)],p=0.5),
                    ToTensor(),
                    Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2))
                ]),
    'val': transforms.Compose([
                Resize((224, 224), Image.BILINEAR),
                ToTensor(),
                Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
            ])
}

#train val split
dataset_size = len(labels)
indices = list(range(dataset_size))
split = int(np.floor(val_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_dataset = TrainDataset(list(image_paths[train_indices]),list(labels[train_indices]),data_transform)
val_dataset = TrainDataset(list(image_paths[val_indices]),list(labels[val_indices]),data_transform,val=True)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    drop_last=True
)
print(len(train_dataset),len(val_dataset),dataset_size)

15120 3780 18900


## 4. Model 정의

In [12]:
from custom_models import ResnetModel2 as MyModel

## 5. Train

In [14]:
from sklearn.metrics import f1_score
import numpy as np
from loss import f1_loss

# 모델을 정의
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyModel(num_classes=num_classes).to(device)


#loss function, optimizer 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

#모델 학습
from tqdm import tqdm
right_count = [0 for _ in range(num_classes)]
wrong_count = [0 for _ in range(num_classes)]

for epoch in tqdm(range(num_epoch)):
    #train
    model.train()
    cur_loss,cur_acc,cur_f1 = 0.0,0.0,0.0
    for i,data in enumerate(train_dataloader):
        images,labels = data
        images,labels = images.to(device),labels.to(device)
        
        optimizer.zero_grad()
        model_output = model(images)
        loss = criterion(model_output,labels)
        loss.backward()
        optimizer.step()
        
        cur_loss += loss.item()
        
        predict = model_output.argmax(dim=-1)
        cur_acc += torch.sum(labels==predict)
        
        f1 = f1_loss(labels,model_output)
        cur_f1 += f1
                
        for k in range(batch_size):
            if predict[k]==labels[k]: right_count[labels[k]]+=1
            else: wrong_count[labels[k]]+=1
        
        if i%100==99:
            print('[%d,%5d] train_loss: %.3f, train_acc:%.3f, train_f1:%.3f'% (epoch+1,i+1,cur_loss/100,cur_acc/1600,cur_f1/100))
            cur_loss,cur_acc,cur_f1 = 0.0,0.0,0.0
        
    #eval
    model.eval()
    valid_loss,valid_acc,valid_f1=0.0,0.0,0.0
    for i,data in enumerate(val_dataloader):
        images,labels = data
        images,labels = images.to(device),labels.to(device)
        
        model_output = model(images)
        #print(model_output,labels)
        loss = criterion(model_output,labels)
        valid_loss += loss.item()
        
        predict = model_output.argmax(dim=-1)
        valid_acc += torch.sum(labels==predict)
        
        f1 = f1_loss(labels,model_output)
        valid_f1 += f1
        
    print('epoch [%d] valid_loss: %.3f, valid_acc:%.3f, valid_f1:%.3f'% (epoch+1,valid_loss/(i+1),valid_acc/(4*(i+1)),valid_f1/(i+1)))

print('training finished')

  0%|          | 0/15 [00:00<?, ?it/s]

[1,  100] train_loss: 2.284, train_acc:0.296, train_f1:4.213
[1,  200] train_loss: 1.763, train_acc:0.467, train_f1:6.416
[1,  300] train_loss: 1.509, train_acc:0.548, train_f1:7.650
[1,  400] train_loss: 1.378, train_acc:0.585, train_f1:8.177
[1,  500] train_loss: 1.250, train_acc:0.614, train_f1:8.318
[1,  600] train_loss: 1.144, train_acc:0.649, train_f1:8.548
[1,  700] train_loss: 1.144, train_acc:0.637, train_f1:8.568
[1,  800] train_loss: 1.120, train_acc:0.646, train_f1:8.197
[1,  900] train_loss: 1.015, train_acc:0.676, train_f1:8.326


  7%|▋         | 1/15 [05:46<1:20:44, 346.02s/it]

epoch [1] valid_loss: 0.993, valid_acc:0.683, valid_f1:7.321
[2,  100] train_loss: 1.017, train_acc:0.667, train_f1:8.514
[2,  200] train_loss: 0.923, train_acc:0.693, train_f1:8.548
[2,  300] train_loss: 0.916, train_acc:0.699, train_f1:8.504
[2,  400] train_loss: 0.946, train_acc:0.692, train_f1:8.465
[2,  500] train_loss: 0.909, train_acc:0.712, train_f1:8.697
[2,  600] train_loss: 0.895, train_acc:0.711, train_f1:8.745
[2,  700] train_loss: 0.937, train_acc:0.685, train_f1:8.662
[2,  800] train_loss: 0.878, train_acc:0.713, train_f1:8.394
[2,  900] train_loss: 0.859, train_acc:0.716, train_f1:8.490


 13%|█▎        | 2/15 [10:53<1:12:27, 334.41s/it]

epoch [2] valid_loss: 0.842, valid_acc:0.716, valid_f1:7.470
[3,  100] train_loss: 0.905, train_acc:0.712, train_f1:8.675
[3,  200] train_loss: 0.824, train_acc:0.725, train_f1:8.574
[3,  300] train_loss: 0.789, train_acc:0.741, train_f1:8.541
[3,  400] train_loss: 0.884, train_acc:0.704, train_f1:8.526
[3,  500] train_loss: 0.827, train_acc:0.726, train_f1:8.747
[3,  600] train_loss: 0.827, train_acc:0.722, train_f1:8.811
[3,  700] train_loss: 0.850, train_acc:0.703, train_f1:8.703
[3,  800] train_loss: 0.867, train_acc:0.704, train_f1:8.356
[3,  900] train_loss: 0.872, train_acc:0.705, train_f1:8.513


 20%|██        | 3/15 [16:08<1:05:42, 328.56s/it]

epoch [3] valid_loss: 0.790, valid_acc:0.731, valid_f1:7.512
[4,  100] train_loss: 0.851, train_acc:0.716, train_f1:8.694
[4,  200] train_loss: 0.755, train_acc:0.753, train_f1:8.691
[4,  300] train_loss: 0.778, train_acc:0.734, train_f1:8.595
[4,  400] train_loss: 0.820, train_acc:0.719, train_f1:8.536
[4,  500] train_loss: 0.811, train_acc:0.724, train_f1:8.663
[4,  600] train_loss: 0.786, train_acc:0.731, train_f1:8.849
[4,  700] train_loss: 0.853, train_acc:0.709, train_f1:8.718
[4,  800] train_loss: 0.839, train_acc:0.704, train_f1:8.349
[4,  900] train_loss: 0.839, train_acc:0.721, train_f1:8.554


 27%|██▋       | 4/15 [21:40<1:00:26, 329.67s/it]

epoch [4] valid_loss: 0.789, valid_acc:0.730, valid_f1:7.420
[5,  100] train_loss: 0.809, train_acc:0.726, train_f1:8.724
[5,  200] train_loss: 0.760, train_acc:0.754, train_f1:8.650
[5,  300] train_loss: 0.755, train_acc:0.744, train_f1:8.616
[5,  400] train_loss: 0.794, train_acc:0.723, train_f1:8.605
[5,  500] train_loss: 0.791, train_acc:0.748, train_f1:8.781
[5,  600] train_loss: 0.790, train_acc:0.741, train_f1:8.803
[5,  700] train_loss: 0.814, train_acc:0.717, train_f1:8.809
[5,  800] train_loss: 0.817, train_acc:0.725, train_f1:8.355
[5,  900] train_loss: 0.784, train_acc:0.734, train_f1:8.560


 33%|███▎      | 5/15 [26:43<53:35, 321.56s/it]  

epoch [5] valid_loss: 0.766, valid_acc:0.730, valid_f1:7.375
[6,  100] train_loss: 0.776, train_acc:0.734, train_f1:8.773
[6,  200] train_loss: 0.711, train_acc:0.756, train_f1:8.673
[6,  300] train_loss: 0.745, train_acc:0.748, train_f1:8.658
[6,  400] train_loss: 0.761, train_acc:0.739, train_f1:8.624
[6,  500] train_loss: 0.732, train_acc:0.749, train_f1:8.743
[6,  600] train_loss: 0.717, train_acc:0.759, train_f1:8.822
[6,  700] train_loss: 0.793, train_acc:0.726, train_f1:8.784
[6,  800] train_loss: 0.795, train_acc:0.724, train_f1:8.398
[6,  900] train_loss: 0.740, train_acc:0.756, train_f1:8.649


 40%|████      | 6/15 [32:12<48:35, 323.96s/it]

epoch [6] valid_loss: 0.729, valid_acc:0.746, valid_f1:7.496
[7,  100] train_loss: 0.807, train_acc:0.723, train_f1:8.733
[7,  200] train_loss: 0.678, train_acc:0.767, train_f1:8.663
[7,  300] train_loss: 0.763, train_acc:0.741, train_f1:8.573
[7,  400] train_loss: 0.796, train_acc:0.730, train_f1:8.572
[7,  500] train_loss: 0.753, train_acc:0.743, train_f1:8.776
[7,  600] train_loss: 0.762, train_acc:0.736, train_f1:8.801
[7,  700] train_loss: 0.772, train_acc:0.741, train_f1:8.774
[7,  800] train_loss: 0.774, train_acc:0.736, train_f1:8.377
[7,  900] train_loss: 0.742, train_acc:0.746, train_f1:8.582


 47%|████▋     | 7/15 [37:46<43:35, 326.94s/it]

epoch [7] valid_loss: 0.740, valid_acc:0.746, valid_f1:7.495
[8,  100] train_loss: 0.780, train_acc:0.735, train_f1:8.719
[8,  200] train_loss: 0.719, train_acc:0.749, train_f1:8.647
[8,  300] train_loss: 0.739, train_acc:0.741, train_f1:8.585
[8,  400] train_loss: 0.774, train_acc:0.732, train_f1:8.602
[8,  500] train_loss: 0.738, train_acc:0.748, train_f1:8.762
[8,  600] train_loss: 0.756, train_acc:0.752, train_f1:8.812
[8,  700] train_loss: 0.751, train_acc:0.734, train_f1:8.769
[8,  800] train_loss: 0.779, train_acc:0.724, train_f1:8.405
[8,  900] train_loss: 0.721, train_acc:0.757, train_f1:8.615


 53%|█████▎    | 8/15 [40:32<32:29, 278.55s/it]

epoch [8] valid_loss: 0.733, valid_acc:0.749, valid_f1:7.519
[9,  100] train_loss: 0.788, train_acc:0.731, train_f1:8.733
[9,  200] train_loss: 0.680, train_acc:0.769, train_f1:8.707
[9,  300] train_loss: 0.725, train_acc:0.743, train_f1:8.577
[9,  400] train_loss: 0.748, train_acc:0.749, train_f1:8.626
[9,  500] train_loss: 0.724, train_acc:0.752, train_f1:8.791
[9,  600] train_loss: 0.737, train_acc:0.756, train_f1:8.900
[9,  700] train_loss: 0.772, train_acc:0.746, train_f1:8.830
[9,  800] train_loss: 0.774, train_acc:0.728, train_f1:8.371
[9,  900] train_loss: 0.732, train_acc:0.746, train_f1:8.582


 60%|██████    | 9/15 [42:47<23:33, 235.64s/it]

epoch [9] valid_loss: 0.729, valid_acc:0.751, valid_f1:7.573
[10,  100] train_loss: 0.753, train_acc:0.739, train_f1:8.776
[10,  200] train_loss: 0.679, train_acc:0.778, train_f1:8.706
[10,  300] train_loss: 0.700, train_acc:0.750, train_f1:8.641
[10,  400] train_loss: 0.761, train_acc:0.744, train_f1:8.677
[10,  500] train_loss: 0.729, train_acc:0.762, train_f1:8.813
[10,  600] train_loss: 0.727, train_acc:0.760, train_f1:8.796
[10,  700] train_loss: 0.736, train_acc:0.752, train_f1:8.848
[10,  800] train_loss: 0.798, train_acc:0.722, train_f1:8.316
[10,  900] train_loss: 0.721, train_acc:0.755, train_f1:8.572


 67%|██████▋   | 10/15 [44:55<16:56, 203.34s/it]

epoch [10] valid_loss: 0.726, valid_acc:0.752, valid_f1:7.550
[11,  100] train_loss: 0.748, train_acc:0.741, train_f1:8.773
[11,  200] train_loss: 0.671, train_acc:0.776, train_f1:8.692
[11,  300] train_loss: 0.714, train_acc:0.751, train_f1:8.665
[11,  400] train_loss: 0.725, train_acc:0.749, train_f1:8.587
[11,  500] train_loss: 0.730, train_acc:0.762, train_f1:8.814
[11,  600] train_loss: 0.732, train_acc:0.752, train_f1:8.790
[11,  700] train_loss: 0.755, train_acc:0.741, train_f1:8.796
[11,  800] train_loss: 0.783, train_acc:0.719, train_f1:8.398
[11,  900] train_loss: 0.739, train_acc:0.759, train_f1:8.603


 73%|███████▎  | 11/15 [47:15<12:17, 184.42s/it]

epoch [11] valid_loss: 0.708, valid_acc:0.759, valid_f1:7.565
[12,  100] train_loss: 0.790, train_acc:0.732, train_f1:8.715
[12,  200] train_loss: 0.699, train_acc:0.764, train_f1:8.710
[12,  300] train_loss: 0.703, train_acc:0.750, train_f1:8.622
[12,  400] train_loss: 0.751, train_acc:0.740, train_f1:8.626
[12,  500] train_loss: 0.733, train_acc:0.753, train_f1:8.845
[12,  600] train_loss: 0.723, train_acc:0.754, train_f1:8.811
[12,  700] train_loss: 0.732, train_acc:0.736, train_f1:8.813
[12,  800] train_loss: 0.771, train_acc:0.722, train_f1:8.406
[12,  900] train_loss: 0.712, train_acc:0.752, train_f1:8.637


 80%|████████  | 12/15 [49:31<08:29, 169.76s/it]

epoch [12] valid_loss: 0.713, valid_acc:0.753, valid_f1:7.562
[13,  100] train_loss: 0.761, train_acc:0.738, train_f1:8.797
[13,  200] train_loss: 0.696, train_acc:0.769, train_f1:8.722
[13,  300] train_loss: 0.695, train_acc:0.768, train_f1:8.625
[13,  400] train_loss: 0.746, train_acc:0.743, train_f1:8.600
[13,  500] train_loss: 0.696, train_acc:0.764, train_f1:8.814
[13,  600] train_loss: 0.716, train_acc:0.764, train_f1:8.837
[13,  700] train_loss: 0.735, train_acc:0.750, train_f1:8.855
[13,  800] train_loss: 0.762, train_acc:0.735, train_f1:8.393
[13,  900] train_loss: 0.709, train_acc:0.756, train_f1:8.605


 87%|████████▋ | 13/15 [51:51<05:21, 160.87s/it]

epoch [13] valid_loss: 0.726, valid_acc:0.751, valid_f1:7.495
[14,  100] train_loss: 0.753, train_acc:0.747, train_f1:8.794
[14,  200] train_loss: 0.684, train_acc:0.781, train_f1:8.710
[14,  300] train_loss: 0.691, train_acc:0.767, train_f1:8.577
[14,  400] train_loss: 0.721, train_acc:0.758, train_f1:8.598
[14,  500] train_loss: 0.732, train_acc:0.762, train_f1:8.774
[14,  600] train_loss: 0.729, train_acc:0.741, train_f1:8.836
[14,  700] train_loss: 0.751, train_acc:0.748, train_f1:8.788
[14,  800] train_loss: 0.796, train_acc:0.732, train_f1:8.457
[14,  900] train_loss: 0.733, train_acc:0.754, train_f1:8.642


 93%|█████████▎| 14/15 [53:54<02:29, 149.40s/it]

epoch [14] valid_loss: 0.709, valid_acc:0.757, valid_f1:7.603
[15,  100] train_loss: 0.734, train_acc:0.740, train_f1:8.797
[15,  200] train_loss: 0.683, train_acc:0.769, train_f1:8.684
[15,  300] train_loss: 0.723, train_acc:0.749, train_f1:8.638
[15,  400] train_loss: 0.722, train_acc:0.756, train_f1:8.620
[15,  500] train_loss: 0.718, train_acc:0.759, train_f1:8.840
[15,  600] train_loss: 0.716, train_acc:0.759, train_f1:8.813
[15,  700] train_loss: 0.743, train_acc:0.744, train_f1:8.823
[15,  800] train_loss: 0.749, train_acc:0.738, train_f1:8.405
[15,  900] train_loss: 0.728, train_acc:0.749, train_f1:8.606


100%|██████████| 15/15 [56:04<00:00, 224.31s/it]

epoch [15] valid_loss: 0.737, valid_acc:0.742, valid_f1:7.565
training finished





In [17]:
torch.save(model,'../models/0826_2_model.pth')

In [18]:
print(right_count)

[27442, 17138, 1944, 36831, 38144, 2512, 4090, 2554, 107, 5699, 4901, 194, 5245, 3364, 114, 7222, 7611, 201]


In [19]:
print(wrong_count)

[5633, 7627, 3126, 7314, 10366, 4088, 2450, 2396, 853, 3031, 5104, 1021, 1280, 1631, 891, 1433, 2214, 1029]
