In [1]:
from torch.utils.tensorboard import SummaryWriter
import torch

In [2]:
writer = SummaryWriter()
image1 = torch.randn(1, 3, 256, 256)
image2 = torch.randn(1, 3, 256, 256)
example = [image1, image2]

In [7]:
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import math
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(planes, planes, kernel_size=1, stride=stride, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation,
                               bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.dilation = dilation

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, nInputChannels, block, layers, os=16, pretrained=False):
        self.inplanes = 64
        super(ResNet, self).__init__()
        if os == 8:
            strides = [1, 2, 1, 1]
            dilations = [1, 1, 2, 2]
            blocks = [1, 2, 1]
        else:
            raise NotImplementedError

        self.conv1 = nn.Conv2d(nInputChannels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2])
        self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3])

        self._init_weight()

        if pretrained:
            self._load_pretrained_model()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion))

        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _make_MG_unit(self, block, planes, blocks=[1, 2, 4], stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion))

        layers = []
        layers.append(block(self.inplanes, planes, stride, dilation=blocks[0] * dilation, downsample=downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, len(blocks)):
            layers.append(block(self.inplanes, planes, stride=1, dilation=blocks[i] * dilation))

        return nn.Sequential(*layers)

    def forward(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        return x1, x2, x3, x4

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                # m.weight.data.normal_(0, math.sqrt(2. / n))
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _load_pretrained_model(self):
        pretrain_dict = model_zoo.load_url(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
                                           model_dir='D:/Projects/Segmentation/pretrained')
        model_dict = {}
        state_dict = self.state_dict()
        for k, v in pretrain_dict.items():
            if k in state_dict:
                model_dict[k] = v
        state_dict.update(model_dict)
        self.load_state_dict(state_dict)


def ResNet101(nInputChannels=3, os=8, pretrained=False):
    model = ResNet(nInputChannels, Bottleneck, [3, 4, 23, 3], os, pretrained=pretrained)
    return model


class ASPP_module(nn.Module):
    def __init__(self, inplanes, planes, dilation):
        super(ASPP_module, self).__init__()
        if dilation == 1:
            kernel_size = 1
            padding = 0
        else:
            kernel_size = 3
            padding = dilation

        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding,
                                     dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)
        x = self.relu(x)

        return x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            # torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class APPAP(nn.Module):
    """ Position attention module"""

    # Ref from SAGAN
    def __init__(self, in_dim, os):
        super(APPAP, self).__init__()
        self.chanel_in = in_dim
        self.os = os
        self.query_conv = nn.Conv2d(in_channels=640, out_channels=128, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=640, out_channels=128, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=512, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        if self.os == 8:
            dilations = [1, 2, 3, 6]
        self.aspp1 = ASPP_module(inplanes=512, planes=128, dilation=dilations[0])
        self.aspp2 = ASPP_module(inplanes=512, planes=128, dilation=dilations[1])
        self.aspp3 = ASPP_module(inplanes=512, planes=128, dilation=dilations[2])
        self.aspp4 = ASPP_module(inplanes=512, planes=128, dilation=dilations[3])
        self.aspp5 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(512, 128, 1, stride=1, bias=False),
                                   nn.BatchNorm2d(128), nn.ReLU())
        self.conv = nn.Conv2d(in_channels=640, out_channels=512, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()  ### [2 512 32 32] m=2 C =  512 h = 32 w = 32
        #proj_query1 = self.query_conv(x)
        p_q1 = self.aspp1(x)
        #print(p_q1.shape)           [2 128 32 32]
        p_q2 = self.aspp2(x)
        #print(p_q2.shape)           [2 128 32 32]
        p_q3 = self.aspp3(x)
        #print(p_q3.shape)           [2 128 32 32]
        p_q4 = self.aspp4(x)
        #print(p_q4.shape)           [2 128 32 32]
        p_q5 = self.aspp5(x)
        #print(p_q5.shape)           [2 128 1 1]
        p_q5 = F.interpolate(p_q5, size=p_q4.size()[2:], mode='bilinear', align_corners=True)
        #print(p_q5.shape)           [2 128 32 32]
        proj_query = torch.cat((p_q1, p_q2, p_q3, p_q4, p_q5), dim=1)
        #print(proj_query.shape)      [2 640 32 32]
        proj_query = self.query_conv(proj_query).view(m_batchsize, -1, width * height).permute(0, 2, 1)
        #print(proj_query.shape)      [2 1024 128]
        #proj_key1 = self.key_conv(x)
        p_k1 = self.aspp1(x)
        p_k2 = self.aspp2(x)
        p_k3 = self.aspp3(x)
        p_k4 = self.aspp4(x)
        p_k5 = self.aspp5(x)
        p_k5 = F.interpolate(p_k5, size=p_k4.size()[2:], mode='bilinear', align_corners=True)
        proj_key = torch.cat((p_k1, p_k2, p_k3, p_k4, p_k5), dim=1)
        proj_key = self.key_conv(proj_key).view(m_batchsize, -1, width * height)
        #print(proj_key.shape) [2 128 1024]
        energy = torch.bmm(proj_query, proj_key)
        #print(energy.shape) [2 1024 1024]
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
        #print(proj_value.shape) [2 512 1024]
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        #print(out.shape)  [2 512 1024]
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma * out + x
        return out


class APPAC(nn.Module):

    def __init__(self, in_dim, os):
        super(APPAC, self).__init__()
        self.chanel_in = in_dim
        self.os = os
        if self.os == 8:
            dilations = [1, 2, 3, 6]

        self.aspp1 = ASPP_module(inplanes=512, planes=128, dilation=dilations[0])
        self.aspp2 = ASPP_module(inplanes=512, planes=128, dilation=dilations[1])
        self.aspp3 = ASPP_module(inplanes=512, planes=128, dilation=dilations[2])
        self.aspp4 = ASPP_module(inplanes=512, planes=128, dilation=dilations[3])
        self.aspp5 = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(512, 128, 1, stride=1, bias=False),
                                   nn.BatchNorm2d(128), nn.ReLU())
        self.conv1 = nn.Conv2d(in_channels=640, out_channels=512, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        m_batchsize, C, height, width = x.size()
        #print(x.shape) [2 512 32 32]
        p_q1 = self.aspp1(x)
        p_q2 = self.aspp2(x)
        p_q3 = self.aspp3(x)
        p_q4 = self.aspp4(x)
        p_q5 = self.aspp5(x)
        p_q5 = F.interpolate(p_q5, size=p_q4.size()[2:], mode='bilinear', align_corners=True)
        proj_query = torch.cat((p_q1, p_q2, p_q3, p_q4, p_q5), dim=1)
        proj_query = self.conv1(proj_query).view(m_batchsize, C, -1)
        #print(proj_query.shape) [2 512 1024]

        p_k1 = self.aspp1(x)
        p_k2 = self.aspp2(x)
        p_k3 = self.aspp3(x)
        p_k4 = self.aspp4(x)
        p_k5 = self.aspp5(x)
        p_k5 = F.interpolate(p_k5, size=p_k4.size()[2:], mode='bilinear', align_corners=True)
        proj_key = torch.cat((p_k1, p_k2, p_k3, p_k4, p_k5), dim=1)
        proj_key = self.conv1(proj_key).view(m_batchsize, C, -1).permute(0, 2, 1)
        #print(proj_key.shape)  [2 1024 612]
        energy = torch.bmm(proj_query, proj_key)
        #print(energy.shape) [2 512 512]
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
        attention = self.softmax(energy_new)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma * out + x
        #print(out.shape) [2 512 32 32]
        return out


class APPAHead(nn.Module):
    def __init__(self, in_channels, norm_layer, os):
        super(APPAHead, self).__init__()
        inter_channels = in_channels // 4
        self.conv5a = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
                                    norm_layer(inter_channels),
                                    nn.ReLU())

        self.conv5c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
                                    norm_layer(inter_channels),
                                    nn.ReLU())

        self.appap = APPAP(inter_channels, os)
        self.appac = APPAC(inter_channels, os)
        self.conv51 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
                                    norm_layer(inter_channels),
                                    nn.ReLU())
        self.conv52 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, padding=1, bias=False),
                                    norm_layer(inter_channels),
                                    nn.ReLU())

    def forward(self, x):
        feat1 = self.conv5a(x)
        ###print(feat1.shape) [2 512 32 32]
        appap_feat = self.appap(feat1)
        appap_conv = self.conv51(appap_feat)

        feat2 = self.conv5c(x)
        appac_feat = self.appac(feat2)
        appac_conv = self.conv52(appac_feat)

        feat_sum = appap_conv + appac_conv

        return feat_sum


class APPANet(nn.Module):
    def __init__(self, nInputChannels=3, n_classes=2, os=8, aux=False, pretrained=False, _print=True):
        if _print:
            print("Constructing APPANet model...")
            print("Backbone: Resnet-101")
            print("Number of classes: {}".format(n_classes))
            print("Output stride: {}".format(os))
            print("Number of Input Channels: {}".format(nInputChannels))
        super(APPANet, self).__init__()
        self.head = APPAHead(2048, nn.BatchNorm2d, os)
        # Atrous Convolution
        self.resnet_features = ResNet101(nInputChannels, os, pretrained=pretrained)

        self.conv1 = nn.Conv2d(512, 256, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(256)
        self.relu = nn.ReLU()
        # adopt [1x1, 48] for channel reduction
        self.conv2 = nn.Conv2d(256, 48, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(48)

        self.conv3 = nn.Conv2d(64, 48, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(48)

        self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       nn.ReLU(),
                                       nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                       nn.BatchNorm2d(256),
                                       nn.ReLU(),
                                       nn.Conv2d(256, n_classes, kernel_size=1, stride=1))

    def forward(self, x_1, x_2):
        x = torch.cat((x_1, x_2), 1)
        ###print(x.shape) [2 6 256 256]
        x1, x2, x3, x4 = self.resnet_features(x)
        ###print(x1.shape,x2.shape,x3.shape,x4.shape)[2 256 64 64] [2 512 32 32] [2 1024 32 32] [2 2048 32 32]

        x_sum = self.head(x4)
        #print(x_sum.shape)   [2 512 32 32]
        x = self.conv1(x_sum)
        x = self.bn1(x)
        x = self.relu(x)
        #print(x.shape)  [2 256 32 32]
        x = F.interpolate(x, size=(int(math.ceil(x_1.size()[-2] / 4)),
                                   int(math.ceil(x_1.size()[-1] / 4))), mode='bilinear', align_corners=True)
        #print(x.shape) #[2 256 64 64]
        x1 = self.conv2(x1)
        x1 = self.bn2(x1)
        x1 = self.relu(x1)
        #print(x1.shape)
        x = torch.cat((x, x1), dim=1)

        x = self.last_conv(x)
        x = F.interpolate(x, size=x_1.size()[2:], mode='bilinear', align_corners=True)
        #x = F.softmax(x,dim=1)
        x = F.sigmoid(x)
        #print(x.shape)
        return x


# if __name__ == "__main__":
#     model = APPANet(nInputChannels=6, n_classes=2, os=8, pretrained=False, _print=True)
#     model.eval()
#     image1 = torch.randn(1, 3, 256, 256)
#     image2 = torch.randn(1, 3, 256, 256)
#     with torch.no_grad():
#         output = model.forward(image1, image2)

from torchsummary import summary

model = APPANet(nInputChannels=6, n_classes=2, os=8, pretrained=False, _print=True)
writer.add_graph(model=model, input_to_model=example)
summary(model, input_size=[(3, 256, 256), (3, 256, 256)], batch_size=2, device="cpu")

Constructing APPANet model...
Backbone: Resnet-101
Number of classes: 2
Output stride: 8
Number of Input Channels: 6


  x = F.interpolate(x, size=(int(math.ceil(x_1.size()[-2] / 4)),
  int(math.ceil(x_1.size()[-1] / 4))), mode='bilinear', align_corners=True)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [2, 64, 128, 128]          18,816
       BatchNorm2d-2          [2, 64, 128, 128]             128
              ReLU-3          [2, 64, 128, 128]               0
         MaxPool2d-4            [2, 64, 64, 64]               0
            Conv2d-5            [2, 64, 64, 64]           4,096
       BatchNorm2d-6            [2, 64, 64, 64]             128
              ReLU-7            [2, 64, 64, 64]               0
            Conv2d-8            [2, 64, 64, 64]          36,864
       BatchNorm2d-9            [2, 64, 64, 64]             128
             ReLU-10            [2, 64, 64, 64]               0
           Conv2d-11           [2, 256, 64, 64]          16,384
      BatchNorm2d-12           [2, 256, 64, 64]             512
           Conv2d-13           [2, 256, 64, 64]          16,384
      BatchNorm2d-14           [2, 256,

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class CDNet(nn.Module):
    def __init__(self, in_ch=6, out_ch=2):
        super(CDNet, self).__init__()
        filters = 64
        self.conv1 = nn.Conv2d(in_ch, filters, kernel_size=7, padding=3, stride=1)
        self.bn = nn.BatchNorm2d(filters)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(filters, filters, kernel_size=7, padding=3, stride=1)

        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.final = nn.Conv2d(filters, out_ch, kernel_size=1, stride=1)
        #self.sigmod = nn.Sigmoid(dim=1)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), 1)

        x = self.pool(self.relu(self.bn(self.conv1(x))))
        x = self.pool(self.relu(self.bn(self.conv2(x))))
        x = self.pool(self.relu(self.bn(self.conv2(x))))
        x = self.pool(self.relu(self.bn(self.conv2(x))))

        x = self.relu(self.bn(self.conv2(self.up(x))))
        x = self.relu(self.bn(self.conv2(self.up(x))))
        x = self.relu(self.bn(self.conv2(self.up(x))))
        x = self.relu(self.bn(self.conv2(self.up(x))))

        x = self.final(x)
        x = torch.sigmoid(x)
        #print(x.shape)

        return x



from torchsummary import summary
model = CDNet(in_ch = 6,out_ch =2)
writer.add_graph(model=model, input_to_model=example)
summary(model,input_size=[(3,256,256),(3,256,256)],batch_size = 2, device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [2, 64, 256, 256]          18,880
       BatchNorm2d-2          [2, 64, 256, 256]             128
              ReLU-3          [2, 64, 256, 256]               0
         MaxPool2d-4          [2, 64, 128, 128]               0
            Conv2d-5          [2, 64, 128, 128]         200,768
       BatchNorm2d-6          [2, 64, 128, 128]             128
              ReLU-7          [2, 64, 128, 128]               0
         MaxPool2d-8            [2, 64, 64, 64]               0
            Conv2d-9            [2, 64, 64, 64]         200,768
      BatchNorm2d-10            [2, 64, 64, 64]             128
             ReLU-11            [2, 64, 64, 64]               0
        MaxPool2d-12            [2, 64, 32, 32]               0
           Conv2d-13            [2, 64, 32, 32]         200,768
      BatchNorm2d-14            [2, 64,

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class FCN_CD(nn.Module):
    def __init__(self, in_ch=6, out_ch=2):
        super(FCN_CD, self).__init__()
        filters = [64, 128, 256, 512, 4096]
        self.conv1 = nn.Conv2d(in_ch, filters[0], kernel_size=3, padding=1, stride=1)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)
        self.conv2 = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1, stride=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(filters[0], filters[1], kernel_size=3, padding=1, stride=1)
        self.conv4 = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1, stride=1)
        self.conv5 = nn.Conv2d(filters[1], filters[2], kernel_size=3, padding=1, stride=1)
        self.conv6 = nn.Conv2d(filters[2], filters[2], kernel_size=3, padding=1, stride=1)
        self.conv7 = nn.Conv2d(filters[2], filters[3], kernel_size=3, padding=1, stride=1)
        self.conv8 = nn.Conv2d(filters[3], filters[3], kernel_size=3, padding=1, stride=1)
        self.conv9 = nn.Conv2d(filters[3], filters[4], kernel_size=7, padding=3, stride=1)
        self.conv10 = nn.Conv2d(filters[4], filters[4], kernel_size=1, stride=1)
        self.conv11 = nn.Conv2d(filters[4], out_ch, kernel_size=1, stride=1)
        self.deconv1 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv12 = nn.Conv2d(filters[3], out_ch, kernel_size=1, padding=0, stride=1)
        self.deconv2 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=4, padding=1, stride=2)
        self.conv13 = nn.Conv2d(filters[2], out_ch, kernel_size=3, padding=1, stride=1)
        self.deconv3 = nn.ConvTranspose2d(out_ch, out_ch, kernel_size=8, stride=8)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), 1)
        #print(x.shape)
        x = self.pool(self.relu(self.conv2(self.relu(self.conv1(x)))))
        #print(x.shape)
        x = self.pool(self.relu(self.conv4(self.relu(self.conv3(x)))))
        #print(x.shape)
        x1 = self.pool(self.relu(self.conv6(self.relu(self.conv5(x)))))
        #print(x1.shape)
        x2 = self.pool(self.relu(self.conv8(self.relu(self.conv7(x1)))))
        #print(x2.shape)
        x3 = self.pool(self.relu(self.conv8(self.relu(self.conv8(x2)))))
        #print(x3.shape)
        x4 = self.dropout(self.relu(self.conv9(x3)))
        #print(x4.shape)
        x5 = self.dropout(self.relu(self.conv10(x4)))
        #print(x5.shape)
        x6 = self.deconv1(self.conv11(x5))
        #print(x6.shape)
        x7 = self.conv12(x2)
        #print(x7.shape)
        x8 = self.deconv2(x7 + x6)
        print(x8.shape)
        x9 = self.conv13(x1)
        print(x9.shape)
        x10 = self.deconv3(x8 + x9)
        final = self.softmax(x10)

        return final



from torchsummary import summary
model = FCN_CD(in_ch = 6,out_ch =2)
writer.add_graph(model=model, input_to_model=example)
summary(model,input_size=[(3,256,256),(3,256,256)],batch_size = 2, device="cpu")

torch.Size([1, 2, 32, 32])
torch.Size([1, 2, 32, 32])
torch.Size([1, 2, 32, 32])
torch.Size([1, 2, 32, 32])
torch.Size([1, 2, 32, 32])
torch.Size([1, 2, 32, 32])
torch.Size([2, 2, 32, 32])
torch.Size([2, 2, 32, 32])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [2, 64, 256, 256]           3,520
              ReLU-2          [2, 64, 256, 256]               0
            Conv2d-3          [2, 64, 256, 256]          36,928
              ReLU-4          [2, 64, 256, 256]               0
         MaxPool2d-5          [2, 64, 128, 128]               0
            Conv2d-6         [2, 128, 128, 128]          73,856
              ReLU-7         [2, 128, 128, 128]               0
            Conv2d-8         [2, 128, 128, 128]         147,584
              ReLU-9         [2, 128, 128, 128]               0
        MaxPool2d-10           [2, 128, 64, 64]               0
           Conv

In [None]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class BottleNeck(nn.Module):
    """Residual block for resnet over 50 layers
    """
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))


class DCM(nn.Module):
    def __init__(self, in_C, out_C):
        super(DCM, self).__init__()
        self.ks = [1, 3, 5]
        if in_C == 2048:
            self.mid_C = in_C // 4
        if in_C == 1024:
            self.mid_C = in_C // 2
        if in_C == 512:
            self.mid_C = in_C
        if in_C == 256:
            self.mid_C = in_C
        if in_C == 128:
            self.mid_C = in_C
        if in_C == 64:
            self.mid_C = in_C
        self.ger_kernel_branches = nn.ModuleList()
        for k in self.ks:
            self.ger_kernel_branches.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(k),
                    nn.Conv2d(in_C, self.mid_C, kernel_size=1)
                )
            )

        self.trans_branches = nn.ModuleList()
        self.fuse_inside_branches = nn.ModuleList()
        for i in range(len(self.ks)):
            self.trans_branches.append(
                nn.Conv2d(in_C, self.mid_C, kernel_size=1)
            )
            self.fuse_inside_branches.append(
                nn.Conv2d(self.mid_C, self.mid_C, 1)
            )

        self.fuse_outside = nn.Conv2d(len(self.ks) * self.mid_C + in_C, out_C, 1)

    def forward(self, x, y):
        """
        x: 被卷积的特征
        y: 用来生成卷积核
        """
        feats_branches = [x]
        for i in range(len(self.ks)):
            kernel = self.ger_kernel_branches[i](y)
            kernel_single = kernel.split(1, dim=0)
            x_inside = self.trans_branches[i](x)
            x_inside_single = x_inside.split(1, dim=0)
            feat_single = []
            for kernel_single_item, x_inside_single_item \
                    in zip(kernel_single, x_inside_single):
                feat_inside_single = self.fuse_inside_branches[i](
                    F.conv2d(
                        x_inside_single_item,
                        weight=kernel_single_item.transpose(0, 1),
                        bias=None,
                        stride=1,
                        padding=self.ks[i] // 2,
                        dilation=1,
                        groups=self.mid_C
                    )
                )
                feat_single.append(feat_inside_single)
            feat_single = torch.cat(feat_single, dim=0)
            feats_branches.append(feat_single)
        return self.fuse_outside(torch.cat(feats_branches, dim=1))


class ContextBlock(nn.Module):
    def __init__(self, inplanes, ratio, pooling_type='att',
                 fusion_types=('channel_add',)):
        super(ContextBlock, self).__init__()
        valid_fusion_types = ['channel_add', 'channel_mul']

        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'

        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types

        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)
        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term
        return out


class MSAANet(nn.Module):

    def __init__(self, in_channel, out_channel, block, num_block, _print=True):
        super(MSAANet, self).__init__()
        if _print:
            print("Constructing MSAANet model...")
            print("Backbone: Resnet-101")
            print("Number of classes: {}".format(out_channel))
            print("Number of Input Channels: {}".format(in_channel))
        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # we use a different inputsize than the original paper
        # so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 2)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        # self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_1 = nn.Conv2d(3072, 1024, 1)
        self.dconv_up3 = double_conv(1024, 512)
        self.conv_2 = nn.Conv2d(1024, 512, 1)
        self.dconv_up2 = double_conv(512, 256)
        self.conv_3 = nn.Conv2d(512, 256, 1)
        self.dconv_up1 = double_conv(256, 128)
        #self.dconv_up0 = double_conv(192, 128)

        self.dconv_last = nn.Sequential(
            nn.Conv2d(192, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, out_channel, 1)
        )
        self.cb1 = ContextBlock(inplanes=64, ratio=1. / 16., pooling_type='att')
        self.dcm1 = DCM(in_C=64, out_C=64)
        self.cb2 = ContextBlock(inplanes=256, ratio=1. / 16., pooling_type='att')
        self.dcm2 = DCM(in_C=256, out_C=256)
        self.cb3 = ContextBlock(inplanes=512, ratio=1. / 16., pooling_type='att')
        self.dcm3 = DCM(in_C=512, out_C=512)
        self.cb4 = ContextBlock(inplanes=1024, ratio=1. / 16., pooling_type='att')
        self.dcm4 = DCM(in_C=1024, out_C=1024)
        self.cb5 = ContextBlock(inplanes=2048, ratio=1. / 16., pooling_type='att')
        self.dcm5 = DCM(in_C=2048, out_C=2048)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block
        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer

        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        ###print(x.shape) [1 6 256 256]
        conv1 = self.conv1(x)
        ###print(conv1.shape) [1 64 256 256]
        temp = self.maxpool(conv1)
        ###print(temp.shape)  [1 64 128 128]  stage1
        temp = self.dcm1(temp, temp)
        temp = self.cb1(temp)
        #print(temp.shape)
        conv2 = self.conv2_x(temp)
        ###print(conv2.shape) [1 256 64 64]   stage2
        conv2 = self.dcm2(conv2, conv2)
        conv2 = self.cb2(conv2)
        conv3 = self.conv3_x(conv2)
        ###print(conv3.shape) [1 512 32 32]   stage3
        conv3 = self.dcm3(conv3, conv3)
        conv3 = self.cb3(conv3)
        conv4 = self.conv4_x(conv3)
        ###print(conv4.shape) [1 1024 16 16]  stage4
        conv4 = self.dcm4(conv4, conv4)
        conv4 = self.cb4(conv4)
        bottle = self.conv5_x(conv4)
        ###print(bottle.shape) [1 2048 8 8]   stage5
        # output = self.avg_pool(output)
        # output = output.view(output.size(0), -1)
        # output = self.fc(output)
        bottle = self.dcm5(bottle, bottle)
        bottle = self.cb5(bottle)
        x = self.upsample(bottle)
        ###print(x.shape)  [1 2048 16 16]
        # print(x.shape)
        # print(conv4.shape)
        x = torch.cat([x, conv4], dim=1)
        ###print(x.shape)  [1 3072 16 16]
        x = self.conv_1(x)  ###[1 1024 16 16]
        x = self.dconv_up3(x)  ###[1 512 16 16]
        ###print(x.shape)###
        x = self.upsample(x)

        x = torch.cat([x, conv3], dim=1)
        ###print(x.shape)   [1 512 32 32]
        # print(x.shape)
        # print(conv3.shape) dim=1)###[1 1024 32 32]
        x = self.conv_2(x)  ###[1 512 32 32]
        x = self.dconv_up2(x)  ###[1 256 32 32]
        x = self.upsample(x)  ###[1 256 64 64]
        x = torch.cat([x, conv2], dim=1)  ###[1 512 64 64]
        x = self.conv_3(x)
        x = self.dconv_up1(x)
        x = self.upsample(x)  ###[1 128 128 128]
        x = torch.cat([x, temp], dim=1)  ###[1 192 128 128]
        out = self.dconv_last(x)
        #print(out.shape)
        #x = F.softmax(x,dim=1)
        out = torch.sigmoid(out)
        return out


def Get_MSAANet(in_channel=6, out_channel=2):
    return MSAANet(in_channel, out_channel, block=BottleNeck, num_block=[3, 4, 23, 3], _print=True)


if __name__ == "__main__":
    model = Get_MSAANet(6, 2)
    model.eval()
    image1 = torch.randn(1, 3, 256, 256)
    image2 = torch.randn(1, 3, 256, 256)
    with torch.no_grad():
        output = model.forward(image1, image2)



In [None]:
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class BottleNeck(nn.Module):
    """Residual block for resnet over 50 layers
    """
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))


class DCM(nn.Module):
    def __init__(self, in_C, out_C):
        super(DCM, self).__init__()
        self.ks = [1, 3, 5]
        if in_C == 2048:
            self.mid_C = in_C // 4
        if in_C == 1024:
            self.mid_C = in_C // 2
        if in_C == 512:
            self.mid_C = in_C
        if in_C == 256:
            self.mid_C = in_C
        if in_C == 128:
            self.mid_C = in_C
        if in_C == 64:
            self.mid_C = in_C
        self.ger_kernel_branches = nn.ModuleList()
        for k in self.ks:
            self.ger_kernel_branches.append(
                nn.Sequential(
                    nn.AdaptiveAvgPool2d(k),
                    nn.Conv2d(in_C, self.mid_C, kernel_size=1)
                )
            )

        self.trans_branches = nn.ModuleList()
        self.fuse_inside_branches = nn.ModuleList()
        for i in range(len(self.ks)):
            self.trans_branches.append(
                nn.Conv2d(in_C, self.mid_C, kernel_size=1)
            )
            self.fuse_inside_branches.append(
                nn.Conv2d(self.mid_C, self.mid_C, 1)
            )

        self.fuse_outside = nn.Conv2d(len(self.ks) * self.mid_C + in_C, out_C, 1)

    def forward(self, x, y):
        """
        x: 被卷积的特征
        y: 用来生成卷积核
        """
        feats_branches = [x]
        for i in range(len(self.ks)):
            kernel = self.ger_kernel_branches[i](y)
            kernel_single = kernel.split(1, dim=0)
            x_inside = self.trans_branches[i](x)
            x_inside_single = x_inside.split(1, dim=0)
            feat_single = []
            for kernel_single_item, x_inside_single_item \
                    in zip(kernel_single, x_inside_single):
                feat_inside_single = self.fuse_inside_branches[i](
                    F.conv2d(
                        x_inside_single_item,
                        weight=kernel_single_item.transpose(0, 1),
                        bias=None,
                        stride=1,
                        padding=self.ks[i] // 2,
                        dilation=1,
                        groups=self.mid_C
                    )
                )
                feat_single.append(feat_inside_single)
            feat_single = torch.cat(feat_single, dim=0)
            feats_branches.append(feat_single)
        return self.fuse_outside(torch.cat(feats_branches, dim=1))
        #return torch.cat((fuse_outside,x),dim=1)


class ContextBlock(nn.Module):
    def __init__(self, inplanes, ratio, pooling_type='att',
                 fusion_types=('channel_add',)):
        super(ContextBlock, self).__init__()
        valid_fusion_types = ['channel_add', 'channel_mul']

        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'

        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types

        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)
        return context

    def forward(self, x, x1):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)
        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term
        out = torch.cat((out, x1), dim=1)
        return out


class MSAANet(nn.Module):

    def __init__(self, in_channel, out_channel, block, num_block):
        super(MSAANet, self).__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # we use a different inputsize than the original paper
        # so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 2)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        # self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_1 = nn.Conv2d(3584, 2048, 1)
        self.dconv_up3 = double_conv(2048, 1024)
        self.conv_2 = nn.Conv2d(1536, 512, 1)
        self.dconv_up2 = double_conv(512, 256)
        self.conv_3 = nn.Conv2d(512, 256, 1)
        self.dconv_up1 = double_conv(256, 128)
        #self.dconv_up0 = double_conv(192, 128)

        self.dconv_last = nn.Sequential(
            nn.Conv2d(192, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, out_channel, 1)
        )
        self.cb1 = ContextBlock(inplanes=64, ratio=1. / 16., pooling_type='att')
        self.dcm1 = DCM(in_C=64, out_C=64)
        self.cb2 = ContextBlock(inplanes=256, ratio=1. / 16., pooling_type='att')
        self.dcm2 = DCM(in_C=256, out_C=256)
        self.cb3 = ContextBlock(inplanes=512, ratio=1. / 16., pooling_type='att')
        self.dcm3 = DCM(in_C=512, out_C=512)
        self.cb4 = ContextBlock(inplanes=512, ratio=1. / 16., pooling_type='att')
        self.dcm4 = DCM(in_C=1024, out_C=512)
        self.cb5 = ContextBlock(inplanes=512, ratio=1. / 16., pooling_type='att')
        self.dcm5 = DCM(in_C=2048, out_C=512)

        self.conv1_1x = nn.Conv2d(128, 64, 1)
        self.conv2_1x = nn.Conv2d(512, 256, 1)
        self.conv3_1x = nn.Conv2d(1024, 512, 1)
        self.conv4_1x = nn.Conv2d(1536, 1024, 1)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block
        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer

        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        ###print(x.shape) [1 6 256 256]
        conv1 = self.conv1(x)
        ###print(conv1.shape) [1 64 256 256]
        temp = self.maxpool(conv1)
        temp_1 = temp
        ###print(temp.shape)  [1 64 128 128]  stage1
        temp = self.dcm1(temp, temp)
        temp = self.cb1(temp, temp_1)
        #print(temp.shape)
        temp = self.conv1_1x(temp)
        conv2 = self.conv2_x(temp)
        ###print(conv2.shape)##
        ###print(conv2.shape) [1 256 64 64]   stage2
        conv2_1 = conv2
        conv2 = self.dcm2(conv2, conv2)
        conv2 = self.cb2(conv2, conv2_1)
        ##print(conv2.shape)[1 512 64 64]
        conv2 = self.conv2_1x(conv2)
        conv3 = self.conv3_x(conv2)
        ###print(conv3.shape) ###[1 512 32 32]   stage3
        conv3_1 = conv3
        conv3 = self.dcm3(conv3, conv3)
        conv3 = self.cb3(conv3, conv3_1)
        conv3 = self.conv3_1x(conv3)
        conv4 = self.conv4_x(conv3)
        ##print(conv4.shape) ###[1 1024 16 16]  stage4
        conv4_1 = conv4
        conv4 = self.dcm4(conv4, conv4)
        ###print(conv4.shape)
        conv4 = self.cb4(conv4, conv4_1)
        ##print(conv4.shape)##[1 1536 16 16]
        conv4 = self.conv4_1x(conv4)
        bottle = self.conv5_x(conv4)
        ###print(bottle.shape) [1 2048 8 8]   stage5
        # output = self.avg_pool(output)
        # output = output.view(output.size(0), -1)
        # output = self.fc(output)
        botttle_1 = bottle
        bottle = self.dcm5(bottle, bottle)
        bottle = self.cb5(bottle, botttle_1)
        x = self.upsample(bottle)
        ###print(x.shape)  [1 2048 16 16]
        # print(x.shape)
        # print(conv4.shape)
        x = torch.cat([x, conv4], dim=1)
        ###print(x.shape)  [1 3072 16 16]
        x = self.conv_1(x)  ###[1 1024 16 16]
        x = self.dconv_up3(x)  ###[1 512 16 16]
        ###print(x.shape)###
        x = self.upsample(x)

        x = torch.cat([x, conv3], dim=1)
        ###print(x.shape)   [1 512 32 32]
        # print(x.shape)
        # print(conv3.shape) dim=1)###[1 1024 32 32]
        x = self.conv_2(x)  ###[1 512 32 32]
        x = self.dconv_up2(x)  ###[1 256 32 32]
        x = self.upsample(x)  ###[1 256 64 64]
        x = torch.cat([x, conv2], dim=1)  ###[1 512 64 64]
        x = self.conv_3(x)
        x = self.dconv_up1(x)
        x = self.upsample(x)  ###[1 128 128 128]
        x = torch.cat([x, temp], dim=1)  ###[1 192 128 128]
        out = self.dconv_last(x)
        #print(out.shape)
        #x = F.softmax(x,dim=1)
        out = torch.sigmoid(out)
        return out


def Get_MSAANet(in_channel=6, out_channel=2):
    return MSAANet(in_channel, out_channel, block=BottleNeck, num_block=[3, 4, 23, 3])


if __name__ == "__main__":
    model = Get_MSAANet(6, 2)
    model.eval()
    image1 = torch.randn(1, 3, 256, 256)
    image2 = torch.randn(1, 3, 256, 256)
    with torch.no_grad():
        output = model.forward(image1, image2)



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class conv_block_nested(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(conv_block_nested, self).__init__()

        self.act = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)

        return x


class NestedUNet_CD(nn.Module):
    def __init__(self, in_ch=6, out_ch=2):
        super(NestedUNet_CD, self).__init__()

        n1 = 64
        filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv0_0 = conv_block_nested(in_ch, filters[0], filters[0])
        self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
        self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
        self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
        self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])

        self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0])
        self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
        self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
        self.conv3_1 = conv_block_nested(filters[3] + filters[4], filters[3], filters[3])

        self.conv0_2 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0])
        self.conv1_2 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1])
        self.conv2_2 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2])

        self.conv0_3 = conv_block_nested(filters[0] * 3 + filters[1], filters[0], filters[0])
        self.conv1_3 = conv_block_nested(filters[1] * 3 + filters[2], filters[1], filters[1])

        self.conv0_4 = conv_block_nested(filters[0] * 4 + filters[1], filters[0], filters[0])

        self.final1 = nn.Conv2d(filters[0], out_ch, kernel_size=1)
        self.sigmoid = nn.Sigmoid()
        self.final2 = nn.Conv2d(out_ch * 4, out_ch, kernel_size=1)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), 1)

        x0_0 = self.conv0_0(x)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up(x1_0)], 1))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up(x1_1)], 1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up(x2_1)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up(x1_2)], 1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.Up(x4_0)], 1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.Up(x3_1)], 1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.Up(x2_2)], 1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.Up(x1_3)], 1))

        x0_1 = self.final1(x0_1)
        y0_1 = self.sigmoid(x0_1)
        x0_2 = self.final1(x0_2)
        y0_2 = self.sigmoid(x0_2)
        x0_3 = self.final1(x0_3)
        y0_3 = self.sigmoid(x0_3)
        x0_4 = self.final1(x0_4)
        y0_4 = self.sigmoid(x0_4)

        y0_5 = self.sigmoid(self.final2(torch.cat([y0_1, y0_2, y0_3, y0_4], 1)))

        return y0_5


"""
from torchsummary import summary
model = NestedUNet_CD(in_ch = 6,out_ch =2)
summary(model,input_size=[(3,256,256),(3,256,256)],batch_size = 2, device="cpu")
"""

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d


class SiamUnet_conc(nn.Module):
    """SiamUnet_conc segmentation network."""

    def __init__(self, input_nbr, label_nbr):
        super(SiamUnet_conc, self).__init__()

        self.input_nbr = input_nbr

        self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(16)
        self.do11 = nn.Dropout2d(p=0.2)
        self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(16)
        self.do12 = nn.Dropout2d(p=0.2)

        self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(32)
        self.do21 = nn.Dropout2d(p=0.2)
        self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(32)
        self.do22 = nn.Dropout2d(p=0.2)

        self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(64)
        self.do31 = nn.Dropout2d(p=0.2)
        self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(64)
        self.do32 = nn.Dropout2d(p=0.2)
        self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(64)
        self.do33 = nn.Dropout2d(p=0.2)

        self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(128)
        self.do41 = nn.Dropout2d(p=0.2)
        self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(128)
        self.do42 = nn.Dropout2d(p=0.2)
        self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(128)
        self.do43 = nn.Dropout2d(p=0.2)

        self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv43d = nn.ConvTranspose2d(384, 128, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(128)
        self.do43d = nn.Dropout2d(p=0.2)
        self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(128)
        self.do42d = nn.Dropout2d(p=0.2)
        self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(64)
        self.do41d = nn.Dropout2d(p=0.2)

        self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv33d = nn.ConvTranspose2d(192, 64, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(64)
        self.do33d = nn.Dropout2d(p=0.2)
        self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(64)
        self.do32d = nn.Dropout2d(p=0.2)
        self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(32)
        self.do31d = nn.Dropout2d(p=0.2)

        self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv22d = nn.ConvTranspose2d(96, 32, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(32)
        self.do22d = nn.Dropout2d(p=0.2)
        self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(16)
        self.do21d = nn.Dropout2d(p=0.2)

        self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv12d = nn.ConvTranspose2d(48, 16, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(16)
        self.do12d = nn.Dropout2d(p=0.2)
        self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)

        self.sm = nn.Softmax(dim=1)

    def forward(self, x1, x2):
        """Forward method."""
        # Stage 1
        x11_1 = self.do11(F.relu(self.bn11(self.conv11(x1))))
        #print(x11.shape)
        x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11_1))))
        #print(x12_1.shape)
        x1p_1 = F.max_pool2d(x12_1, kernel_size=2, stride=2)
        #print(x1p.shape)

        # Stage 2
        x21_1 = self.do21(F.relu(self.bn21(self.conv21(x1p_1))))
        x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21_1))))
        x2p_1 = F.max_pool2d(x22_1, kernel_size=2, stride=2)

        # Stage 3
        x31_1 = self.do31(F.relu(self.bn31(self.conv31(x2p_1))))
        x32_1 = self.do32(F.relu(self.bn32(self.conv32(x31_1))))
        x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32_1))))
        x3p_1 = F.max_pool2d(x33_1, kernel_size=2, stride=2)

        # Stage 4
        x41_1 = self.do41(F.relu(self.bn41(self.conv41(x3p_1))))
        x42_1 = self.do42(F.relu(self.bn42(self.conv42(x41_1))))
        x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42_1))))
        x4p_1 = F.max_pool2d(x43_1, kernel_size=2, stride=2)

        ####################################################
        # Stage 1
        x11_2 = self.do11(F.relu(self.bn11(self.conv11(x2))))
        x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11_2))))
        x1p_2 = F.max_pool2d(x12_2, kernel_size=2, stride=2)

        # Stage 2
        x21_2 = self.do21(F.relu(self.bn21(self.conv21(x1p_2))))
        x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21_2))))
        x2p_2 = F.max_pool2d(x22_2, kernel_size=2, stride=2)

        # Stage 3
        x31_2 = self.do31(F.relu(self.bn31(self.conv31(x2p_2))))
        x32_2 = self.do32(F.relu(self.bn32(self.conv32(x31_2))))
        x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32_2))))
        x3p_2 = F.max_pool2d(x33_2, kernel_size=2, stride=2)

        # Stage 4
        x41_2 = self.do41(F.relu(self.bn41(self.conv41(x3p_2))))
        x42_2 = self.do42(F.relu(self.bn42(self.conv42(x41_2))))
        x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42_2))))
        x4p_2 = F.max_pool2d(x43_2, kernel_size=2, stride=2)

        ####################################################
        # Stage 4d
        x4d = self.upconv4(x4p_2)
        pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), x43_1, x43_2), 1)
        x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
        x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
        x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), x33_1, x33_2), 1)
        x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
        x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
        x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), x22_1, x22_2), 1)
        x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
        x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), x12_1, x12_2), 1)
        x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
        x11d = self.conv11d(x12d)

        return self.sm(x11d)


"""
from torchsummary import summary
model = SiamUnet_conc(input_nbr=3, label_nbr=2)
summary(model,input_size=[(3,256,256),(3,256,256)],batch_size = 2, device="cpu")
"""

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d


class SiamUnet_diff(nn.Module):
    """SiamUnet_diff segmentation network."""

    def __init__(self, input_nbr, label_nbr):
        super(SiamUnet_diff, self).__init__()

        self.input_nbr = input_nbr

        self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(16)
        self.do11 = nn.Dropout2d(p=0.2)
        self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(16)
        self.do12 = nn.Dropout2d(p=0.2)

        self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(32)
        self.do21 = nn.Dropout2d(p=0.2)
        self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(32)
        self.do22 = nn.Dropout2d(p=0.2)

        self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(64)
        self.do31 = nn.Dropout2d(p=0.2)
        self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(64)
        self.do32 = nn.Dropout2d(p=0.2)
        self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(64)
        self.do33 = nn.Dropout2d(p=0.2)

        self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(128)
        self.do41 = nn.Dropout2d(p=0.2)
        self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(128)
        self.do42 = nn.Dropout2d(p=0.2)
        self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(128)
        self.do43 = nn.Dropout2d(p=0.2)

        self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(128)
        self.do43d = nn.Dropout2d(p=0.2)
        self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(128)
        self.do42d = nn.Dropout2d(p=0.2)
        self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(64)
        self.do41d = nn.Dropout2d(p=0.2)

        self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(64)
        self.do33d = nn.Dropout2d(p=0.2)
        self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(64)
        self.do32d = nn.Dropout2d(p=0.2)
        self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(32)
        self.do31d = nn.Dropout2d(p=0.2)

        self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(32)
        self.do22d = nn.Dropout2d(p=0.2)
        self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(16)
        self.do21d = nn.Dropout2d(p=0.2)

        self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(16)
        self.do12d = nn.Dropout2d(p=0.2)
        self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)

        self.sm = nn.Softmax(dim=1)

    def forward(self, x1, x2):
        """Forward method."""
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(x1))))
        x12_1 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_1, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_1 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_1, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_1 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_1, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_1 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_1, kernel_size=2, stride=2)

        ####################################################
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(x2))))
        x12_2 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12_2, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22_2 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22_2, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33_2 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33_2, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43_2 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43_2, kernel_size=2, stride=2)

        # Stage 4d
        x4d = self.upconv4(x4p)
        pad4 = ReplicationPad2d((0, x43_1.size(3) - x4d.size(3), 0, x43_1.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), torch.abs(x43_1 - x43_2)), 1)
        x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
        x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
        x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d((0, x33_1.size(3) - x3d.size(3), 0, x33_1.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), torch.abs(x33_1 - x33_2)), 1)
        x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
        x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
        x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d((0, x22_1.size(3) - x2d.size(3), 0, x22_1.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), torch.abs(x22_1 - x22_2)), 1)
        x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
        x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d((0, x12_1.size(3) - x1d.size(3), 0, x12_1.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), torch.abs(x12_1 - x12_2)), 1)
        x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
        x11d = self.conv11d(x12d)

        return self.sm(x11d)


"""
from torchsummary import summary
model = SiamUnet_diff(input_nbr=3, label_nbr=2)
summary(model,input_size=[(3,256,256),(3,256,256)],batch_size = 1, device="cpu")
"""

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.padding import ReplicationPad2d


class Unet(nn.Module):
    """EF segmentation network."""

    def __init__(self, input_nbr, label_nbr):
        super(Unet, self).__init__()

        self.input_nbr = input_nbr

        self.conv11 = nn.Conv2d(input_nbr, 16, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(16)
        self.do11 = nn.Dropout2d(p=0.2)
        self.conv12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(16)
        self.do12 = nn.Dropout2d(p=0.2)

        self.conv21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(32)
        self.do21 = nn.Dropout2d(p=0.2)
        self.conv22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(32)
        self.do22 = nn.Dropout2d(p=0.2)

        self.conv31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(64)
        self.do31 = nn.Dropout2d(p=0.2)
        self.conv32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(64)
        self.do32 = nn.Dropout2d(p=0.2)
        self.conv33 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(64)
        self.do33 = nn.Dropout2d(p=0.2)

        self.conv41 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(128)
        self.do41 = nn.Dropout2d(p=0.2)
        self.conv42 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(128)
        self.do42 = nn.Dropout2d(p=0.2)
        self.conv43 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(128)
        self.do43 = nn.Dropout2d(p=0.2)

        self.upconv4 = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv43d = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(128)
        self.do43d = nn.Dropout2d(p=0.2)
        self.conv42d = nn.ConvTranspose2d(128, 128, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(128)
        self.do42d = nn.Dropout2d(p=0.2)
        self.conv41d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(64)
        self.do41d = nn.Dropout2d(p=0.2)

        self.upconv3 = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv33d = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(64)
        self.do33d = nn.Dropout2d(p=0.2)
        self.conv32d = nn.ConvTranspose2d(64, 64, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(64)
        self.do32d = nn.Dropout2d(p=0.2)
        self.conv31d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(32)
        self.do31d = nn.Dropout2d(p=0.2)

        self.upconv2 = nn.ConvTranspose2d(32, 32, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv22d = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(32)
        self.do22d = nn.Dropout2d(p=0.2)
        self.conv21d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(16)
        self.do21d = nn.Dropout2d(p=0.2)

        self.upconv1 = nn.ConvTranspose2d(16, 16, kernel_size=3, padding=1, stride=2, output_padding=1)

        self.conv12d = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(16)
        self.do12d = nn.Dropout2d(p=0.2)
        # self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=1, padding=1)
        self.conv11d = nn.ConvTranspose2d(16, label_nbr, kernel_size=3, padding=1)
        #self.sm = nn.Softmax(dim=1)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), 1)

        """Forward method."""
        # Stage 1
        x11 = self.do11(F.relu(self.bn11(self.conv11(x))))
        x12 = self.do12(F.relu(self.bn12(self.conv12(x11))))
        x1p = F.max_pool2d(x12, kernel_size=2, stride=2)

        # Stage 2
        x21 = self.do21(F.relu(self.bn21(self.conv21(x1p))))
        x22 = self.do22(F.relu(self.bn22(self.conv22(x21))))
        x2p = F.max_pool2d(x22, kernel_size=2, stride=2)

        # Stage 3
        x31 = self.do31(F.relu(self.bn31(self.conv31(x2p))))
        x32 = self.do32(F.relu(self.bn32(self.conv32(x31))))
        x33 = self.do33(F.relu(self.bn33(self.conv33(x32))))
        x3p = F.max_pool2d(x33, kernel_size=2, stride=2)

        # Stage 4
        x41 = self.do41(F.relu(self.bn41(self.conv41(x3p))))
        x42 = self.do42(F.relu(self.bn42(self.conv42(x41))))
        x43 = self.do43(F.relu(self.bn43(self.conv43(x42))))
        x4p = F.max_pool2d(x43, kernel_size=2, stride=2)

        # Stage 4d
        x4d = self.upconv4(x4p)
        pad4 = ReplicationPad2d((0, x43.size(3) - x4d.size(3), 0, x43.size(2) - x4d.size(2)))
        x4d = torch.cat((pad4(x4d), x43), 1)
        x43d = self.do43d(F.relu(self.bn43d(self.conv43d(x4d))))
        x42d = self.do42d(F.relu(self.bn42d(self.conv42d(x43d))))
        x41d = self.do41d(F.relu(self.bn41d(self.conv41d(x42d))))

        # Stage 3d
        x3d = self.upconv3(x41d)
        pad3 = ReplicationPad2d((0, x33.size(3) - x3d.size(3), 0, x33.size(2) - x3d.size(2)))
        x3d = torch.cat((pad3(x3d), x33), 1)
        x33d = self.do33d(F.relu(self.bn33d(self.conv33d(x3d))))
        x32d = self.do32d(F.relu(self.bn32d(self.conv32d(x33d))))
        x31d = self.do31d(F.relu(self.bn31d(self.conv31d(x32d))))

        # Stage 2d
        x2d = self.upconv2(x31d)
        pad2 = ReplicationPad2d((0, x22.size(3) - x2d.size(3), 0, x22.size(2) - x2d.size(2)))
        x2d = torch.cat((pad2(x2d), x22), 1)
        x22d = self.do22d(F.relu(self.bn22d(self.conv22d(x2d))))
        x21d = self.do21d(F.relu(self.bn21d(self.conv21d(x22d))))

        # Stage 1d
        x1d = self.upconv1(x21d)
        pad1 = ReplicationPad2d((0, x12.size(3) - x1d.size(3), 0, x12.size(2) - x1d.size(2)))
        x1d = torch.cat((pad1(x1d), x12), 1)
        x12d = self.do12d(F.relu(self.bn12d(self.conv12d(x1d))))
        x11d = self.conv11d(x12d)
        x = torch.sigmoid(x11d)
        print(x.shape)
        return x



from torchsummary import summary
model = Unet(input_nbr=6, label_nbr=2)
summary(model,input_size=[(3,256,256),(3,256,256)],batch_size = 1, device="cpu")

torch.Size([2, 2, 254, 254])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [1, 16, 256, 256]             880
       BatchNorm2d-2          [1, 16, 256, 256]              32
         Dropout2d-3          [1, 16, 256, 256]               0
            Conv2d-4          [1, 16, 256, 256]           2,320
       BatchNorm2d-5          [1, 16, 256, 256]              32
         Dropout2d-6          [1, 16, 256, 256]               0
            Conv2d-7          [1, 32, 128, 128]           4,640
       BatchNorm2d-8          [1, 32, 128, 128]              64
         Dropout2d-9          [1, 32, 128, 128]               0
           Conv2d-10          [1, 32, 128, 128]           9,248
      BatchNorm2d-11          [1, 32, 128, 128]              64
        Dropout2d-12          [1, 32, 128, 128]               0
           Conv2d-13            [1, 64, 64, 64]          18,496
      Batc

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict


class SeparableConv2d(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, relu_first=True,
                 bias=False, norm_layer=nn.BatchNorm2d):
        super().__init__()
        depthwise = nn.Conv2d(inplanes, inplanes, kernel_size,
                              stride=stride, padding=dilation,
                              dilation=dilation, groups=inplanes, bias=bias)
        bn_depth = norm_layer(inplanes)
        pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
        bn_point = norm_layer(planes)

        if relu_first:
            self.block = nn.Sequential(OrderedDict([('relu', nn.ReLU()),
                                                    ('depthwise', depthwise),
                                                    ('bn_depth', bn_depth),
                                                    ('pointwise', pointwise),
                                                    ('bn_point', bn_point)
                                                    ]))
        else:
            self.block = nn.Sequential(OrderedDict([('depthwise', depthwise),
                                                    ('bn_depth', bn_depth),
                                                    ('relu1', nn.ReLU(inplace=True)),
                                                    ('pointwise', pointwise),
                                                    ('bn_point', bn_point),
                                                    ('relu2', nn.ReLU(inplace=True))
                                                    ]))

    def forward(self, x):
        return self.block(x)


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, dilation=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=2, dilation=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=3, dilation=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class _ASPP(nn.Module):
    def __init__(self, in_channels=2048, out_channels=256):
        super().__init__()

        dilations = [6, 12, 18]

        self.aspp0 = nn.Sequential(OrderedDict([('conv', nn.Conv2d(in_channels, out_channels, 1, bias=False)),
                                                ('bn', nn.BatchNorm2d(out_channels)),
                                                ('relu', nn.ReLU(inplace=True))]))
        self.aspp1 = SeparableConv2d(in_channels, out_channels, dilation=dilations[0], relu_first=False)
        self.aspp2 = SeparableConv2d(in_channels, out_channels, dilation=dilations[1], relu_first=False)
        self.aspp3 = SeparableConv2d(in_channels, out_channels, dilation=dilations[2], relu_first=False)

        self.image_pooling = nn.Sequential(OrderedDict([('gap', nn.AdaptiveAvgPool2d((1, 1))),
                                                        ('conv', nn.Conv2d(in_channels, out_channels, 1, bias=False)),
                                                        ('bn', nn.BatchNorm2d(out_channels)),
                                                        ('relu', nn.ReLU(inplace=True))]))

        self.conv = nn.Conv2d(out_channels * 5, out_channels, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout2d(p=0.1)

    def forward(self, x):
        pool = self.image_pooling(x)
        pool = F.interpolate(pool, size=x.shape[2:], mode='bilinear', align_corners=True)

        x0 = self.aspp0(x)
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x = torch.cat((pool, x0, x1, x2, x3), dim=1)

        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dropout(x)

        return x


class UNet_ASPP(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet_ASPP, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.aspp = _ASPP(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x_1, x_2):
        #print(x.shape)
        x = torch.cat((x_1, x_2), dim=1)
        x1 = self.inc(x)
        #print(x1.shape)
        x2 = self.down1(x1)
        #print(x2.shape)
        x3 = self.down2(x2)
        #print(x3.shape)
        x4 = self.down3(x3)
        #print(x4.shape)
        x5 = self.aspp(x4)
        #print(x5.shape)
        x = self.up1(x5, x4)
        #print(x.shape)
        x = self.up2(x, x3)
        #print(x.shape)
        x = self.up3(x, x2)
        #print(x.shape)
        x = self.up4(x, x1)
        #print(x.shape)
        x = self.outc(x)
        #print(x.shape)
        x = F.softmax(x, dim=1)
        print(x.shape)
        return x


"""
from torchsummary import summary
model = UNet_ASPP(n_channels=6,n_classes=2)
summary(model,input_size=[(3,256,256),(3,256,256)],batch_size = 2, device="cpu")
"""