In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
from loader import SingleChannelDataset



In [128]:
class Encoder(nn.Module):
    
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv3d(
            in_channels = 1,
            out_channels = 8,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.maxpool1 = nn.MaxPool3d(
            kernel_size = (2, 2, 2),
            stride = (2, 2, 2)
        )
        
        self.conv2 = nn.Conv3d(
            in_channels = 8,
            out_channels = 16,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.maxpool2 = nn.MaxPool3d(
            kernel_size = (3, 3, 3),
            stride = (2, 2, 2),
            padding = (1, 1, 1)
        )
        
        self.conv3 = nn.Conv3d(
            in_channels = 16,
            out_channels = 32,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.maxpool3 = nn.MaxPool3d(
            kernel_size = (2, 2, 2),
            stride = (2, 2, 2)
        )
        
        self.bn1 = nn.BatchNorm3d(8, affine = True)
        self.bn2 = nn.BatchNorm3d(16, affine = True)
        self.bn3 = nn.BatchNorm3d(32, affine = True)
        
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.bn1(x)
        x = self.maxpool1(x)
        x = F.relu(self.conv2(x))
        x = self.bn2(x)
        x = self.maxpool2(x)
        x = F.relu(self.conv3(x))
        x = self.bn3(x)
        x = self.maxpool3(x)
        
        return x



In [134]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.upsample3 = nn.Upsample(
            scale_factor = 2,
            mode = 'nearest'
        )
        
        self.conv3_r = nn.Conv3d(
            in_channels = 32,
            out_channels = 16,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.upsample2 = nn.Upsample(
            scale_factor = 2,
            mode = 'nearest'
        )
        
        self.conv2_r = nn.Conv3d(
            in_channels = 16,
            out_channels = 8,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.upsample1 = nn.Upsample(
            scale_factor = 2,
            mode = 'nearest'
        )
        
        self.conv1_r = nn.Conv3d(
            in_channels = 8,
            out_channels = 1,
            kernel_size = (3, 3, 3),
            stride = (1, 1, 1),
            padding = (1, 1, 1)
        )
        
        self.bn3_r = nn.BatchNorm3d(16, affine = True)
        self.bn2_r = nn.BatchNorm3d(8, affine = True)
        self.bn1_r = nn.BatchNorm3d(1, affine = True)
    
    def forward(self, x):
        x = self.upsample3(x)
        x = F.relu(self.conv3_r(x))
        x = self.bn3_r(x)
        x = self.upsample2(x)
        x = F.relu(self.conv2_r(x))
        x = self.bn2_r(x)
        x = self.upsample1(x)
        x = F.relu(self.conv1_r(x))
        x = self.bn1_r(x)
        
        return x

In [135]:
device = torch.device('cpu')
encoder = Encoder().to(device)
decoder = Decoder().to(device)

In [136]:
from torchsummary import summary

summary(encoder, (1, 72, 72, 72))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [-1, 8, 72, 72, 72]             224
       BatchNorm3d-2        [-1, 8, 72, 72, 72]              16
         MaxPool3d-3        [-1, 8, 36, 36, 36]               0
            Conv3d-4       [-1, 16, 36, 36, 36]           3,472
       BatchNorm3d-5       [-1, 16, 36, 36, 36]              32
         MaxPool3d-6       [-1, 16, 18, 18, 18]               0
            Conv3d-7       [-1, 32, 18, 18, 18]          13,856
       BatchNorm3d-8       [-1, 32, 18, 18, 18]              64
         MaxPool3d-9          [-1, 32, 9, 9, 9]               0
         Upsample-10       [-1, 32, 18, 18, 18]               0
           Conv3d-11       [-1, 16, 18, 18, 18]          13,840
      BatchNorm3d-12       [-1, 16, 18, 18, 18]              32
         Upsample-13       [-1, 16, 36, 36, 36]               0
           Conv3d-14        [-1, 8, 36,

In [137]:
summary(decoder, (32, 9, 9, 9))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
          Upsample-1       [-1, 32, 18, 18, 18]               0
            Conv3d-2       [-1, 16, 18, 18, 18]          13,840
       BatchNorm3d-3       [-1, 16, 18, 18, 18]              32
          Upsample-4       [-1, 16, 36, 36, 36]               0
            Conv3d-5        [-1, 8, 36, 36, 36]           3,464
       BatchNorm3d-6        [-1, 8, 36, 36, 36]              16
          Upsample-7        [-1, 8, 72, 72, 72]               0
            Conv3d-8        [-1, 1, 72, 72, 72]             217
       BatchNorm3d-9        [-1, 1, 72, 72, 72]               2
Total params: 17,571
Trainable params: 17,571
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.09
Forward/backward pass size (MB): 42.71
Params size (MB): 0.07
Estimated Total Size (MB): 42.87
-------------------------------------------