적용 논문 : U-Net with spatial pyramid pooling for drusen segmentation in optical coherence tomography
![image.png](attachment:image.png)

In [1]:
#-*- encoding: utf-8 -*-
from sklearn.model_selection import KFold, GroupKFold, StratifiedKFold
from torch.optim.optimizer import Optimizer
from torch.optim import Adam,lr_scheduler,AdamW
from matplotlib import pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch 

import pandas as pd
import numpy as np
import sys, gc, os, glob

# 학과 DGX CORE GPU 5인 경우 
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [2]:
# Training.ipynb 하위 폴더에 train / test 폴더를 두고 이미지 numpy 파일을 저장합니다. 
train_files = sorted(glob.glob('./train/*.npy'))
test_files = sorted(glob.glob('./test/*.npy'))

In [None]:
# 데이터 프레임형식으로 변경 
## Generator에 활용하기 위한 용도 
train_df = pd.DataFrame(train_files, columns = ['id'])
test_df = pd.DataFrame(test_files, columns = ['id'])

In [None]:
import matplotlib.pylab as plt
# 학습데이터가 어떻게 구성되어있는지 시각적으로 확인 
## 학습시에 필요 없는 데이터 탐색용 코드 
color_map = plt.cm.get_cmap('RdBu')
color_map = color_map.reversed()
image_sample = np.load(train_files[56]).flatten().reshape(120, 120, 5)

plt.style.use('fivethirtyeight')
plt.figure(figsize=(20, 20))

for i in range(4):
    plt.subplot(1,5,i+1)
    plt.imshow(image_sample[:, :, i], cmap=color_map)

plt.subplot(1,5,5)
plt.imshow(image_sample[:,:,-1], cmap = color_map)
plt.show()

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# DataLoader 생성
# (h x w x 4) 
# (t-30, t-20, t-10, t)
class CloudTrainDataset(Dataset):
    """ Cloud Dataset """ 
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):  
        transform = transforms.Compose([
                transforms.ToTensor()
                ])
        image = np.load(self.data.loc[idx, 'id'])
        
        # target은 가장 마지막 채널 
        # input은 target을 제외한 모든 채널
        # +10분뒤의 정보를 예측 
        target = image[:, :, -1]
        image = image[:, :, 0:4]

            
        return transform(image), transform(target) 

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CloudTestDataset(Dataset):
    """ Cloud Dataset """ 
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):  
        transform = transforms.Compose([
                transforms.ToTensor()
                ])
        
        # 테스트용 데이터는 전체 채널이 모두 input으로 들어감 
        image = np.load(self.data.loc[idx, 'id'])
        return transform(image)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Pyramid Pooling Module 생성 
class PPM(nn.Module):
    def __init__(self, in_dim, reduction_dim, bins):
        super(PPM, self).__init__()
        self.features = []
        # 4번 반복 
        for bin in bins:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduction_dim),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
        return torch.cat(out, 1)

In [None]:
# Unet 네트워크 생성 
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)]
            layers += [nn.BatchNorm2d(num_features=out_channels)]
            layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)
            return cbr

        # Contracting path 
        # padding을 넣어서 output size가 바뀌지 않도록 설정했음 
        self.pad = nn.ZeroPad2d(4) 
        self.enc1_1 = CBR2d(in_channels=4, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)     
        self.pool1 = nn.MaxPool2d(kernel_size=2)
        self.ppm1 = PPM(64, 64//4, (1, 2, 3, 6))

        self.enc2_1 = CBR2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.ppm2 = PPM(128, 128//4, (1, 2, 3, 6))
 
        self.enc3_1 = CBR2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
        self.pool3 = nn.MaxPool2d(kernel_size=2)    
        self.ppm3 = PPM(256, 256//4, (1, 2, 3, 6))

        self.enc4_1 = CBR2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
        self.enc4_2 = CBR2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
        self.pool4 = nn.MaxPool2d(kernel_size=2)    
        self.ppm4 = PPM(256, 256//4, (1, 2, 3, 6))

        self.enc5_1 = CBR2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)

        # Expansive path 
        self.dec5_1 = CBR2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
        self.unpool4 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2, padding=0, bias=True)
        
        ## Encoder와 Decoder 파트의 인덱스를 동일하게 하기 위해서, 순서를 거꾸로 매겨줌 
        ### self.enc4_2의 output channel -> input channel / input channel -> output channel으로 매겨짐 
        ### 512인 이유는 self.enc4_2의 output channel + self.unpool4의 output channel 
        self.dec4_2 = CBR2d(in_channels=256*4, out_channels=256) 
        self.dec4_1 = CBR2d(in_channels=256, out_channels=256) 

        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2, padding=0, bias=True)
        
        
        self.dec3_2 = CBR2d(in_channels=256*4, out_channels=256) 
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128) 

        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2, padding=0, bias=True)

        self.dec2_2 = CBR2d(in_channels=128*4, out_channels=128) 
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64) 

        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True)

        self.dec1_2 = CBR2d(in_channels=64*4, out_channels=64) 
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)
        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True) # Output Segmentation map 

        self._initialize_weights()

    def forward(self, x):
        x = self.pad(x)
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)
        ppm1 = self.ppm1(enc1_2)


        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)
        ppm2 = self.ppm2(enc2_2)
        
        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)
        ppm3 = self.ppm3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)
        ppm4 = self.ppm4(enc4_2)

        enc5_1 = self.enc5_1(pool4)

        dec5_1 = self.dec5_1(enc5_1)
        unpool4 = self.unpool4(dec5_1)

        cat4 = torch.cat((unpool4, enc4_2, ppm4), dim=1) # torch.cat의 dim : [0: batch, 1:channel, 2:height, 3:width]
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)
        unpool3 = self.unpool3(dec4_1)

        cat3 = torch.cat((unpool3, enc3_2, ppm3), dim=1) 
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)
        unpool2 = self.unpool2(dec3_1)

        cat2 = torch.cat((unpool2, enc2_2, ppm2), dim=1) 
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)
        unpool1 = self.unpool1(dec2_1)

        cat1 = torch.cat((unpool1, enc1_2, ppm1), dim=1) 
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1) 
        # 넣어준 zeropadding 길이만큼 다시 삭제 
        x = x[:,:,4:4+120, 4:4+120]      
        return torch.sigmoid(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.xavier_uniform_(m.weight)

                # xavier_uniform은 bias에 대해서는 제공하지 않음 
                # ValueError: Fan in and fan out can not be computed for tensor with fewer than 2 dimensions
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)

In [None]:
# 입출력 사이즈가 잘 맞는지 확인 
if __name__ == '__main__':
    cloud_model = UNet()
    input = torch.rand(32, 4, 120, 120)
    output = cloud_model(input)
    print(output.size(), '\n')

In [None]:
# 시드고정 
import random
def seed_everything(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

In [None]:
import time 
from tqdm.notebook import tqdm
# 5폴드를 이용한 앙상블 
criterion = nn.L1Loss(reduction = 'sum')
kf = KFold(5, shuffle=True, random_state=0)
dev = torch.cuda.set_device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model_name = 'UNet'
nb_epochs = 35
valid_preds = []

for i, (tr_idx, val_idx) in enumerate(kf.split(range(0, train_df.shape[0]))):
    print("==" * 20, "{} Fold Start".format(i), "=="*20)
    start_time = time.time()
    best_valid = 0
    valid_pred = []
    
    x_trainset = train_df.loc[tr_idx].reset_index(drop=True)
    x_validset = train_df.loc[val_idx].reset_index(drop=True)
    
    train_loader = DataLoader(CloudTrainDataset(data = x_trainset), batch_size=128, shuffle=True, num_workers=8, pin_memory=True)
    valid_loader = DataLoader(CloudTrainDataset(data = x_validset), batch_size=128, shuffle=False, num_workers=8, pin_memory=True)
    model = UNet().cuda()

    optimizer = Adam(model.parameters(), lr=0.025)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, 2, eta_min=1e-6) # 1e-6
    
    train_loss_list, valid_loss_list = [], []
    for epoch in (range(0, nb_epochs)):
        # train 학습
        train_loss, valid_loss = 0, 0

        model.train()
        for idx, (x_train, y_train) in enumerate(train_loader):
            x_train, y_train = x_train.to(dev), y_train.to(dev)
            optimizer.zero_grad()
            prediction = model(x_train.cuda().float())
            loss = criterion(prediction, y_train.cuda().float())
            loss.backward()
            train_loss += (255 * loss.item() / (len(train_loader.dataset) * 120 * 120))
            optimizer.step()
        
        # valid 검증 
        model.eval()
        with torch.no_grad():
            for idx, (x_valid, y_valid) in enumerate(valid_loader):
                x_valid, y_valid = x_valid.to(dev), y_valid.to(dev)               
                prediction = model(x_valid.cuda().float())
                loss = criterion(prediction, y_valid.cuda().float())    
                valid_loss += (255 * loss.item() / (len(valid_loader.dataset) * 120 * 120))
        
        end_time = time.time()
        if epoch % 1 == 0: 
            print('Epoch {:4d}/{} Time {:4f} Train Loss: {:.3f}  Valid Loss: {:.3f} '.format(epoch+1, nb_epochs, end_time - start_time, train_loss, valid_loss))

        train_loss_list.append(train_loss)
        valid_loss_list.append(valid_loss)

        if (best_valid == 0 or (best_valid > valid_loss)):
            best_valid = valid_loss
            # Background 제외한 버전 
            torch.save(model.state_dict(), './savedmodel/{}_FOLD_PYRAMID{}.pth'.format(model_name, i))
        
        scheduler.step()
    valid_preds.append(best_valid)

In [None]:
print(valid_preds, '\n')
print(np.mean(valid_preds))