In [8]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from tqdm import tqdm

In [9]:
epochs      = 100
batch_size  = 64
sample_size = 100    # Number of random values to sample
g_lr        = 1.0e-4 # Generator's learning rate
d_lr        = 1.0e-4 # Discriminator's learning rate

In [10]:
transform = transforms.ToTensor()
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, drop_last=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [11]:
# Generator Network
class Generator(nn.Sequential):
    def __init__(self, sample_size: int):
        super().__init__(
            nn.Linear(sample_size, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 784),
            nn.Sigmoid()
        )

        # Random value vector size
        self.sample_size = sample_size

    def forward(self, batch_size: int):
        # Generate randon values
        z = torch.randn(batch_size, self.sample_size)

        # Generator output
        output = super().forward(z)

        # Convert the output into a greyscale image (1x28x28)
        generated_images = output.reshape(batch_size, 1, 28, 28)
        return generated_images


# Discriminator Network
class Discriminator(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
        )

    def forward(self, images: torch.Tensor, targets: torch.Tensor):
        prediction = super().forward(images.reshape(-1, 784))
        loss = F.binary_cross_entropy_with_logits(prediction, targets)
        return loss

In [12]:
def save_image_grid(epoch: int, images: torch.Tensor, ncol: int):
    image_grid = make_grid(images, ncol)     # Images in a grid
    image_grid = image_grid.permute(1, 2, 0) # Move channel last
    image_grid = image_grid.cpu().numpy()    # To Numpy

    plt.imshow(image_grid)
    plt.xticks([])
    plt.yticks([])
    plt.savefig(f'generated_{epoch:03d}.jpg')
    plt.close()


# Real and fake labels
real_targets = torch.ones(batch_size, 1)
fake_targets = torch.zeros(batch_size, 1)


In [13]:
# Generator and Discriminator networks
generator = Generator(sample_size)
discriminator = Discriminator()


# Optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=g_lr)


In [7]:
real_targets = torch.ones(64, 1)
fake_targets = torch.zeros(64, 1)

In [14]:
for epoch in range(epochs):

    d_losses = []
    g_losses = []

    for images, labels in tqdm(dataloader):
        #===============================
        # Discriminator Network Training
        #===============================

        # Loss with MNIST image inputs and real_targets as labels
        discriminator.train()
        d_loss = discriminator(images, real_targets)

        # Generate images in eval mode
        generator.eval()
        with torch.no_grad():
            generated_images = generator(batch_size)

        # Loss with generated image inputs and fake_targets as labels
        d_loss += discriminator(generated_images, fake_targets)

        # Optimizer updates the discriminator parameters
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        #===============================
        # Generator Network Training
        #===============================

        # Generate images in train mode
        generator.train()
        generated_images = generator(batch_size)

        # Loss with generated image inputs and real_targets as labels
        discriminator.eval() # eval but we still need gradients
        g_loss = discriminator(generated_images, real_targets)

        # Optimizer updates the generator parameters
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        # Keep losses for logging
        d_losses.append(d_loss.item())
        g_losses.append(g_loss.item())

    # Print average losses
    print(epoch, np.mean(d_losses), np.mean(g_losses))

    # Save images
    save_image_grid(epoch, generator(batch_size), ncol=8)

100%|██████████| 937/937 [00:24<00:00, 38.18it/s]


0 0.3941705447314516 2.8679530865481975


100%|██████████| 937/937 [00:18<00:00, 50.93it/s]


1 0.291881683033075 2.297138937636653


100%|██████████| 937/937 [00:18<00:00, 51.29it/s]


2 0.479583153030025 1.5103637691241318


100%|██████████| 937/937 [00:17<00:00, 52.29it/s]


3 0.4501350437437076 1.3760129184071512


100%|██████████| 937/937 [00:25<00:00, 37.25it/s]


4 0.3015219208047891 1.9597076687767003


100%|██████████| 937/937 [00:20<00:00, 46.52it/s]


5 0.4188115142898091 2.0971147498809666


100%|██████████| 937/937 [00:18<00:00, 50.53it/s]


6 0.5727681798767 2.0034566370883646


100%|██████████| 937/937 [00:19<00:00, 47.97it/s]


7 0.6132460226689485 1.9189736031417277


100%|██████████| 937/937 [00:18<00:00, 50.86it/s]


8 0.5210979998811459 1.9498007599037765


100%|██████████| 937/937 [00:25<00:00, 36.83it/s]


9 0.5243897172786383 1.9591540747034892


100%|██████████| 937/937 [00:24<00:00, 38.54it/s]


10 0.5697235787245609 1.9062810028120956


100%|██████████| 937/937 [00:20<00:00, 45.04it/s]


11 0.5322815350941685 1.9341877053870462


100%|██████████| 937/937 [00:18<00:00, 50.79it/s]


12 0.5398858386780689 1.9581929450355702


100%|██████████| 937/937 [00:18<00:00, 51.73it/s]


13 0.6119047231264532 1.8258598793532577


100%|██████████| 937/937 [00:21<00:00, 43.27it/s]


14 0.44428044808839656 2.139992334672201


100%|██████████| 937/937 [00:17<00:00, 53.87it/s]


15 0.4363362607541181 2.235077979852194


100%|██████████| 937/937 [00:16<00:00, 56.73it/s]


16 0.4324555545409181 2.291784570209499


100%|██████████| 937/937 [00:17<00:00, 52.33it/s]


17 0.45224838033557063 2.2820102239749223


100%|██████████| 937/937 [00:21<00:00, 43.51it/s]


18 0.36967774540217735 2.469987015714004


100%|██████████| 937/937 [00:18<00:00, 51.80it/s]


19 0.44147701087539964 2.3953729016803753


100%|██████████| 937/937 [00:17<00:00, 52.16it/s]


20 0.48873715288738306 2.3145834295319645


100%|██████████| 937/937 [00:17<00:00, 52.59it/s]


21 0.44757267870923373 2.2687828084066304


100%|██████████| 937/937 [00:19<00:00, 47.37it/s]


22 0.4624371603051779 2.242186425335252


100%|██████████| 937/937 [00:19<00:00, 48.23it/s]


23 0.4629346132437473 2.2655014355638112


100%|██████████| 937/937 [00:18<00:00, 49.62it/s]


24 0.4168343710034227 2.500656989875191


100%|██████████| 937/937 [00:20<00:00, 45.94it/s]


25 0.43643975418494374 2.431606576689534


100%|██████████| 937/937 [00:20<00:00, 46.85it/s]


26 0.4938416966538292 2.329688063935002


  2%|▏         | 23/937 [00:00<00:20, 45.59it/s]


KeyboardInterrupt: 