In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from matplotlib import pyplot as plt
import torchvision.utils as vutils

import numpy as np

In [2]:
%config InlineBackend.figure_format='retina'
plt.rc('font', size=16)

In [3]:
def output_dimension(i, p, k, s):
    o = (i+2*p-k)/s+1
    return o

## Checkout put dimension after convolution

In [4]:
x = torch.rand(1, 28, 28)  # MNIST data dimension

print('x original dimension :', x.size())
x = x[None, :]
print('expand the x tensor with a dummy axis :', x.size())

x original dimension : torch.Size([1, 28, 28])
expand the x tensor with a dummy axis : torch.Size([1, 1, 28, 28])


In [5]:
cov1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=4, stride=2, padding=3)
print(f'expected output dimension = {output_dimension(x.size()[2], p=3, k=4, s=2)}')
xp = cov1(x)
print(f'output dimension = ', xp.size())

expected output dimension = 16.0
output dimension =  torch.Size([1, 4, 16, 16])


In [6]:
cov2 = nn.Conv2d(in_channels=4, out_channels=8, kernel_size=4, stride=2, padding=1)
print(f'expected output dimension = {output_dimension(xp.size()[2], p=1, k=4, s=2)}')
xp = cov2(xp)
print(f'output dimension = ', xp.size())

expected output dimension = 8.0
output dimension =  torch.Size([1, 8, 8, 8])


In [7]:
cov3 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=4, stride=2, padding=1)
print(f'expected output dimension = {output_dimension(xp.size()[2], p=1, k=4, s=2)}')
xp = cov3(xp)
print(f'output dimension = ', xp.size())

expected output dimension = 4.0
output dimension =  torch.Size([1, 16, 4, 4])


In [8]:
cov4 = nn.Conv2d(in_channels=16, out_channels=64, kernel_size=4, stride=2, padding=0)
print(f'expected output dimension = {output_dimension(xp.size()[2], p=0, k=4, s=2)}')
xp = cov4(xp)
print(f'output dimension = ', xp.size())

expected output dimension = 1.0
output dimension =  torch.Size([1, 64, 1, 1])


In [9]:
# import torch.nn.functional as F
# xp = F.adaptive_avg_pool2d(xp, 1).reshape(1, -1)
# print(f'output dimension = ', xp.size())

pl = nn.AdaptiveAvgPool2d(1)
xp = pl(xp)
xp.size()

torch.Size([1, 64, 1, 1])

In [10]:
flat = nn.Flatten()
xp = flat(xp)
xp.size()

torch.Size([1, 64])

In [11]:
fc1 = nn.Linear(64, 128)
xp = fc1(xp)
xp.size()

torch.Size([1, 128])

In [12]:
fc2 = nn.Linear(128, 20*2)
xp = fc2(xp)
xp.size()

torch.Size([1, 40])

In [13]:
d=20
xp = torch.rand(1, 40)
xp.view(-1, 2, d).size()

torch.Size([1, 2, 20])

## decoder

In [14]:
x = torch.rand(1, 20)
fc3 = nn.Linear(20, 64)
xp = fc3(x)
xp.size()

torch.Size([1, 64])

In [15]:
unflat = nn.Unflatten(1, (64,1,1))
xp = unflat(xp)
xp.size()

torch.Size([1, 64, 1, 1])

In [16]:
covT1 = nn.ConvTranspose2d(64, out_channels=4*8, kernel_size=4, stride=1, padding=0)
xp = covT1(xp)
xp.size()

torch.Size([1, 32, 4, 4])

In [17]:
covT2 = nn.ConvTranspose2d(32, out_channels=16, kernel_size=4, stride=2, padding=1)
xp = covT2(xp)
xp.size()

torch.Size([1, 16, 8, 8])

In [18]:
covT2 = nn.ConvTranspose2d(16, out_channels=8, kernel_size=4, stride=2, padding=1)
xp = covT2(xp)
xp.size()

torch.Size([1, 8, 16, 16])

In [19]:
covT3 = nn.ConvTranspose2d(8, out_channels=1, kernel_size=4, stride=2, padding=3)
xp = covT3(xp)
xp.size()

torch.Size([1, 1, 28, 28])