# Unet Blocks

In [33]:
import torch
from torch import nn

class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv2d_1 = nn.Conv2d(in_channels=in_channels,
                                 out_channels=out_channels,
                                 kernel_size=(3,3),
                                 stride=1,
                                 padding=1)
        self.relu_1 = nn.ReLU()
        self.conv2d_2 = nn.Conv2d(in_channels=out_channels,
                                 out_channels=out_channels,
                                 kernel_size=(3,3),
                                 stride=1,
                                 padding=1)
        self.relu_2 = nn.ReLU()
        
    def forward(self, x):
        return self.relu_2(self.conv2d_2(self.relu_1(self.conv2d_1(x))))   


class encoder_block(nn.Module):
    def __init__(self, in_features=[3,64,128,256,512]):
        super().__init__()
        self.encBlock = nn.ModuleList([conv_block(in_features[x],
                                                  in_features[x+1]) 
                                       for x in range(len(in_features)-1)])
        self.maxpool = nn.MaxPool2d(kernel_size=(2, 2), 
                                    stride=2, 
                                    padding=0)

    def forward(self, x):
        encBlock_out = []
        for block in self.encBlock:
            x = block(x)
            encBlock_out.append(x)
            x = self.maxpool(x)
        return encBlock_out
        

class decoder_block(nn.Module):
    def __init__(self, in_features=[1024,512,256,128,64]):
        super().__init__()
        self.channels = in_features
        self.decBlock = nn.ModuleList([conv_block(in_features[x],
                                                  in_features[x+1]) 
                                       for x in range(len(in_features)-1)])
        self.upscaling = nn.ModuleList(nn.ConvTranspose2d(in_features[x],
                                                  in_features[x+1],
                                            kernel_size=(3,3),
                                            stride=2,
                                            padding=1,
                                            output_padding=1) for x in range(len(in_features)-1))

    def forward(self, x, encFeatures):
        for i in range(len(self.channels)-1):
            x = self.upscaling[i](x)
            x = torch.cat(x,encFeat[i])
            x = self.decBlock(x)                             
        return x
        

class Unet(nn.Module):
    def __init__(self,enc_channels=[3,64,128,256,512],
                    dec_channels=[1024,512,256,128,64],
                    n_classes = 1,
                    out_size=(256,256)):
        super().__init__()
        self.encoder = encoder_block(enc_channels)
        self.decoder = decoder_block(dec_channels)

        self.head = nn.Conv2d(dec_channels[-1], n_classes,
                             kernel_size=(3,3), stride=1, padding=1)
        self.out_size = out_size  
        self.conv2d_intermedian = nn.Conv2d(dec_channels[1], enc_channels[-1]*2,
                                           kernel_size=(3,3),stride=1,
                                           padding=1)

    def forward(self, x):
        encFeatures = self.encoder(x)
        
        intermedial_layer = self.conv2d_intermedian(encFeatures[-1])
        
        decFeatures = self.decoder(intermedial_layer,encFeatures[::-1])
        
        classifier = self.head(decFeatures)
        
        return classifier       

# Testing Blocks

In [17]:
from torchsummary import summary

img = torch.randn(1,3,256,256).to('cuda')
conv_block_m = conv_block(3,10).to('cuda')
conv_block_m(img).shape

torch.Size([1, 10, 256, 256])

In [18]:
summary(conv_block_m, (3,100,100))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 10, 100, 100]             280
              ReLU-2         [-1, 10, 100, 100]               0
            Conv2d-3         [-1, 10, 100, 100]             910
              ReLU-4         [-1, 10, 100, 100]               0
Total params: 1,190
Trainable params: 1,190
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.11
Forward/backward pass size (MB): 3.05
Params size (MB): 0.00
Estimated Total Size (MB): 3.17
----------------------------------------------------------------


## Encoder Block

In [19]:
encoder = encoder_block().to('cuda')


In [20]:
encoder.encBlock

ModuleList(
  (0): conv_block(
    (conv2d_1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu_1): ReLU()
    (conv2d_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu_2): ReLU()
  )
  (1): conv_block(
    (conv2d_1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu_1): ReLU()
    (conv2d_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu_2): ReLU()
  )
  (2): conv_block(
    (conv2d_1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu_1): ReLU()
    (conv2d_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu_2): ReLU()
  )
  (3): conv_block(
    (conv2d_1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu_1): ReLU()
    (conv2d_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu_2): ReLU()
  )
)

In [21]:
img_enc = encoder(img)
[x.shape for x in img_enc]

[torch.Size([1, 64, 256, 256]),
 torch.Size([1, 128, 128, 128]),
 torch.Size([1, 256, 64, 64]),
 torch.Size([1, 512, 32, 32])]

# Decoder Block

In [34]:
decoder = decoder_block().to('cuda')
decoder

decoder_block(
  (decBlock): ModuleList(
    (0): conv_block(
      (conv2d_1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_1): ReLU()
      (conv2d_2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_2): ReLU()
    )
    (1): conv_block(
      (conv2d_1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_1): ReLU()
      (conv2d_2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_2): ReLU()
    )
    (2): conv_block(
      (conv2d_1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_1): ReLU()
      (conv2d_2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_2): ReLU()
    )
    (3): conv_block(
      (conv2d_1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu_1): ReLU()
      (conv2d_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (rel

In [35]:
img_interm = torch.randn(1,1024,32,32).to('cuda')
decoder(img_interm, img_enc)


NameError: name 'encFeat' is not defined

# Unet

In [31]:
model = Unet().to('cuda')

In [32]:
model(img)

RuntimeError: Given transposed=1, weight of size [512, 256, 3, 3], expected input[1, 1024, 32, 32] to have 512 channels, but got 1024 channels instead