In [9]:
from torch import nn
from typing import List, Tuple
from torchvision import io
from torchvision import transforms
import torch

In [6]:
class ConvLayer(nn.Module):
    def __init__(
        self,
        inchannels,
        outchannels,
        kernel_size: int=3,
        prev_kernel_size: int=3,
        prev_n_channels: int=64):
        super(ConvLayer, self).__init__()
        self._kernel_size = kernel_size
        self._prev_kernel_size = prev_kernel_size
        self._prev_n_channels = prev_n_channels
        self.conv1 = nn.Conv2d(
            inchannels, 
            outchannels,
            kernel_size=self._kernel_size
            )
        self.conv2 = nn.Conv2d(
            outchannels,
            outchannels,
            kernel_size=self._kernel_size
            )
        self.net = nn.Sequential(
            self.conv1, 
            nn.ReLU(),
            self.conv2,
            nn.ReLU())
    
    def initialize_params(self):
        for layer in self.named_layers():
            print(layer, torch.sqrt(2./(self._prev_kernel_size**2) * self._prev_n_channels))

    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, inchannels: int=1, channel_list: Tuple=(1,64,128,256,512,1024,)):
        super(UNet, self).__init__()
        self.conv1 = ConvLayer(channel_list[0], channel_list[1],3,1,1)
        self.pool1 = nn.MaxPool2d((2,2), 2)
        self.conv2 = ConvLayer(
            channel_list[1], 
            channel_list[2], 
            prev_kernel_size=3, 
            prev_n_channels=channel_list[1])
        self.pool2 = nn.MaxPool2d((2,2), 2)
        self.conv3 = ConvLayer(
            channel_list[2], 
            channel_list[3],
            prev_kernel_size=3,
            prev_n_channels=channel_list[2]
            )
        self.pool3 = nn.MaxPool2d((2,2), 2)
        self.conv4 = ConvLayer(
            channel_list[3], 
            channel_list[4],
            prev_kernel_size=3,
            prev_n_channels=channel_list[3]
            )
        self.pool4 = nn.MaxPool2d((2,2), 2)
        self.conv5 = ConvLayer(
            channel_list[4], 
            channel_list[5],
            prev_kernel_size=3,
            prev_n_channels=channel_list[4]
            )        

        self.upsample1 = nn.ConvTranspose2d(channel_list[5], channel_list[4], 2, 2)
        self.conv6 = ConvLayer(
            channel_list[5],
            channel_list[4],
            prev_kernel_size=3,
            prev_n_channels=channel_list[4]
            )
        self.upsample2 = nn.ConvTranspose2d(channel_list[4], channel_list[3], 2,2)
        
        self.conv7 = ConvLayer(
            channel_list[4],
            channel_list[3],
            prev_kernel_size=3,
            prev_n_channels=channel_list[3]
            )
        
        self.upsample3 = nn.ConvTranspose2d(channel_list[3], channel_list[2], 2,2)
        self.conv8 = ConvLayer(
            channel_list[3],
            channel_list[2],
            prev_kernel_size=3,
            prev_n_channels=channel_list[2]
            )
        self.upsample4 = nn.ConvTranspose2d(channel_list[2], channel_list[1], 2,2)
        self.conv9 = ConvLayer(
            channel_list[2],
            channel_list[1],
            prev_kernel_size=3,
            prev_n_channels=channel_list[1]
            )
        self.conv10 = nn.Conv2d(channel_list[1], 2, 1, 1)
        self.activation = nn.Softmax(dim=1)

    def initialize_params(self):
        for name, layer in self.named_parameters():
            if "conv" in name:
                layer.initialize_params()
            raise NotImplementedError

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(self.pool1(x1))
        x3 = self.conv3(self.pool2(x2))
        x4 = self.conv4(self.pool3(x3))
        x5 = self.conv5(self.pool4(x4))
        
        x_up_1 = self.upsample1(x5)
        x_up_1 = torch.concat((
            transforms.CenterCrop(x_up_1.shape[2:3])(x4),
            x_up_1), dim=1)

        x_up_2 = self.upsample2(self.conv6(x_up_1))
        x_up_2 = torch.concat(
            (transforms.CenterCrop(x_up_2.shape[2:3])(x3),
            x_up_2), dim=1)

        x_up_3 = self.upsample3(self.conv7(x_up_2))
        x_up_3 = torch.concat(
            (transforms.CenterCrop(x_up_3.shape[2:3])(x2),
            x_up_3), dim=1)

        x_up_4 = self.upsample4(self.conv8(x_up_3))
        x_up_4 = torch.concat(
            (transforms.CenterCrop(x_up_4.shape[2:3])(x1),
            x_up_4), dim=1)

        xout = self.conv10(self.conv9(x_up_4))
        return self.activation(xout)


In [7]:
# Test the output shape of UNet
test_im = io.read_image('../xray_samp/sample/images/00000013_005.png')
test_im = test_im.float().unsqueeze(0)
test_im.shape

torch.Size([1, 1, 1024, 1024])

In [8]:
net = UNet(1)
net.initialize_params()
out = net(test_im)
out.shape


('conv1.conv1.weight', Parameter containing:
tensor([[[[-0.2236,  0.1261, -0.0761],
          [ 0.0939,  0.2668,  0.1707],
          [ 0.0949,  0.0703,  0.2829]]],


        [[[ 0.1893,  0.2719,  0.1629],
          [ 0.2106,  0.2429, -0.1048],
          [ 0.2301, -0.0748, -0.0665]]],


        [[[-0.2061,  0.3115, -0.1006],
          [ 0.1470, -0.0535,  0.0560],
          [-0.0963, -0.2960,  0.1497]]],


        [[[ 0.2358, -0.0700,  0.1266],
          [-0.0021, -0.2337,  0.3283],
          [-0.0972, -0.2653, -0.0564]]],


        [[[-0.0634,  0.0131,  0.3323],
          [ 0.3140, -0.0408,  0.3287],
          [-0.0625,  0.0909,  0.2788]]],


        [[[-0.0443,  0.3160, -0.2832],
          [ 0.1095, -0.1287,  0.2741],
          [-0.2002,  0.1846, -0.1373]]],


        [[[-0.2954, -0.2281,  0.1123],
          [ 0.0411,  0.0865, -0.2304],
          [ 0.2855, -0.1077, -0.1207]]],


        [[[ 0.3038,  0.2378, -0.1094],
          [ 0.1660, -0.1099, -0.1408],
          [ 0.0223, -0.2931, -

NotImplementedError: 

###### Initialize network weights as follows :  "For a network with our architecture (alternating convolution and ReLU layers) this can be achievedby  drawing  the  initial  weights  from  a  Gaussian  distribution  with  a  standard deviation of √2/N, where N denotes the number of incoming nodes of one neuron [5]. E.g. for a 3x3 convolution and 64 feature channels in the previous layer N= 9·64 = 576"


In [53]:
'''
only convolutional and upsampleolutional layers have learnable parameters.
So, search for each convolutional layer, to initialize, 
simply ask what the previous convolutional layer's kernel size was,
and find the number of channels of that previous layer (dim[1])
'''

In [7]:
for name, param in net.named_parameters():
    print(name)

conv1._conv1.weight
conv1._conv1.bias
conv1._conv2.weight
conv1._conv2.bias
conv2._conv1.weight
conv2._conv1.bias
conv2._conv2.weight
conv2._conv2.bias
conv3._conv1.weight
conv3._conv1.bias
conv3._conv2.weight
conv3._conv2.bias
conv4._conv1.weight
conv4._conv1.bias
conv4._conv2.weight
conv4._conv2.bias
conv5._conv1.weight
conv5._conv1.bias
conv5._conv2.weight
conv5._conv2.bias
deconv1.weight
deconv1.bias
conv6._conv1.weight
conv6._conv1.bias
conv6._conv2.weight
conv6._conv2.bias
deconv2.weight
deconv2.bias
conv7._conv1.weight
conv7._conv1.bias
conv7._conv2.weight
conv7._conv2.bias
deconv3.weight
deconv3.bias
conv8._conv1.weight
conv8._conv1.bias
conv8._conv2.weight
conv8._conv2.bias
deconv4.weight
deconv4.bias
conv9._conv1.weight
conv9._conv1.bias
conv9._conv2.weight
conv9._conv2.bias
conv10.weight
conv10.bias
