In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=(3,3), stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=(3,3), stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=(1,3), bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=(1,10), stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=(1,3), bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=(1,4), stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=1000):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=(1,3), stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=1)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=1)
        self.linear = nn.Linear(16, num_classes)
        self.flat = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
#         print("Input ==> ", x.size())
        out = F.relu(self.bn1(self.conv1(x)))
#         print("F.relu(self.bn1(self.conv1(x))) ==> ", x.size())
        out = self.layer1(out)
#         print("layer1 ==> ", x.size())
        out = self.layer2(out)
#         print("layer2 ==> ", x.size())
        out = self.layer3(out)
#         print("layer3 ==> ", x.size())
        out = self.layer4(out)
#         print("layer4 ==>", out.size())
        out = F.avg_pool2d(out, (6,4))
        print("avg_pool2d ===>", out.size())
        # out = out.view(out.size(0), -1)
        # print("out.view ===>", out.size())
        out = self.linear(out)
        out = self.flat(out)
#         print("Out ===>", out.size())
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])


def test():
    net = ResNet18()
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size())




In [16]:

if __name__ == "__main__":
    image = torch.rand(1, 3, 64, 64)
    unet = UNet(img_ch=3, output_ch=3)
    model = ResNet34()
    unet_out = unet(image)
    final = model(unet_out)
    print(final.size())

3x2Conv =>       torch.Size([62, 64])
max_pool_2x1 =>  torch.Size([31, 32])
3x3Conv =>       torch.Size([31, 32])
max_pool_2x1 =>  torch.Size([15, 16])
3x3Conv =>       torch.Size([15, 16])
up_trans_3x2 =>  torch.Size([31, 32])
up_conv_3x3 =>   torch.Size([31, 32])
up_trans_2x2 =>  torch.Size([62, 64])
up_conv_2x3 =>   torch.Size([64, 64])
Final =>         torch.Size([64, 64])
avg_pool2d ===> torch.Size([1, 512, 11, 16])
torch.Size([1, 1, 11, 1000])


In [17]:
import torch
import torch.nn as nn
from torch.nn import init
import logging
from torchvision import models

def double_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(3,2),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(3,2), padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def double_conv1(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(3,2),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(3,2),stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def double_conv2(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(3,3),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(3,3), padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def double_conv3(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3,padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=3, padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def double_conv4(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(2, 3),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(2, 3), padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def up_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2))
    return conv

def up_conv1(in_c, out_c):
    conv = nn.Sequential(
        nn.ConvTranspose2d(in_c, out_c, kernel_size=(3,2),stride=2))
    return conv

class UNet(nn.Module):
    def __init__(self,img_ch=1,output_ch=1):
        super(UNet, self).__init__()

        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=(2,1), stride=2)
        self.down_conv_1 = double_conv1(img_ch, 64)
        self.down_conv_2 = double_conv2(64, 128)
        self.down_conv_3 = double_conv3(128, 256)

        self.up_trans_3 = up_conv1(256, 128)
        self.up_conv_3 = double_conv2(256, 128)
        
        self.up_trans_4 = up_conv(128, 64)
        self.up_conv_4 = double_conv4(128, 64)
        
        self.out = nn.Conv2d(
            in_channels=64,
            out_channels=output_ch,
            kernel_size=1,stride=1,padding=0)
    
    def forward(self, image):
        # encoder
        #print("Input Image =>  ", image.size()[2:4])
        #print("Encoder =================")
        x1 = self.down_conv_1(image)
        print("3x2Conv =>      ", x1.size()[2:4])
        x2 = self.max_pool_2x2(x1)
        print("max_pool_2x1 => ", x2.size()[2:4])
        x3 = self.down_conv_2(x2)
        print("3x3Conv =>      ", x3.size()[2:4])
        x4 = self.max_pool_2x2(x3)
        print("max_pool_2x1 => ", x4.size()[2:4])
        x5 = self.down_conv_3(x4)
        print("3x3Conv =>      ", x5.size()[2:4])
        
        # decoder
       # print("Decoder =================")
        x = self.up_trans_3(x5)
        print("up_trans_3x2 => ", x.size()[2:4])

        x = self.up_conv_3(torch.cat([x, x3], 1))
        print("up_conv_3x3 =>  ", x.size()[2:4])

        x = self.up_trans_4(x)
        print("up_trans_2x2 => ", x.size()[2:4])

        x = self.up_conv_4(torch.cat([x, x1], 1))
        print("up_conv_2x3 =>  ", x.size()[2:4])

        # output
        x = self.out(x)
        print("Final =>        ", x.size()[2:4])
        return x






# model = models.vgg16()

if __name__ == "__main__":
    image = torch.rand(1, 10, 100, 100)
    model = UNet(img_ch=10,output_ch=3)
    #print(model)
    print(model)
    model(image)

UNet(
  (max_pool_2x2): MaxPool2d(kernel_size=(2, 1), stride=2, padding=0, dilation=1, ceil_mode=False)
  (down_conv_1): Sequential(
    (0): Conv2d(10, 64, kernel_size=(3, 2), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 2), stride=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (down_conv_2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (down_conv_3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3)

In [10]:
import torch
import torch.nn as nn
from torch.nn import init


def single_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def double_conv1(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(3,2),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(3,2),stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def double_conv2(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(3,3),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(3,3), padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv


def double_upconv2(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(3, 3),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(3, 3),padding=1, stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def double_upconv1(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(2, 3),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(2, 3),padding=1, stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def up_conv1(in_c, out_c):
    conv = nn.Sequential(
        nn.ConvTranspose2d(in_c, out_c, kernel_size=(2, 2), stride=2))
    return conv
    
def up_conv2(in_c, out_c):
    conv = nn.Sequential(
        nn.ConvTranspose2d(in_c, out_c, kernel_size=(2,2), stride=2))
    return conv

class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi

class Att_UNet(nn.Module):
    def __init__(self,img_ch=1,output_ch=1):
        super(Att_UNet, self).__init__()
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.down_conv_1 = double_conv1(img_ch, 64)
        self.down_conv_2 = double_conv2(64, 128)
        self.down_conv_3 = double_conv2(128, 256)
    
        self.up_trans_1 = up_conv1(256, 128)
        self.Att1 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.up_conv_1 = double_upconv1(256, 128)
        
        self.up_trans_2 = up_conv2(128, 64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.up_conv_2 = double_upconv2(128, 64)

        self.out = nn.Conv2d(
            in_channels=64,
            out_channels=output_ch,
            kernel_size=1,stride=1,padding=0)

    def forward(self, image):
        # encoder
        print("Input Image            => ", image.size())
        print("Encoder =================")
        x1 = self.down_conv_1(image)
        print("Conv3x2, S1, P1        => ", x1.size())
        x2 = self.max_pool_2x2(x1)
        print("max_pool_2x1           => ", x2.size())
        x3 = self.down_conv_2(x2)
        print("Conv3x3, S1, P1        => ", x3.size())
        x4 = self.max_pool_2x2(x3)
        print("max_pool_2x1           => ", x4.size())
        x5 = self.down_conv_3(x4)
        print("Conv3x3, S1, P1        => ", x5.size())
        
        
        # decoder
        print("Decoder =================")
        x = self.up_trans_1(x5)
        print("up_trans_1x18, S3, P0  => ", x.size()[2:4])
        x3 = nn.functional.interpolate(x3, (x.size()[2], x.size()[3]))
        x3 = self.Att1(g=x,x=x3)
        x = self.up_conv_1(torch.cat([x, x3], 1))
        print("up_conv_3x3, S1, P1    => ", x.size()[2:4])

        x = self.up_trans_2(x)
        print("up_trans_2x2, S2, P0   => ", x.size()[2:4])
        x1 = nn.functional.interpolate(x1, (x.size()[2], x.size()[3]))
        x1 = self.Att2(g=x,x=x1)
        x = self.up_conv_2(torch.cat([x, x1], 1))
        print("up_conv_2x3, s1, p1    => ", x.size()[2:4])
        # output
        x = self.out(x)
        print(x.size())
        return x


if __name__ == "__main__":
    image = torch.rand(1, 3, 512, 512)
    model = Att_UNet(img_ch=3,output_ch=1)
    model(image)

Input Image            =>  torch.Size([1, 3, 512, 512])
Conv3x2, S1, P1        =>  torch.Size([1, 64, 510, 512])
max_pool_2x1           =>  torch.Size([1, 64, 255, 256])
Conv3x3, S1, P1        =>  torch.Size([1, 128, 255, 256])
max_pool_2x1           =>  torch.Size([1, 128, 127, 128])
Conv3x3, S1, P1        =>  torch.Size([1, 256, 127, 128])
up_trans_1x18, S3, P0  =>  torch.Size([254, 256])
up_conv_3x3, S1, P1    =>  torch.Size([256, 256])
up_trans_2x2, S2, P0   =>  torch.Size([512, 512])
up_conv_2x3, s1, p1    =>  torch.Size([512, 512])
torch.Size([1, 1, 512, 512])


In [11]:
import torch
import torch.nn as nn
from torch.nn import init

def single_conv(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def double_conv1(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(3,2),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(3,2),stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv

def double_conv2(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=(3,3),padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_c, out_c, kernel_size=(3,3), padding=1 ,stride=1, bias=True),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True))
    return conv


def up_conv1(in_c, out_c):
    conv = nn.Sequential(
        nn.ConvTranspose2d(in_c, out_c, kernel_size=(2, 2), stride=2))
    return conv
    
def up_conv2(in_c, out_c):
    conv = nn.Sequential(
        nn.ConvTranspose2d(in_c, out_c, kernel_size=(2,2), stride=2))
    return conv


class Recurrent_block(nn.Module):
    def __init__(self,ch_out,t=2):
        super(Recurrent_block,self).__init__()
        self.t = t
        self.ch_out = ch_out
        self.conv = nn.Sequential(
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
		    nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        for i in range(self.t):

            if i==0:
                x1 = self.conv(x)
            
            x1 = self.conv(x+x1)
        return x1

class RRCNN_block(nn.Module):
    def __init__(self,ch_in,ch_out,t=2):
        super(RRCNN_block,self).__init__()
        self.RCNN = nn.Sequential(
            Recurrent_block(ch_out,t=t),
            Recurrent_block(ch_out,t=t)
        )
        self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)

    def forward(self,x):
        x = self.Conv_1x1(x)
        x1 = self.RCNN(x)
        return x+x1


class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi



class Att_R2U(nn.Module):
    def __init__(self,img_ch=1,output_ch=1,t=2):
        super(Att_R2U, self).__init__()
        self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.RCNN1 = RRCNN_block(img_ch, 64, t=t)
        self.RCNN2 = RRCNN_block(64, 128, t=t)
        self.RCNN3 = RRCNN_block(128, 256, t=t)

        self.up_trans_1 = up_conv1(256, 128)
        self.Att1 = Attention_block(F_g=128,F_l=128,F_int=64)
        self.Up_RRCNN1 = RRCNN_block(256, 128,t=t)
        
        self.up_trans_2 = up_conv2(128, 64)
        self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
        self.Up_RRCNN2 = RRCNN_block(128, 64,t=t)

        self.out = nn.Conv2d(
            in_channels=64,
            out_channels=output_ch,
            kernel_size=1,stride=1,padding=0)

    def forward(self, image):
        # encoder
        print("Input Image            => ", image.size())
        print("Encoder =================")
        x1 = self.RCNN1(image)
        print("Conv3x2, S1, P1        => ", x1.size())
        x2 = self.max_pool_2x2(x1)
        print("max_pool_2x1           => ", x2.size())
        x3 = self.RCNN2(x2)
        print("Conv3x3, S1, P1        => ", x3.size())
        x4 = self.max_pool_2x2(x3)
        print("max_pool_2x1           => ", x4.size())
        x5 = self.RCNN3(x4)
        print("Conv3x3, S1, P1        => ", x5.size())
        
        
        # decoder
        print("Decoder =================")
        x = self.up_trans_1(x5)
        print("up_trans_1x18, S3, P0  => ", x.size())
        x3 = nn.functional.interpolate(x3, (x.size()[2], x.size()[3]))
        x3 = self.Att1(g=x,x=x3)
        x = self.Up_RRCNN1(torch.cat([x, x3], 1))
        print("up_conv_3x3, S1, P1    => ", x.size())

        x = self.up_trans_2(x)
        print("up_trans_2x2, S2, P0   => ", x.size())
        x1 = nn.functional.interpolate(x1, (x.size()[2], x.size()[3]))
        x1 = self.Att2(g=x,x=x1)
        x = self.Up_RRCNN2(torch.cat([x, x1], 1))
        print("up_conv_2x3, s1, p1    => ", x.size())
        # output
        x = self.out(x)
        print(x.size())
        return x



if __name__ == "__main__":
    print("start")
    image = torch.rand(1, 3, 512, 512)
    model = Att_R2U(img_ch=3)
    model(image)

start
Input Image            =>  torch.Size([1, 3, 512, 512])
Conv3x2, S1, P1        =>  torch.Size([1, 64, 512, 512])
max_pool_2x1           =>  torch.Size([1, 64, 256, 256])
Conv3x3, S1, P1        =>  torch.Size([1, 128, 256, 256])
max_pool_2x1           =>  torch.Size([1, 128, 128, 128])
Conv3x3, S1, P1        =>  torch.Size([1, 256, 128, 128])
up_trans_1x18, S3, P0  =>  torch.Size([1, 128, 256, 256])
up_conv_3x3, S1, P1    =>  torch.Size([1, 128, 256, 256])
up_trans_2x2, S2, P0   =>  torch.Size([1, 64, 512, 512])
up_conv_2x3, s1, p1    =>  torch.Size([1, 64, 512, 512])
torch.Size([1, 1, 512, 512])
