In [12]:
import os 
import numpy as np
import matplotlib.pyplot as plt
import torch 
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
from torchvision import models, datasets, transforms
import torch.optim as optim 
import torch.functional as F
from torch.utils.tensorboard import SummaryWriter

In [17]:
class Discriminator(nn.Module):
    
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        # Did not use BatchNorm in the last layer of the generator and the first layer of the 
        # discriminator
        # Input: N x channels_img x 64 x 64
        self.discriminator = nn.Sequential(
            nn.Conv2d(channels_img, features_d, kernel_size = 4, stride = 2, padding = 1), #32x32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1),# 16x16
            self._block(features_d*2, features_d*4, 4, 2, 1), #8x8
            self._block(features_d*4, features_d*8, 4, 2, 1), #4x4
            nn.Conv2d(features_d*8, 1, kernel_size = 4, stride = 2, padding = 0), #1x1(Prediction)
            nn.Sigmoid()
        )
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels,kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
    
    def forward(self,x):
        return self.discriminator(x)
    
    def _init_weight(self):
        # Is an iterator over all the modules
        for m in self.modules():
            if isinstance(m,(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
                nn.init.normal_(m.weight.data, 0.0, 0.02)
                
class Generator(nn.Module):
    
    # Here channels_img is nothing but the inputs channels 
    # and features_g is nothing but the output channels
    def __init__(self, z_dim, channels_img, features_g):
        
        super(Generator, self).__init__()
        
        self.net = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16(1024=64*64) x 4 x 4
            self._block(features_g*16, features_g*8, 4, 2, 1), # f_g*16 x f_g*8 x 8 x 8
            self._block(features_g*8, features_g*4, 4, 2, 1), # f_g*8 x f_g*4 x 16 x 16
            self._block(features_g*4, features_g*2, 4, 2, 1), #32 x 32
            nn.ConvTranspose2d(features_g*2, channels_img, 4, 2, 1),
            nn.Tanh() #[-1,1]
        )
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self,x):
        return self.net(x)
    
    def _init_weight(self):    
        for m in self.modules():
            if isinstance(m,(nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            
# def test():
#     N, in_channels, H, W = 8, 3, 64, 64
#     z_dim = 100
#     X = torch.randn((N, in_channels, H, W))
#     disc = Discriminator(in_channels,8)
#     assert disc(X).shape == (N, 1, 1, 1) # One Value per example
#     gen = Generator(z_dim, in_channels, 64)
#     z = torch.randn((N, z_dim, 1, 1))
#     assert gen(z).shape == (N, in_channels, H, W) # Ouput Generated image
#     print("Success")
# test()

Success


In [19]:
## The training Setup

device = torch.device("torch" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4
BATCH_SIZE = 128
IMAGE_SIZE = 64


0.0001
