In [6]:
!pip install natsort

Collecting natsort
  Downloading natsort-8.1.0-py3-none-any.whl (37 kB)
Installing collected packages: natsort
Successfully installed natsort-8.1.0
[0m

In [31]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import math
import itertools
import imageio
import natsort
from glob import glob


import logging
import time
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image

In [32]:
def get_data_loader(batch_size):
    # MNIST Dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))])

    train_dataset = datasets.MNIST(root='.', train=True, transform=transform, download=True)

    # Data Loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    return train_loader

In [49]:
def generate_images(epoch, path, fixed_noise, num_test_samples, netG, device, use_fixed=False):
    z = torch.randn(num_test_samples, 100, 1, 1, device=device)
    size_figure_grid = int(math.sqrt(num_test_samples))
    title = None
  
    if use_fixed:
        generated_fake_images = netG(fixed_noise)
        path += 'fixed_noise/'
        title = 'Fixed Noise'
    else:
        generated_fake_images = netG(z)
        path += 'variable_noise/'
        title = 'Variable Noise'
  
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6,6))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i,j].get_xaxis().set_visible(False)
        ax[i,j].get_yaxis().set_visible(False)
    for k in range(num_test_samples):
        i = k//4
        j = k%4
        ax[i,j].cla()
        ax[i,j].imshow(generated_fake_images[k].data.cpu().numpy().reshape(28,28), cmap='Greys')
    label = 'Epoch_{}'.format(epoch+1)
    fig.text(0.5, 0.04, label, ha='center')
    fig.suptitle(title)
    fig.savefig(path+label+'.png')
    
def save_gif(path, fps, fixed_noise=False):
    if fixed_noise==True:
        path += 'fixed_noise/'
    else:
        path += 'variable_noise/'
    images = glob(path + '*.png')
    images = natsort.natsorted(images)
    gif = []

    for image in images:
        gif.append(imageio.imread(image))
    imageio.mimsave(path+'animated.gif', gif, fps=fps)

In [41]:
class Generator(nn.Module):
    def __init__(self, nc, nz, ngf):
      super(Generator, self).__init__()
      self.network = nn.Sequential(
          nn.ConvTranspose2d(nz, ngf*4, 4, 1, 0, bias=False),
          nn.BatchNorm2d(ngf*4),
          nn.ReLU(True),
  
          nn.ConvTranspose2d(ngf*4, ngf*2, 3, 2, 1, bias=False),
          nn.BatchNorm2d(ngf*2),
          nn.ReLU(True),
  
          nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
          nn.BatchNorm2d(ngf),
          nn.ReLU(True),
  
          nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
          nn.Tanh()
      )
  
    def forward(self, input):
      output = self.network(input)
      return output

In [42]:
class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
                
                nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 2),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True),
                
                nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
                nn.Sigmoid()
            )
    def forward(self, input):
        output = self.network(input)
        return output.view(-1, 1).squeeze(1)

# TRAINING CONFIG

In [51]:
batch_size = 128
num_epochs = 100
ndf, ngf = 32, 32  # num disc features / num gen features
nz = 100 #noise size
d_lr = 0.0002
g_lr = 0.0002
nc = 1 #num of channels

num_test_samples = 16
output_path = '/notebooks/results/'
fps = 5

In [52]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [53]:
train_loader = get_data_loader(batch_size)

netG = Generator(nc, nz, ngf).to(device)
netD = Discriminator(nc, ndf).to(device)

criterion = nn.BCELoss()

optimizerD = optim.Adam(netD.parameters(), lr=d_lr)
optimizerG = optim.Adam(netG.parameters(), lr=g_lr)

In [54]:
# initialize other variables
real_label = 1
fake_label = 0
num_batches = len(train_loader)
fixed_noise = torch.randn(num_test_samples, 100, 1, 1, device=device)


In [None]:
for epoch in range(num_epochs):
    for i, (real_images, _) in enumerate(train_loader):
        bs = real_images.shape[0]
        #training discriminator

        netD.zero_grad()
        real_images = real_images.to(device)
        label = torch.full((bs,), real_label, device=device)

        output = netD(real_images)
        lossD_real = criterion(output.to(torch.float32), label.to(torch.float32))
        lossD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(bs, nz, 1, 1, device=device)
        fake_images = netG(noise)
        label.fill_(fake_label)
        output = netD(fake_images.detach())
        lossD_fake = criterion(output.to(torch.float32), label.to(torch.float32))
        lossD_fake.backward()
        D_G_z1 = output.mean().item()
        lossD = lossD_real + lossD_fake
        optimizerD.step()

        # training generator

        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake_images)
        lossG = criterion(output.to(torch.float32), label.to(torch.float32))
        lossG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if (i+1)%100 == 0:
            print('Epoch [{}/{}], step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, Discriminator - D(G(x)): {:.2f}, Generator - D(G(x)): {:.2f}'.format(epoch+1, num_epochs, 
                                                        i+1, num_batches, lossD.item(), lossG.item(), D_x, D_G_z1, D_G_z2))
    netG.eval()
    generate_images(epoch, output_path, fixed_noise, num_test_samples, netG, device, use_fixed=True)
    netG.train()

    # Save gif:
    save_gif(output_path, fps, fixed_noise=True)

Epoch [1/100], step [100/469], d_loss: 0.2074, g_loss: 3.6349, D(x): 0.90, Discriminator - D(G(x)): 0.08, Generator - D(G(x)): 0.03
Epoch [1/100], step [200/469], d_loss: 0.0267, g_loss: 5.6536, D(x): 0.99, Discriminator - D(G(x)): 0.01, Generator - D(G(x)): 0.00
Epoch [1/100], step [300/469], d_loss: 0.0492, g_loss: 5.7374, D(x): 0.98, Discriminator - D(G(x)): 0.03, Generator - D(G(x)): 0.00
Epoch [1/100], step [400/469], d_loss: 0.0171, g_loss: 6.2142, D(x): 0.99, Discriminator - D(G(x)): 0.01, Generator - D(G(x)): 0.00


  gif.append(imageio.imread(image))


Epoch [2/100], step [100/469], d_loss: 0.0148, g_loss: 6.3939, D(x): 0.99, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [2/100], step [200/469], d_loss: 0.0050, g_loss: 6.3683, D(x): 1.00, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [2/100], step [300/469], d_loss: 0.0092, g_loss: 7.1093, D(x): 0.99, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [2/100], step [400/469], d_loss: 0.0070, g_loss: 7.3391, D(x): 1.00, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [3/100], step [100/469], d_loss: 0.0086, g_loss: 6.3264, D(x): 0.99, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [3/100], step [200/469], d_loss: 0.0053, g_loss: 6.8762, D(x): 1.00, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [3/100], step [300/469], d_loss: 0.0062, g_loss: 6.8208, D(x): 1.00, Discriminator - D(G(x)): 0.01, Generator - D(G(x)): 0.00
Epoch [3/100], step [400/469], d_loss: 0.0033, g_loss: 7.3472, D(x): 1.00, D

  fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6,6))


Epoch [22/100], step [100/469], d_loss: 0.1382, g_loss: 9.0342, D(x): 0.90, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [22/100], step [200/469], d_loss: 0.0098, g_loss: 7.0442, D(x): 1.00, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [22/100], step [300/469], d_loss: 0.0039, g_loss: 7.8708, D(x): 1.00, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [22/100], step [400/469], d_loss: 0.0173, g_loss: 7.7114, D(x): 0.99, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [23/100], step [100/469], d_loss: 0.0608, g_loss: 6.8116, D(x): 0.95, Discriminator - D(G(x)): 0.00, Generator - D(G(x)): 0.00
Epoch [23/100], step [200/469], d_loss: 0.0099, g_loss: 7.5996, D(x): 1.00, Discriminator - D(G(x)): 0.01, Generator - D(G(x)): 0.00
Epoch [23/100], step [300/469], d_loss: 0.0191, g_loss: 6.2982, D(x): 0.99, Discriminator - D(G(x)): 0.01, Generator - D(G(x)): 0.01
Epoch [23/100], step [400/469], d_loss: 0.0074, g_loss: 7.5107, D(x):