In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchsummary import summary

##model

In [None]:
# https://github.com/miguelvr/dropblock/blob/master/dropblock/dropblock.py
class DropBlock2D(nn.Module):
    def __init__(self, drop_prob, block_size):
        super(DropBlock2D, self).__init__()
        self.drop_prob = drop_prob
        self.block_size = block_size
    def forward(self, x):
        # shape: (bsize, channels, height, width)
        assert x.dim() == 4, \
            "Expected input with 4 dimensions (bsize, channels, height, width)"
        if not self.training or self.drop_prob == 0.:
            return x
        else:
            # get gamma value
            gamma = self._compute_gamma(x)
            # sample mask
            mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
            # place mask on input device
            mask = mask.to(x.device)
            # compute block mask
            block_mask = self._compute_block_mask(mask)
            # apply block mask
            out = x * block_mask[:, None, :, :]
            # scale output
            out = out * block_mask.numel() / block_mask.sum()
            return out
    def _compute_block_mask(self, mask):
        block_mask = F.max_pool2d(input=mask[:, None, :, :],
                                  kernel_size=(self.block_size, self.block_size),
                                  stride=(1, 1),
                                  padding=self.block_size // 2)

        if self.block_size % 2 == 0:
            block_mask = block_mask[:, :, :-1, :-1]

        block_mask = 1 - block_mask.squeeze(1)

        return block_mask
    def _compute_gamma(self, x):
        return self.drop_prob / (self.block_size ** 2)
class DropBlock3D(DropBlock2D):
    def __init__(self, drop_prob, block_size):
        super(DropBlock3D, self).__init__(drop_prob, block_size)
    def forward(self, x):
        # shape: (bsize, channels, depth, height, width)
        assert x.dim() == 5, \
            "Expected input with 5 dimensions (bsize, channels, depth, height, width)"
        if not self.training or self.drop_prob == 0.:
            return x
        else:
            # get gamma value
            gamma = self.drop_prob / (self.block_size ** 3)
            # sample mask
            mask = (torch.rand(x.shape[0], *x.shape[2:], device = x.device) < gamma).float()
            # place mask on input device
            # mask = mask.to(x.device)
            # compute block mask
            block_mask = self._compute_block_mask(mask)
            # apply block mask
            out = x * block_mask[:, None, :, :, :]
            # scale output
            out = out * block_mask.numel() / block_mask.sum()
            return out
    def _compute_block_mask(self, mask):
        block_mask = F.max_pool3d(input=mask[:, None, :, :, :],
                                  kernel_size=(self.block_size, self.block_size, self.block_size),
                                  stride=(1, 1, 1),
                                  padding=self.block_size // 2)
        if self.block_size % 2 == 0:
            block_mask = block_mask[:, :, :-1, :-1, :-1]
        block_mask = 1 - block_mask.squeeze(1)
        return block_mask

In [None]:
class CBAM(nn.Module):
    def __init__(self, in_channel, reduction_ratio = 8):
        super().__init__()
        self.hid_channel = max(1, in_channel // reduction_ratio)
        self.globalAvgPool = nn.AdaptiveAvgPool3d(1)
        self.globalMaxPool = nn.AdaptiveMaxPool3d(1)
        # Shared MLP.
        self.fc = nn.Sequential(nn.Conv3d(in_channel, self.hid_channel, 1, bias=False),
                               nn.Mish(),
                               nn.Conv3d(self.hid_channel, in_channel, 1, bias=False))
        self.sigmoid = nn.Sigmoid()
        self.conv1 = nn.Conv3d(2, 1, kernel_size=7, 
                               stride=1, padding=3, bias=False)
    def forward(self, x):
        ''' Channel attention '''
        avgOut = self.fc(self.globalAvgPool(x))
        maxOut = self.fc(self.globalMaxPool(x))
        Mc = self.sigmoid(avgOut + maxOut)
        Mf1 = Mc * x

        ''' Spatial attention. '''
        avg_out = torch.mean(Mf1, dim=1, keepdim=True)
        max_out, _ = torch.max(Mf1, dim=1, keepdim=True)

        Ms = torch.cat([max_out, avg_out], dim=1)
        Ms = self.sigmoid(self.conv1(Ms))
        Mf2 = Ms * Mf1
        return Mf2

In [None]:
class ConvBn(nn.Sequential):
    def __init__(self, in_channel, out_channel, kernel_size = 3, 
                 padding = 1, drop_block=False, block_size = 1, drop_prob = 0):
        super().__init__()
        self.add_module("conv",nn.Conv3d(in_channel, out_channel, kernel_size, padding = padding,bias=False))
        if drop_block:
            self.add_module("drop_block", DropBlock3D(block_size = block_size, drop_prob = drop_prob))
        self.add_module("bn", nn.BatchNorm3d(out_channel))
        self.add_module("mish", nn.Mish())
        self.add_module("cbam", CBAM(out_channel))

class DownSampleBlock(nn.Sequential):
    def __init__(self, in_channel, block_size = 1, drop_prob = 0):
        super().__init__()
        out_channel = in_channel // 2
        self.add_module("conv1", nn.Conv3d(in_channel, out_channel, 1, bias=False))
        self.add_module("drop_block1", DropBlock3D(block_size = block_size, drop_prob = drop_prob))
        self.add_module("bn", nn.BatchNorm3d(out_channel))
        self.add_module("mish", nn.Mish())
        self.add_module("cbam", CBAM(out_channel))
        self.add_module("conv2", nn.Conv3d(out_channel, out_channel, 2, 2, bias=False))
        self.add_module("drop_block2", DropBlock3D(block_size = block_size, drop_prob = drop_prob))


class AttentionBlock(nn.Module):
    def __init__(self, in_channel, in_channel_skip, out_channel):
        super().__init__()
        self.conv_input = nn.Sequential(
            nn.Conv3d(in_channel, out_channel, 1, padding = 0, bias=False),
            nn.BatchNorm3d(out_channel),
            nn.ConvTranspose3d(out_channel, out_channel, 2, 2),
            CBAM(out_channel)
        )
        self.conv_skip = nn.Sequential(
            nn.Conv3d(in_channel_skip, out_channel, 1, bias = False),
            nn.BatchNorm3d(out_channel),
        )
        self.mixed_weight = nn.Sequential(
            nn.Mish(),
            nn.Conv3d(out_channel, 1, 1, bias = False),
            nn.BatchNorm3d(1),
            nn.Sigmoid()
        )
    def forward(self, x, skip):
        input_weight = self.conv_input(x)
        skip_weight = self.conv_skip(skip)
        output_weight = self.mixed_weight(input_weight + skip_weight)
        return output_weight * skip

class DenseLayer(nn.Module):
    def __init__(self, in_channel, grow_rate):
        super().__init__()
        self.layer = nn.Sequential(
            ConvBn(in_channel, grow_rate*4,kernel_size=1, padding=0),
            ConvBn(grow_rate*4, grow_rate)
        )
    def forward(self, x):
        output = self.layer(x)
        return torch.cat([output, x], dim = 1)

class DenseBlock(nn.Sequential):
    def __init__(self, in_channel, grow_rate, repetition):
        super().__init__()
        for i in range(repetition):
            layer = DenseLayer(in_channel+i*grow_rate, grow_rate)
            self.add_module(f"dense_layer_{i+1}", layer)

class DecoderBlock(nn.Module):
    def __init__(self, in_channel, in_channel_skip, out_channel, 
                 block_size = 1, drop_prob = 0):
        super().__init__()
        self.conv_trans = nn.ConvTranspose3d(in_channel, out_channel, 2, 2)
        self.attention = AttentionBlock(in_channel, in_channel_skip, out_channel)
        self.convbn = ConvBn(in_channel_skip + out_channel, out_channel, drop_block=True,
                            block_size = block_size, drop_prob = drop_prob)
    
    def forward(self, x, skip):
        output = self.conv_trans(x)
        attention = self.attention(x, skip)
        output = torch.cat([output, attention], dim=1)
        return self.convbn(output)

class UpsampleBlock(nn.Sequential):
    def __init__(self,  in_channel, out_channel, times):
        super().__init__()
        for i in range(times):
            channel = in_channel if i == 0 else out_channel
            self.add_module(f"convtrans{i+1}", nn.ConvTranspose3d(channel, out_channel, 2, 2))
            self.add_module(f"cbam{i+1}", CBAM(out_channel))  

In [None]:
class SegNet(nn.Module):
    def __init__(self, input_channel = 2, in_channel = 32, 
                 num_classes = 4, drop_prob = 0):
        super().__init__()
        self.conv1 = nn.Sequential(
            ConvBn(input_channel, in_channel),
            ConvBn(in_channel, in_channel)
        )
        grow_list = [16, 16, 32, 32, 32]
        repetition_list = [4, 6, 4, 6, 6]
        block_list = [4, 3, 2, 1]
        ch_decoder = [128, 64, 32, 16]
        in_ch_skip = []
        self.dense_list = nn.ModuleList()
        self.downsample_list = nn.ModuleList()
        self.decoder_list = nn.ModuleList()
        self.up_sample_list = nn.ModuleList()

        for i in range(4):
            self.dense_list.append(DenseBlock(in_channel, grow_list[i], repetition_list[i]))
            in_channel += repetition_list[i] * grow_list[i]
            in_ch_skip.append(in_channel)
            self.downsample_list.append(DownSampleBlock(in_channel, block_list[i], drop_prob))
            in_channel = in_channel // 2

        i+=1
        self.bottle_neck = DenseBlock(in_channel, grow_list[i], repetition_list[i])
        in_channel += repetition_list[i] * grow_list[i]
        for i in range(4):
            self.decoder_list.append(DecoderBlock(in_channel, in_ch_skip[-i-1], ch_decoder[i],
                                                  block_list[-i-1], drop_prob))
            self.up_sample_list.append(UpsampleBlock(in_channel, num_classes, 4-i))
            in_channel = ch_decoder[i]
        in_channel += 4 * num_classes

        self.conv2 = nn.Sequential(
            nn.BatchNorm3d(in_channel),
            nn.Mish(),
            nn.Conv3d(in_channel, num_classes, kernel_size=1, padding=0),
            nn.Softmax(dim=1)
            )
        
    def forward(self, x):
        x = self.conv1(x)
        encoder_for_cat = []
        output_cat = []
        for i in range(4):
            x = self.dense_list[i](x)
            encoder_for_cat.append(x)
            x = self.downsample_list[i](x)
        x = self.bottle_neck(x)
        output_cat.append(self.up_sample_list[0](x))
        for i in range(4):
            x = self.decoder_list[i](x, encoder_for_cat[-i-1])
            if i < 3:
                output_cat.append(self.up_sample_list[i+1](x))
        output_cat.append(x)
        output = torch.cat(output_cat, dim=1)
        output = self.conv2(output)
 
        return output


In [None]:
# summary(SegNet(), (2,32,32,32))

In [None]:
# S = SegNet(drop_prob=0.5)

In [None]:
# for layer in S.modules():
#     if isinstance(layer, DropBlock3D):
#         layer.drop_prob = 0.4

In [None]:
# for layer in S.modules():
#     if isinstance(layer, DropBlock3D):
#         print(layer.drop_prob)

0.4
0.4
0.4
0.4
0.4
0.4
0.4
0.4
0.4
0.4
0.4
0.4
