In [8]:
import torch
import torch.nn as nn

# Define a simple model with 1x1 convolution
class OneByOneConvModel(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OneByOneConvModel, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        x = self.conv1x1(x)
        return x

# Example usage:
input_shape = (3, 32, 32)  # Input shape (channels, height, width)
in_channels = input_shape[0]
out_channels = 16

# Create an input tensor
input_tensor = torch.randn(1, *input_shape)  # Batch size 1

# Instantiate the model
model = OneByOneConvModel(in_channels, out_channels)

# Forward pass
output_tensor = model(input_tensor)
print("Output shape:", output_tensor.shape)


Output shape: torch.Size([1, 16, 32, 32])


In [9]:
class middle_stem(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(middle_stem, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.deconv = nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        

    def forward(self, x):
        x = self.conv1x1(x)
        x = self.deconv(x)
        x = self.batchnorm(x)
        x = self.relu(x)
        return x

In [10]:
input_tensor = torch.randn( 1, 1024, 56, 28)

middle_ = middle_stem(1024, 512)
out = middle_.forward(input_tensor)
print(out.shape)

torch.Size([1, 512, 224, 112])


In [11]:
# stack layers according to the # custom hyper parameter 

class middle_stem(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=1):
        super(middle_stem, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
        # Stack self.deconv layers
        self.deconv_layers = nn.ModuleList([
            nn.ConvTranspose2d(out_channels, out_channels, kernel_size=4, stride=2, padding=1)
            for _ in range(num_layers)
        ])
        
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        

    def forward(self, x):
        x = self.conv1x1(x)
        
        # Forward pass through each deconv layer in the ModuleList
        for deconv_layer in self.deconv_layers:
            x = deconv_layer(x)
            x = self.batchnorm(x)
            x = self.relu(x)
        
        return x

In [16]:
input_tensor = torch.randn( 1, 1024, 56, 28)

middle_ = middle_stem(1024, 512, 1)
out = middle_.forward(input_tensor)
print(out.shape)

torch.Size([1, 512, 112, 56])
