In [1]:
from torch import nn
from typing import List
from torchvision import io
from torchvision import transforms
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [43]:
class ConvLayer(nn.Module):
    def __init__(self, inchannels, upsamp_multiplier):
        super(ConvLayer, self).__init__()
        outchannels = int(inchannels*upsamp_multiplier)
        self._conv1 = nn.Conv2d(
            inchannels, 
            outchannels,
            kernel_size=3)
        self._conv2 = nn.Conv2d(
            outchannels,
            outchannels,
            kernel_size=3
            )
        self.net = nn.Sequential(
            self._conv1, 
            nn.ReLU(),
            self._conv2,
            nn.ReLU())
    
    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, inchannels: int=1):
        super(UNet, self).__init__()
        self.conv1 = ConvLayer(inchannels, 64)
        self.pool1 = nn.MaxPool2d((2,2), 2)
        self.conv2 = ConvLayer(inchannels*64, 2)
        self.pool2 = nn.MaxPool2d((2,2), 2)
        self.conv3 = ConvLayer(inchannels*64*2, 2)
        self.pool3 = nn.MaxPool2d((2,2), 2)
        self.conv4 = ConvLayer(inchannels*64*4, 2)
        self.pool4 = nn.MaxPool2d((2,2), 2)
        self.conv5 = ConvLayer(inchannels*64*8, 2)        

        self.deconv1 = nn.ConvTranspose2d(inchannels*16*64, inchannels*8*64, (2,2),2)
        self.conv6 = ConvLayer(inchannels*16*64, 0.5)
        self.deconv2 = nn.ConvTranspose2d(inchannels*8*64, inchannels*4*64, (2,2),2)
        self.conv7 = ConvLayer(inchannels*8*64, 0.5)
        self.deconv3 = nn.ConvTranspose2d(inchannels*4*64, inchannels*2*64, (2,2),2)
        self.conv8 = ConvLayer(inchannels*4*64, 0.5)
        self.deconv4 = nn.ConvTranspose2d(inchannels*2*64, inchannels*64, (2,2),2)
        self.conv9 = ConvLayer(inchannels*2*64, 0.5)
        self.conv10 = nn.Conv2d(inchannels*64, 2, (1,1), 1)
        self.activation = nn.Softmax(dim=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(self.pool1(x1))
        x3 = self.conv3(self.pool2(x2))
        x4 = self.conv4(self.pool3(x3))
        x5 = self.conv5(self.pool4(x4))
        
        x_up_1 = self.deconv1(x5)
        x_up_1 = torch.concat((
            transforms.CenterCrop(x_up_1.shape[2:3])(x4),
            x_up_1), dim=1)

        x_up_2 = self.deconv2(self.conv6(x_up_1))
        x_up_2 = torch.concat(
            (transforms.CenterCrop(x_up_2.shape[2:3])(x3),
            x_up_2), dim=1)

        x_up_3 = self.deconv3(self.conv7(x_up_2))
        x_up_3 = torch.concat(
            (transforms.CenterCrop(x_up_3.shape[2:3])(x2),
            x_up_3), dim=1)

        x_up_4 = self.deconv4(self.conv8(x_up_3))
        x_up_4 = torch.concat(
            (transforms.CenterCrop(x_up_4.shape[2:3])(x1),
            x_up_4), dim=1)

        xout = self.conv10(self.conv9(x_up_4))
        return self.activation(xout)

In [32]:
# Test the output shape of UNet
test_im = io.read_image('../xray_samp/sample/images/00000013_005.png')
test_im = test_im.float().unsqueeze(0)
test_im.shape

torch.Size([1, 1, 1024, 1024])

In [44]:
net = UNet(1)
out = net(test_im)
out.shape


tensor([[[[0.5552, 0.5487, 0.5449,  ..., 0.6751, 0.7867, 0.7300],
          [0.5434, 0.5430, 0.5418,  ..., 0.6584, 0.7687, 0.7245],
          [0.5453, 0.5444, 0.5435,  ..., 0.6457, 0.7685, 0.7334],
          ...,
          [0.5757, 0.5836, 0.5907,  ..., 0.5812, 0.5736, 0.5644],
          [0.5772, 0.5797, 0.5787,  ..., 0.5666, 0.5710, 0.5733],
          [0.5752, 0.5729, 0.5644,  ..., 0.5520, 0.5613, 0.5705]],

         [[0.4448, 0.4513, 0.4551,  ..., 0.3249, 0.2133, 0.2700],
          [0.4566, 0.4570, 0.4582,  ..., 0.3416, 0.2313, 0.2755],
          [0.4547, 0.4556, 0.4565,  ..., 0.3543, 0.2315, 0.2666],
          ...,
          [0.4243, 0.4164, 0.4093,  ..., 0.4188, 0.4264, 0.4356],
          [0.4228, 0.4203, 0.4213,  ..., 0.4334, 0.4290, 0.4267],
          [0.4248, 0.4271, 0.4356,  ..., 0.4480, 0.4387, 0.4295]]]],
       grad_fn=<SoftmaxBackward0>)


###### Initialize network weights as follows :  "For a network with our architecture (alternating convolution and ReLU layers) this can be achievedby  drawing  the  initial  weights  from  a  Gaussian  distribution  with  a  standard deviation of √2/N, where N denotes the number of incoming nodes of one neuron [5]. E.g. for a 3x3 convolution and 64 feature channels in the previous layer N= 9·64 = 576"


In [53]:
'''
only convolutional and deconvolutional layers have learnable parameters.
So, search for each convolutional
layer, to initialize, simply ask what the previous convolutional
layer's kernel size was,
and find the number of channels of that previous layer (dim[1])
'''

In [62]:
net.apply(lambda x: print(x.__class__.__name__))


Conv2d
Conv2d
Conv2d
ReLU
Conv2d
ReLU
Sequential
ConvLayer
MaxPool2d
Conv2d
Conv2d
Conv2d
ReLU
Conv2d
ReLU
Sequential
ConvLayer
MaxPool2d
Conv2d
Conv2d
Conv2d
ReLU
Conv2d
ReLU
Sequential
ConvLayer
MaxPool2d
Conv2d
Conv2d
Conv2d
ReLU
Conv2d
ReLU
Sequential
ConvLayer
MaxPool2d
Conv2d
Conv2d
Conv2d
ReLU
Conv2d
ReLU
Sequential
ConvLayer
ConvTranspose2d
Conv2d
Conv2d
Conv2d
ReLU
Conv2d
ReLU
Sequential
ConvLayer
ConvTranspose2d
Conv2d
Conv2d
Conv2d
ReLU
Conv2d
ReLU
Sequential
ConvLayer
ConvTranspose2d
Conv2d
Conv2d
Conv2d
ReLU
Conv2d
ReLU
Sequential
ConvLayer
ConvTranspose2d
Conv2d
Conv2d
Conv2d
ReLU
Conv2d
ReLU
Sequential
ConvLayer
Conv2d
Softmax
UNet


UNet(
  (conv1): ConvLayer(
    (_conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (_conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (net): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (3): ReLU()
    )
  )
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): ConvLayer(
    (_conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (_conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (net): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
      (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (3): ReLU()
    )
  )
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): ConvLayer(
    (_conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (_conv2): Conv2d(256, 256, kernel_

torch.Size([1, 2, 836, 836])