In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import torch
import torchvision
import torch.nn as nn



In [138]:
class UNet(nn.Module):

    def __init__(self, filters, img_channels=3):
        super(UNet, self).__init__()
        self.up_channels = filters[::-1]
        filters.insert(0, img_channels) # To start from the initial number channels of the image
        self.down_channels = filters
        self.relu = nn.ReLU()
        self.output = nn.Conv2d(filters[1], img_channels, kernel_size=1)
        

    def block_down(self, in_chan, out_chan, kernel_conv=3, pad_conv=1, kernel_pool=2):
        layers = []
        layers.append(nn.Conv2d(in_chan, out_chan, kernel_size=kernel_conv, padding=pad_conv))
        layers.append(nn.Conv2d(out_chan, out_chan, kernel_size=kernel_conv, padding=pad_conv))
        layers.append(nn.MaxPool2d(kernel_size=kernel_pool))
        return layers
    
    def block_up(self, in_chan, out_chan, kernel_conv=3, pad_conv=1, kernel_trans=2, strd_trans=2):
        layers = []
        layers.append(nn.ConvTranspose2d(in_chan, out_chan, kernel_size=kernel_trans, stride=strd_trans))
        layers.append(nn.Conv2d(in_chan, out_chan, kernel_size=kernel_conv, padding=pad_conv))
        layers.append(nn.Conv2d(out_chan, out_chan, kernel_size=kernel_conv, padding=pad_conv))
        return layers

    def forward(self, x):

        # Down block
        x_down = []
        n_dchan = len(self.down_channels)
        for i in range(n_dchan - 1):
            #print(i, 'down')
            in_chan, out_chan = self.down_channels[i], self.down_channels[i + 1]
            #print(in_chan, out_chan) 
            layers = self.block_down(in_chan, out_chan)
            for j in range(2):
                x = self.relu(layers[j](x))    
            x_down.append(x)
            if i != n_dchan - 2: # To avoid maxpooling last block of layers     
                x = layers[-1](x)

        # Up block
        x_down.reverse()
        for i in range(len(self.up_channels) - 1):
            #print(i, 'up')
            in_chan, out_chan = self.up_channels[i], self.up_channels[i + 1] 
            #print(in_chan, out_chan)             
            layers = self.block_up(in_chan, out_chan)
            t_layer = layers[0](x)
            x = torch.concat((t_layer, x_down[i+1]), dim=1)
            for j in range(1, 3):
                x = self.relu(layers[j](x))

        # Output layer        
        x = self.output(x)  # Same size as input image
        return x


In [140]:
filters = [32, 64, 128]
net = UNet(filters)
img = torch.rand(1, 3, 128, 128) * 0.2 + 0.4  
pred = net(img)
print(pred.shape)  

torch.Size([1, 3, 128, 128])
