In [2]:
import torch
import torch.nn as nn

class CustomBlock(nn.Module):
    def __init__(self, in_channels=3, out_channels=512):
        '''        
        # Example usage:
        in_channels = 3
        out_channels = 512
        '''
        super(CustomBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.activation1 = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.activation2 = nn.ReLU(inplace=True)
        
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm3 = nn.BatchNorm2d(out_channels)
        self.activation3 = nn.ReLU(inplace=True)
        
        self.conv4 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.norm4 = nn.BatchNorm2d(out_channels)
        self.activation4 = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.activation1(x)
        
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activation2(x)
        
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.activation3(x)
        
        x = self.conv4(x)
        x = self.norm4(x)
        x = self.activation4(x)
        
        return x

# Example usage:
in_channels = 3
out_channels = 512

# Assuming input tensor x has shape (batch_size, in_channels, height, width)
x = torch.randn(1, 3, 896, 448)  # Example input tensor


#total_size = in_channels * height_ * width_ 
#output_shape = (in_channels/2, height_/2 , width_/2)
#print(total_size * 1/height_ * 1/width_)


custom_block = CustomBlock(in_channels, out_channels)
output = custom_block(x)
print("Output shape:", output.shape)  # Expected output shape: (batch_size, out_channels, 16, 16) due to the stride of 2 in the first convolution layer


Output shape: torch.Size([1, 512, 448, 224])
