In [2]:
## 라이브러리 추가하기
import os
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt

from torchvision import transforms, datasets

In [3]:
# hyper parameter 설정
lr = 1e-3
batch_size = 4
num_epoch = 100

data_dir = './datasets'
ckpt_dir = './checkpoint'   #train된 Network가 저장될 dir
log_dir = './log'           #tensorboard의 로그가 기록될 dir

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
## 네트워크 학습하기
transform = transforms.Compose([Normalization(mean=0.5, std=0.5), RandomFlip(), ToTensor()])

dataset_test = Dataset(data_dir=os.path.join(data_dir, 'test'), transform=transform)
loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=8)

dataset_val = Dataset(data_dir=os.path.join(data_dir, 'val'), transform=transform)
loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=8)
## 네트워크 생성하기
net = UNet().to(device)

## 손실함수 정의하기
fn_loss = nn.BCEWithLogitsLoss().to(device)

## Optimizer 설정하기
optim = torch.optim.Adam(net.parameters(), lr=lr)

## 그밖에 부수적인 variables 설정하기
num_data_test = len(dataset_test)
num_data_val = len(num_data_val)

num_batch_train = np.ceil(num_data_test / batch_size)
num_batch_val = np.ceil(num_data_val / batch_size)

## 그밖에 부수적인 functions 설정하기
# tensor variable에서 Numpy로 변환시키는 함수
fn_tonumpy = lambda x: x.to('cpu').detach().numpy().transpose(0, 2, 3, 1)
# normalization 되어있는 data를 반대로 de-normalization 하는 함수
fn_denorm = lambda x, mean, std: (x * std) + mean
# 네트워크 아웃풋 이미지를 bin 클래스로 분류해주는 함수
fn_class = lambda x: 1.0 * (x > 0.5)

# 텐서보드를 사용하기 위한 summaryWriter 설정
writer_train = SummaryWriter(log_dir=os.path.join(log_dir, 'train'))
writer_val = SummaryWriter(log_dir=os.path.join(log_dir, 'val'))
           
## 네트워크 저장하기
def save(ckpt_dir, net, optim, epoch):
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    torch.save({'net': net.state_dict(), 'optim': optim.state_dict()},
               "./%s/model_epoch%d.pth" % (ckpt_dir, epoch))

## 네트워크 불러오기
st_epoch = 0

def load(ckpt_dir, net, optim):
    if not os.path.exists(ckpt_dir):
        epoch = 0
        return net, optim, epoch

    ckpt_lst = os.listdir(ckpt_dir)
    ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))

    dict_model = torch.load('./%s/%s' % (ckpt_dir, ckpt_lst[-1]))

    net.load_state_dict(dict_model['net'])
    optim.load_state_dict(dict_model['optim'])
    epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])

    return net, optim, epoch

# train을 진행시키는 반복문
st_epoch = 0
# 학습 이전에, 저장되어있는 network가 있다면 불러와서 연속적으로 학습할 수 있도록 load 해줌
net, optim, st_epoch = load(chkpt_dir=ckpt_dir, net=net, optim=optim )

for epoch in range(st_epoch + 1, num_epoch + 1) :
    net.train()
    loss_arr = []
    
    for batch, data in enumerate(loader_test, 1):
        # forward pass
        # Netword의 input을 받아 Output을 출력
        label = data['label'].to(device)
        input = data['input'].to(device)
        
        # backward
        optim.zero_grad()
        
        loss = fn_loss(output, label)
        loss.backward()
        
        optim.step()
        
        # 손실함수 계산
        loss_arr += [loss.item()]
        print("TRAIN: EPOCH %04d / %04d BATCH %04d / %04d| LOSS %.4f" %
                (epoch,num_epoch, batch, num_batch_train np.mean(loss_arr)))
        
        # tensorboard에 input & out & lable 저장 구문
        label = fn_tonumpy(label)
        input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))
        output = fn_tonumpy(fn_class(output))
        
        writer_train.add_imgae('label', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
        writer_train.add_imgae('input', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
        writer_train.add_imgae('output', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
        
        # Save the loss in tensorboard
    writer_train.add_scalar('loss', np.mean(loss_arr), epoch)

    
    # Network validation
    # backward가 없어, 사전에 막기 위해 torch.no_grad() 실시
    with torch.no_grad() :
        net.eval()               # Network validation 명시를 위해
        loss_arr = []
        
        for batch, data in enumerate(loader_test, 1):
                # forward pass
                label = data['label'].to(device)
                input = data['input'].to(device)    
                
                output = net(input)
                
                # 손실함수 계산
                loss = fn_loss(output, lable)
                
                loss_arr += [loss.item()]
                print("VALID: EPOCH %04d / %04d BATCH %04d / %04d| LOSS %.4f" %
                      (epoch,num_epoch, batch, num_batch_val np.mean(loss_arr)))
        
            # tensorboard에 input & out & lable 저장 구문
            label = fn_tonumpy(label)
            input = fn_tonumpy(fn_denorm(input)
            output = fn_tonumpy(fn_class(output))

            writer_val.add_imgae('label', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
            writer_val.add_imgae('input', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
            writer_val.add_imgae('output', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
          
        # Save the loss in tensorboard
        writer_val.add_scalar('loss', np.mean(loss_arr), epoch)   
        
        # epoch이 다섯번씩 진행될 때 마다 네트워크를 저장하는 부분
        if epochch % 5 == 0:
            save(ckpt_dir=ckpt_dir, net =net, optim = optim, epoch= epoch)

writer_train.close()           
writer_val.close()
                               
                    

