## 0. Libarary 불러오기 및 경로설정

In [6]:
import os
import pandas as pd
from PIL import Image
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

from torchvision import models
from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize, CenterCrop

In [7]:
train_dir = '/opt/ml/input/data/train'

## 2. Train Dataset 정의

In [8]:
class TrainDataset(Dataset):
    def __init__(self,img_paths,labels,trans_dict,val=False):
        self.img_paths = img_paths
        self.labels = labels
        self.trans_dict = trans_dict
        self.val = val
        
    def __getitem__(self,index):
        image = Image.open(self.img_paths[index])
        label = self.labels[index]
        if self.trans_dict:
            if self.val:
                image = self.trans_dict['val'](image)
            elif label in [0,1,3,4]:
                image = self.trans_dict['train1'](image)
            else:
                image = self.trans_dict['train2'](image)
        
        return image,label

    def __len__(self):
        return len(self.img_paths)

sample 이미지 살펴보고 transform 결정

In [9]:
#hyperparam 정의
num_epoch = 15
num_classes = 18
batch_size = 16
val_split = 0.2
random_seed= 48
shuffle_dataset = True

In [10]:
#hyperparam 정의
num_epoch = 15
num_classes = 18
batch_size = 16
val_split = 0.2
random_seed= 48
shuffle_dataset = True

In [11]:
# meta 데이터와 이미지 경로를 불러옵니다.
train_info = pd.read_csv(os.path.join(train_dir, 'train3.csv'))
image_paths = train_info['path']
labels = train_info['category']

#transform 정의
data_transform = {
    'train1': transforms.Compose([
                CenterCrop(300),  
                Resize((224, 224)),
                transforms.RandomHorizontalFlip(p=0.5),
                ToTensor(),
                Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
            ]),
    'train2': transforms.Compose([
                    CenterCrop(300),  
                    Resize((224, 224)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply(transforms=[transforms.ColorJitter(brightness=0.5)],p=0.5),
                    #transforms.RandomApply(transforms=[transforms.Pad(padding=5, fill=0, padding_mode='constant')],p=0.5),
                    transforms.RandomApply(transforms=[transforms.RandomPerspective(distortion_scale=0.2, p=0.5)],p=0.5),
                    ToTensor(),
                    Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2))
                ]),
    'val': transforms.Compose([
                Resize((224, 224), Image.BILINEAR),
                ToTensor(),
                Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
            ])
}

#train val split
dataset_size = len(labels)
indices = list(range(dataset_size))
split = int(np.floor(val_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

train_dataset = TrainDataset(list(image_paths[train_indices]),list(labels[train_indices]),data_transform)
val_dataset = TrainDataset(list(image_paths[val_indices]),list(labels[val_indices]),data_transform,val=True)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    drop_last=True
)
print(len(train_dataset),len(val_dataset),dataset_size)

15120 3780 18900


## 4. Model 정의

In [12]:
from custom_models import ResnetModel2 as MyModel

## 5. Train

In [None]:
from sklearn.metrics import f1_score
import numpy as np
from loss import f1_loss

# 모델을 정의
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MyModel(num_classes=num_classes).to(device)


#loss function, optimizer 정의
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)

#모델 학습
from tqdm import tqdm
right_count = [0 for _ in range(num_classes)]
wrong_count = [0 for _ in range(num_classes)]

for epoch in tqdm(range(num_epoch)):
    #train
    model.train()
    cur_loss,cur_acc,cur_f1 = 0.0,0.0,0.0
    for i,data in enumerate(train_dataloader):
        images,labels = data
        images,labels = images.to(device),labels.to(device)
        
        optimizer.zero_grad()
        model_output = model(images)
        loss = criterion(model_output,labels)
        loss.backward()
        optimizer.step()
        
        cur_loss += loss.item()
        
        predict = model_output.argmax(dim=-1)
        cur_acc += torch.sum(labels==predict)
        
        f1 = f1_loss(labels,model_output)
        cur_f1 += f1
                
        for k in range(batch_size):
            if predict[k]==labels[k]: right_count[labels[k]]+=1
            else: wrong_count[labels[k]]+=1
        
        if i%100==99:
            print('[%d,%5d] train_loss: %.3f, train_acc:%.3f, train_f1:%.3f'% (epoch+1,i+1,cur_loss/100,cur_acc/1600,cur_f1/100))
            cur_loss,cur_acc,cur_f1 = 0.0,0.0,0.0
        
    #eval
    model.eval()
    valid_loss,valid_acc,valid_f1=0.0,0.0,0.0
    for i,data in enumerate(val_dataloader):
        images,labels = data
        images,labels = images.to(device),labels.to(device)
        
        model_output = model(images)
        #print(model_output,labels)
        loss = criterion(model_output,labels)
        valid_loss += loss.item()
        
        predict = model_output.argmax(dim=-1)
        valid_acc += torch.sum(labels==predict)
        
        f1 = f1_loss(labels,model_output)
        valid_f1 += f1
        
    print('epoch [%d] valid_loss: %.3f, valid_acc:%.3f, valid_f1:%.3f'% (epoch+1,valid_loss/(i+1),valid_acc/(4*(i+1)),valid_f1/(i+1)))

print('training finished')

  0%|          | 0/15 [00:00<?, ?it/s]

[1,  100] train_loss: 2.284, train_acc:0.296, train_f1:4.213
[1,  200] train_loss: 1.763, train_acc:0.467, train_f1:6.416
[1,  300] train_loss: 1.509, train_acc:0.548, train_f1:7.650
[1,  400] train_loss: 1.378, train_acc:0.585, train_f1:8.177
[1,  500] train_loss: 1.250, train_acc:0.614, train_f1:8.318
[1,  600] train_loss: 1.144, train_acc:0.649, train_f1:8.548
[1,  700] train_loss: 1.144, train_acc:0.637, train_f1:8.568
[1,  800] train_loss: 1.120, train_acc:0.646, train_f1:8.197


In [None]:
torch.save(model,'models/0826_2_model.pth')

In [None]:
print(right_count)

In [None]:
print(wrong_count)