In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np

In [2]:
from torchsummary import summary

In [3]:
resnet_leaky_relu = 0.01

In [4]:
class ResNetBlock(nn.Module):
    def __init__(self, batch_size="pog", channels=256, alpha=0.01):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, stride=1, padding=1)
        
        self.bnorm1 = nn.BatchNorm2d(channels)
        self.bnorm2 = nn.BatchNorm2d(channels)
        
        self.lrelu1 = nn.LeakyReLU(alpha)
        
    def forward(self, _input):
        x = self.conv1(_input)
        # todo: try relu before batch norm
        x = self.bnorm1(x)
        x = self.lrelu1(x)
        x = self.conv2(x)
        res = self.bnorm2(x)
        
        return res

In [5]:
class Generator(nn.Module):
    def __init__(self, batch_size="pog", channels=3, img_height=256, img_width=256):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=1, padding=3)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        
        self.bnorm1 = nn.BatchNorm2d(64)
        self.bnorm2 = nn.BatchNorm2d(128)
        self.bnorm3 = nn.BatchNorm2d(256)
        
        self.rblock1 = ResNetBlock(batch_size, 256)
        self.rblock2 = ResNetBlock(batch_size, 256)
        self.rblock3 = ResNetBlock(batch_size, 256)
        self.rblock4 = ResNetBlock(batch_size, 256)
        self.rblock5 = ResNetBlock(batch_size, 256)
        self.rblock6 = ResNetBlock(batch_size, 256)
        self.rblock7 = ResNetBlock(batch_size, 256)
        self.rblock8 = ResNetBlock(batch_size, 256)
        self.rblock9 = ResNetBlock(batch_size, 256)
        
        self.conv_trans1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv_trans2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.final_conv = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, stride=1, padding=3)
        
        self.bnorm4 = nn.BatchNorm2d(128)
        self.bnorm5 = nn.BatchNorm2d(64)
        
    def forward(self, _input):
        x = self.conv1(_input)
        x = F.leaky_relu(self.bnorm1(x), negative_slope=0.01)
        
        x = self.conv2(x)
        x = F.leaky_relu(self.bnorm2(x), negative_slope=0.01)
        
        x = self.conv3(x)
        x = F.leaky_relu(self.bnorm3(x), negative_slope=0.01)
        
        x = self.rblock1(x)
        x = self.rblock2(x)
        x = self.rblock3(x)
        x = self.rblock4(x)
        x = self.rblock5(x)
        x = self.rblock6(x)
        x = self.rblock7(x)
        x = self.rblock8(x)
        x = self.rblock9(x)
        
        x = self.conv_trans1(x)
        x = F.leaky_relu(self.bnorm4(x), negative_slope=0.01)
        
        x = self.conv_trans2(x)
        x = F.leaky_relu(self.bnorm5(x), negative_slope=0.01)
        
        x = self.final_conv(x)
        
        res = F.tanh(x)
        
        return res


In [6]:
sanity_check_generator_model = Generator()
summary(sanity_check_generator_model, input_size=(3, 256, 256))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           9,472
       BatchNorm2d-2         [-1, 64, 256, 256]             128
            Conv2d-3        [-1, 128, 128, 128]          73,856
       BatchNorm2d-4        [-1, 128, 128, 128]             256
            Conv2d-5          [-1, 256, 64, 64]         295,168
       BatchNorm2d-6          [-1, 256, 64, 64]             512
            Conv2d-7          [-1, 256, 64, 64]         590,080
       BatchNorm2d-8          [-1, 256, 64, 64]             512
         LeakyReLU-9          [-1, 256, 64, 64]               0
           Conv2d-10          [-1, 256, 64, 64]         590,080
      BatchNorm2d-11          [-1, 256, 64, 64]             512
      ResNetBlock-12          [-1, 256, 64, 64]               0
           Conv2d-13          [-1, 256, 64, 64]         590,080
      BatchNorm2d-14          [-1, 256,

In [7]:
class Discriminator(nn.Module):
    """
    pix2pix discriminator
    """
    def __init__(self, batch_size="pog", channels=3, img_height=256, img_width=256):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1)
        
        self.bnorm1 = nn.BatchNorm2d(64)
        self.bnorm2 = nn.BatchNorm2d(128)
        self.bnorm3 = nn.BatchNorm2d(256)
        self.bnorm4 = nn.BatchNorm2d(512)
        
    def forward(self, _input):
        x = self.conv1(_input)
        x = F.leaky_relu(self.bnorm1(x), negative_slope=0.01)
        
        x = self.conv2(x)
        x = F.leaky_relu(self.bnorm2(x), negative_slope=0.01)
        
        x = self.conv3(x)
        x = F.leaky_relu(self.bnorm3(x), negative_slope=0.01)
        
        x = self.conv4(x)
        x = F.leaky_relu(self.bnorm4(x), negative_slope=0.01)
        
        res = self.conv5(x)
        return res

In [8]:
sanity_check_discriminator_model = Discriminator()
summary(sanity_check_discriminator_model, input_size=(3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           3,136
       BatchNorm2d-2         [-1, 64, 128, 128]             128
            Conv2d-3          [-1, 128, 64, 64]         131,200
       BatchNorm2d-4          [-1, 128, 64, 64]             256
            Conv2d-5          [-1, 256, 32, 32]         524,544
       BatchNorm2d-6          [-1, 256, 32, 32]             512
            Conv2d-7          [-1, 512, 31, 31]       2,097,664
       BatchNorm2d-8          [-1, 512, 31, 31]           1,024
            Conv2d-9            [-1, 1, 30, 30]           8,193
Total params: 2,766,657
Trainable params: 2,766,657
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.75
Forward/backward pass size (MB): 35.51
Params size (MB): 10.55
Estimated Total Size (MB): 46.82
------------------------------------