In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from loader import SingleChannelDataset
from torchsummary import summary


In [50]:
class ConvBlock(nn.Module):
    
    def __init__(self, depth, non_linearity):
        super(ConvBlock, self).__init__()
        
        self.depth = depth
        self.non_linearity = non_linearity
        
        self.in_channels = 16 * (2 ** int((depth - 2) / 2)) if depth != 1 else 1
        self.out_channels = 16 * (2 ** int((depth - 1) / 2))
        
        self.conv = nn.Conv3d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = (5, 5, 5),
            stride = (1, 1, 1),
            padding = (1, 1, 1),
            dilation = (2, 2, 2)
        )
        
        self.bn = nn.BatchNorm3d(self.out_channels, affine = True)
        
        self.maxpool = nn.MaxPool3d(
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1)
        )
    
    def forward(self, x):
        x = self.maxpool(
                self.non_linearity(
                    self.bn(
                        self.conv(x)
                    )
                )
            )
        
        return x








class ConvInverseBlock(nn.Module):
    
    def __init__(self, depth, num_layers, original_input_size, non_linearity):
        super(ConvInverseBlock, self).__init__()
        
        self.depth = depth
        self.non_linearity = non_linearity
        
        self.in_channels = 16 * (2 ** int((depth - 1) / 2))
        self.out_channels = 16 * (2 ** int((depth - 2) / 2)) if depth != 1 else 1
        
        upsample_size = original_input_size - 4 * (depth - 1) + 2
        
        self.upsample = nn.Upsample(
            size = upsample_size,
            mode = 'trilinear',
            align_corners = True
        )
        
        self.conv_r = nn.Conv3d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = (5, 5, 5),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.bn = nn.BatchNorm3d(self.out_channels, affine = True)

    
    def forward(self, x):
        x = self.non_linearity(
                self.bn(
                    self.conv_r(
                        self.upsample(x)
                    )
                )
            )
        
        return x







class Encoder(nn.Module):
    
    def __init__(self, num_layers, non_linearity):
        super(Encoder, self).__init__()
        
        self.non_linearity = non_linearity
        
        modules = []
        for i in range(1, num_layers + 1):
            modules.append(ConvBlock(i, self.non_linearity))
        
        self.conv = nn.Sequential(*modules)
    
    def forward(self, x):
        x = self.conv(x)
        
        return x






class Decoder(nn.Module):
    def __init__(self, num_layers, original_input_size, non_linearity):
        super(Decoder, self).__init__()
                
        modules = []
        for i in range(num_layers, 0, -1):
            modules.append(ConvInverseBlock(i, num_layers, original_input_size, non_linearity))
        
        self.conv_inv = nn.Sequential(*modules)
    
    def forward(self, x):
        x = self.conv_inv(x)
        
        return x

In [51]:
device = torch.device('cuda:0')
encoder = Encoder(5, F.relu).to(device)
decoder = Decoder(5, 72, F.relu).to(device)

In [52]:
summary(encoder, (1, 72, 72, 72))
summary(decoder, (64, 32, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 66, 66, 66]           2,016
       BatchNorm3d-2       [-1, 16, 66, 66, 66]              32
         MaxPool3d-3       [-1, 16, 64, 64, 64]               0
         ConvBlock-4       [-1, 16, 64, 64, 64]               0
            Conv3d-5       [-1, 16, 58, 58, 58]          32,016
       BatchNorm3d-6       [-1, 16, 58, 58, 58]              32
         MaxPool3d-7       [-1, 16, 56, 56, 56]               0
         ConvBlock-8       [-1, 16, 56, 56, 56]               0
            Conv3d-9       [-1, 32, 50, 50, 50]          64,032
      BatchNorm3d-10       [-1, 32, 50, 50, 50]              64
        MaxPool3d-11       [-1, 32, 48, 48, 48]               0
        ConvBlock-12       [-1, 32, 48, 48, 48]               0
           Conv3d-13       [-1, 32, 42, 42, 42]         128,032
      BatchNorm3d-14       [-1, 32, 42,