In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from loader import SingleChannelDataset
from torchsummary import summary


In [36]:


class Encoder(nn.Module):
    
    def __init__(self, num_layers, non_linearity):
        super(Encoder, self).__init__()
        
        self.non_linearity = non_linearity
        
        self.conv1 = nn.Conv3d(
            in_channels = 1,
            out_channels = 8,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.maxpool1 = nn.MaxPool3d(
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1)
        )
        
        self.conv2 = nn.Conv3d(
            in_channels = 8,
            out_channels = 16,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.maxpool2 = nn.MaxPool3d(
            kernel_size = (3, 3, 3),
            stride = (2, 2, 2),
            padding = (1, 1, 1)
        )
        
        self.conv3 = nn.Conv3d(
            in_channels = 16,
            out_channels = 32,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.maxpool3 = nn.MaxPool3d(
            kernel_size = (2, 2, 2),
            stride = (2, 2, 2)
        )
        
        self.bn1 = nn.BatchNorm3d(8, affine = True)
        self.bn2 = nn.BatchNorm3d(16, affine = True)
        self.bn3 = nn.BatchNorm3d(32, affine = True)
        
    
    def forward(self, x):
        x = self.non_linearity(self.conv1(x))
        x = self.bn1(x)
        x = self.maxpool1(x)
        x = self.non_linearity(self.conv2(x))
        x = self.bn2(x)
        x = self.maxpool2(x)
        x = self.non_linearity(self.conv3(x))
        x = self.bn3(x)
        x = self.maxpool3(x)
        
        return x



In [40]:
encoder = Encoder(F.relu).to(device)
summary(encoder, (1, 72, 72, 72))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 72, 72, 72]             448
       BatchNorm3d-2       [-1, 16, 72, 72, 72]              32
         MaxPool3d-3       [-1, 16, 70, 70, 70]               0
Total params: 480
Trainable params: 480
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.42
Forward/backward pass size (MB): 133.00
Params size (MB): 0.00
Estimated Total Size (MB): 134.42
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 70, 70, 70]           6,928
       BatchNorm3d-2       [-1, 16, 70, 70, 70]              32
         MaxPool3d-3       [-1, 16, 68, 68, 68]               0
Total params: 6,960
Trainable params: 6,960
N

In [9]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.upsample3 = nn.Upsample(
            scale_factor = 2,
            mode = 'nearest'
        )
        
        self.conv3_r = nn.Conv3d(
            in_channels = 32,
            out_channels = 16,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.upsample2 = nn.Upsample(
            scale_factor = 2,
            mode = 'nearest'
        )
        
        self.conv2_r = nn.Conv3d(
            in_channels = 16,
            out_channels = 8,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.upsample1 = nn.Upsample(
            scale_factor = 2,
            mode = 'nearest'
        )
        
        self.conv1_r = nn.Conv3d(
            in_channels = 8,
            out_channels = 1,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.bn3_r = nn.BatchNorm3d(16, affine = True)
        self.bn2_r = nn.BatchNorm3d(8, affine = True)
        self.bn1_r = nn.BatchNorm3d(1, affine = True)
    
    def forward(self, x):
        x = self.upsample3(x)
        x = F.relu(self.conv3_r(x))
        x = self.bn3_r(x)
        x = self.upsample2(x)
        x = F.relu(self.conv2_r(x))
        x = self.bn2_r(x)
        x = self.upsample1(x)
        x = F.relu(self.conv1_r(x))
        x = self.bn1_r(x)
        
        return x

In [10]:
device = torch.device('cpu')
encoder = Encoder().to(device)
decoder = Decoder().to(device)

In [11]:

summary(encoder, (1, 72, 72, 72))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [-1, 8, 72, 72, 72]             224
       BatchNorm3d-2        [-1, 8, 72, 72, 72]              16
         MaxPool3d-3        [-1, 8, 71, 71, 71]               0
            Conv3d-4       [-1, 16, 71, 71, 71]           3,472
       BatchNorm3d-5       [-1, 16, 71, 71, 71]              32
         MaxPool3d-6       [-1, 16, 36, 36, 36]               0
            Conv3d-7       [-1, 32, 36, 36, 36]          13,856
       BatchNorm3d-8       [-1, 32, 36, 36, 36]              64
         MaxPool3d-9       [-1, 32, 18, 18, 18]               0
Total params: 17,664
Trainable params: 17,664
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.42
Forward/backward pass size (MB): 184.69
Params size (MB): 0.07
Estimated Total Size (MB): 186.18
-----------------------------------------

In [137]:
summary(decoder, (32, 9, 9, 9))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Upsample-1       [-1, 32, 18, 18, 18]               0
            Conv3d-2       [-1, 16, 18, 18, 18]          13,840
       BatchNorm3d-3       [-1, 16, 18, 18, 18]              32
          Upsample-4       [-1, 16, 36, 36, 36]               0
            Conv3d-5        [-1, 8, 36, 36, 36]           3,464
       BatchNorm3d-6        [-1, 8, 36, 36, 36]              16
          Upsample-7        [-1, 8, 72, 72, 72]               0
            Conv3d-8        [-1, 1, 72, 72, 72]             217
       BatchNorm3d-9        [-1, 1, 72, 72, 72]               2
Total params: 17,571
Trainable params: 17,571
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.09
Forward/backward pass size (MB): 42.71
Params size (MB): 0.07
Estimated Total Size (MB): 42.87
-------------------------------------------

In [3]:
class ConvBlock(nn.Module):
    
    def __init__(self, depth, non_linearity):
        super(ConvBlock, self).__init__()
        
        self.depth = depth
        self.non_linearity = non_linearity
        
        self.in_channels = 16 * (2 ** int((depth - 2) / 2)) if depth != 1 else 1
        self.out_channels = 16 * (2 ** int((depth - 1) / 2))
        
        # reduces input shape by 4
        self.conv = nn.Conv3d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = (5, 5, 5),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.bn = nn.BatchNorm3d(self.out_channels, affine = True)
        
        self.maxpool = nn.MaxPool3d(
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1)
        )
    
    def forward(self, x):
        x = self.maxpool(
                self.non_linearity(
                    self.bn(
                        self.conv(x)
                    )
                )
            )
        
        return x


In [5]:
device = torch.device('cuda:0')

conv_block = ConvBlock(1, F.relu).to(device)
summary(conv_block, (1, 72, 72, 72))

conv_block = ConvBlock(2, F.relu).to(device)
summary(conv_block, (16, 68, 68, 68))

# conv_block = ConvBlock(3, F.relu).to(device)
# summary(conv_block, (16, 68, 68, 68))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 70, 70, 70]           2,016
       BatchNorm3d-2       [-1, 16, 70, 70, 70]              32
         MaxPool3d-3       [-1, 16, 68, 68, 68]               0
Total params: 2,048
Trainable params: 2,048
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 1.42
Forward/backward pass size (MB): 122.12
Params size (MB): 0.01
Estimated Total Size (MB): 123.55
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 66, 66, 66]          32,016
       BatchNorm3d-2       [-1, 16, 66, 66, 66]              32
         MaxPool3d-3       [-1, 16, 64, 64, 64]               0
Total params: 32,048
Trainable params: 32

In [6]:


class Encoder(nn.Module):
    
    def __init__(self, num_layers, non_linearity):
        super(Encoder, self).__init__()
        
        self.non_linearity = non_linearity
        
        modules = []
        for i in range(1, num_layers + 1):
            modules.append(ConvBlock(i, self.non_linearity))
        
        self.conv = nn.Sequential(*modules)
    
    def forward(self, x):
        x = self.conv(x)
        
        return x



In [7]:
encoder = Encoder(7, F.relu).to(device)
summary(encoder, (1, 72, 72, 72))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 70, 70, 70]           2,016
       BatchNorm3d-2       [-1, 16, 70, 70, 70]              32
         MaxPool3d-3       [-1, 16, 68, 68, 68]               0
         ConvBlock-4       [-1, 16, 68, 68, 68]               0
            Conv3d-5       [-1, 16, 66, 66, 66]          32,016
       BatchNorm3d-6       [-1, 16, 66, 66, 66]              32
         MaxPool3d-7       [-1, 16, 64, 64, 64]               0
         ConvBlock-8       [-1, 16, 64, 64, 64]               0
            Conv3d-9       [-1, 32, 62, 62, 62]          64,032
      BatchNorm3d-10       [-1, 32, 62, 62, 62]              64
        MaxPool3d-11       [-1, 32, 60, 60, 60]               0
        ConvBlock-12       [-1, 32, 60, 60, 60]               0
           Conv3d-13       [-1, 32, 58, 58, 58]         128,032
      BatchNorm3d-14       [-1, 32, 58,

In [15]:
class ConvInverseBlock(nn.Module):
    
    def __init__(self, depth, num_layers, original_input_size, non_linearity):
        super(ConvInverseBlock, self).__init__()
        
        self.depth = depth
        self.non_linearity = non_linearity
        
        self.in_channels = 16 * (2 ** int((depth - 1) / 2))
        self.out_channels = 16 * (2 ** int((depth - 2) / 2)) if depth != 1 else 1
        
        upsample_size = original_input_size - 4 * (num_layers + depth - 7) + 2
        
        self.upsample = nn.Upsample(
            size = upsample_size,
            mode = 'nearest'
        )
        
        self.conv_r = nn.Conv3d(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            kernel_size = (5, 5, 5),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.bn = nn.BatchNorm3d(self.out_channels, affine = True)

    
    def forward(self, x):
        x = self.non_linearity(
                self.bn(
                    self.conv_r(
                        self.upsample(x)
                    )
                )
            )
        
        return x


In [16]:
device = torch.device('cuda:0')

conv_block = ConvInverseBlock(7, 7, 72, F.relu).to(device)
summary(conv_block, (128, 44, 44, 44))

conv_block = ConvInverseBlock(6, 7, 72, F.relu).to(device)
summary(conv_block, (64, 48, 48, 48))

# conv_block = ConvInverseBlock(6, 7, 72, F.relu).to(device)
# summary(conv_block, (64, 48, 48, 48))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Upsample-1      [-1, 128, 50, 50, 50]               0
            Conv3d-2       [-1, 64, 48, 48, 48]       1,024,064
       BatchNorm3d-3       [-1, 64, 48, 48, 48]             128
Total params: 1,024,192
Trainable params: 1,024,192
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 41.59
Forward/backward pass size (MB): 230.07
Params size (MB): 3.91
Estimated Total Size (MB): 275.57
----------------------------------------------------------------
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Upsample-1       [-1, 64, 54, 54, 54]               0
            Conv3d-2       [-1, 64, 52, 52, 52]         512,064
       BatchNorm3d-3       [-1, 64, 52, 52, 52]             128
Total params: 512,192
Trainable 

In [17]:
class Decoder(nn.Module):
    def __init__(self, num_layers, original_input_size, non_linearity):
        super(Decoder, self).__init__()
                
        modules = []
        for i in range(num_layers, 0, -1):
            modules.append(ConvInverseBlock(i, num_layers, original_input_size, non_linearity))
        
        self.conv_inv = nn.Sequential(*modules)
    
    def forward(self, x):
        x = self.conv_inv(x)
        
        return x

In [18]:
decoder = Decoder(7, 72, F.relu).to(device)
summary(decoder, (128, 44, 44, 44))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Upsample-1      [-1, 128, 50, 50, 50]               0
            Conv3d-2       [-1, 64, 48, 48, 48]       1,024,064
       BatchNorm3d-3       [-1, 64, 48, 48, 48]             128
  ConvInverseBlock-4       [-1, 64, 48, 48, 48]               0
          Upsample-5       [-1, 64, 54, 54, 54]               0
            Conv3d-6       [-1, 64, 52, 52, 52]         512,064
       BatchNorm3d-7       [-1, 64, 52, 52, 52]             128
  ConvInverseBlock-8       [-1, 64, 52, 52, 52]               0
          Upsample-9       [-1, 64, 58, 58, 58]               0
           Conv3d-10       [-1, 32, 56, 56, 56]         256,032
      BatchNorm3d-11       [-1, 32, 56, 56, 56]              64
 ConvInverseBlock-12       [-1, 32, 56, 56, 56]               0
         Upsample-13       [-1, 32, 62, 62, 62]               0
           Conv3d-14       [-1, 32, 60,