In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

Define discriminator class

In [18]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

Define generator class

In [19]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

Set hyper parameters

In [20]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 50

Create discriminator and generator instances

In [21]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

Preparing dataset

In [22]:
# Fixed noise to see how it has changed across the epochs
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]    
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Set optimizers for the discriminator and generator

In [23]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

Loss function

In [24]:
criterion = nn.BCELoss()
# For tensorboard
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MINST/real")
step = 0

Training loop

In [25]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ## Train Discriminator: max log(D(real)) + log(1 - D(G(z))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake= disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ## Train Generator: min log(1 - D(G(z))) <--> leads to saturating gradients <--> max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        # To visualize outputs on Tensorboard
        if batch_idx == 0:
            print(
                f"Epoch: [{epoch}/{num_epochs}]"
                f"Loss D: {lossD: .4f}, Loss G: {lossG: .4f}"
            )
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(real, normalize=True)
                
                writer_fake.add_image(
                    "MNIST Fake Images", img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    "MNIST Real Images", img_grid_real, global_step=step
                )

                step += 1

Epoch: [0/50]Loss D:  0.7146, Loss G:  0.6605
Epoch: [1/50]Loss D:  0.6426, Loss G:  0.9203
Epoch: [2/50]Loss D:  0.5099, Loss G:  1.0493
Epoch: [3/50]Loss D:  0.6296, Loss G:  0.7889
Epoch: [4/50]Loss D:  0.4742, Loss G:  1.0969
Epoch: [5/50]Loss D:  0.4355, Loss G:  1.1889
Epoch: [6/50]Loss D:  0.7352, Loss G:  0.8020
Epoch: [7/50]Loss D:  0.6037, Loss G:  0.9161
Epoch: [8/50]Loss D:  0.4775, Loss G:  1.2056
Epoch: [9/50]Loss D:  0.4674, Loss G:  1.3637
Epoch: [10/50]Loss D:  0.5985, Loss G:  1.1166
Epoch: [11/50]Loss D:  0.3602, Loss G:  1.4793
Epoch: [12/50]Loss D:  0.6442, Loss G:  0.9487
Epoch: [13/50]Loss D:  0.8125, Loss G:  0.8827
Epoch: [14/50]Loss D:  0.5606, Loss G:  1.2232
Epoch: [15/50]Loss D:  0.6147, Loss G:  1.6517
Epoch: [16/50]Loss D:  0.6455, Loss G:  1.5277
Epoch: [17/50]Loss D:  0.6503, Loss G:  0.9969
Epoch: [18/50]Loss D:  0.5759, Loss G:  1.0751
Epoch: [19/50]Loss D:  0.5927, Loss G:  1.1048
Epoch: [20/50]Loss D:  0.7123, Loss G:  0.9129
Epoch: [21/50]Loss D:  

Things to try to get better results:
1. Use larger network
2. Bettter normalization with BatchNorm
3. Different learning rate
4. Change architecture to a CNN