In [1]:
from torch import nn
from typing import List
from torchvision import io
from torchvision import transforms
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [29]:
class ConvLayer(nn.Module):
    def __init__(self, inchannels, upsamp_multiplier):
        super(ConvLayer, self).__init__()
        outchannels = int(inchannels*upsamp_multiplier)
        self._conv1 = nn.Conv2d(
            inchannels, 
            outchannels,
            kernel_size=3)
        self._conv2 = nn.Conv2d(
            outchannels,
            outchannels,
            kernel_size=3
            )
        self.net = nn.Sequential(
            self._conv1, 
            nn.ReLU(),
            self._conv2,
            nn.ReLU())
    
    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, inchannels: int=1):
        super(UNet, self).__init__()
        self.conv1 = ConvLayer(inchannels, 64)
        self.pool1 = nn.MaxPool2d((2,2), 2)
        self.conv2 = ConvLayer(inchannels*64, 2)
        self.pool2 = nn.MaxPool2d((2,2), 2)
        self.conv3 = ConvLayer(inchannels*64*2, 2)
        self.pool3 = nn.MaxPool2d((2,2), 2)
        self.conv4 = ConvLayer(inchannels*64*4, 2)
        self.pool4 = nn.MaxPool2d((2,2), 2)
        self.conv5 = ConvLayer(inchannels*64*8, 2)        

        self.deconv1 = nn.ConvTranspose2d(inchannels*16*64, inchannels*8*64, (2,2),2)
        self.conv6 = ConvLayer(inchannels*16*64, 0.5)
        self.deconv2 = nn.ConvTranspose2d(inchannels*8*64, inchannels*4*64, (2,2),2)
        self.conv7 = ConvLayer(inchannels*8*64, 0.5)
        self.deconv3 = nn.ConvTranspose2d(inchannels*4*64, inchannels*2*64, (2,2),2)
        self.conv8 = ConvLayer(inchannels*4*64, 0.5)
        self.deconv4 = nn.ConvTranspose2d(inchannels*2*64, inchannels*64, (2,2),2)
        self.conv9 = ConvLayer(inchannels*2*64, 0.5)
        self.conv10 = nn.Conv2d(inchannels*64, 2, (1,1), 1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(self.pool1(x1))
        x3 = self.conv3(self.pool2(x2))
        x4 = self.conv4(self.pool3(x3))
        x5 = self.conv5(self.pool4(x4))
        print(x5.shape)
        
        x_up_1 = self.deconv1(x5)
        x_up_1 = torch.concat((
            transforms.CenterCrop(x_up_1.shape[2:3])(x4),
            x_up_1), axis=1)
        print(x_up_1.shape)

        x_up_2 = self.deconv2(self.conv6(x_up_1))
        x_up_2 = torch.concat(
            (transforms.CenterCrop(x_up_2.shape[2:3])(x3),
            x_up_2), dim=1)
        print(x_up_2.shape)

        x_up_3 = self.deconv3(self.conv7(x_up_2))
        x_up_3 = torch.concat(
            (transforms.CenterCrop(x_up_3.shape[2:3])(x2),
            x_up_3), dim=1)
        print(x_up_3.shape)

        x_up_4 = self.deconv4(self.conv8(x_up_3))
        x_up_4 = torch.concat(
            (transforms.CenterCrop(x_up_4.shape[2:3])(x1),
            x_up_4), dim=1)
        print(x_up_4.shape)
        xout = self.conv10(self.conv9(x_up_4))
        return xout

In [26]:
# Test the output shape of UNet
test_im = io.read_image('../xray_samp/sample/images/00000013_005.png')
test_im = test_im.float().unsqueeze(0)
test_im.shape

torch.Size([1, 1, 1024, 1024])

In [30]:
net = UNet(1)
out = net(test_im)

torch.Size([1, 1024, 56, 56])
torch.Size([1, 1024, 112, 112])


RuntimeError: Given transposed=1, weight of size [512, 256, 2, 2], expected input[1, 1024, 112, 112] to have 512 channels, but got 1024 channels instead

###### Initialize network weights as follows : √2/N, where N denotes the number of incoming nodes of one neuron [5]. E.g. for a 3x3 convolution and 64 feature channels in the previous layerN= 9·64 = 576


In [5]:
torch.cuda.is_available()

False