#### 1. 라이브러리 호출

In [7]:
import os, glob, random, cv2
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import albumentations as A
import segmentation_models_pytorch as smp
import model.metric as module_metric

from data_loader.data_loaders import make_dataloder
from albumentations.pytorch import transforms
from model.loss import *
from trainer.trainer import *
from pathlib import Path
from model.VGGNet_BN import VGGNet_BN

#### 2. 시드고정

In [8]:
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)

#### 3. 하이퍼 파라미터 설정

In [9]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

lr = 1e-3
batch_size = 32
num_epoch = 200

mean=(0.485, 0.456, 0.406)
std=(0.229, 0.224, 0.225)
image_size = 224

In [10]:
transform_train = A.Compose([
    A.Resize(image_size, image_size),
    A.HorizontalFlip(),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
    A.Rotate((-10, 10), p=0.5, border_mode=cv2.BORDER_REFLECT,),
    A.Normalize(mean, std),
    transforms.ToTensorV2(transpose_mask=True)
])

transform_val = A.Compose([
    A.Resize(image_size, image_size),
    A.Normalize(mean, std),
    transforms.ToTensorV2(transpose_mask=True)
])

In [11]:
train_dataloader = make_dataloder(transform=transform_train, train_=True, batch_size=batch_size)
val_dataloader = make_dataloder(transform=transform_val, train_=False, batch_size=batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
model = VGGNet_BN()
model.to(device=device)

VGGNet_BN(
  (conv_block_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv_block_2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv_block_3): Sequential(

In [16]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), momentum=0.9, lr=lr)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer, T_0=20, T_mult=2, eta_min=1e-5)
metrics = [getattr(module_metric, met) for met in ['Accuracy']]

#### 4. wandb config 생성

In [17]:
train_config = {}
train_config['Batch size'] = batch_size
train_config['Learning Rate'] = lr
train_config['Epochs'] = num_epoch
train_config['Image size'] = image_size

train_config['Loss fn'] = criterion.__class__.__name__
train_config['Optimizer'] = optimizer.__class__.__name__
train_config['LR Scheduler'] = scheduler.__class__.__name__
train_config['Metric'] = {str(idx+1) : metric for idx, metric in enumerate([metrics[i].__name__ for i in range(len(metrics))])}

In [23]:
save_dir = f"./saved/{model.__class__.__name__}/"
trainer = Trainer(model, criterion, metrics, optimizer, device, num_epoch, save_dir, mean, std,
                  data_loader=train_dataloader, valid_data_loader=val_dataloader,
                  lr_scheduler=scheduler)

In [24]:
trainer.early_stop = 30
train_config['Early stop'] = trainer.early_stop

In [27]:
wandb.init(project='Classification', name=f"{trainer.dir.split('/')[1]}", config=train_config)

[34m[1mwandb[0m: Currently logged in as: [33mimlim[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [28]:
trainer.train()


Epoch : 0 | 