<a href="https://colab.research.google.com/github/naveedilyas/CASA-Crowd/blob/main/Nested_UNet_Densnet_Resnet_Dialted_Conv_MD_01.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import torch
import torch.nn as nn
class batchnorm_relu(nn.Module):
    def __init__(self, in_c):
        super().__init__()

        self.bn = nn.BatchNorm2d(in_c)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.bn(inputs)
        x = self.relu(x)
        return x

class conv_block_nested(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch,stride=1):
        super().__init__()
        """ Convolutional layer """
        self.b1 = batchnorm_relu(in_ch)
        self.c1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, stride=stride)
        self.b2 = batchnorm_relu(mid_ch)
        self.c2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, stride=1)

        """ Shortcut Connection (Identity Mapping) """


        # Adjust the dimensions if the input channels or stride is different from the output
        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_ch))
        else:
          self.shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=0, stride=stride)


    def forward(self, inputs):
        x = self.b1(inputs)
        # print("1", x.shape)
        x = self.c1(x)
        # print("c1", x.shape)
        x = self.b2(x)
        # print("3", x.shape)
        x = self.c2(x)
        # print("c2", x.shape)
        s = self.shortcut(inputs)

        skip = x + s
        # print("skip", skip.shape)
        return skip


class DDCB0_0(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB0_0, self).__init__()
        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + in_ch, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 448, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 128, 1),nn.BatchNorm2d(128),nn.ReLU(True),nn.Dropout(0.2),
                                   nn.Conv2d(128, 64, 3, padding=1, dilation=1),nn.BatchNorm2d(64),nn.ReLU(True),nn.Dropout(0.2),)

        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 128, 1),nn.BatchNorm2d(128),nn.ReLU(True),nn.Dropout(0.2),
                                   nn.Conv2d(128, 64, 3, padding=2, dilation=2),nn.BatchNorm2d(64),nn.ReLU(True),nn.Dropout(0.2),)

        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 128, 1),nn.BatchNorm2d(128),nn.ReLU(True),nn.Dropout(0.2),
                                   nn.Conv2d(128, 64, 3, padding=3, dilation=3),nn.BatchNorm2d(64),nn.ReLU(True),nn.Dropout(0.2),)

        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, out_ch, 3, padding=4, dilation=4),nn.ReLU(True),)

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        print('output', output.shape)
        return output




class DDCB1(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB1, self).__init__()
        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, in_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + in_ch, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 448, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 128, 1),nn.BatchNorm2d(128),nn.ReLU(True),nn.Dropout(0.2),
                                   nn.Conv2d(128, 64, 3, padding=1, dilation=1),nn.BatchNorm2d(64),nn.ReLU(True),nn.Dropout(0.2),)

        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 128, 1),nn.BatchNorm2d(128),nn.ReLU(True),nn.Dropout(0.2),
                                   nn.Conv2d(128, 64, 3, padding=2, dilation=2),nn.BatchNorm2d(64),nn.ReLU(True),nn.Dropout(0.2),)

        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 128, 1),nn.BatchNorm2d(128),nn.ReLU(True),nn.Dropout(0.2),
                                   nn.Conv2d(128, 64, 3, padding=3, dilation=3),nn.BatchNorm2d(64),nn.ReLU(True),nn.Dropout(0.2),)

        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, out_ch, 3, padding=4, dilation=4),nn.ReLU(True),)

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        print('output', output.shape)
        return output


# class conv_block_nested(nn.Module):
#     def __init__(self, in_ch, mid_ch, out_ch):
#         super(conv_block_nested, self).__init__()
#         self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
#         self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
#         self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
#         self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

#         # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
#         # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
#         # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
#         # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

#     def forward(self, x):
#         x1_raw = self.conv1(x)
#         x1 = torch.cat([x, x1_raw], 1)
#         x2_raw = self.conv2(x1)
#         x2 = torch.cat([x, x1_raw, x2_raw], 1)
#         x3_raw = self.conv3(x2)
#         x3 = torch.cat([x, x2_raw, x3_raw], 1)
#         output = self.conv4(x3)
#         return output

# class down_sample_Conv(nn.Module):
#   def __init__(self, in_ch1, out_ch1):
#     super(down_sample_Conv,self).__init__()
#     self.pool = nn.Conv2d(in_ch1, out_ch1, kernel_size=3, stride=1, padding=1, dilation=2)
#   def forward(self,x):
#     x = self.pool(x)
#     return x

In [7]:
class DDCB_5_64_64(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB_5_64_64, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        # print('output', output.shape)
        return output

class DDCB_64_128_128(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB_64_128_128, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 128, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 256, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 256, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        # print('output', output.shape)
        return output

class DDCB_128_256_256(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB_128_256_256, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 256, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 512, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 512, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        # print('output', output.shape)
        return output


class DDCB_256_512_512(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB_256_512_512, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 512, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, 512, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 1024, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 1024, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        # x3_raw = self.conv3(x2)
        # print('x3_raw', x3_raw.shape)
        # x3 = torch.cat([x, x2_raw, x3_raw], 1)
        # print('x3', x3.shape)
        # output = self.conv4(x3)
        # # print('output', output.shape)
        return x2

In [8]:
class NestedUNet(nn.Module):
    def __init__(self, in_ch=5, out_ch=6):
        super(NestedUNet, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        # self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
        self.conv0_0 = DDCB_5_64_64(in_ch, filters[0], filters[0])
        self.conv1_0 = DDCB_64_128_128(filters[0], filters[1], filters[1])
        self.conv2_0 = DDCB_128_256_256(filters[1], filters[2], filters[2])
        self.conv3_0 = DDCB1(filters[2], filters[3], filters[3])
        self.conv4_0 = DDCB1(filters[3], filters[4], filters[4])

        self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0])
        self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
        self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
        self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3])

        self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1], filters[0], filters[0])
        self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2], filters[1], filters[1])
        self.conv2_2 = conv_block_nested(filters[2]*2 + filters[3], filters[2], filters[2])

        self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1], filters[0], filters[0])
        self.conv1_3 = conv_block_nested(filters[1]*3 + filters[2], filters[1], filters[1])

        self.conv0_4 = conv_block_nested(filters[0]*4 + filters[1], filters[0], filters[0])

        self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)


    def forward(self, x):

        x0_0 = self.conv0_0(x)
        print("x0_0",x0_0.shape)
        x1_0 = self.conv1_0(self.pool(x0_0))
        # print("x1_0", x0_0.shape)
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))
        # print("x0_1", x0_1.shape)

        x2_0 = self.conv2_0(self.pool(x1_0))
        # print("x2_0", x2_0.shape)
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        # print("x3_0", x3_0.shape)
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        # print("x4_0", x4_0.shape)
        x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))

        output = self.final(x0_4)
        return output

import torch
import torch.nn as nn

class DDCB(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        print('output', output.shape)
        return output


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if __name__ == "__main__":
    # Instantiate the NestedUNet model
    model = NestedUNet().to(device)
    # Generate a sample input tensor
    input_tensor = torch.randn(1, 5, 320, 320).to(device)
    output = model(input_tensor)
    # print(output.shape)


x torch.Size([1, 5, 320, 320])
x1_raw torch.Size([1, 64, 320, 320])
x1 torch.Size([1, 69, 320, 320])
x2_raw torch.Size([1, 64, 320, 320])
x2 torch.Size([1, 133, 320, 320])
x3_raw torch.Size([1, 64, 320, 320])
x3 torch.Size([1, 133, 320, 320])
x0_0 torch.Size([1, 64, 320, 320])
x torch.Size([1, 64, 160, 160])
x1_raw torch.Size([1, 128, 160, 160])
x1 torch.Size([1, 192, 160, 160])
x2_raw torch.Size([1, 128, 160, 160])
x2 torch.Size([1, 320, 160, 160])
x3_raw torch.Size([1, 128, 160, 160])
x3 torch.Size([1, 320, 160, 160])
x torch.Size([1, 128, 80, 80])
x1_raw torch.Size([1, 256, 80, 80])
x1 torch.Size([1, 384, 80, 80])
x2_raw torch.Size([1, 256, 80, 80])
x2 torch.Size([1, 640, 80, 80])
x3_raw torch.Size([1, 256, 80, 80])
x3 torch.Size([1, 640, 80, 80])


In [47]:
import torch
import torch.nn as nn

class DDCB_64_128_128(nn.Module):
    def __init__(self, in_ch=64, mid_ch=128, out_ch):
        super(DDCB_64_128_128, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 128, 1), nn.ReLU(True), nn.Conv2d(128, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        print('output', output.shape)
        return output

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Instantiate the NestedUNet model
    model = DDCB_64_128_128(64,128,128).to(device)
    # Generate a sample input tensor
    input_tensor = torch.randn(1,64,224,224).to(device)
    output = model(input_tensor)
    print(output.shape)

x torch.Size([1, 64, 224, 224])
x1_raw torch.Size([1, 128, 224, 224])
x1 torch.Size([1, 192, 224, 224])


RuntimeError: ignored

In [58]:
import torch
import torch.nn as nn

class DDCB_64_128_128(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB_64_128_128, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, in_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128 + 64, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 256, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        # print('output', output.shape)
        return output

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Instantiate the NestedUNet model
    model = DDCB_64_128_128(64,128,128).to(device)
    # Generate a sample input tensor
    input_tensor = torch.randn(1,64,224,224).to(device)
    output = model(input_tensor)
    print(output.shape)

x torch.Size([1, 64, 224, 224])
x1_raw torch.Size([1, 64, 224, 224])
x1 torch.Size([1, 128, 224, 224])
x2_raw torch.Size([1, 128, 224, 224])
x2 torch.Size([1, 256, 224, 224])
x3_raw torch.Size([1, 128, 224, 224])
x3 torch.Size([1, 320, 224, 224])
torch.Size([1, 128, 224, 224])


In [61]:
import torch
import torch.nn as nn

class DDCB_64_128_128(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB_64_128_128, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 128, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 256, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 256, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        # print('output', output.shape)
        return output

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Instantiate the NestedUNet model
    model = DDCB_64_128_128(64,128,128).to(device)
    # Generate a sample input tensor
    input_tensor = torch.randn(1,64,224,224).to(device)
    output = model(input_tensor)
    print(output.shape)

x torch.Size([1, 64, 224, 224])
x1_raw torch.Size([1, 128, 224, 224])
x1 torch.Size([1, 192, 224, 224])
x2_raw torch.Size([1, 128, 224, 224])
x2 torch.Size([1, 320, 224, 224])
x3_raw torch.Size([1, 128, 224, 224])
x3 torch.Size([1, 320, 224, 224])
torch.Size([1, 128, 224, 224])


In [91]:
import torch
import torch.nn as nn

class DDCB_128_256_256(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB_128_256_256, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 256, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=2, dilation=2), nn.ReLU(True))
        self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 512, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 512, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        x3_raw = self.conv3(x2)
        print('x3_raw', x3_raw.shape)
        x3 = torch.cat([x, x2_raw, x3_raw], 1)
        print('x3', x3.shape)
        output = self.conv4(x3)
        # print('output', output.shape)
        return output

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Instantiate the NestedUNet model
    model = DDCB_128_256_256(128,256,256).to(device)
    # Generate a sample input tensor
    input_tensor = torch.randn(1,128,224,224).to(device)
    output = model(input_tensor)
    print(output.shape)

x torch.Size([1, 128, 224, 224])
x1_raw torch.Size([1, 256, 224, 224])
x1 torch.Size([1, 384, 224, 224])
x2_raw torch.Size([1, 256, 224, 224])
x2 torch.Size([1, 640, 224, 224])
x3_raw torch.Size([1, 256, 224, 224])
x3 torch.Size([1, 640, 224, 224])
torch.Size([1, 256, 224, 224])


In [1]:
import torch
import torch.nn as nn

class DDCB_256_512_512(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(DDCB_256_512_512, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(in_ch, mid_ch, 1), nn.ReLU(True), nn.Conv2d(mid_ch, out_ch, 3, padding=1,dilation=1),nn.ReLU(True))
        self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 512, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, 512, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 1024, mid_ch, 1), nn.ReLU(True),nn.Conv2d(mid_ch, out_ch, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 1024, out_ch, 3, padding=4, dilation=4), nn.ReLU(True))

        # self.conv1 = nn.Sequential(nn.Conv2d(in_ch, 256, 1), nn.ReLU(True), nn.Conv2d(256, 64, 3, padding=1,dilation=1),nn.ReLU(True))
        # self.conv2 = nn.Sequential(nn.Conv2d(in_ch + 64, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=2, dilation=2), nn.ReLU(True))
        # self.conv3 = nn.Sequential(nn.Conv2d(in_ch + 128, 256, 1), nn.ReLU(True),nn.Conv2d(256, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        # self.conv4 = nn.Sequential(nn.Conv2d(in_ch + 128, 512, 3, padding=4, dilation=4), nn.ReLU(True))

    def forward(self, x):
        print('x', x.shape)
        x1_raw = self.conv1(x)
        print('x1_raw', x1_raw.shape)
        x1 = torch.cat([x, x1_raw], 1)
        print('x1', x1.shape)
        x2_raw = self.conv2(x1)
        print('x2_raw', x2_raw.shape)
        x2 = torch.cat([x, x1_raw, x2_raw], 1)
        print('x2', x2.shape)
        # x3_raw = self.conv3(x2)
        # print('x3_raw', x3_raw.shape)
        # x3 = torch.cat([x, x2_raw, x3_raw], 1)
        # print('x3', x3.shape)
        # output = self.conv4(x3)
        # # print('output', output.shape)
        return x2

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Instantiate the NestedUNet model
    model = DDCB_256_512_512(256,512,512).to(device)
    # Generate a sample input tensor
    input_tensor = torch.randn(1,256,512,512).to(device)
    output = model(input_tensor)
    print(output.shape)

x torch.Size([1, 256, 512, 512])
x1_raw torch.Size([1, 512, 512, 512])
x1 torch.Size([1, 768, 512, 512])
x2_raw torch.Size([1, 512, 512, 512])
x2 torch.Size([1, 1280, 512, 512])
torch.Size([1, 1280, 512, 512])
