# Generative Adversarial Network (GAN) Tutorial in PyTorch
---
### By David Reiman (dreiman@ucsc.edu)

<img src="GAN.png">
<div style="text-align: right">(Image credit: <a href="https://medium.freecodecamp.org/an-intuitive-introduction-to-generative-adversarial-networks-gans-7a2264a81394">Thalles Silva</a>)</div>

---

Let's make some imports first.

In [1]:
%matplotlib inline

import os
import math
import warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from torch.utils import data
from torch.autograd import Variable
from torchvision import datasets
from torchvision.utils import save_image
from tqdm import tqdm_notebook as tqdm
from IPython import display

Change the matplotlib style to something a little prettier.

In [2]:
plt.style.use('seaborn')
warnings.filterwarnings('ignore')

Check for a GPU.

In [3]:
cuda = True if torch.cuda.is_available() else False

Specify our generator network architecture.

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.linear = nn.Linear(100, 8*8*32)
        self.conv = nn.Sequential(
            nn.BatchNorm2d(32),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(
                in_channels=32, 
                out_channels=32, 
                kernel_size=3, 
                stride=1, 
                padding=1,
            ),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(32, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(
                in_channels=32, 
                out_channels=16, 
                kernel_size=3, 
                stride=1, 
                padding=1,
            ),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(16, 0.8),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(
                in_channels=16, 
                out_channels=8, 
                kernel_size=3, 
                stride=1, 
                padding=1,
            ),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(8, 0.8),
            nn.Conv2d(
                in_channels=8, 
                out_channels=3, 
                kernel_size=3, 
                stride=1, 
                padding=1,
            ),
            nn.Tanh()
        )

    def forward(self, z):
        z = self.linear(z)
        z = z.view(z.shape[0], 32, 8, 8)
        z = self.conv(z)
        return z

Now our discriminator network architecture.

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(
                in_channels=3, 
                out_channels=32, 
                kernel_size=3, 
                stride=2, 
                padding=1,
            ),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(
                in_channels=32, 
                out_channels=64, 
                kernel_size=3, 
                stride=2, 
                padding=1,
            ),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(
                in_channels=64, 
                out_channels=128, 
                kernel_size=3, 
                stride=2, 
                padding=1,
            ),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(
                in_channels=128, 
                out_channels=256, 
                kernel_size=3, 
                stride=2, 
                padding=1,
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.classifier = nn.Sequential(
            nn.Linear(4*4*256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

Let's define our filepaths and a few hyperparameters.

In [8]:
data_path = './data'
save_path = './images'
epochs = 200
batch_size = 64
lr = 0.0002
sample_interval = 200

If the save path doesn't exist, we'll create it.

In [9]:
if not os.path.isdir(save_path):
    os.makedirs(save_path)

Define data preprocessing steps for data loader.

In [10]:
data_transforms = transforms.Compose([
    transforms.Resize([64, 64]),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]),
])

Define image folder dataset with path and transforms.

In [11]:
dataset = datasets.ImageFolder(
    root=data_path,
    transform=data_transforms,
)

Create a data loader from dataset.

In [12]:
data_loader = data.DataLoader(
    dataset=dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=4,
    pin_memory=cuda,
)

We'll make our generator/discriminator instances.

In [13]:
generator = Generator()
discriminator = Discriminator()

And define our loss function: binary cross entropy.

In [14]:
adversarial_loss = torch.nn.BCELoss()

If PyTorch found a GPU, we'll store our model parameters on the GPU.

In [15]:
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

Now we'll define our optimizers.

In [16]:
opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

Here we define the Tensor function as the cuda/non-cuda version based on if PyTorch found a GPU or not.

In [17]:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

These are the ground truth "labels" for the discriminator.

In [18]:
ones = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False)
zeros = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False)

Now we setup the training loop and get to learning.

In [19]:
for epoch in tqdm(range(epochs)):
    for i, (images, _) in tqdm(enumerate(data_loader), total=len(data_loader), leave=False):

        # Create a latent input for generator — this is a draw from a spherical Gaussian
        z = Variable(Tensor(np.random.normal(0, 1, [batch_size, 100])))
        
        # Make a batch of fake images from latent input
        fake_images = generator(z)
        
        # Get a batch of real images from the data loader
        real_images = Variable(images.type(Tensor))
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        # Clear gradients from previous update steps
        opt_d.zero_grad()

        # Measure discriminator's ability to differentiate real vs. generated samples
        real_loss = adversarial_loss(discriminator(real_images), ones)
        fake_loss = adversarial_loss(discriminator(fake_images.detach()), zeros)
        d_loss = (real_loss + fake_loss) / 2.
        
        # Backpropagate the gradient information and make an update step on d
        d_loss.backward()
        opt_d.step()

        # -----------------
        #  Train Generator
        # -----------------
        
        # Clear gradients from previous update steps
        opt_g.zero_grad()

        # Train the generator to maximize the discriminator's probability estimates on fake images
        g_loss = adversarial_loss(discriminator(fake_images), ones)

        # Backpropagate the gradient information and make an update step on g
        g_loss.backward()
        opt_g.step()
        
        # Save a generated sample every sample_interval batches
        batches_done = epoch * len(data_loader) + i
        if batches_done % sample_interval == 0:
            filename = os.path.join(save_path, '%d.png' % batches_done)
            save_image(fake_images.data[:25], filename, nrow=5, normalize=True)

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))

HBox(children=(IntProgress(value=0, max=156), HTML(value='')))




All done training. Let's check the last sample generated.

<img src="./images/31000.png">

### References 
The majority of this code was adapted from Erik Linder-Norén's excellent repository of <a href="https://github.com/eriklindernoren/PyTorch-GAN">GAN implementations in PyTorch</a>.