In [1]:
import torch
import torch.nn as nn

In [2]:
class DownSampling(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(DownSampling, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride),
            nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1),
            nn.ReLU()
        )
        
    def forward(self, x):
        x = self.net(x)
        return x

class UpSampling(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(UpSampling, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride),
            nn.BatchNorm2d(out_channels, eps=1e-05, momentum=0.1),
            nn.ReLU()
        )
        
    def forward(self, x):
        x = self.net(x)
        return x

In [3]:
scale = (2, 2)
stride = (2, 2)
g = 32
F = 1024
T = 256

In [4]:
ds = DownSampling(g, 2*g, scale, stride)
input = torch.rand(1, g, F, T)
output = ds(input)
print(" g x F   x T   :", input.shape)
print("processing DownSampling")

print("2g x F/2 x T/2 :", output.shape)

 g x F   x T   : torch.Size([1, 32, 1024, 256])
2g x F/2 x T/2 : torch.Size([1, 64, 512, 128])


In [5]:
us = UpSampling(2*g, g, scale, stride)
input = torch.rand(1, 2*g, F//2, T//2)
output = us(input)
print("2g x F/2 x T/2 :", input.shape)
print("processing UpSampling")
print(" g x F   x T   :", output.shape)

2g x F/2 x T/2 : torch.Size([1, 64, 512, 128])
 g x F   x T   : torch.Size([1, 32, 1024, 256])
