In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MDCBlock(nn.Module):
    def __init__(self, in_channels, out_channels,  d1, d2, d3):
        super(MDCBlock, self).__init__()

        # Pre-Channel Mixer (Point-wise convolution)
        self.pre_mixer = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        # Dilated Convolutional Layers
        self.dcl1 = nn.Conv2d(out_channels//3, out_channels//3, kernel_size=3, dilation=d1, padding=d1)
				# dilation => d2
        self.dcl2 = nn.Conv2d(out_channels//3, out_channels//3, kernel_size=3, dilation=d2, padding=d2)
        # dilation => d3 
        self.dcl3 = nn.Conv2d(out_channels//3, out_channels//3, kernel_size=3, dilation=d3, padding=d3)


        # Post-Channel Mixer
        self.post_mixer = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
				
				# channel wise concat x 
	
        # Pre-Channel Mixer
        x = self.pre_mixer(x)
        print("x shape ", x.shape)
        #x shape  torch.Size([1, 64, 32, 32])
        x1 = x[:,0:out_channels//3 * 1 ,:,:] 
        x2 = x[:,out_channels//3 * 1 : out_channels//3 * 2,:,:]
        x3 = x[:,out_channels//3 * 2 : out_channels//3 * 3,:,:]

        print(x1.shape)
        print(x2.shape)
        print(x3.shape)
        # Dilated Convolutional Layers
        dcl1_out = self.dcl1(x1)
        dcl2_out = self.dcl2(x2)
        dcl3_out = self.dcl3(x3)

        # Concatenate the outputs along the channel dimension
        x = torch.cat([dcl1_out, dcl2_out, dcl3_out], dim=1)

        print("concat",x.shape)
        # Post-Channel Mixer
        x = self.post_mixer(x)

        return x

# Example usage:
# Assuming input tensor x has shape (batch_size, in_channels, height, width)
# and you want to obtain output channels out_channels with receptive fields r1, r2, r3 and dilation rates d1, d2, d3
#in_channels = 32 # encoder output == ( 1, 56, 28, 1024 ) x 4 
#out_channels = 64  
d1, d2, d3 = 3, 5, 7
#x = torch.randn(1, in_channels, 896, 448)  # Example input tensor

in_channels = 768
out_channels = 384

x = torch.randn(1, 768, 28, 14)
print("Input shape:", x.shape)
mdc_block = MDCBlock(in_channels, out_channels, d1, d2, d3)
output = mdc_block(x)
print("#---------------------")
print("Output shape:", output.shape)  # Output shape will depend on the input size and parameters of the block

Input shape: torch.Size([1, 768, 28, 14])
x shape  torch.Size([1, 384, 28, 14])
torch.Size([1, 128, 28, 14])
torch.Size([1, 128, 28, 14])
torch.Size([1, 128, 28, 14])
concat torch.Size([1, 384, 28, 14])
#---------------------
Output shape: torch.Size([1, 384, 28, 14])


In [4]:
d1, d2, d3 = 3,3,3

x = torch.randn(1, 768, 28, 14)
print("Input shape:", x.shape)
mdc_block = MDCBlock(in_channels, out_channels, d1, d2, d3)
output = mdc_block(x)
print("#---------------------")
print("Output shape:", output.shape)  # Output shape will depend on the input size and parameters of the block


Input shape: torch.Size([1, 768, 28, 14])
x shape  torch.Size([1, 384, 28, 14])
torch.Size([1, 128, 28, 14])
torch.Size([1, 128, 28, 14])
torch.Size([1, 128, 28, 14])
concat torch.Size([1, 384, 28, 14])
#---------------------
Output shape: torch.Size([1, 384, 28, 14])


In [5]:
d1, d2, d3 = 1,3,3

x = torch.randn(1, 768, 28, 14)
print("Input shape:", x.shape)
mdc_block = MDCBlock(in_channels, out_channels, d1, d2, d3)
output = mdc_block(x)
print("#---------------------")
print("Output shape:", output.shape)  # Output shape will depend on the input size and parameters of the block


Input shape: torch.Size([1, 768, 28, 14])
x shape  torch.Size([1, 384, 28, 14])
torch.Size([1, 128, 28, 14])
torch.Size([1, 128, 28, 14])
torch.Size([1, 128, 28, 14])
concat torch.Size([1, 384, 28, 14])
#---------------------
Output shape: torch.Size([1, 384, 28, 14])


In [4]:
#----- 
import torch.nn as nn
import torch

class MDCBlock(nn.Module):
    def __init__(self, in_channels, out_channels,  d1, d2, d3):
        super(MDCBlock, self).__init__()

        self.out_channels = out_channels
        # Pre-Channel Mixer (Point-wise convolution)
        self.pre_mixer = nn.Conv2d(in_channels, self.out_channels, kernel_size=1)

        # Dilated Convolutional Layers
        self.dcl1 = nn.Conv2d(self.out_channels//3,self.out_channels//3, kernel_size=3, dilation=d1, padding=d1)
				# dilation => d2
        self.dcl2 = nn.Conv2d(self.out_channels//3, self.out_channels//3, kernel_size=3, dilation=d2, padding=d2)
        # dilation => d3 
        self.dcl3 = nn.Conv2d(self.out_channels//3, self.out_channels//3, kernel_size=3, dilation=d3, padding=d3)


        # Post-Channel Mixer
        self.post_mixer = nn.Sequential(
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
				
				# channel wise concat x 
	
        # Pre-Channel Mixer
        x = self.pre_mixer(x)
        print("x shape ", x.shape)
        #x shape  torch.Size([1, 64, 32, 32])
        x1 = x[:,0:self.out_channels//3 * 1 ,:,:] 
        x2 = x[:,self.out_channels//3 * 1 : self.out_channels//3 * 2,:,:]
        x3 = x[:,self.out_channels//3 * 2 : self.out_channels//3 * 3,:,:]

        print(x1.shape)
        print(x2.shape)
        print(x3.shape)
        # Dilated Convolutional Layers
        dcl1_out = self.dcl1(x1)
        dcl2_out = self.dcl2(x2)
        dcl3_out = self.dcl3(x3)

        # Concatenate the outputs along the channel dimension
        x = torch.cat([dcl1_out, dcl2_out, dcl3_out], dim=1)

        #print("concat",x.shape)
        # Post-Channel Mixer
        x = self.post_mixer(x)

        return x

In [9]:
d1, d2, d3 = 3, 5, 7
#x = torch.randn(1, in_channels, 896, 448)  # Example input tensor

x = torch.randn(1, 768, 28, 14)
print("Input shape:", x.shape)

MDC_Block_1 = MDCBlock(768,768,3,5,7)
output = MDC_Block_1(x)
print("#---------------------")
print("Output shape:", output.shape)  # Output shape will depend on the input size and parameters of the block


Input shape: torch.Size([1, 768, 28, 14])
x shape  torch.Size([1, 768, 28, 14])
torch.Size([1, 256, 28, 14])
torch.Size([1, 256, 28, 14])
torch.Size([1, 256, 28, 14])
#---------------------
Output shape: torch.Size([1, 768, 28, 14])
