In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn import init
from torchinfo import summary as summary

In [5]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('Linear') != -1:
        init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_xavier(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.xavier_normal_(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.xavier_normal_(m.weight.data, gain=1)
    elif classname.find('BatchNorm') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_kaiming(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        init.orthogonal_(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.orthogonal_(m.weight.data, gain=1)
    elif classname.find('BatchNorm') != -1:
        init.normal_(m.weight.data, 1.0, 0.02)
        init.constant_(m.bias.data, 0.0)


def init_weights(net, init_type='normal'):
    #print('initialization method [%s]' % init_type)
    if init_type == 'normal':
        net.apply(weights_init_normal)
    elif init_type == 'xavier':
        net.apply(weights_init_xavier)
    elif init_type == 'kaiming':
        net.apply(weights_init_kaiming)
    elif init_type == 'orthogonal':
        net.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)

In [6]:
class unetConv2(nn.Module):
    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
        super(unetConv2, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.BatchNorm2d(out_size),
                                     nn.ReLU(inplace=True), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        else:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv2d(in_size, out_size, ks, s, p),
                                     nn.ReLU(inplace=True), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')

    def forward(self, inputs):
        x = inputs
        for i in range(1, self.n + 1):
            conv = getattr(self, 'conv%d' % i)
            x = conv(x)

        return x
    
    
class unetConv3(nn.Module):
    def __init__(self, in_size, out_size, is_batchnorm, n=2, ks=3, stride=1, padding=1):
        super(unetConv3, self).__init__()
        self.n = n
        self.ks = ks
        self.stride = stride
        self.padding = padding
        s = stride
        p = padding
        if is_batchnorm:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv3d(in_size, out_size, ks, s, p),
                                     nn.BatchNorm3d(out_size),
                                     nn.ReLU(inplace=True), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        else:
            for i in range(1, n + 1):
                conv = nn.Sequential(nn.Conv3d(in_size, out_size, ks, s, p),
                                     nn.ReLU(inplace=True), )
                setattr(self, 'conv%d' % i, conv)
                in_size = out_size

        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')

    def forward(self, inputs):
        x = inputs
        for i in range(1, self.n + 1):
            conv = getattr(self, 'conv%d' % i)
            x = conv(x)

        return x
    
    
class BlockPT(nn.Module):
    '''
    kwargs or not > 'dual'list kaiba~
    spatial_dims = 2 or 3
    pooling_params = [kernel_size(==stride), is_ceil]; ex) 8, True
    conv_params = [filtermap_size, concat_channels, kernel_size, padding, is_inplace]; ex) 64, 64*5, 3, 1, True
    '''
    def __init__(self, spatial_dims: int, conv_params: list, pooling_params: list):
        super(BlockPT, self).__init__()
        assert spatial_dims == 2 or spatial_dims == 3, 'invalid dimension input'

        filter_size, concat_channels = conv_params[0], conv_params[1]
        pooling_size, is_ceil = pooling_params[0], pooling_params[1]
        kernel_size, padding, is_inplace = conv_params[2], conv_params[3], conv_params[4]

        if spatial_dims == 2:
            self.PT = nn.MaxPool2d(pooling_size, pooling_size, ceil_mode=is_ceil)
            self.PT_conv = nn.Conv2d(filter_size, concat_channels, kernel_size, padding=padding)
            self.PT_norm = nn.BatchNorm2d(concat_channels)
            
        elif spatial_dims == 3:
            self.PT = nn.MaxPool3d(pooling_size, pooling_size, ceil_mode=is_ceil)
            self.PT_conv = nn.Conv3d(filter_size, concat_channels, kernel_size, padding=padding)
            self.PT_norm = nn.BatchNorm3d(concat_channels)

        self.PT_Layer = nn.Sequential(self.PT, self.PT_conv, self.PT_norm,
                                      nn.ReLU(inplace=is_inplace))
                
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.PT_Layer(x)
        

class BlockUT(nn.Module):
    '''
    kwargs or not > 'dual'list kaiba~
    spaital_dims = 2 or 3
    scale_params = [scaling_factor]
    conv_params = [filtermap_size(or Upchannel), concat_channels, kernel_size, padding, is_inplace]
    '''
    def __init__(self, spatial_dims: int, conv_params: list, scale_params: list):
        super(BlockUT, self).__init__()
        assert spatial_dims == 2 or spatial_dims == 3, 'invalid dimension input'
        mode = 'bilinear' if spatial_dims == 2 else 'trilinear'
    
        scale_factor = scale_params[0]
        filter_size, concat_channels = conv_params[0], conv_params[1]
        kernel_size, padding, is_inplace = conv_params[2], conv_params[3], conv_params[4]
                
        if spatial_dims == 2:
            self.UT_conv = nn.Conv2d(filter_size, concat_channels, kernel_size, padding=padding)
            self.UT_norm = nn.BatchNorm2d(concat_channels)
            
        elif spatial_dims == 3:
            self.UT_conv = nn.Conv3d(filter_size, concat_channels, kernel_size, padding=padding) 
            self.UT_norm = nn.BatchNorm3d(concat_channels)
        
        self.UT_Layer = nn.Sequential(self.UT_conv, self.UT_norm,
                                      nn.Upsample(scale_factor=scale_factor, mode=mode),
                                      nn.ReLU(inplace=True))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.UT_Layer(x)
        
    
class BlockCAT(nn.Module):
    '''
    spatial_dims = 2 or 3
    conv_params = [filtermap_size, concat_channels, kernel_size, padding, is_inplace]; ex) 64, 64*5, 3, 1, True
    '''
    def __init__(self, spatial_dims: int, conv_params: list):
        super(BlockCAT, self).__init__()    
        assert spatial_dims == 2 or spatial_dims == 3, 'invalid dimension input'

        filter_size, concat_channels = conv_params[0], conv_params[1]
        kernel_size, padding, is_inplace = conv_params[2], conv_params[3], conv_params[4] 
        
        if spatial_dims == 2:
            self.CAT_conv = nn.Conv2d(filter_size, concat_channels, kernel_size, padding=padding)
            self.CAT_norm = nn.BatchNorm2d(concat_channels)
        
        elif spatial_dims == 3:
            self.CAT_conv = nn.Conv3d(filter_size, concat_channels, kernel_size, padding=padding)
            self.CAT_norm = nn.BatchNorm3d(concat_channels)
            
        self.CAT_Layer = nn.Sequential(self.CAT_conv, self.CAT_norm,
                                       nn.ReLU(inplace=is_inplace))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.CAT_Layer(x)
            
        
class BlockFUSION(nn.Module):
    '''
    spatial_dims = 2 or 3
    conv_params = [filtermap_size, concat_channels, kernel_size, padding, is_inplace]
    '''
    def __init__(self, spatial_dims: int, conv_params: list):
        super(BlockFUSION, self).__init__()
        assert spatial_dims == 2 or spatial_dims == 3, 'invalid dimension input'
        
        filter_size, concat_channels = conv_params[0], conv_params[1]
        kernel_size, padding, is_inplace = conv_params[2], conv_params[3], conv_params[4]
        
        if spatial_dims == 2:
            self.FUSION_conv = nn.Conv2d(filter_size, concat_channels, kernel_size, padding=padding)
            self.FUSION_norm = nn.BatchNorm2d(concat_channels)
            
        elif spatial_dims == 3:
            self.FUSION_conv = nn.Conv3d(filter_size, concat_channels, kernel_size, padding=padding)
            self.FUSION_norm = nn.BatchNorm3d(concat_channels)
        
        self.FUSION_Layer = nn.Sequential(self.FUSION_conv, self.FUSION_norm,
                                          nn.ReLU(inplace=is_inplace))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.FUSION_Layer(x) # x must be concatenated by torch.cat method, 5 tensors with axis = 1

In [7]:
class CustomUNet3Plus(nn.Module):
    def __init__(self, in_channels: int=1, n_classes: int=1, spatial_dims: int=3, feature_scale: int=4, is_deconv: bool=True, is_batchnorm: bool=True, is_sigmoid: bool=True):
        super(CustomUNet3Plus, self).__init__()
        self.filters = [64, 128, 256, 512, 1024] # filter map list
        self.is_sigmoid = is_sigmoid    # boolean: apply sigmoid method or not

        self.CatBlocks = 5
        self.CatChannels = self.filters[0]
        self.UpChannels = self.CatChannels * self.CatBlocks

        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.spatial_dims = spatial_dims
        # self.feature_scale = feature_scale

        ### Encoder ###
        if self.spatial_dims==2: 
            self.conv1 = unetConv2(self.in_channels, self.filters[0], self.is_batchnorm)
            self.conv2 = unetConv2(self.filters[0], self.filters[1], self.is_batchnorm)
            self.conv3 = unetConv2(self.filters[1], self.filters[2], self.is_batchnorm)
            self.conv4 = unetConv2(self.filters[2], self.filters[3], self.is_batchnorm)
            self.conv5 = unetConv2(self.filters[3], self.filters[4], self.is_batchnorm)   # like bottle neck 
            
            self.maxpool1 = nn.MaxPool2d(kernel_size=2)
            self.maxpool2 = nn.MaxPool2d(kernel_size=2)
            self.maxpool3 = nn.MaxPool2d(kernel_size=2)
            self.maxpool4 = nn.MaxPool2d(kernel_size=2)
            self.output = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
                
        elif self.spatial_dims==3:
            self.conv1 = unetConv3(self.in_channels, self.filters[0], self.is_batchnorm)
            self.conv2 = unetConv3(self.filters[0], self.filters[1], self.is_batchnorm)
            self.conv3 = unetConv3(self.filters[1], self.filters[2], self.is_batchnorm)
            self.conv4 = unetConv3(self.filters[2], self.filters[3], self.is_batchnorm)
            self.conv5 = unetConv3(self.filters[3], self.filters[4], self.is_batchnorm)   # like bottle neck 
            
            self.maxpool1 = nn.MaxPool3d(kernel_size=2)
            self.maxpool2 = nn.MaxPool3d(kernel_size=2)
            self.maxpool3 = nn.MaxPool3d(kernel_size=2)
            self.maxpool4 = nn.MaxPool3d(kernel_size=2)
            self.output = nn.Conv3d(self.UpChannels, n_classes, 3, padding=1)
       
        ### decoder ###     
        self.indice1_1 = BlockPT(spatial_dims=self.spatial_dims, conv_params=[self.filters[0], self.CatChannels, 3, 1, True], pooling_params=[8, True])
        self.indice2_1 = BlockPT(spatial_dims=self.spatial_dims, conv_params=[self.filters[1], self.CatChannels, 3, 1, True], pooling_params=[4, True])
        self.indice3_1 = BlockPT(spatial_dims=self.spatial_dims, conv_params=[self.filters[2], self.CatChannels, 3, 1, True], pooling_params=[2, True])
        self.indice4_1 = BlockCAT(spatial_dims=self.spatial_dims, conv_params=[self.filters[3], self.CatChannels, 3, 1, True])
        self.indice5_1 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[2], conv_params=[self.filters[4], self.CatChannels, 3, 1, True])

        self.indice1_2 = BlockPT(spatial_dims=self.spatial_dims, conv_params=[self.filters[0], self.CatChannels, 3, 1, True], pooling_params=[4, True]) 
        self.indice2_2 = BlockPT(spatial_dims=self.spatial_dims, conv_params=[self.filters[1], self.CatChannels, 3, 1, True], pooling_params=[2, True])
        self.indice3_2 = BlockCAT(spatial_dims=self.spatial_dims, conv_params=[self.filters[2], self.CatChannels, 3, 1, True])
        self.indice4_2 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[2], conv_params=[self.UpChannels, self.CatChannels, 3, 1, True])
        self.indice5_2 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[4], conv_params=[self.filters[4], self.CatChannels, 3, 1, True])

        self.indice1_3 = BlockPT(spatial_dims=self.spatial_dims, conv_params=[self.filters[0], self.CatChannels, 3, 1, True], pooling_params=[2, True]) 
        self.indice2_3 = BlockCAT(spatial_dims=self.spatial_dims, conv_params=[self.filters[1], self.CatChannels, 3, 1, True])
        self.indice3_3 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[2], conv_params=[self.UpChannels, self.CatChannels, 3, 1, True])
        self.indice4_3 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[4], conv_params=[self.UpChannels, self.CatChannels, 3, 1, True])
        self.indice5_3 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[8], conv_params=[self.filters[4], self.CatChannels, 3, 1, True])
        
        self.indice1_4 = BlockCAT(spatial_dims=self.spatial_dims, conv_params=[self.filters[0], self.CatChannels, 3, 1, True])
        self.indice2_4 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[2], conv_params=[self.UpChannels, self.CatChannels, 3, 1, True])
        self.indice3_4 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[4], conv_params=[self.UpChannels, self.CatChannels, 3, 1, True])
        self.indice4_4 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[8], conv_params=[self.UpChannels, self.CatChannels, 3, 1, True])
        self.indice5_4 = BlockUT(spatial_dims=self.spatial_dims, scale_params=[16], conv_params=[self.filters[4], self.CatChannels, 3, 1, True])
        
        self.decoder = BlockFUSION(spatial_dims=self.spatial_dims, conv_params=[self.UpChannels, self.UpChannels, 3, 1, True])

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv3d): init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm3d): init_weights(m, init_type='kaiming')

                
    def forward(self, x):
        ### enc ###
        x1 = self.conv1(x)
        x1_out = self.maxpool1(x1)
        
        x2 = self.conv2(x1_out)
        x2_out = self.maxpool2(x2)
        
        x3 = self.conv3(x2_out)
        x3_out = self.maxpool3(x3)
        
        x4 = self.conv4(x3_out)
        x4_out = self.maxpool4(x4)
        
        x5_dec = self.conv5(x4_out)        
        
        ### dec, {dec_class_name}_(level)_idx ###
        x4_1 = self.indice1_1(x1)
        x4_2 = self.indice2_1(x2)
        x4_3 = self.indice3_1(x3)
        x4_4 = self.indice4_1(x4)
        x4_5 = self.indice5_1(x5_dec)
        x4_dec = self.decoder(torch.cat((x4_1, x4_2, x4_3, x4_4, x4_5), 1))

        x3_1 = self.indice1_2(x1)
        x3_2 = self.indice2_2(x2)
        x3_3 = self.indice3_2(x3)        
        x3_4 = self.indice4_2(x4_dec)
        x3_5 = self.indice5_2(x5_dec)
        x3_dec = self.decoder(torch.cat((x3_1, x3_2, x3_3, x3_4, x3_5), 1))
        
        x2_1 = self.indice1_3(x1)                
        x2_2 = self.indice2_3(x2)                
        x2_3 = self.indice3_3(x3_dec)                
        x2_4 = self.indice4_3(x4_dec)                
        x2_5 = self.indice5_3(x5_dec)                
        x2_dec = self.decoder(torch.cat((x2_1, x2_2, x2_3, x2_4, x2_5), 1))
        
        x1_1 = self.indice1_4(x1)
        x1_2 = self.indice2_4(x2_dec)
        x1_3 = self.indice3_4(x3_dec)
        x1_4 = self.indice4_4(x4_dec)
        x1_5 = self.indice5_4(x5_dec)
        x1_dec = self.decoder(torch.cat((x1_1, x1_2, x1_3, x1_4, x1_5), 1))
        
        x_out = self.output(x1_dec)
        
        if self.is_sigmoid == True: return torch.sigmoid(x_out)
        else: return x_out

In [10]:
device = torch.device(f"cuda:6" if torch.cuda.is_available() else "cpu")

DIM = 2
HWD = 32
BATCH = 2
CHANNEL = 1

if DIM == 3: input_size = (BATCH, CHANNEL, HWD, HWD, HWD)
elif DIM == 2: input_size = (BATCH, CHANNEL, HWD, HWD)
else: raise NotImplementedError

model = CustomUNet3Plus(in_channels=CHANNEL, n_classes=1, spatial_dims=DIM, feature_scale=4, is_deconv=True, is_batchnorm=True)
summary(model, input_size=input_size)


Layer (type:depth-idx)                   Output Shape              Param #
CustomUNet3Plus                          [2, 1, 32, 32]            --
├─unetConv2: 1-1                         [2, 64, 32, 32]           --
│    └─Sequential: 2-1                   [2, 64, 32, 32]           --
│    │    └─Conv2d: 3-1                  [2, 64, 32, 32]           640
│    │    └─BatchNorm2d: 3-2             [2, 64, 32, 32]           128
│    │    └─ReLU: 3-3                    [2, 64, 32, 32]           --
│    └─Sequential: 2-2                   [2, 64, 32, 32]           --
│    │    └─Conv2d: 3-4                  [2, 64, 32, 32]           36,928
│    │    └─BatchNorm2d: 3-5             [2, 64, 32, 32]           128
│    │    └─ReLU: 3-6                    [2, 64, 32, 32]           --
├─MaxPool2d: 1-2                         [2, 64, 16, 16]           --
├─unetConv2: 1-3                         [2, 128, 16, 16]          --
│    └─Sequential: 2-3                   [2, 128, 16, 16]          --
│    │  