In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
import torchvision.transforms.functional as TF

from torchinfo import summary

In [3]:
pos_weight = torch.ones([64])
pos_weight.shape

torch.Size([64])

In [4]:
def build_conv_block(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv3d(in_channels[0], out_channels[0], kernel_size=3, padding=1, padding_mode='zeros'), 
      nn.ReLU(),
      nn.Conv3d(in_channels[1], out_channels[1], kernel_size=3, padding=1, padding_mode='zeros'), 
      nn.ReLU(),
      )


class DownBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = build_conv_block(in_channels=[in_channels[0], in_channels[1]], out_channels=[out_channels[0], out_channels[1]])
    self.down = nn.MaxPool3d(kernel_size=2)

  def forward(self, x):
    out = self.conv(x)
    downscaled = self.down(out)

    return out, downscaled


class UpBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.up = nn.ConvTranspose3d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=2, stride=2)
    # self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
    self.conv = build_conv_block(in_channels=[in_channels[0], in_channels[1]], out_channels=[out_channels[1], out_channels[1]])

  def forward(self, x, skip):
    upscaled = self.up(x)
    # print(upscaled.shape, skip.shape)
    concat = torch.cat([skip, upscaled], dim=1)
    # print(concat.shape)
    out = self.conv(concat)
    # print(out.shape)
    
    # print()
    return out


class Unet(nn.Module):
  def __init__(self):
      super().__init__()
      self.down1 = DownBlock(in_channels=[2, 16], out_channels=[16, 16])
      self.down2 = DownBlock(in_channels=[16, 32], out_channels=[32, 32])
      self.down3 = DownBlock(in_channels=[32, 64], out_channels=[64, 64])
      self.down4 = DownBlock(in_channels=[64, 128], out_channels=[128, 128])

      self.bottle_neck = build_conv_block(in_channels=[128, 256], out_channels=[256, 256])

      self.up1 = UpBlock(in_channels=[256, 128], out_channels=[128, 128])
      self.up2 = UpBlock(in_channels=[128, 64], out_channels=[64, 64])
      self.up3 = UpBlock(in_channels=[64, 32], out_channels=[32, 32])
      self.up4 = UpBlock(in_channels=[32, 16], out_channels=[16, 16])

      self.output = nn.Sequential(
          nn.Conv3d(in_channels=16, out_channels=1, kernel_size=1),
          nn.Sigmoid()   
      )
  
  def forward(self, x):
    # Contracting path
    skip1, down1 = self.down1(x)
    skip2, down2 = self.down2(down1)
    skip3, down3 = self.down3(down2)
    skip4, down4 = self.down4(down3)

    # Bottle neck
    bottom = self.bottle_neck(down4)

    # Expanding path
    up1 = self.up1(bottom, skip4)
    up2 = self.up2(up1, skip3)
    up3 = self.up3(up2, skip2)
    up4 = self.up4(up3, skip1)

    # Output
    out = self.output(up4)
    
    return out
  
model = Unet()
model

Unet(
  (down1): DownBlock(
    (conv): Sequential(
      (0): Conv3d(2, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ReLU()
      (2): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): ReLU()
    )
    (down): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down2): DownBlock(
    (conv): Sequential(
      (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ReLU()
      (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): ReLU()
    )
    (down): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (down3): DownBlock(
    (conv): Sequential(
      (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ReLU()
      (2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): ReLU()
    )
    (down): MaxPool3d(kernel_size

In [5]:
summary(model, input_size=(2, 2, 80, 128, 128))

Layer (type:depth-idx)                   Output Shape              Param #
Unet                                     [2, 1, 80, 128, 128]      --
├─DownBlock: 1-1                         [2, 16, 80, 128, 128]     --
│    └─Sequential: 2-1                   [2, 16, 80, 128, 128]     --
│    │    └─Conv3d: 3-1                  [2, 16, 80, 128, 128]     880
│    │    └─ReLU: 3-2                    [2, 16, 80, 128, 128]     --
│    │    └─Conv3d: 3-3                  [2, 16, 80, 128, 128]     6,928
│    │    └─ReLU: 3-4                    [2, 16, 80, 128, 128]     --
│    └─MaxPool3d: 2-2                    [2, 16, 40, 64, 64]       --
├─DownBlock: 1-2                         [2, 32, 40, 64, 64]       --
│    └─Sequential: 2-3                   [2, 32, 40, 64, 64]       --
│    │    └─Conv3d: 3-5                  [2, 32, 40, 64, 64]       13,856
│    │    └─ReLU: 3-6                    [2, 32, 40, 64, 64]       --
│    │    └─Conv3d: 3-7                  [2, 32, 40, 64, 64]       27,680
│  

In [156]:
o = model(torch.zeros((2, 2, 80, 128, 128)))
o.shape

torch.Size([2, 1, 128, 128, 80])

In [94]:
# o.shape

torch.Size([2, 256, 8, 8, 5])

In [87]:
# c = UpBlock(in_channels=[256, 128], out_channels=[128, 128])
# t = torch.zeros([2, 256, 8, 8, 5])
# t.shape

torch.Size([2, 256, 8, 8, 5])

In [88]:
# c(t, t)

In [23]:
# summary(model, input_size=(2, 2, 128, 128, 80))

Layer (type:depth-idx)                   Output Shape              Param #
Unet                                     --                        --
├─Sequential: 1-1                        [2, 16, 128, 128, 80]     --
│    └─Conv3d: 2-1                       [2, 16, 128, 128, 80]     880
│    └─ReLU: 2-2                         [2, 16, 128, 128, 80]     --
│    └─Conv3d: 2-3                       [2, 16, 128, 128, 80]     6,928
│    └─ReLU: 2-4                         [2, 16, 128, 128, 80]     --
├─Sequential: 1-2                        [2, 32, 64, 64, 40]       --
│    └─Conv3d: 2-5                       [2, 32, 64, 64, 40]       13,856
│    └─ReLU: 2-6                         [2, 32, 64, 64, 40]       --
│    └─Conv3d: 2-7                       [2, 32, 64, 64, 40]       27,680
│    └─ReLU: 2-8                         [2, 32, 64, 64, 40]       --
├─Sequential: 1-3                        [2, 64, 32, 32, 20]       --
│    └─Conv3d: 2-9                       [2, 64, 32, 32, 20]       55,360

In [10]:
def build_conv_block(in_channels, out_channels):
  return nn.Sequential(
      nn.Conv3d(in_channels[0], out_channels[0], kernel_size=3, padding=1, padding_mode='zeros'), 
      nn.ReLU(),
      nn.Conv3d(in_channels[1], out_channels[1], kernel_size=3, padding=1, padding_mode='zeros'), 
      nn.ReLU(),
      )


class DownBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = build_conv_block(in_channels=[in_channels[0], in_channels[1]], out_channels=[out_channels[0], out_channels[1]])
    self.down = nn.MaxPool3d(kernel_size=2)

  def forward(self, x):
    out = self.conv(x)
    downscaled = self.down(out)

    return out, downscaled

class UpBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.up = nn.ConvTranspose3d(in_channels=in_channels[0], out_channels=out_channels[0], kernel_size=2, stride=2)
    #self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
    self.conv = build_conv_block(in_channels=[in_channels[0], in_channels[1]], out_channels=[out_channels[1], out_channels[1]])

  def forward(self, x, skip):
    upscaled = self.up(x)
    # print(upscaled.shape, skip.shape)
    concat = torch.cat([skip, upscaled], dim=1)
    # print(concat.shape)
    out = self.conv(concat)
    # print(out.shape)
    
    # print()
    return out

class Unet(nn.Module):
  def __init__(self):
      super().__init__()
      self.down1 = DownBlock(in_channels=[2, 32], out_channels=[32, 32])
      self.down2 = DownBlock(in_channels=[32, 64], out_channels=[64, 64])
      self.down3 = DownBlock(in_channels=[64, 128], out_channels=[128, 128])
      # self.down4 = DownBlock(in_channels=[64, 128], out_channels=[128, 128])

      self.bottle_neck = build_conv_block(in_channels=[128, 256], out_channels=[256, 256])

      # self.up1 = UpBlock(in_channels=[256, 128], out_channels=[128, 128])
      self.up2 = UpBlock(in_channels=[256, 128], out_channels=[128, 128])
      self.up3 = UpBlock(in_channels=[128, 64], out_channels=[64, 64])
      self.up4 = UpBlock(in_channels=[64, 32], out_channels=[32, 32])

      self.output = nn.Sequential(
          nn.Conv3d(in_channels=32, out_channels=1, kernel_size=1),
          nn.Sigmoid()   
      )
  
  def forward(self, x):
    # Contracting path
    skip1, down1 = self.down1(x)
    skip2, down2 = self.down2(down1)
    skip3, down3 = self.down3(down2)
    # skip4, down4 = self.down4(down3)

    # Bottle neck
    bottom = self.bottle_neck(down3)

    # Expanding path
    # up1 = self.up1(bottom, skip4)
    up2 = self.up2(bottom, skip3)
    up3 = self.up3(up2, skip2)
    up4 = self.up4(up3, skip1)

    # Output
    out = self.output(up4)
    
    return out
  
# model = Unet()
# model

In [12]:
summary(Unet(), input_size=(2, 2, 40, 128, 128))

Layer (type:depth-idx)                   Output Shape              Param #
Unet                                     [2, 1, 40, 120, 120]      --
├─DownBlock: 1-1                         [2, 32, 40, 120, 120]     --
│    └─Sequential: 2-1                   [2, 32, 40, 120, 120]     --
│    │    └─Conv3d: 3-1                  [2, 32, 40, 120, 120]     1,760
│    │    └─ReLU: 3-2                    [2, 32, 40, 120, 120]     --
│    │    └─Conv3d: 3-3                  [2, 32, 40, 120, 120]     27,680
│    │    └─ReLU: 3-4                    [2, 32, 40, 120, 120]     --
│    └─MaxPool3d: 2-2                    [2, 32, 20, 60, 60]       --
├─DownBlock: 1-2                         [2, 64, 20, 60, 60]       --
│    └─Sequential: 2-3                   [2, 64, 20, 60, 60]       --
│    │    └─Conv3d: 3-5                  [2, 64, 20, 60, 60]       55,360
│    │    └─ReLU: 3-6                    [2, 64, 20, 60, 60]       --
│    │    └─Conv3d: 3-7                  [2, 64, 20, 60, 60]       110,656

In [8]:
model = Unet()

In [9]:
model(torch.zeros((2, 2, 40, 128, 128)))

RuntimeError: Given groups=1, weight of size [128, 256, 3, 3, 3], expected input[2, 384, 10, 32, 32] to have 256 channels, but got 384 channels instead

In [13]:
class Iteration(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.block = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channels, output_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        ) 

    def forward(self, x):
        x = self.block(x)
        return self.dropout(x)
        
class U_net(nn.Module):
    def __init__(self, input_channels=3, output_channels=5, features=[64, 128, 256, 512, 1024]):
        super(U_net, self).__init__()
        self.decoders = nn.ModuleList()
        self.encoders = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encode
        for feature in features:
            self.encoders.append(Iteration(input_channels, feature))
            input_channels = feature

        # Decode
        for feature in reversed(features):
            self.decoders.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
            self.decoders.append(Iteration(feature*2, feature))

        # Bottleneck
        self.bottleneck = Iteration(features[-1], features[-1]*2)

        # Last Conv
        self.last_conv = nn.Conv2d(features[0], output_channels, kernel_size=1)

    def forward(self, x):
            
        # Encoder path (maxpooling):
        skip_connections = []
        for encode in self.encoders:
            x = encode(x)
            # Store all feature maps from encoder path for skip connection
            skip_connections.append(x)
            x = self.pool(x)
        skip_connections = skip_connections[::-1]

        # Bottleneck of U-net 
        x = self.bottleneck(x)

        # Decoder path (upconvolution & concat. with skip_connections):
        for index in range(0, len(self.decoders), 2):
            x = self.decoders[index](x)
            skip_connection = skip_connections[index//2]
            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.decoders[index+1](concat_skip)
        
        return self.last_conv(x)

In [14]:
slimak = U_net()

In [15]:
slimak.decoders

ModuleList(
  (0): ConvTranspose2d(2048, 1024, kernel_size=(2, 2), stride=(2, 2))
  (1): Iteration(
    (dropout): Dropout(p=0.1, inplace=False)
    (block): Sequential(
      (0): Conv2d(2048, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (2): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
  (3): Iteration(
    (dropout): Dropout(p=0.1, inplace=False)
    (block): Sequential(
      (0): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d