In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MDCBlock(nn.Module):
    def __init__(self, in_channels, out_channels, r1, r2, r3, d1, d2, d3, d4):
        super(MDCBlock, self).__init__()

        # Pre-Channel Mixer (Point-wise convolution)
        self.pre_mixer = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        # Dilated Convolutional Layers
        self.dcl1 = nn.Conv2d(out_channels//4, out_channels//4, kernel_size=3, dilation=d1, padding=d1)
				# dilation => d2
        self.dcl2 = nn.Conv2d(out_channels//4, out_channels//4, kernel_size=3, dilation=d2, padding=d2)
        # dilation => d3 
        self.dcl3 = nn.Conv2d(out_channels//4, out_channels//4, kernel_size=3, dilation=d3, padding=d3)
        # dilation => d4 
        self.dcl4 = nn.Conv2d(out_channels//4, out_channels//4, kernel_size=3, dilation=d4, padding=d4)

        # Post-Channel Mixer
        self.post_mixer = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
				
				# channel wise concat x 
	
        # Pre-Channel Mixer
        x = self.pre_mixer(x)
        print("x shape ", x.shape)
        #x shape  torch.Size([1, 64, 32, 32])
        x1 = x[:,0:out_channels//4 * 1 ,:,:] 
        x2 = x[:,out_channels//4 * 1 : out_channels//4 * 2,:,:]
        x3 = x[:,out_channels//4 * 2 : out_channels//4 * 3,:,:]
        x4 = x[:,out_channels//4 * 3 : out_channels//4 * 4,:,:]

        print(x1.shape)
        print(x2.shape)
        print(x3.shape)
        print(x4.shape)
        # Dilated Convolutional Layers
        dcl1_out = self.dcl1(x1)
        dcl2_out = self.dcl2(x2)
        dcl3_out = self.dcl3(x3)
        dcl4_out = self.dcl4(x4)

        # Concatenate the outputs along the channel dimension
        x = torch.cat([dcl1_out, dcl2_out, dcl3_out, dcl4_out], dim=1)

        print("concat",x.shape)
        # Post-Channel Mixer
        x = self.post_mixer(x)

        return x

# Example usage:
# Assuming input tensor x has shape (batch_size, in_channels, height, width)
# and you want to obtain output channels out_channels with receptive fields r1, r2, r3 and dilation rates d1, d2, d3
#in_channels = 32 # encoder output == ( 1, 56, 28, 1024 ) x 4 
#out_channels = 64 
r1, r2, r3, r4 = 3, 5, 7, 9
d1, d2, d3, d4 = 1, 2, 3, 4
#x = torch.randn(1, in_channels, 896, 448)  # Example input tensor


x = torch.randn(1, 1024, 56, 28)
print("Input shape:", x.shape)
mdc_block = MDCBlock(in_channels, out_channels, r1, r2, r3, d1, d2, d3, d4)
output = mdc_block(x)
print("Output shape:", output.shape)  # Output shape will depend on the input size and parameters of the block

x shape  torch.Size([1, 64, 896, 448])
torch.Size([1, 16, 896, 448])
torch.Size([1, 16, 896, 448])
torch.Size([1, 16, 896, 448])
torch.Size([1, 16, 896, 448])
concat torch.Size([1, 64, 896, 448])
Output shape: torch.Size([1, 64, 896, 448])
