<p style="font-size: 18px;">
  This is the accompanying code for the post titled "Understanding GAN Training Strategies, Ethical Implications, and Building Your First GAN with PyTorch"<br>
  You can find it <a href="https://pureai.substack.com/p/implement-a-gan-with-pytorch">here</a>.<br>
  Published: January 13, 2024<br>
  <a href="https://pureai.substack.com">https://pureai.substack.com</a>
</p>

Welcome to this Jupyter notebook! If you're new to Python or don't have it installed on your system, don't worry; you can still follow along and explore the code.

Here's a quick guide to getting started:

- Using an Online Platform: You can run this notebook in a web browser using platforms like Google Colab or Binder. These services offer free access to Jupyter notebooks and don't require any installation.
- Installing Python Locally: If you'd prefer to run this notebook on your own machine, you'll need to install Python. A popular distribution for scientific computing is Anaconda, which includes Python, Jupyter, and other useful tools.
  - Download Anaconda from [here](https://www.anaconda.com/download).
  - Follow the installation instructions for your operating system.
  - Launch Jupyter Notebook from Anaconda Navigator or by typing jupyter notebook in your command line or terminal.
- Opening the Notebook: Once you have Jupyter running, navigate to the location of this notebook file (.ipynb) and click on it to open.
- Running the Code: You can run each cell in the notebook by selecting it and pressing Shift + Enter. Feel free to modify the code and experiment with it.
- Need More Help?: If you're new to Python or Jupyter notebooks, you might find these resources helpful:
  - [Python.org's Beginner's Guide](https://docs.python.org/3/tutorial/index.html)
  - [Jupyter Notebook Basics](https://jupyter-notebook.readthedocs.io/en/stable/examples/Notebook/Notebook%20Basics.html)

_Note: this Notebook/Code feature the use of Poetry, a dependency and virtual environment manager. If you don't have Poetry, please install it via the Python package manager pip. Then change directories to this code, and run `poetry install --no-root`, which will install all of th required dependencies for you. You then select the poetry virtual environment as your Python kernel._

Happy coding, and enjoy exploring the fascinating world of GANs with PyTorch!

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import sys

import torch
import torchvision

print(f'PyTorch version= {torch.__version__}')
print(f'torchvision version= {torchvision.__version__}')
print(f'CUDA available= {torch.cuda.is_available()}')

In [None]:
if torch.cuda.is_available():
    # CUDA Installation
    print('CUDA Version')
    !nvcc --version
    print()

    # CUDNN Installation
    print(f'CUDNN Version: {torch.backends.cudnn.version()}')
    print(f'Number of CUDA Devices: {torch.cuda.device_count()}')
    print(f'Active CUDA Device: {torch.cuda.current_device()}')
    print(f'Available devices: {torch.cuda.device_count()}, Name: {torch.cuda.get_device_name(0)}')
    print(f'Current CUDA device: {torch.cuda.current_device()}')

In [None]:
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as vtransforms
import torchvision.utils as vutils

import torch.backends.cudnn as cudnn
cudnn.benchmark = True  # Might benefit if the nnet instance remains same

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Hide some PyTorch warnings (bugs)
import warnings
warnings.filterwarnings('ignore')

Define the dataloader.

In [None]:
# Get the training and testing datasets from a path with resize and normalization
def get_dataloader(_img_size, _bs, _ds, _path):
    train_ds = _ds(
        root=_path, download=True, train=True,
        transform=vtransforms.Compose([
            vtransforms.Resize(_img_size),
            vtransforms.ToTensor(),
            vtransforms.Normalize((0.5,), (0.5,))
        ]))

    # Use pin_memory=True to fix GPU memory
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=_bs, shuffle=True,
                                           # pin_memory=True,
                                           num_workers=4)

    test_ds = _ds(
        root=_path, download=True, train=False,
        transform=vtransforms.Compose([
            vtransforms.Resize(_img_size),
            vtransforms.ToTensor(),
            vtransforms.Normalize((0.5,), (0.5,))
        ]))

    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=_bs, shuffle=True,
                                          # pin_memory=True,
                                          num_workers=4)

    return train_dl, test_dl

def get_dl_mnist(_img_size, _bs):
    return get_dataloader(_img_size, _bs, dset.MNIST, './MNIST')

def get_dl_fashionmnist(_img_size, _bs):
    return get_dataloader(_img_size, _bs, dset.FashionMNIST, './fashion-mnist')

In [None]:
# Initialize conv layers with N(0,0.02), batch norm layers N(1,0.02) and 0 bias
def init_weights(_m):
    if isinstance(_m, nn.Conv2d) or isinstance(_m, nn.ConvTranspose2d):
        torch.nn.init.normal_(_m.weight, mean=0.0, std=0.02)
    if isinstance(_m, nn.BatchNorm2d):
        torch.nn.init.normal_(_m.weight, mean=1.0, std=0.02)
        _m.bias.data.fill_(0.0)

### GAN Definition

In [None]:
IMG_CHANNEL = 1
D_HIDDEN= 64

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(

            nn.Conv2d(IMG_CHANNEL, D_HIDDEN, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(D_HIDDEN, D_HIDDEN*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(D_HIDDEN*2, D_HIDDEN*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(D_HIDDEN*4, D_HIDDEN*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN*8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(D_HIDDEN*8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, _input):
        return self.net(_input).view(-1, 1).squeeze(1)

# Check the network layers
print(Discriminator())

In [None]:
Z_DIM= 100
G_HIDDEN= 64

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(

            nn.ConvTranspose2d(Z_DIM, G_HIDDEN*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(G_HIDDEN*8),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(G_HIDDEN*8, G_HIDDEN*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN*4),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(G_HIDDEN*4, G_HIDDEN*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN*2),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(G_HIDDEN*2, G_HIDDEN, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(G_HIDDEN, IMG_CHANNEL, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, _input):
        return self.net(_input)

# Check the network layers
print(Generator())

Define a helper function to make an output directory, as needed.

In [None]:
import os

def mkdir(dir):
    try:
        if not os.path.exists(dir):
            os.mkdir(dir)
    except:
        pass

mkdir('gan_output')

Now we define the training loop.

In [None]:
%%time

X_DIM= 64
BATCH_SIZE= 1000
ETA= 1e-3

train_dl, test_dl = get_dl_fashionmnist(X_DIM, BATCH_SIZE)

# Create the Discriminator and place it in gpu
netD = Discriminator().to(device)
netD.apply(init_weights)

# Create the Generator and place it in gpu
netG = Generator().to(device)
netG.apply(init_weights)

# Optimizers
criterion = nn.BCELoss()
optimizerD = optim.Adam(netD.parameters(), lr=ETA, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=ETA, betas=(0.5, 0.999))

info= True
EPOCHS= 30

# Learning the real and fake - reminder this is not a classification problem
REAL_LABEL= 1
FAKE_LABEL= 0

gan_d_avg_loss = []
gan_g_avg_loss = []

# Same noise sample to generate the sample fake
viz_noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1, device=device)

for e in range(EPOCHS):
    d_loss_accum = 0
    g_loss_accum = 0
    batch_count = 0

    for i, data in enumerate(train_dl):
        x_real = data[0].to(device)
        real_label = torch.full((x_real.size(0),), REAL_LABEL, dtype=torch.float32, device=device)
        fake_label = torch.full((x_real.size(0),), FAKE_LABEL, dtype=torch.float32, device=device)

        # Update D with real data
        netD.zero_grad()
        y_real = netD(x_real)
        loss_D_real = criterion(y_real, real_label)
        loss_D_real.backward()

        # Update D with fake data
        z_noise = torch.randn(x_real.size(0), Z_DIM, 1, 1, device=device)
        x_fake = netG(z_noise)
        y_fake = netD(x_fake.detach())
        loss_D_fake = criterion(y_fake, fake_label)
        loss_D_fake.backward()
        optimizerD.step()

        loss_D = loss_D_real + loss_D_fake
        d_loss_accum += loss_D.item()

        # Update G with fake data
        netG.zero_grad()
        y_fake_r = netD(x_fake)
        loss_G = criterion(y_fake_r, real_label)
        g_loss_accum += loss_G.item()
        batch_count += 1
        loss_G.backward()
        optimizerG.step()

        if info:
            sys.stderr.write("\r{:03d}/{:3d} | LossDr: {:6.2f} | lossDf: {:6.2f} | lossG: {:6.2f}".format(
                e+1, EPOCHS, loss_D_real.mean().item(), loss_D_fake.mean().item(), loss_G.mean().item()))
            sys.stderr.flush()

            if i == 0:
                with torch.no_grad():
                    viz_sample = netG(viz_noise)
                    vutils.save_image(vutils.make_grid(viz_sample[:32], nrow=4),
                                      f'./gan_output/fake_samples_{e}.png', normalize=True)

    if info:
        torch.save(netG.state_dict(), f'./gan_output/netG_{e}.pth')
        torch.save(netD.state_dict(), f'./gan_output/netD_{e}.pth')

    gan_d_avg_loss.append(d_loss_accum / batch_count)
    gan_g_avg_loss.append(g_loss_accum / batch_count)

Let's load the saved weights and generate some fake images.

In [None]:
netG2 = Generator()
netG2.load_state_dict(torch.load(f'./gan_output/netG_{EPOCHS-1:d}.pth'))
netG2.to(device)

plt.figure(1, figsize=(10, 5), dpi=72)
for i in range(5):
    plt.subplot(1, 5, i+1)
    with torch.no_grad():
        x_fake = netG2(torch.randn(1, Z_DIM, 1, 1, device=device))
    plt.axis('off')
    plt.imshow(x_fake.to('cpu').numpy().reshape(64,64), cmap=plt.cm.gray_r, interpolation='nearest')
plt.show()

After completing the training of the GAN, we can illustrate the evolution of the network's learning process, showcasing how it progressively refined its weights to accurately generate computer-simulated images of clothing.

In [None]:
def draw_strip(_row, _col, _offset, _path):
    # epoch
    ix = (0, 1, 2, 5, 10, 15, 20, 25, 29, EPOCHS-1)
    # position excluding the grid
    def img_xy(x, y):
        return 2*(x+1)+64*x, 2*(y+1)+64*y

    x, y = img_xy(_row,_col)
    for i, e in enumerate(ix):
        img = plt.imread(f'./{_path}/fake_samples_{e}.png')
        plt.subplot(4, 10, 10*_offset + i+1)
        plt.axis('off')
        plt.imshow(img[x:x+64,y:y+64], cmap=plt.cm.gray_r, interpolation='nearest')

plt.figure(1, figsize=(20, 10), dpi=72)
draw_strip(0,0,0, 'gan_output')
draw_strip(1,0,1, 'gan_output')
draw_strip(2,2,2, 'gan_output')
draw_strip(3,0,3, 'gan_output')