In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [0]:
bs = 16

## Layers

In [0]:
def conv1d(ni, nf, k=3, s=2, p=1): return nn.Conv1d(ni, nf, kernel_size=k, stride=s, padding=p)

In [0]:
def conv2d(ni, nf, k=3, s=2, p=1): return nn.Conv2d(ni, nf, kernel_size=k, stride=s, padding=p)

In [0]:
def upsample(ni, nf, k=3, s=2, p=1, op=0): return nn.ConvTranspose2d(ni, nf, kernel_size=k, stride=s, padding=p, output_padding=op)

In [0]:
class Flatten(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x

In [0]:
class ResnetBlock(nn.Module):
    def __init__(self, nf):
        super().__init__()
        self.conv1 = conv2d(nf, nf, s=1)
        self.batchnorm = nn.BatchNorm2d(nf)
        self.relu = nn.ReLU(True)
        self.conv2 = conv2d(nf, nf, s=1)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.batchnorm(out)
        out = self.relu(out)
        out = self.conv2(out)
        return x + out

In [0]:
def conv_res(ni, nf): return nn.Sequential(conv2d(ni, nf), ResnetBlock(nf))

In [0]:
def up_res(ni, nf): return nn.Sequential(upsample(ni, nf, op=1), ResnetBlock(nf))

In [0]:
def create_encoder(image_size=64, latent_dim=4):
    channels = [1, 4, 8, 16, 32]
    layers = []
    layers.append(conv_res(channels[0], channels[1]))  # (bs, 32, 32)
    layers.append(conv_res(channels[1], channels[2]))  # (bs, 16, 16)
    layers.append(conv_res(channels[2], channels[3]))  # (bs, 8, 8)
    layers.append(conv_res(channels[3], channels[4]))  # (bs, 4, 4)
    return nn.Sequential(*layers)

In [0]:
def create_decoder(image_size=64, latent_dim=4):
    channels = [32, 16, 8, 4, 1]
    layers = []
    # use upsampling
    layers.append(up_res(channels[0], channels[1])) # (bs, 16, 16)
    layers.append(up_res(channels[1], channels[2])) # (bs, 8, 8)
    layers.append(up_res(channels[2], channels[3])) # (bs, 4, 4)
    layers.append(up_res(channels[3], channels[4])) # (bs, 1, 1)
    return nn.Sequential(*layers)

## Playground

#### Conv1D

In [0]:
# conv 1
input = torch.randn(1, 32, 32)
output = conv1d(32, 16, k=1, s=2, p=0)(input)
print(output.shape)

torch.Size([1, 16, 16])


#### Conv2D

In [0]:
# conv
input = torch.randn(bs, 1, 64, 64)
output = conv2d(1, 1, k=3, s=2)(input)
print(output.shape)
# conv
input = torch.randn(bs, 1, 64, 64)
output = conv2d(1, 1, k=3, s=1)(input)
print(output.shape)
# conv
input = torch.randn(bs, 1, 32, 32)
output = conv2d(1, 8, k=3, s=2)(input)
print(output.shape)
# conv
input = torch.randn(bs, 8, 16, 16)
output = conv2d(8, 16, k=3, s=2)(input)
print(output.shape)

torch.Size([16, 1, 32, 32])
torch.Size([16, 1, 64, 64])
torch.Size([16, 8, 16, 16])
torch.Size([16, 16, 8, 8])


#### ResNet

In [0]:
input = torch.randn(bs, 1, 64, 64)
output = ResnetBlock(1)(input)
print(output.shape)

torch.Size([16, 1, 64, 64])


#### Encoder

In [0]:
input = torch.randn(bs, 1, 64, 64)
output = create_encoder()(input)
print(output.shape)

torch.Size([16, 512])


#### Flatten

In [0]:
# conv 1
input = torch.randn(bs, 1, 32, 32)
output = Flatten()(input)
print(output.shape)

torch.Size([16, 1024])


#### Upsample

In [0]:
input = torch.randn(bs, 64, 32, 32)
output = upsample(64, 32)(input, output_size=(64, 64))
print(output.shape)

torch.Size([16, 32, 64, 64])


In [0]:
input = torch.randn(bs, 64, 32, 32)
output = upsample(64, 32, op=1)(input)
print(output.shape)

torch.Size([16, 32, 64, 64])


#### Reshape

In [0]:
torch.randn(bs, 16).reshape((bs, 4, 4)).shape

torch.Size([16, 4, 4])

In [0]:
input = torch.randn(bs, 16)
input.view(bs, 4, 4).shape

torch.Size([16, 4, 4])

#### Decoder

In [0]:
input = torch.randn(bs, 32, 4, 4)
output = create_decoder()(input)
print(output.shape)

torch.Size([16, 1, 64, 64])
