# Informations

The number of input channels for the images of each dataset :

- LSUN: 3
- CIFAR10: 3
- MNIST: 1

# Setting environment

In [1]:
import platform

print(f"Python version: {platform.python_version()}")
assert platform.python_version_tuple() >= ("3", "6")

import numpy as np
print(f"Numpy version: {np.__version__}")

import matplotlib
print(f"Matplotlib version: {matplotlib.__version__}")
import matplotlib.pyplot as plt
import seaborn as sns

import torch
print(f"PyTorch version: {torch.__version__}")
import torch.nn as nn
from torchvision.utils import save_image
from torchvision.datasets import MNIST
from torchvision.transforms.transforms import Compose, Normalize, ToTensor, Resize
from torch.utils.data import DataLoader

# from torch.utils.tensorboard import SummaryWriter

Python version: 3.8.9
Numpy version: 1.21.5
Matplotlib version: 3.5.1
PyTorch version: 1.10.1


In [2]:

%matplotlib inline
plt.rcParams["figure.figsize"] = 10, 8
%config InlineBackend.figure_format = 'retina'
sns.set()
%load_ext tensorboard

# Loading the MNIST dataset

In [18]:
!rm -rf "./data"
!mkdir output

In [19]:
workers = 0

In [20]:
def get_data_loader():
    mnist_data = MNIST(
        root='./data',
        download=True,
        transform=Compose([
            Resize(64),
            ToTensor(),
            Normalize((0.5,), (0.5,))])
    )

    return DataLoader(
        mnist_data,
        batch_size=64,
        shuffle=True,
        num_workers=workers,
    )

In [21]:
data_loader = get_data_loader()

print("Dataset size : ", len(data_loader))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Dataset size :  938


# Building and training the models

In [22]:
# custom weights initialization called on generator and discriminator
def initialize_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

In [23]:
nb_gpu = 1
ngf = 64
ndf = 64
z = 100
n_channels = 1
device = torch.device('cuda:0' if (torch.cuda.is_available() and nb_gpu > 0) else 'cpu')

In [30]:
# noinspection PyTypeChecker

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.nb_gpu = nb_gpu
        self.generator = nn.Sequential(
            nn.ConvTranspose2d(z, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf, n_channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        if x.is_cuda and self.nb_gpu > 1:
             return nn.parallel.data_parallel(self.generator, x, range(self.nb_gpu))
        return self.generator(x)

In [31]:
# noinspection PyTypeChecker

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.nb_gpu = nb_gpu
        self.discriminator = nn.Sequential(
            nn.Conv2d(n_channels, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        if x.is_cuda and self.nb_gpu > 1:
            return nn.parallel.data_parallel(self.discriminator, x, range(self.nb_gpu)).view(-1, 1).squeeze(1)
        return self.discriminator(x).view(-1, 1).squeeze(1)

In [32]:
generator = Generator().to(device).apply(initialize_weights)
print(generator)

Generator(
  (generator): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
 

In [33]:
discriminator = Discriminator().to(device).apply(initialize_weights)
print(discriminator)

Discriminator(
  (discriminator): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)


In [34]:
generator_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

batch_size = 32
epochs = 1

criterion = nn.BCELoss()
fixed_noise = torch.randn(batch_size, z, 1, 1, device=device)

real_label = 1
fake_label = 0

In [38]:
for epoch in range(epochs):
    for i, data in enumerate(data_loader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        discriminator.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label,
                           dtype=real_cpu.dtype, device=device)

        output = discriminator(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, z, 1, 1, device=device)
        fake = generator(noise)
        label.fill_(fake_label)
        output = discriminator(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        discriminator_error = errD_real + errD_fake
        discriminator_optimizer.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        generator.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = discriminator(fake)
        generator_error = criterion(output, label)
        generator_error.backward()
        D_G_z2 = output.mean().item()
        generator_optimizer.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, epochs, i, len(data_loader),
                 discriminator_error.item(), generator_error.item(), D_x, D_G_z1, D_G_z2))
        if i % 100 == 0:
            save_image(real_cpu,
                    "output/real_samples.png",
                    normalize=True)
            fake = generator(fixed_noise)
            save_image(fake.detach(),
                    "output/fake_samples_epoch_%03d.png",
                    normalize=True)
    # do checkpointing
    torch.save(generator.state_dict(), f"output/generator_epoch_{epoch}.pth")
    torch.save(discriminator.state_dict(), f"output/discriminator_epoch_{epoch}.pth")

[0/1][0/938] Loss_D: 1.2433 Loss_G: 6.2348 D(x): 0.9872 D(G(z)): 0.6375 / 0.0031
[0/1][1/938] Loss_D: 0.4794 Loss_G: 6.9927 D(x): 0.9632 D(G(z)): 0.2979 / 0.0015


KeyboardInterrupt: 