In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

# Encoder based on ResNet
class ResNetEncoder(nn.Module):
    def __init__(self, resnet_type='resnet50', pretrained=True):
        super(ResNetEncoder, self).__init__()
        if resnet_type == 'resnet50':
            self.resnet = models.resnet50(pretrained=pretrained)
        elif resnet_type == 'resnet152':
            self.resnet = models.resnet152(pretrained=pretrained)
        
        # Remove the fully connected layer
        self.encoder = nn.Sequential(*list(self.resnet.children())[:-2])
        
    def forward(self, x):
        return self.encoder(x)

# Decoder with Dropout
class Decoder(nn.Module):
    def __init__(self, input_channels, output_channels, dropout_rate=0.3):
        super(Decoder, self).__init__()
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(input_channels, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout_rate),
            
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(64, output_channels, kernel_size=1),
            nn.Sigmoid()  # You can change this activation if needed
        )
    
    def forward(self, x):
        return self.decoder(x)

# Main Model with Two ResNet Encoders and One Decoder
class DualEncoderDecoder(nn.Module):
    def __init__(self, output_channels):
        super(DualEncoderDecoder, self).__init__()
        
        # Two encoders: ResNet50 and ResNet152
        self.encoder1 = ResNetEncoder(resnet_type='resnet50')
        self.encoder2 = ResNetEncoder(resnet_type='resnet152')
        
        # Decoder with Dropout
        self.decoder = Decoder(input_channels=4096, output_channels=output_channels)

    def forward(self, x1, x2):
        # Forward pass through both encoders
        enc1 = self.encoder1(x1)
        enc2 = self.encoder2(x2)
        
        # Concatenate the encoder outputs
        combined = torch.cat((enc1, enc2), dim=1)
        
        # Forward pass through the decoder
        return self.decoder(combined)



In [None]:
model = DualEncoderDecoder(output_channels=3)  # Example output for RGB images

In [None]:
input1 = torch.randn(2, 3, 256, 256)  # Input for ResNet50
input2 = torch.randn(2, 3, 256, 256)  # Input for ResNet152

output = model(input1, input2)

print(output.shape)  # Output should be (4, 3, 224, 224)