In [18]:
import torch
import torch.nn as nn

class DeconvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DeconvBlock, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.deconv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x

# Example usage:
# Assuming input has shape [batch_size, in_channels, height, width]
in_channels = 64
out_channels = 64
input_tensor = torch.randn(8, in_channels, 16, 16)  # Example input tensor with batch size 8, 64 channels, and spatial dimensions 16x16

deconv_block = DeconvBlock(in_channels, out_channels)
output_tensor = deconv_block(input_tensor)

print(input_tensor.shape)
print(output_tensor.shape)  # Output shape will be [8, 32, 32, 32] since stride 2 doubles spatial dimensions


torch.Size([8, 64, 16, 16])
torch.Size([8, 64, 32, 32])
