In [None]:
import torch
import torch.nn as nn

import torchvision
import torchvision.transforms as transforms

from torchvision.transforms.functional import center_crop

class ConvolutionBlock(nn.Module):
    def __init__(self, in_shape: int, out_shape: int):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_shape, out_shape, kernel_size=3),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(out_shape, out_shape, kernel_size=3),
            nn.ReLU()
        )
        
    def forward(self, x: torch.Tensor):
        out = self.conv1(x)
        out = self.conv2(out)
        return out

class U_Net(nn.Module):
    def __init__(self, in_shape:int, out_shape:int):
        super().__init__()

        conv_list = [ConvolutionBlock(2 ** i, 2 ** (i + 1)) for i in range(6, 10)]
        conv_list.insert(0, ConvolutionBlock(in_shape, 64))

        deconv_list = [ConvolutionBlock(2 ** i, 2 ** (i - 1)) for i in range(10, 6)]
        conv_list.append(0, ConvolutionBlock(64, out_shape, 1))

        self.conv = nn.ModuleList(conv_list)
        self.conv_pool = nn.ModuleList([nn.MaxPool2d(2) for i in range (10, 6)])
        
        self.deconv = nn.ModuleList(deconv_list)
        self.deconv_pool = nn.ModuleList([nn.ConvTranspose2d(in_shape, out_shape, kernel_size=2, stride=2) for i in range (10, 6)])

        self.loss = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, x):

        out = []
        
        for i in range(len(self.conv_pool)):
            _out = self.conv[i](x)
            out.append(_out)
            x = self.conv_pool[i](_out)
        
        x = self.conv[-1](x)

        for i in range(len(self.deconv_pool)):
            x = self.deconv[i](x)
            crop = center_crop(x, out[i], [x.shape[2], x.shape[3]])
            _out = torch.cat([x, crop], dif=1)
            x = self.deconv_pool[i](_out)
        
        x = self.deconv[-1](x)

        return x