In [1]:
import sys, io
import ast, re

import torch
import torch.nn as nn
import numpy as np
from torchsummary import summary
# from UNet import UNet3D

In [2]:
class Swish(torch.nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

In [9]:
class ConvBlock3D(nn.Module):
    """Convolution blocks of the form specified by `seq`.
    """
    def __init__(self, in_channels, out_channels=None, mid_channels=None,
            kernel_size=3, seq='CBA'):
        super().__init__()

        if out_channels is None:
            out_channels = in_channels

        self.in_channels = in_channels
        self.out_channels = out_channels
        if mid_channels is None:
            self.mid_channels = max(in_channels, out_channels)
        self.kernel_size = kernel_size

        self.bn_channels = in_channels
        self.idx_conv = 0
        self.num_conv = sum([seq.count(l) for l in ['U', 'D', 'C']])

        layers = [self._get_layer(l) for l in seq]

        self.convs = nn.Sequential(*layers)

    def _get_layer(self, l):
        if l == 'U':
            in_channels, out_channels = self._setup_conv()
            return nn.ConvTranspose3d(in_channels, out_channels, 2, stride = 2, dilation = 3)
        elif l == 'D':
            in_channels, out_channels = self._setup_conv()
            return nn.Conv3d(in_channels, out_channels, 2, stride = 2, padding = 3)
        elif l == 'C':
            in_channels, out_channels = self._setup_conv()
            return nn.Conv3d(in_channels, out_channels, self.kernel_size)
        elif l == 'B':
            return nn.BatchNorm3d(self.bn_channels)
        elif l == 'A':
            return Swish()
        else:
            raise NotImplementedError('layer type {} not supported'.format(l))

    def _setup_conv(self):
        self.idx_conv += 1

        in_channels = out_channels = self.mid_channels
        if self.idx_conv == 1:
            in_channels = self.in_channels
        if self.idx_conv == self.num_conv:
            out_channels = self.out_channels

        self.bn_channels = out_channels

        return in_channels, out_channels

    def forward(self, x):
        return self.convs(x)


def narrow_like3D(a, b):
    """Narrow a to be like b.
    Try to be symmetric but cut more on the right for odd difference,
    consistent with the downsampling.
    """
    for d in range(2, a.dim()):
        width = a.shape[d] - b.shape[d]
        half_width = width // 2
        a = a.narrow(d, half_width, a.shape[d] - width)
    return a


In [10]:
class UNet3D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv_0l = ConvBlock3D(in_channels, 64, seq='CAC')
        self.down_0l = ConvBlock3D(64, seq='BADBA')
        self.conv_1l = ConvBlock3D(64, seq='CBAC')
        self.down_1l = ConvBlock3D(64, seq='BADBA')

        self.conv_2c = ConvBlock3D(64, seq='CBAC')

        self.up_1r = ConvBlock3D(64, seq='BAUBA')
        self.conv_1r = ConvBlock3D(128, 64, seq='CBAC')
        self.up_0r = ConvBlock3D(64, seq='BAUBA')
        self.conv_0r = ConvBlock3D(128, out_channels, seq='CAC')

    def forward(self, x):
        y0 = self.conv_0l(x)
        x = self.down_0l(y0)

        y1 = self.conv_1l(x)
        x = self.down_1l(y1)

        x = self.conv_2c(x)

        x = self.up_1r(x)
        y1 = narrow_like3D(y1, x)
        x = torch.cat([y1, x], dim=1)
        del y1
        x = self.conv_1r(x)

        x = self.up_0r(x)
        y0 = narrow_like3D(y0, x)   
        x = torch.cat([y0, x], dim=1)
        del y0
        x = self.conv_0r(x)

        return x

    def check_pad_size(self):
        check_data = torch.empty(2,self.conv_0l.in_channels,100,100,100)
        output = self(check_data)
        pad_size = (check_data.shape[2]-output.shape[2])/2.
        return int(pad_size)


In [11]:
def short_summary(model, input_size):
    old_stdout = sys.stdout
    new_stdout = io.StringIO()
    sys.stdout = new_stdout
    
    try:
        summary(model, input_size)
    finally:
        print('torch summary failed')
        sys.stdout = old_stdout
    
    x = str(new_stdout.getvalue())
    sys.stdout = old_stdout
    
    block_counter = 1
    
    for y in x.split('\n'):
        if 'ConvBlock' in y:
            l = y.strip()
            l = re.sub(' +', ' ', l)

            channels = int(l.split()[2].strip().replace(',', ''))
            size = int(l.split()[3].strip().replace(',', ''))

            print(f'Block {block_counter} - output channels = {channels}, output size = {size}')
            block_counter += 1

In [12]:
device = torch.device('cuda:0')
model = UNet3D(1, 128).to(device)

In [13]:
short_summary(model, (1, 72, 72, 72))

Block 1 - output channels = 64, output size = 68
Block 2 - output channels = 64, output size = 37
Block 3 - output channels = 64, output size = 33
Block 4 - output channels = 64, output size = 19
Block 5 - output channels = 64, output size = 15
Block 6 - output channels = 64, output size = 32
Block 7 - output channels = 64, output size = 28
Block 8 - output channels = 64, output size = 58
Block 9 - output channels = 128, output size = 54


In [8]:
summary(model, (1, 72, 72, 72))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 64, 70, 70, 70]           1,792
             Swish-2       [-1, 64, 70, 70, 70]               0
            Conv3d-3       [-1, 64, 68, 68, 68]         110,656
       ConvBlock3D-4       [-1, 64, 68, 68, 68]               0
       BatchNorm3d-5       [-1, 64, 68, 68, 68]             128
             Swish-6       [-1, 64, 68, 68, 68]               0
            Conv3d-7       [-1, 64, 37, 37, 37]          32,832
       BatchNorm3d-8       [-1, 64, 37, 37, 37]             128
             Swish-9       [-1, 64, 37, 37, 37]               0
      ConvBlock3D-10       [-1, 64, 37, 37, 37]               0
           Conv3d-11       [-1, 64, 35, 35, 35]         110,656
      BatchNorm3d-12       [-1, 64, 35, 35, 35]             128
            Swish-13       [-1, 64, 35, 35, 35]               0
           Conv3d-14       [-1, 64, 33,