In [12]:
import os
import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms, datasets
from PIL import Image

# 데이터 불러오기
dir_data = './datasets'

name_label = 'train-labels.tif'
name_input = 'train-volume.tif'

img_label = Image.open(os.path.join(dir_data, name_label))
img_input = Image.open(os.path.join(dir_data, name_input))

ny, nx = img_label.size                  # 512 * 512
nframe = img_label.n_frames              # 30 frame

# dataset을 Train과 Test로 나누기
nframe_train = 24
nframe_val = 3
nframe_test = 3

# data가 저장될 Directory 설정하기
dir_save_train = os.path.join(dir_data, 'train')
dir_save_val = os.path.join(dir_data, 'val')
dir_save_test = os.path.join(dir_data, 'test')

# Dir 생성
if not os.path.exists(dir_save_train) :
    os.makedirs(dir_save_train)
if not os.path.exists(dir_save_val) :
    os.makedirs(dir_save_val)
if not os.path.exists(dir_save_test) :
    os.makedirs(dir_save_test)
    
# data dir에 dataset을 Random 하게 저장하기
id_frame = np.arange(nframe)
np.random.shuffle(id_frame)

#train set 저장하기
offset_nframe = 0

for i in range(nframe_train) :
    img_label.seek(id_frame[i + offset_nframe])
    img_input.seek(id_frame[i + offset_nframe])
    
    label_ = np.asarray(img_label)
    input_ = np.asarray(img_input)
    
    np.save(os.path.join(dir_save_train, 'label_%03d.npy' % i), label_)
    np.save(os.path.join(dir_save_train, 'label_%03d.npy' % i), input_)
    
# val set 저장하기
offset_nframe += nframe_train

for i in range(nframe_val) :
    img_label.seek(id_frame[i + offset_nframe])
    img_input.seek(id_frame[i + offset_nframe])
    
    label_ = np.asarray(img_label)
    input_ = np.asarray(img_input)
    
    np.save(os.path.join(dir_save_val, 'label_%03d.npy' % i), label_)
    np.save(os.path.join(dir_save_val, 'label_%03d.npy' % i), input_)
    
# test set 저장하기
offset_nframe += nframe_val

for i in range(nframe_test) :
    img_label.seek(id_frame[i + offset_nframe])
    img_input.seek(id_frame[i + offset_nframe])
    
    label_ = np.asarray(img_label)
    input_ = np.asarray(img_input)
    
    np.save(os.path.join(dir_save_test, 'label_%03d.npy' % i), label_)
    np.save(os.path.join(dir_save_test, 'label_%03d.npy' % i), input_)

# hyper parameter 설정
lr = 1e-3
batch_size = 4
num_epoch = 100

data_dir = './datasets'
ckpt_dir = './checkpoint'   #train된 Network가 저장될 dir
log_dir = './log'           #tensorboard의 로그가 기록될 dir

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Layer 생성하기
# UNet 네트워크에 nn.Module을 상속하기
class UNet(nn.Module) :      
    def __init__(self) :
        super(UNet, self).__init__()  # 상속 초기화
        
        # Convolution Batch-nomarlization ReLU 2D
        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True) :
            # Convolution Layer 정의하기
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                kernel_size=kernel_size, stride=stride, padding=padding,
                                bias=bias)]
            # Batch-nomarlization 정희하기
            layers += [nn.BatchNorm2d(num_features=out_channels)]
            # ReLU 정의하기
            layers += [nn.ReLU()]
            
            cbr = nn.Sequential(*layers)
            
            return cbr
        
        # Contracting path (Encoder 부분)
        # kernel_size=3, stride=1, padding=1, bias=True 생략 가능
        self.enc1_1 = CBR2d(in_channels=1, out_channels=64)
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64)
        
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        
        self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128)
        
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        
        self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256)
        
        self.pool3 = nn.MaxPool2d(kernel_size=2)
        
        self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc4_2 = CBR2d(in_channels=512, out_channels=512)
        
        self.pool4 = nn.MaxPool2d(kernel_size=2)
        
        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)
        
        # Expansive path (Decoder 부분)
        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)
        
        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec4_2 = CBR2d(in_channels=512 * 2, out_channels=512) # *2 하는 이유는 UNet의 해당 Decoder부분 그림 잘 보기
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256)
        
        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec3_2 = CBR2d(in_channels=256 * 2, out_channels=256)
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128)
        
        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec2_2 = CBR2d(in_channels=128 * 2, out_channels=128)
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64)
        
        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                          kernel_size=2, stride=2, padding=0, bias=True)
        
        self.dec1_2 = CBR2d(in_channels=64 * 2, out_channels=64)
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)
        
        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=2, stride=2, padding=0, bias=True)
        
        
    # UNet 레이어 연결하기
    def forward(self, x) :
        # Encoder 부분 연결하기
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)
        
        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2) 
        
        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)
        
        enc5_1 = self.enc5_1(pool4)
        
        # Decoder 부분 연결하기
        
        dec5_1 = self.dec5_1(enc5_1)
        
        unpool4 = self.unpool4(dec5_1)
        cat4 = torch.cat((unpool4, enc4_2), dim=1) # dim=[0:batch, 1:channel, 2:height, 3:width]
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)
        
        unpool3 = self.unpool3(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)
        
        unpool2 = self.unpool2(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)
        
        unpool1 = self.unpool1(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)
        
        x = self.fc(dec1_1)
        
        return x


## 데이터 로더를 구현하기
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform

        lst_data = os.listdir(self.data_dir)

        lst_label = [f for f in lst_data if f.startswith('label')]
        lst_input = [f for f in lst_data if f.startswith('input')]

        lst_label.sort()
        lst_input.sort()

        self.lst_label = lst_label
        self.lst_input = lst_input

    def __len__(self):
        return len(self.lst_label)

    def __getitem__(self, index):
        label = np.load(os.path.join(self.data_dir, self.lst_label[index]))
        input = np.load(os.path.join(self.data_dir, self.lst_input[index]))

        label = label/255.0
        input = input/255.0

        if label.ndim == 2:
            label = label[:, :, np.newaxis]
        if input.ndim == 2:
            input = input[:, :, np.newaxis]

        data = {'input': input, 'label': label}

        if self.transform:
            data = self.transform(data)

        return data


# ToTensor() : numpy -> tensor로 변환

class ToTensor(object) :
    def __call__(self, data) :
        label, input = data['label'], data['input']
        
        # Image의 Numpy 차원 = (Y,X,ch)
        # Image의 tensor 차원 = (ch,Y,X)
        label = label.transpose((2,0,1)).astype(np.float32)
        input = input.transpose((2,0,1)).astype(np.float32)
        
        data = {'label': torch.from_numpy(label), 'input': torch.from_numpy(input)}
        
        return data

class Normalization(object) :
    def __init__(self, mean=0.5, std=0.5) :
        self.mean = mean
        self.std = std
    
    def __call__(self, data) :
        label, input = data['label'], data['input']
        
        input = (input - self.mean) / self.std
        data = {'label' :label, 'input':input}
        
        return data
    
class RandomFlip(object) :
    def __init__ (self, data) :
        label, input = data['label'], data['input']
        
        if np.random.rand() > 0.5: # 50% 확률
            label = np.fliplr(label)
            input = np.fliplr(input)
        
        if np.random.rand() > 0.5 :
            label = np.flipud(label)
            input = np.flipud(input)
        
        data = {'label' :label, 'input':input}
        
        return data




## 네트워크 학습하기위해 data load 부분
train_transform = transforms.Compose([Normalization(mean=0.5, std=0.5),RandomFlip(),ToTensor()])
test_transform = transforms.Compose([Normalization(mean=0.5, std=0.5),ToTensor()])

dataset_train = Dataset(data_dir=os.path.join(data_dir, 'train'), transform=train_transform)
loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)


dataset_test = Dataset(data_dir=os.path.join(data_dir, 'val'), transform=test_transform)
loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=False, num_workers=8)
## 네트워크 생성하기
net = UNet().to(device)

## 손실함수 정의하기
fn_loss = nn.BCEWithLogitsLoss().to(device)

## Optimizer 설정하기
optim = torch.optim.Adam(net.parameters(), lr=lr)

## 그밖에 부수적인 variables 설정하기
num_data_test = len(dataset_test)
num_data_train = len(dataset_train)
num_batch_train = np.ceil(num_data_test / batch_size)

    # 그밖에 부수적인 variables 설정하기


## 그밖에 부수적인 functions 설정하기
# tensor variable에서 Numpy로 변환시키는 함수
fn_tonumpy = lambda x: x.to('cpu').detach().numpy().transpose(0, 2, 3, 1)
# normalization 되어있는 data를 반대로 de-normalization 하는 함수
fn_denorm = lambda x, mean, std: (x * std) + mean
# 네트워크 아웃풋 이미지를 bin 클래스로 분류해주는 함수
fn_class = lambda x: 1.0 * (x > 0.5)


# 텐서보드를 사용하기 위한 summaryWriter 설정
writer_train = SummaryWriter(log_dir=os.path.join(log_dir, 'train'))
writer_val = SummaryWriter(log_dir=os.path.join(log_dir, 'val'))
       
    
## 네트워크 저장하기
def save(ckpt_dir, net, optim, epoch):
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    torch.save({'net': net.state_dict(), 'optim': optim.state_dict()},
               "./%s/model_epoch%d.pth" % (ckpt_dir, epoch))

## 네트워크 불러오기
st_epoch = 0

def load(ckpt_dir, net, optim):
    if not os.path.exists(ckpt_dir):
        epoch = 0
        return net, optim, epoch

    ckpt_lst = os.listdir(ckpt_dir)
    ckpt_lst.sort(key=lambda f: int(''.join(filter(str.isdigit, f))))

    dict_model = torch.load('./%s/%s' % (ckpt_dir, ckpt_lst[-1]))

    net.load_state_dict(dict_model['net'])
    optim.load_state_dict(dict_model['optim'])
    epoch = int(ckpt_lst[-1].split('epoch')[1].split('.pth')[0])

    return net, optim, epoch

# train을 진행시키는 반복문
st_epoch = 0
# 학습 이전에, 저장되어있는 network가 있다면 불러와서 연속적으로 학습할 수 있도록 load 해줌
net, optim, st_epoch = load(ckpt_dir=ckpt_dir, net=net, optim=optim )

for epoch in range(st_epoch + 1, num_epoch + 1) :
    net.train()
    loss_arr = []
    
    for batch, data in enumerate(loader_train, 1):
        # forward pass
        # Netword의 input을 받아 Output을 출력
        label = data['label'].to(device)
        input = data['input'].to(device)
        
        output = net(input)
        # backward
        optim.zero_grad()
        
        loss = fn_loss(output, label)
        loss.backward()
        
        optim.step()
        
        # 손실함수 계산
        loss_arr += [loss.item()]
        print("TRAIN: EPOCH %04d / %04d BATCH %04d / %04d| LOSS %.4f" %
              (epoch,num_epoch, batch, num_batch_train, np.mean(loss_arr)))
        
        # tensorboard에 input & out & label 저장 구문
        label = fn_tonumpy(label)
        input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))
        output = fn_tonumpy(fn_class(output))
        
        writer_train.add_imgae('label', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
        writer_train.add_imgae('input', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
        writer_train.add_imgae('output', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
        
        # Save the loss in tensorboard
    writer_train.add_scalar('loss', np.mean(loss_arr), epoch)

    
    # Network validation
    # backward가 없어, 사전에 막기 위해 torch.no_grad() 실시
    with torch.no_grad() :
        net.eval()               # Network validation 명시를 위해
        loss_arr = []
        
        for batch, data in enumerate(loader_test, 1):
            # forward pass
            label = data['label'].to(device)
            input = data['input'].to(device)    
                
            output = net(input)
                
            # 손실함수 계산
            loss = fn_loss(output, label)
                
            loss_arr += [loss.item()]
            print("VALID: EPOCH %04d / %04d BATCH %04d / %04d| LOSS %.4f" %
                  (epoch,num_epoch, batch, num_batch_val, np.mean(loss_arr)))
        
            # tensorboard에 input & out & label 저장 구문
            label = fn_tonumpy(label)
            input = fn_tonumpy(fn_denorm(input))
            output = fn_tonumpy(fn_class(output))

            writer_val.add_imgae('label', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
            writer_val.add_imgae('input', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
            writer_val.add_imgae('output', label, num_batch_train * (epoch -1) + batch, dataformats = 'NHWC')
          
        # Save the loss in tensorboard
        writer_val.add_scalar('loss', np.mean(loss_arr), epoch)   
        
        # epoch이 다섯번씩 진행될 때 마다 네트워크를 저장하는 부분
        if epoch % 5 == 0:
            save(ckpt_dir=ckpt_dir, net =net, optim = optim, epoch= epoch)

writer_train.close()           
writer_val.close()

TypeError: __init__() missing 1 required positional argument: 'data'