In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.transforms.functional import to_pil_image
import matplotlib.pylab as plt
%matplotlib inline

import os
import numpy as np
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
class UNet(nn.Module):
    def __init__(self, down_layer_num=4, bottom_layer_num=2, up_layer_num=4, start_channel=1, next_channel=64, final_channel=2, conv_ker=3, conv_st=1):
        super(UNet, self).__init__()
        self.layers = {}
        self.down_layer_num = down_layer_num
        self.bottom_layer_num = bottom_layer_num
        self.up_layer_num = up_layer_num
        
        for i in range(1, down_layer_num+1):
            self.layers["down_conv_"+str(i)] = nn.Sequential(self.ConvBatchnorm2dReLU(start_channel, next_channel, conv_ker, conv_st),
                                                             self.ConvBatchnorm2dReLU(next_channel, next_channel, conv_ker, conv_st))
            self.layers["down_maxpool_"+str(i)] = nn.MaxPool2d(kernel_size=2, stride=2)
            start_channel = next_channel
            next_channel *= next_channel
        
        for i in range(1, bottom_layer_num+1):
            self.layers["bottom_conv_"+str(i)] = self.ConvBatchnorm2dReLU(start_channel, next_channel, conv_ker, conv_st)
            if i == 1:
                start_channel = next_channel
        
        for i in range(1, up_layer_num+1):
            next_channel //= 2
            self.layers["up_upconv_"+str(i)] = nn.ConvTranspose2d(start_channel, next_channel, kernel_size=2, stride=2)
            self.layers["up_conv_"+str(i)] = nn.Sequential(self.ConvBatchnorm2dReLU(start_channel, next_channel, conv_ker, conv_st),
                                                           self.ConvBatchnorm2dReLU(next_channel, next_channel, conv_ker, conv_st))
            start_channel = next_channel
        
        self.layers["final_conv1x1"] = nn.Conv2d(start_channel, final_channel, kernel_size=1, stride=1)
        # self.down_conv_2_1 = self.ConvBatchnorm2dReLU(64, 128, 3, 1)
        # self.down_conv_2_2 = self.ConvBatchnorm2dReLU(128, 128, 3, 1)
        # self.down_maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # self.down_conv_3_1 = self.ConvBatchnorm2dReLU(128, 256, 3, 1)
        # self.down_conv_3_2 = self.ConvBatchnorm2dReLU(256, 256, 3, 1)
        # self.down_maxpool_3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # self.down_conv_4_1 = self.ConvBatchnorm2dReLU(256, 512, 3, 1)
        # self.down_conv_4_2 = self.ConvBatchnorm2dReLU(512, 512, 3, 1)
        # self.down_maxpool_4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # self.bottom_conv_1 = self.ConvBatchnorm2dReLU(512, 1024, 3, 1)
        # self.bottom_conv_2 = self.ConvBatchnorm2dReLU(1024, 1024, 3, 1)
        
        # self.up_upconv_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        
    def ConvBatchnorm2dReLU(self, in_channels, out_channels, kernel_size, stride):
        layers = nn.Sequential([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ])
        return layers

    def forward(self, x):
        res_input = {} # for residual connection
        
        for i in range(1, self.down_layer_num+1):
            x = self.layers["down_conv_"+str(i)](x)
            res_input["down_conv_"+str(i)] = x
            x = self.layers["down_maxpool_"+str(i)](x)
        
        for i in range(1, self.bottom_layer_num+1):
            x = self.layers["bottom_conv_"+str(i)](x)
        
        for i in range(1, self.up_layer_num+1):
            x = self.layers["up_upconv_"+str(i)](x)
            x = torch.cat(transforms.CentorCrop(x.shape[2:])(res_input["down_conv_"+str(self.down_layer_num+1-i)]), x, dim=1) # batch!
            x = self.layers["up_conv_"+str(i)](x)
        
        x = self.layers["final_conv1x1"](x)
        
        return x
        
        