In [26]:
import torch
from torchvision.models import resnet18, vgg16, densenet121
from torch.nn import ConvTranspose2d, Conv2d
from torch.nn.functional import max_pool2d as maxpool2d

import torch.nn as nn

def attention(x, g):
        x = Conv2d(x.size(1), x.size(1), kernel_size=1, stride=1, padding=0)(x)
        g = Conv2d(g.size(1), x.size(1), kernel_size=1, stride=1, padding=0)(g)
        concat=x+g
        r=nn.ReLU(inplace=True)(concat)
        rconv=Conv2d(r.size(1), 1, kernel_size=1, stride=1, padding=0)(r)
        psi = torch.sigmoid(rconv)
        weight=nn.ConvTranspose2d(1, 1, kernel_size=1, stride=1, padding=0)(psi)
        weighted_x=x*weight
        res=weighted_x+g
        return res
class UNetResNet18(nn.Module):
    def __init__(self, num_classes):
        super(UNetResNet18, self).__init__()
        
        # Load ResNet-18 backbone
        resnet = resnet18(pretrained=True)
        vgg= vgg16(pretrained=True)
        densenet= densenet121(pretrained=True)
        self.encoder1 = nn.Sequential(*list(resnet.children())[:-2])  # Remove the fully connected layer and avgpool
        self.encoder2 = nn.Sequential(*list(vgg.children())[:-2])
        self.encoder3 = nn.Sequential(*list(densenet.children())[:-2])
        # Encoder layers
        self.enc11 = nn.Sequential(*list(resnet.children())[:3])  # First conv + BN + ReLU + MaxPool
        self.enc112 = nn.Sequential(*list(resnet.children())[3:4])
        self.enc12 = resnet.layer1
        self.enc13 = resnet.layer2
        self.enc14 = resnet.layer3
        self.enc15 = resnet.layer4
        # self.enc21 = nn.Sequential(*list(vgg.children())[:3])
        # self.enc22 = vgg.features[3:8]
        # self.enc23 = vgg.features[8:15]
        # self.enc24 = vgg.features[15:22]
        # self.enc25 = vgg.features[22:30]
        # self.enc31 = nn.Sequential(*list(densenet.children())[:3])
        # self.enc32 = densenet.denseblock1
        # self.enc33 = densenet.transition1
        # self.enc34 = densenet.denseblock2
        # self.enc35 = densenet.transition2
        # self.enc36 = densenet.denseblock3
        # self.enc37 = densenet.transition3
        # self.enc38 = densenet.denseblock4
        # self.enc39 = densenet.norm5     
        
        # Decoder layers
        self.dec4 = self._decoder_block(512, 256)
        self.dec3 = self._decoder_block(256, 128)
        self.dec2 = self._decoder_block(128, 64)
        self.dec1 = self._decoder_block(64, 64)
        
        # Final output layer
        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)
        
    def _decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )
    
    
    def forward(self, x):
        # Encoder
        enc1 = self.enc11(x)
        enc = self.enc112(enc1)
        enc2 = self.enc12(enc)
        enc3 = self.enc13(enc2)
        enc4 = self.enc14(enc3)
        enc5 = self.enc15(enc4)
        
        # Encoder
        # enc21 = self.enc21(x)
        # enc22 = self.enc22(enc21)
        # enc23 = self.enc23(enc22)
        # enc24 = self.enc24(enc23)
        # enc25 = self.enc25(enc24)
        
        # # Encoder
        # enc31 = self.enc31(x)
        # enc32 = self.enc32(enc31)
        # enc33 = self.enc33(enc32)
        # enc34 = self.enc34(enc33)
        # enc35 = self.enc35(enc34)
        # enc36 = self.enc36(enc35)
        # enc37 = self.enc37(enc36)
        # enc38 = self.enc38(enc37)
        # enc39 = self.enc39(enc38)
        # Concatenate encoder outputs
        # encf1 = torch.cat((enc1, enc21), dim=1)
        # encf2 = torch.cat((enc2, enc22), dim=1)
        # encf3 = torch.cat((enc3, enc23), dim=1)
        # encf4 = torch.cat((enc4, enc24), dim=1)
        # encf5 = torch.cat((enc5, enc25), dim=1)
        
        
        # Decoder
        # skip_enc4=Conv2d(256, 512, kernel_size=3, stride=1, padding='same')(enc4)
        # skip_enc4=maxpool2d(skip_enc4, kernel_size=2, stride=2)(skip_enc4)
        # skip_enc3=Conv2d(128, 256, kernel_size=3, stride=1, padding='same')
        # skip_enc3=maxpool2d(skip_enc3, kernel_size=2, stride=2)(skip_enc3)
        # skip_enc2=Conv2d(64, 128, kernel_size=3, stride=1, padding='same')
        # skip_enc2=maxpool2d(skip_enc2, kernel_size=2, stride=2)(skip_enc2)
        # skip_enc1=Conv2d(64, 64, kernel_size=3, stride=1, padding='same')
        # skip_enc1=maxpool2d(skip_enc1, kernel_size=2, stride=2)(skip_enc1)

        # gate_enc4=Conv2d(512, 512, kernel_size=3, stride=1, padding='same')
        # cat_enc4=skip_enc4+gate_enc4
        # l=nn.ReLU(inplace=True)(cat_enc4)

        print((self.dec4(enc5)).size(), enc4.size())
        dec4 = attention(self.dec4(enc5), enc4)
        print((self.dec3(dec4)).size(), enc3.size())
        dec3 = attention(self.dec3(dec4), enc3)
        print((self.dec2(dec3)).size(), enc2.size())
        dec2 = attention(self.dec2(dec3), enc2)
        print((self.dec1(dec2)).size(), enc1.size())
        dec1 = attention(self.dec1(dec2), enc1)
        
        # Final output
        out = self.final_conv(dec1)
        return out
    
if __name__ == "__main__":
    # Example usage
    model = UNetResNet18(num_classes=2)
    x = torch.randn(1, 3, 224, 224)  # Batch size of 1, 3 channels (RGB), 224x224 image
    output = model.forward(x)
    print(output.shape)  # Should be [1, num_classes, 224, 224]

torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 128, 28, 28]) torch.Size([1, 128, 28, 28])
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
torch.Size([1, 64, 112, 112]) torch.Size([1, 64, 112, 112])
torch.Size([1, 2, 112, 112])
