In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

In [32]:
# Deconvolutional block

def deconv_block(in_channels, out_channels, kernel_size, stride,padding=1):
    output_padding = 1 if stride == 2 else 0
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, output_padding),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.02) if out_channels != 3 else nn.Sigmoid()

    )

def conv_block(in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.MaxPool2d(2, 2) if out_channels < 512 else nn.Identity(),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.02)
    )

In [33]:
random_noise = torch.randn(1, 3, 128, 128)

block = conv_block(in_channels=3, out_channels=32, kernel_size=(5,5), stride=1, padding=2)
output = block(random_noise)
print(output.shape)

block = conv_block(in_channels=32, out_channels=64, kernel_size=(5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = conv_block(in_channels=64, out_channels=128, kernel_size=(5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = conv_block(in_channels=128, out_channels=256, kernel_size=(5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = conv_block(in_channels=256, out_channels=512, kernel_size=(5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = conv_block(in_channels=512, out_channels=1024, kernel_size=(5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = conv_block(in_channels=1024, out_channels=2048, kernel_size=(5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = deconv_block(in_channels=2048, out_channels=1024, kernel_size=(5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = deconv_block(in_channels=1024, out_channels=512, kernel_size=(5,5), stride=1, padding=2)
output = block(output)
print(output.shape)


torch.Size([1, 32, 64, 64])
torch.Size([1, 64, 32, 32])
torch.Size([1, 128, 16, 16])
torch.Size([1, 256, 8, 8])
torch.Size([1, 512, 8, 8])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 2048, 8, 8])
torch.Size([1, 1024, 8, 8])
torch.Size([1, 512, 8, 8])


In [34]:
random_noise = torch.randn(32, 512, 8, 8)

block = deconv_block(512, 1024, (5,5), stride=1, padding=2)
output = block(random_noise)
print(output.shape)

block = deconv_block(1024, 2048, (5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = deconv_block(2048, 1024, (5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = deconv_block(1024, 512, (5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block = deconv_block(512, 256, (5,5), stride=1, padding=2)
output = block(output)
print(output.shape)

block_2 = deconv_block(256, 128, (5,5), stride=2, padding=2)

output_2 = block_2(output)

print(output_2.shape)

block_3 = deconv_block(128, 64, (5,5), stride=2, padding=2)

output_3 = block_3(output_2)

print(output_3.shape)

block_4 = deconv_block(64, 3, (5,5), stride=2, padding=2)

output_4 = block_4(output_3)

print(output_4.shape)

block_5 = deconv_block(3, 3, (5,5), stride=2, padding=2)

output_5 = block_5(output_4)

print(output_5.shape)

torch.Size([32, 1024, 8, 8])
torch.Size([32, 2048, 8, 8])
torch.Size([32, 1024, 8, 8])
torch.Size([32, 512, 8, 8])
torch.Size([32, 256, 8, 8])
torch.Size([32, 128, 16, 16])
torch.Size([32, 64, 32, 32])
torch.Size([32, 3, 64, 64])
torch.Size([32, 3, 128, 128])


In [11]:
random_noise = torch.randn(32, 512, 8, 8)
block = deconv_block(512, 512, (5,5), stride=1, padding=2)

output = block(random_noise)
print(output.shape)



torch.Size([32, 512, 8, 8])


In [36]:
10000 / 64

156.25