In [5]:
import torch
from torch import nn
from torchsummary import summary

Defaulting to user installation because normal site-packages is not writeable
Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1
You should consider upgrading via the '/usr/local/bin/python3.8 -m pip install --upgrade pip' command.[0m


In [13]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size=3):
        super().__init__()
        self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size)
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv3d(out_ch, out_ch, kernel_size)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x


class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.doubleconv = DoubleConv(in_ch, out_ch)
        self.maxpool = nn.MaxPool3d(2)

    def forward(self, x):
        x = self.doubleconv(x)
        x = self.maxpool(x)
        return x


class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, strat='trans_conv'):
        super().__init__()
        if strat == 'trans_conv':
            self.upsample = nn.ConvTranspose3d(in_ch, in_ch, kernel_size=2, stride=2)
        else:
            self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.doubleconv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.upsample(x)
        x = self.doubleconv(x)
        return x


class UNet(nn.Module):
    def __init__(self, in_channels=1, init_filters=32):
        super().__init__()

        # encoder
        self.down1 = DownBlock(in_channels, init_filters)
        self.down2 = DownBlock(init_filters, init_filters*2)
        self.down3 = DownBlock(init_filters*2, init_filters*4)
        self.down4 = DownBlock(init_filters*4, init_filters*8)

        # decoder
        self.up1 = UpBlock(init_filters*8, init_filters*4)
        self.up2 = UpBlock(init_filters*4, init_filters*2)
        self.up3 = UpBlock(init_filters*2, init_filters)
        self.up4 = UpBlock(init_filters, in_channels)

        self.final_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)

    def forward(self, x):
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)
        x = self.final_conv(x)
        return x

In [14]:
unet = UNet()
patch_size = (188, 188, 188)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
summary(unet.to(device), input_size=(1, *patch_size), batch_size=1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [1, 32, 186, 186, 186]             896
         LeakyReLU-2     [1, 32, 186, 186, 186]               0
            Conv3d-3     [1, 32, 184, 184, 184]          27,680
         LeakyReLU-4     [1, 32, 184, 184, 184]               0
        DoubleConv-5     [1, 32, 184, 184, 184]               0
         MaxPool3d-6        [1, 32, 92, 92, 92]               0
         DownBlock-7        [1, 32, 92, 92, 92]               0
            Conv3d-8        [1, 64, 90, 90, 90]          55,360
         LeakyReLU-9        [1, 64, 90, 90, 90]               0
           Conv3d-10        [1, 64, 88, 88, 88]         110,656
        LeakyReLU-11        [1, 64, 88, 88, 88]               0
       DoubleConv-12        [1, 64, 88, 88, 88]               0
        MaxPool3d-13        [1, 64, 44, 44, 44]               0
        DownBlock-14        [1, 64, 44,

In [17]:
torch.cuda.empty_cache()