In [1]:
import torch
import torch.nn as nn
from torchsummary import summary

![nn](./vnet_architecture.png)

# Encoding convolution

## equal scale

In [2]:
x = torch.rand((4,1,128,128,128)) #batch, channel, depth, height, width

conv3d_block = nn.Sequential(
    nn.Conv3d(x.shape[1], 16, 3, 1, padding=1),
    nn.BatchNorm3d(16),
    nn.ReLU(),
)

summary(conv3d_block, x.shape[1:], device='cpu')
conv3d_block

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 16, 128, 128, 128]             448
       BatchNorm3d-2    [-1, 16, 128, 128, 128]              32
              ReLU-3    [-1, 16, 128, 128, 128]               0
Total params: 480
Trainable params: 480
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 8.00
Forward/backward pass size (MB): 768.00
Params size (MB): 0.00
Estimated Total Size (MB): 776.00
----------------------------------------------------------------


Sequential(
  (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)

In [3]:
x = torch.rand((4,1,128,128,128)) #batch, channel, depth, height, width

conv3d_block = nn.Sequential(
    nn.Conv3d(x.shape[1], 16, 5, 1, padding=2),
    nn.BatchNorm3d(16),
    nn.ReLU(),
)

summary(conv3d_block, x.shape[1:], device='cpu')
conv3d_block

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 16, 128, 128, 128]           2,016
       BatchNorm3d-2    [-1, 16, 128, 128, 128]              32
              ReLU-3    [-1, 16, 128, 128, 128]               0
Total params: 2,048
Trainable params: 2,048
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 8.00
Forward/backward pass size (MB): 768.00
Params size (MB): 0.01
Estimated Total Size (MB): 776.01
----------------------------------------------------------------


Sequential(
  (0): Conv3d(1, 16, kernel_size=(5, 5, 5), stride=(1, 1, 1), padding=(2, 2, 2))
  (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)

## extract

In [4]:
x = torch.rand((4,1,128,128,128)) #batch, channel, depth, height, width

conv3d_block = nn.Sequential(
    nn.Conv3d(x.shape[1], 16, 2, 2),
    nn.BatchNorm3d(16),
    nn.ReLU(),
)

summary(conv3d_block, x.shape[1:], device='cpu')
conv3d_block

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 64, 64, 64]             144
       BatchNorm3d-2       [-1, 16, 64, 64, 64]              32
              ReLU-3       [-1, 16, 64, 64, 64]               0
Total params: 176
Trainable params: 176
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 8.00
Forward/backward pass size (MB): 96.00
Params size (MB): 0.00
Estimated Total Size (MB): 104.00
----------------------------------------------------------------


Sequential(
  (0): Conv3d(1, 16, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)

# Decoding conv

In [5]:
x = torch.rand((4,1,128,128,128)) #batch, channel, depth, height, width

conv3d_block = nn.Sequential(
    nn.ConvTranspose3d(x.shape[1], 16, 2, 2),
    nn.BatchNorm3d(16),
    nn.ReLU(),
)

summary(conv3d_block, x.shape[1:], device='cpu')
conv3d_block

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose3d-1    [-1, 16, 256, 256, 256]             144
       BatchNorm3d-2    [-1, 16, 256, 256, 256]              32
              ReLU-3    [-1, 16, 256, 256, 256]               0
Total params: 176
Trainable params: 176
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 8.00
Forward/backward pass size (MB): 6144.00
Params size (MB): 0.00
Estimated Total Size (MB): 6152.00
----------------------------------------------------------------


Sequential(
  (0): ConvTranspose3d(1, 16, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
)

# Mini V-Net

In [7]:
class MiniVNet(nn.Module):
    def __init__(self):
        super(MiniVNet, self).__init__()
    
        def EncodingBlock(in_ch, out_ch, ker_size=2, stride=2, padding=1):
            layers = []
            layers += [nn.Conv3d(in_ch, out_ch, ker_size, stride, padding)]
            layers += [nn.BatchNorm3d(out_ch)]
            layers += [nn.PReLU()]
            return nn.Sequential(*layers)
    
        self.enc1_1 = EncodingBlock(in_ch=1, out_ch=16, ker_size=3, stride=1)
        self.down1 = nn.Conv3d(16, 32, kernel_size=2, stride=2) 

        self.enc2_1 = EncodingBlock(in_ch=32, out_ch=32, ker_size=3, stride=1)
        self.enc2_2 = EncodingBlock(in_ch=32, out_ch=32, ker_size=3, stride=1)
        self.up2 = nn.ConvTranspose3d(32, 32, kernel_size=2, stride=2) 
        
        self.dec1_1 = EncodingBlock(in_ch=32, out_ch=32, ker_size=3, stride=1)

        self.conv = nn.Conv3d(32, 3, 1, 1)

    def forward(self, x):
        enc1_1_res = x
        enc1_1 = self.enc1_1(x)
        enc1_1 += enc1_1_res
        
        enc2_res = self.down1(enc1_1)
        enc2_1 = self.enc2_1(enc2_res)
        enc2_2 = self.enc2_2(enc2_1)
        enc2_2 += enc2_res
        dec1_res = self.up2(enc2_2)
        
        dec1_1 = self.dec1_1(dec1_res)
        dec1_1 += dec1_res
        
        outputs = self.conv(dec1_1)
        
        return outputs

model = MiniVNet()
x = torch.rand((4,1,64,128,128))
summary(model, x.shape[1:], device='cpu')
model

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [-1, 16, 64, 128, 128]             448
       BatchNorm3d-2     [-1, 16, 64, 128, 128]              32
             PReLU-3     [-1, 16, 64, 128, 128]               1
            Conv3d-4       [-1, 32, 32, 64, 64]           4,128
            Conv3d-5       [-1, 32, 32, 64, 64]          27,680
       BatchNorm3d-6       [-1, 32, 32, 64, 64]              64
             PReLU-7       [-1, 32, 32, 64, 64]               1
            Conv3d-8       [-1, 32, 32, 64, 64]          27,680
       BatchNorm3d-9       [-1, 32, 32, 64, 64]              64
            PReLU-10       [-1, 32, 32, 64, 64]               1
  ConvTranspose3d-11     [-1, 32, 64, 128, 128]           8,224
           Conv3d-12     [-1, 32, 64, 128, 128]          27,680
      BatchNorm3d-13     [-1, 32, 64, 128, 128]              64
            PReLU-14     [-1, 32, 64, 1

MiniVNet(
  (enc1_1): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
  (down1): Conv3d(16, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (enc2_1): Sequential(
    (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
  (enc2_2): Sequential(
    (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
  (up2): ConvTranspose3d(32, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (dec1_1): Sequential(
    (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, mome

# V-Net

In [2]:
class VNet(nn.Module):
    def __init__(self):
        super().__init__()
    
        def EncodingBlock(in_ch, out_ch, ker_size=2, stride=2, padding=1):
            layers = []
            layers += [nn.Conv3d(in_ch, out_ch, ker_size, stride, padding)]
            layers += [nn.BatchNorm3d(out_ch)]
            layers += [nn.PReLU()]
            return nn.Sequential(*layers)
    
        self.enc1_1 = EncodingBlock(in_ch=1, out_ch=16, ker_size=3, stride=1)
        self.expand_ch1 = nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1) 
        self.down1 = nn.Conv3d(16, 32, kernel_size=2, stride=2) 

        self.enc2_1 = EncodingBlock(in_ch=32, out_ch=32, ker_size=3, stride=1)
        self.enc2_2 = EncodingBlock(in_ch=32, out_ch=32, ker_size=3, stride=1)
        self.expand_ch2 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1) 
        self.down2 = nn.Conv3d(32, 64, kernel_size=2, stride=2) 
        
        self.enc3_1 = EncodingBlock(in_ch=64, out_ch=64, ker_size=3, stride=1)
        self.enc3_2 = EncodingBlock(in_ch=64, out_ch=64, ker_size=3, stride=1)
        self.enc3_3 = EncodingBlock(in_ch=64, out_ch=64, ker_size=3, stride=1)
        self.expand_ch3 = nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1) 
        self.down3 = nn.Conv3d(64, 128, kernel_size=2, stride=2) 

        self.enc4_1 = EncodingBlock(in_ch=128, out_ch=128, ker_size=3, stride=1)
        self.enc4_2 = EncodingBlock(in_ch=128, out_ch=128, ker_size=3, stride=1)
        self.enc4_3 = EncodingBlock(in_ch=128, out_ch=128, ker_size=3, stride=1)
        self.expand_ch4 = nn.Conv3d(128, 256, kernel_size=3, stride=1, padding=1) 
        self.down4 = nn.Conv3d(128, 256, kernel_size=2, stride=2) 

        self.enc5_1 = EncodingBlock(in_ch=256, out_ch=256, ker_size=3, stride=1)
        self.enc5_2 = EncodingBlock(in_ch=256, out_ch=256, ker_size=3, stride=1)
        self.enc5_3 = EncodingBlock(in_ch=256, out_ch=256, ker_size=3, stride=1)
        self.up5 = nn.ConvTranspose3d(256, 256, kernel_size=2, stride=2) 
        
        self.dec4_1 = EncodingBlock(in_ch=256, out_ch=256, ker_size=3, stride=1)
        self.dec4_2 = EncodingBlock(in_ch=256, out_ch=256, ker_size=3, stride=1)
        self.dec4_3 = EncodingBlock(in_ch=256, out_ch=256, ker_size=3, stride=1)
        self.up4 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2) 
        
        self.dec3_1 = EncodingBlock(in_ch=128, out_ch=128, ker_size=3, stride=1)
        self.dec3_2 = EncodingBlock(in_ch=128, out_ch=128, ker_size=3, stride=1)
        self.dec3_3 = EncodingBlock(in_ch=128, out_ch=128, ker_size=3, stride=1)
        self.up3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2) 

        self.dec2_1 = EncodingBlock(in_ch=64, out_ch=64, ker_size=3, stride=1)
        self.dec2_2 = EncodingBlock(in_ch=64, out_ch=64, ker_size=3, stride=1)
        self.up2 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2) 

        self.dec1_1 = EncodingBlock(in_ch=32, out_ch=32, ker_size=3, stride=1)

        self.conv = nn.Conv3d(32, 3, 1, 1)

    def forward(self, x):
        enc1_1_res = x
        enc1_1 = self.enc1_1(x)
        enc1_1 += enc1_1_res
        enc2_res = self.down1(enc1_1)
        enc2_1 = self.enc2_1(enc2_res)
        enc2_2 = self.enc2_2(enc2_1)
        enc2_2 += enc2_res
        
        enc3_res = self.down2(enc2_2)
        enc3_1 = self.enc3_1(enc3_res)
        enc3_2 = self.enc3_2(enc3_1)
        enc3_3 = self.enc3_3(enc3_2)
        enc3_3 += enc3_res

        enc4_res = self.down3(enc3_3)
        enc4_1 = self.enc4_1(enc4_res)
        enc4_2 = self.enc4_2(enc4_1)
        enc4_3 = self.enc4_3(enc4_2)
        enc4_3 += enc4_res
        enc5_res = self.down4(enc4_3)
        enc5_1 = self.enc5_1(enc5_res)
        enc5_2 = self.enc5_2(enc5_1)
        enc5_3 = self.enc5_3(enc5_2)
        enc5_3 += enc5_res
        
        dec4_res = self.up5(enc5_3)
        enc4_3 = self.expand_ch4(enc4_3)
        dec4 = dec4_res + enc4_3
        dec4_1 = self.dec4_1(dec4)
        dec4_2 = self.dec4_1(dec4_1)
        dec4_3 = self.dec4_1(dec4_2)
        dec4_3 += dec4_res
        
        dec3_res = self.up4(enc4_3)
        enc3_3 = self.expand_ch3(enc3_3)
        dec3 = dec3_res + enc3_3
        dec3_1 = self.dec3_1(dec3)
        dec3_2 = self.dec3_1(dec3_1)
        dec3_3 = self.dec3_1(dec3_2)
        dec3_3 += dec3_res
        
        dec2_res = self.up3(enc3_3)
        enc2_2 = self.expand_ch2(enc2_2)
        dec2 = dec2_res + enc2_2
        dec2_1 = self.dec2_1(dec2)
        dec2_2 = self.dec2_1(dec2_1)
        dec2_2 += dec2_res
        
        dec1_res = self.up2(enc2_2)
        enc1_1 = self.expand_ch1(enc1_1)
        dec1 = dec1_res + enc1_1
        dec1_1 = self.dec1_1(dec1)
        dec1_1 += dec1_res
        
        outputs = self.conv(dec1_1)
        
        return outputs

model = VNet()
x = torch.rand((4,1,64,128,128))
summary(model, x.shape[1:], device='cpu')
model

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [-1, 16, 64, 128, 128]             448
       BatchNorm3d-2     [-1, 16, 64, 128, 128]              32
             PReLU-3     [-1, 16, 64, 128, 128]               1
            Conv3d-4       [-1, 32, 32, 64, 64]           4,128
            Conv3d-5       [-1, 32, 32, 64, 64]          27,680
       BatchNorm3d-6       [-1, 32, 32, 64, 64]              64
             PReLU-7       [-1, 32, 32, 64, 64]               1
            Conv3d-8       [-1, 32, 32, 64, 64]          27,680
       BatchNorm3d-9       [-1, 32, 32, 64, 64]              64
            PReLU-10       [-1, 32, 32, 64, 64]               1
           Conv3d-11       [-1, 64, 16, 32, 32]          16,448
           Conv3d-12       [-1, 64, 16, 32, 32]         110,656
      BatchNorm3d-13       [-1, 64, 16, 32, 32]             128
            PReLU-14       [-1, 64, 16,

VNet(
  (enc1_1): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
  (expand_ch1): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (down1): Conv3d(16, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (enc2_1): Sequential(
    (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
  (enc2_2): Sequential(
    (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
  (expand_ch2): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (down2): Conv3d(32, 64, kernel_size=(2, 2,

# Another version

In [3]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm3d(out_channels),
            nn.PReLU()
        )
        
    def forward(self, x):
        return self.net(x)
    

In [10]:
x = torch.rand(4,1,32,32,32)
tmp = ConvBlock(in_channels=1, out_channels=16)
summary(tmp, x.shape[1:], device='cpu')
tmp

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 32, 32, 32]             448
       BatchNorm3d-2       [-1, 16, 32, 32, 32]              32
             PReLU-3       [-1, 16, 32, 32, 32]               1
Total params: 481
Trainable params: 481
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.12
Forward/backward pass size (MB): 12.00
Params size (MB): 0.00
Estimated Total Size (MB): 12.13
----------------------------------------------------------------


ConvBlock(
  (net): Sequential(
    (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): PReLU(num_parameters=1)
  )
)

In [4]:
class BigBlock(nn.Module):
    def __init__(self, depth, in_channels, out_channels):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(ConvBlock(in_channels, out_channels))
            in_channels = out_channels
            
    def forward(self, x):
        for l in self.layers:
            x = l(x)
        return x
    

In [11]:
x = torch.rand(4,1,32,32,32)
tmp = BigBlock(depth=3, in_channels=1, out_channels=16)
summary(tmp, x.shape[1:], device='cpu')
tmp

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 32, 32, 32]             448
       BatchNorm3d-2       [-1, 16, 32, 32, 32]              32
             PReLU-3       [-1, 16, 32, 32, 32]               1
         ConvBlock-4       [-1, 16, 32, 32, 32]               0
            Conv3d-5       [-1, 16, 32, 32, 32]           6,928
       BatchNorm3d-6       [-1, 16, 32, 32, 32]              32
             PReLU-7       [-1, 16, 32, 32, 32]               1
         ConvBlock-8       [-1, 16, 32, 32, 32]               0
            Conv3d-9       [-1, 16, 32, 32, 32]           6,928
      BatchNorm3d-10       [-1, 16, 32, 32, 32]              32
            PReLU-11       [-1, 16, 32, 32, 32]               1
        ConvBlock-12       [-1, 16, 32, 32, 32]               0
Total params: 14,403
Trainable params: 14,403
Non-trainable params: 0
---------------------------------

BigBlock(
  (layers): ModuleList(
    (0): ConvBlock(
      (net): Sequential(
        (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
      )
    )
    (1): ConvBlock(
      (net): Sequential(
        (0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
      )
    )
    (2): ConvBlock(
      (net): Sequential(
        (0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): PReLU(num_parameters=1)
      )
    )
  )
)

In [5]:
class VNet(nn.Module):
    def __init__(self):
        super().__init__()
    
        self.enc1 = BigBlock(depth=1, in_channels=1, out_channels=16)
        self.expand_ch1 = nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1) 
        self.down1 = nn.Conv3d(16, 32, kernel_size=2, stride=2) 
        
        self.enc2 = BigBlock(depth=2, in_channels=32, out_channels=32)
        self.expand_ch2 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1) 
        self.down2 = nn.Conv3d(32, 64, kernel_size=2, stride=2) 
        
        self.enc3 = BigBlock(depth=3, in_channels=64, out_channels=64)
        self.expand_ch3 = nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1) 
        self.down3 = nn.Conv3d(64, 128, kernel_size=2, stride=2) 

        self.enc4 = BigBlock(depth=3, in_channels=128, out_channels=128)
        self.expand_ch4 = nn.Conv3d(128, 256, kernel_size=3, stride=1, padding=1) 
        self.down4 = nn.Conv3d(128, 256, kernel_size=2, stride=2) 

        self.enc5 = BigBlock(depth=3, in_channels=256, out_channels=256)
        self.up5 = nn.ConvTranspose3d(256, 256, kernel_size=2, stride=2) 
    
        self.dec4 = BigBlock(depth=3, in_channels=256, out_channels=256)
        self.up4 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2) 
        
        self.dec3 = BigBlock(depth=3, in_channels=128, out_channels=128)
        self.up3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2) 

        self.dec2 = BigBlock(depth=2, in_channels=64, out_channels=64)
        self.up2 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2) 

        self.dec1 = BigBlock(depth=1, in_channels=32, out_channels=32)

        self.conv = nn.Conv3d(32, 3, 1, 1)

    def forward(self, x):
        enc1_res = x
        enc1 = self.enc1(x)
        enc1 += enc1_res
        
        enc2_res = self.down1(enc1)
        enc2 = self.enc2(enc2_res)
        enc2 += enc2_res
        
        enc3_res = self.down2(enc2)
        enc3 = self.enc3(enc3_res)
        enc3 += enc3_res

        enc4_res = self.down3(enc3)
        enc4 = self.enc4(enc4_res)
        enc4 += enc4_res
        
        enc5_res = self.down4(enc4)
        enc5 = self.enc5(enc5_res)
        enc5 += enc5_res
        
        dec4_res = self.up5(enc5)
        enc4 = self.expand_ch4(enc4)
        dec4 = dec4_res + enc4
        dec4 = self.dec4(dec4)
        dec4 += dec4_res
        
        dec3_res = self.up4(enc4)
        enc3 = self.expand_ch3(enc3)
        dec3 = dec3_res + enc3
        dec3 = self.dec3(dec3)
        dec3 += dec3_res
        
        dec2_res = self.up3(enc3)
        enc2 = self.expand_ch2(enc2)
        dec2 = dec2_res + enc2
        dec2 = self.dec2(dec2)
        dec2 += dec2_res
        
        dec1_res = self.up2(enc2)
        enc1 = self.expand_ch1(enc1)
        dec1 = dec1_res + enc1
        dec1 = self.dec1(dec1)
        dec1 += dec1_res
        
        outputs = self.conv(dec1)
        
        return outputs

model = VNet()
x = torch.rand((4,1,64,128,128))
summary(model, x.shape[1:], device='cpu')
model

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [-1, 16, 64, 128, 128]             448
       BatchNorm3d-2     [-1, 16, 64, 128, 128]              32
             PReLU-3     [-1, 16, 64, 128, 128]               1
         ConvBlock-4     [-1, 16, 64, 128, 128]               0
          BigBlock-5     [-1, 16, 64, 128, 128]               0
            Conv3d-6       [-1, 32, 32, 64, 64]           4,128
            Conv3d-7       [-1, 32, 32, 64, 64]          27,680
       BatchNorm3d-8       [-1, 32, 32, 64, 64]              64
             PReLU-9       [-1, 32, 32, 64, 64]               1
        ConvBlock-10       [-1, 32, 32, 64, 64]               0
           Conv3d-11       [-1, 32, 32, 64, 64]          27,680
      BatchNorm3d-12       [-1, 32, 32, 64, 64]              64
            PReLU-13       [-1, 32, 32, 64, 64]               1
        ConvBlock-14       [-1, 32, 32,

VNet(
  (enc1): BigBlock(
    (layers): ModuleList(
      (0): ConvBlock(
        (net): Sequential(
          (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): PReLU(num_parameters=1)
        )
      )
    )
  )
  (expand_ch1): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (down1): Conv3d(16, 32, kernel_size=(2, 2, 2), stride=(2, 2, 2))
  (enc2): BigBlock(
    (layers): ModuleList(
      (0): ConvBlock(
        (net): Sequential(
          (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): PReLU(num_parameters=1)
        )
      )
      (1): ConvBlock(
        (net): Sequential(
          (0): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1):