# GANs for MNIST

In this notebook, we fit vanilla GANs to the MNIST dataset.

## 1. Environment

In [3]:
#@markdown Let's start by installing a couple of dependencies.

!pip -q install torch torchvision lightning pandas seaborn git+https://github.com/oelin/valkyrie

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [14]:
#@markdown Let's import everything we need.

from typing import Any, Callable, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import torchvision
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Lambda, ToPILImage, PILToTensor, ToTensor, Normalize

import lightning as L

import valkyrie as vk

In [5]:
#@markdown Let's initailize the MNIST dataset.

transform = Compose([
    ToTensor(),
    Normalize(mean=0.5, std=0.5),
])

train_dataset = MNIST(root='.', train=True, transform=transform, download=True)
test_dataset = MNIST(root='.', train=True, transform=transform, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 31142458.29it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 39873500.27it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 39820192.52it/s]


Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 20752209.99it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



## 2. Models

In [6]:
#@markdown

# Quick error fix

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP

# On Windows platform, the torch.distributed package only
# supports Gloo backend, FileStore and TcpStore.
# For FileStore, set init_method parameter in init_process_group
# to a local file. Example as follow:
# init_method="file:///f:/libtmp/some_file"
# dist.init_process_group(
#    "gloo",
#    rank=rank,
#    init_method=init_method,
#    world_size=world_size)
# For TcpStore, same way as on Linux.

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

setup(0, 1)

In [7]:
#@markdown

import lightning as L


#fabric = L.Fabric(accelerator="cuda", devices=8, strategy="ddp")

fabric = L.Fabric(accelerator="cpu", devices=10, strategy="ddp_notebook")
fabric.launch()

In [20]:
#@markdown The generator.

latent_dimension = 64
input_dimension = 28 * 28

generator = nn.Sequential(
    nn.Linear(latent_dimension, 256),
    nn.ReLU(),
    nn.Linear(256, 512),
    nn.ReLU(),
    nn.Linear(512, input_dimension),
    nn.Tanh(),
).cuda()

In [21]:
#@markdown The discriminator.

discriminator = nn.Sequential(
    nn.Linear(input_dimension, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 1),
    nn.Sigmoid(),
).cuda()

In [22]:
def unnormalize_image(x):
    return ((x + 1) / 2).clamp(0, 1)

In [23]:
#@markdown Train them.

batch_size = 64
epochs = 300
batches = len(train_dataset) // batch_size

train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

criterion = nn.BCELoss()
discriminator_optimizer = Adam(discriminator.parameters(), lr=2e-4)
generator_optimizer = Adam(generator.parameters(), lr=2e-4)


def generate_fake(batch_size) -> torch.Tensor:
    z = torch.randn(batch_size, latent_dimension).cuda()
    x = generator(z)

    return x


discriminator_losses = []
generator_losses = []

for epoch in range(epochs):
    for i, (inputs_real, _) in enumerate(train_dataloader):

        # Train the discriminator.

        if inputs_real.size(0) != batch_size:
            continue

        inputs_real = inputs_real.view(batch_size, -1).cuda()
        inputs_fake = generate_fake(batch_size).detach().cuda()
        labels_real = torch.ones(batch_size, 1).cuda()
        labels_fake = torch.zeros(batch_size, 1).cuda()

        discriminator_optimizer.zero_grad()
        discriminator_loss_real = criterion(discriminator(inputs_real), labels_real)
        discriminator_loss_fake = criterion(discriminator(inputs_fake), labels_fake)
        discriminator_loss = discriminator_loss_real + discriminator_loss_fake
        discriminator_loss.backward()
        discriminator_optimizer.step()

        # Train the generator.

        inputs_fake = generate_fake(batch_size).cuda()

        generator_optimizer.zero_grad()
        generator_loss = criterion(discriminator(inputs_fake), labels_real)
        generator_loss.backward()
        generator_optimizer.step()

        # Record losses.

        discriminator_losses.append(discriminator_loss.item())
        generator_losses.append(generator_loss.item())

        if (i % 100) == 0:
            mean_discriminator_loss = sum(discriminator_losses) / len(discriminator_losses)
            mean_generator_loss = sum(generator_losses) / len(generator_losses)

            print(f'Epoch {epoch}/{epochs}, Batch {i}/{batches}, generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}')

    # Save results.

    print('Saving results...')

    images_to_save = inputs_fake.detach().cpu().view(inputs_fake.size(0), 1, 28, 28)
    save_image(unnormalize_image(images_to_save.data), os.path.join('samples', f'epoch-{epoch}.png'))

Epoch 0/300, Batch 0/937, generator loss: 0.7021292448043823, discriminator loss: 1.4290435314178467
Epoch 0/300, Batch 100/937, generator loss: 1.5424737511294904, discriminator loss: 0.7002204756925602
Epoch 0/300, Batch 200/937, generator loss: 2.4499487085128897, discriminator loss: 0.4083510832166049
Epoch 0/300, Batch 300/937, generator loss: 2.780550401670196, discriminator loss: 0.3540844147548426
Epoch 0/300, Batch 400/937, generator loss: 3.31599023588875, discriminator loss: 0.33587620733886436
Epoch 0/300, Batch 500/937, generator loss: 3.5412213979604954, discriminator loss: 0.34543886923504447
Epoch 0/300, Batch 600/937, generator loss: 3.851242894638398, discriminator loss: 0.3241525048734047
Epoch 0/300, Batch 700/937, generator loss: 3.9208222960099346, discriminator loss: 0.3076201768355516
Epoch 0/300, Batch 800/937, generator loss: 3.948273557923111, discriminator loss: 0.3095491712868735
Epoch 0/300, Batch 900/937, generator loss: 3.9190756919381355, discriminator 

In [26]:
torch.save(discriminator, 'vanilla-gan-mnist-discriminator.pt')
torch.save(discriminator, 'vanilla-gan-mnist-generator.pt')