In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchsummary import summary
from torch.nn import init

In [None]:
# https://github.com/miguelvr/dropblock/blob/master/dropblock/dropblock.py
class DropBlock2D(nn.Module):
    def __init__(self, drop_prob, block_size):
        super(DropBlock2D, self).__init__()
        self.drop_prob = drop_prob
        self.block_size = block_size
    def forward(self, x):
        # shape: (bsize, channels, height, width)
        assert x.dim() == 4, \
            "Expected input with 4 dimensions (bsize, channels, height, width)"
        if not self.training or self.drop_prob == 0.:
            return x
        else:
            # get gamma value
            gamma = self.drop_prob / (self.block_size ** 2)
            # sample mask
            mask = (torch.rand(x.shape[0], *x.shape[2:], device= x.device) < gamma).float()
            # compute block mask
            block_mask = self._compute_block_mask(mask)
            # apply block mask
            out = x * block_mask[:, None, :, :]
            # scale output
            out = out * block_mask.numel() / block_mask.sum()
            return out
    def _compute_block_mask(self, mask):
        block_mask = F.max_pool2d(input=mask[:, None, :, :],
                                  kernel_size=(self.block_size, self.block_size),
                                  stride=(1, 1),
                                  padding=self.block_size // 2)

        if self.block_size % 2 == 0:
            block_mask = block_mask[:, :, :-1, :-1]
        block_mask = 1 - block_mask.squeeze(1)
        return block_mask

class DropBlock3D(DropBlock2D):
    def __init__(self, drop_prob, block_size):
        super(DropBlock3D, self).__init__(drop_prob, block_size)
    def forward(self, x):
        # shape: (bsize, channels, depth, height, width)
        assert x.dim() == 5, \
            "Expected input with 5 dimensions (bsize, channels, depth, height, width)"
        if not self.training or self.drop_prob == 0.:
            return x
        else:
            # get gamma value
            gamma = self.drop_prob / (self.block_size ** 3)
            # sample mask
            mask = (torch.rand(x.shape[0], *x.shape[2:]) < gamma).float()
            # place mask on input device
            mask = mask.to(x.device)
            # compute block mask
            block_mask = self._compute_block_mask(mask)
            # apply block mask
            out = x * block_mask[:, None, :, :, :]
            # scale output
            out = out * block_mask.numel() / block_mask.sum()
            return out
    def _compute_block_mask(self, mask):
        block_mask = F.max_pool3d(input=mask[:, None, :, :, :],
                                  kernel_size=(self.block_size, self.block_size, self.block_size),
                                  stride=(1, 1, 1),
                                  padding=self.block_size // 2)
        if self.block_size % 2 == 0:
            block_mask = block_mask[:, :, :-1, :-1, :-1]
        block_mask = 1 - block_mask.squeeze(1)
        return block_mask

In [None]:
class conv_block(nn.Sequential):
    def __init__(self, ch_in, ch_out, kernel_size = 3, 
                 padding = 1, drop_block=False, block_size = 1, drop_prob = 0):
        super().__init__()
        self.add_module("conv1",nn.Conv2d(ch_in, ch_out, kernel_size, padding = padding,bias=False))
        self.add_module("bn1", nn.BatchNorm2d(ch_out))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv2",nn.Conv2d(ch_out, ch_out, kernel_size, padding = padding,bias=False))
        if drop_block:
            self.add_module("drop_block", DropBlock2D(block_size = block_size, drop_prob = drop_prob))
        self.add_module("bn2", nn.BatchNorm2d(ch_out))
        self.add_module("relu2", nn.ReLU(inplace=True))

class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=2,stride=1,padding="same",bias=False,),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

class U_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=2, drop_prob = 0):
        super(U_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)
        channel = 64
        self.Conv1 = conv_block(ch_in=img_ch,ch_out=channel)
        self.Conv2 = conv_block(ch_in=channel,ch_out=channel*2)
        self.Conv3 = conv_block(ch_in=channel*2,ch_out=channel*4)
        self.Conv4 = conv_block(ch_in=channel*4,ch_out=channel*8, drop_block=True, block_size = 5, drop_prob = drop_prob)
        self.Conv5 = conv_block(ch_in=channel*8,ch_out=channel*16, drop_block=True, block_size = 3, drop_prob = drop_prob)

        self.Up5 = up_conv(ch_in=channel*16,ch_out=channel*8)
        self.Up_conv5 = conv_block(ch_in=channel*16, ch_out=channel*8)

        self.Up4 = up_conv(ch_in=channel*8,ch_out=channel*4)
        self.Up_conv4 = conv_block(ch_in=channel*8, ch_out=channel*4)
        
        self.Up3 = up_conv(ch_in=channel*4,ch_out=channel*2)
        self.Up_conv3 = conv_block(ch_in=channel*4, ch_out=channel*2)
        
        self.Up2 = up_conv(ch_in=channel*2,ch_out=channel)
        self.Up_conv2 = conv_block(ch_in=channel*2, ch_out=channel)

        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(channel, output_ch,kernel_size=1,stride=1,padding=0), 
            nn.Softmax(dim=1)
            )


    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4,d5),dim=1)
        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi

class AttU_Net(nn.Module):
    def __init__(self,img_ch=3,output_ch=2, drop_prob=0):
        super(AttU_Net,self).__init__()
        
        self.Maxpool = nn.MaxPool2d(kernel_size=2,stride=2)

        self.Conv1 = conv_block(ch_in=img_ch,ch_out=64)
        self.Conv2 = conv_block(ch_in=64,ch_out=128)
        self.Conv3 = conv_block(ch_in=128,ch_out=256)
        self.Conv4 = conv_block(ch_in=256,ch_out=512, drop_block=True, block_size = 5, drop_prob = drop_prob)
        self.Conv5 = conv_block(ch_in=512,ch_out=1024, drop_block=True, block_size = 5, drop_prob = drop_prob)

        self.Up5 = up_conv(ch_in=1024,ch_out=512)
        self.Att5 = Attention_block(F_g=512,F_l=512,F_int=256)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512,ch_out=256)
        self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
        
        self.Up3 = up_conv(ch_in=256,ch_out=128)
        self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
        
        self.Up2 = up_conv(ch_in=128,ch_out=64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv_1x1 = nn.Sequential(
            nn.Conv2d(64, output_ch, kernel_size=1,stride=1,padding=0), 
            nn.Softmax(dim=1)
            )

    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)
        
        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        x4 = self.Att5(d5,x4)
        d5 = torch.cat((x4,d5),dim=1)        
        d5 = self.Up_conv5(d5)
        
        d4 = self.Up4(d5)
        x3 = self.Att4(d4,x3)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        x2 = self.Att3(d3,x2)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        x1 = self.Att2(d2,x1)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        return d1


In [None]:
# summary(AttU_Net(), (3,192,256))