In [154]:
pip install -r requirements.txt

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49m/Applications/Xcode.app/Contents/Developer/usr/bin/python3 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [155]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

In [156]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.disc(x)

What is the difference between nn.LeakyReLU and nn.ReLU?
nn.LeakyReLU allows a small gradient when the unit is not active (i.e. x < 0) while nn.ReLU has a gradient of 0 when x < 0.

In [157]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)
    
# Why are we using tanh in the generator and sigmoid in the discriminator?
# The tanh function outputs values between -1 and 1, which is the range of the MNIST images. 
# The sigmoid function outputs values between 0 and 1, which is the range of the discriminator output.

In [158]:
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 5e-4 # learning rate, this is the best as of now and suggested by multiple papers and Andrej Karpathy
z_dim = 128 # latent dim for generator, this is used to generate random noise which is then used to generate images
img_dim = 28 * 28 * 1
batch_size = 16
num_epochs = 25

In [159]:
# Initialize the models
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

In [160]:
# Noise
# Why do we need noise?
# The noise is used as an input to the generator to create a fake image. This noise is sampled from a normal distribution.
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

In [161]:
# Convert the images to be between -1 and 1
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

This function is part of the PyTorch library, specifically the torchvision.transforms module. It's used to define a sequence of transformations to be applied to an input image or data before feeding it into a neural network for processing. Let's break down what each part of this function does:

transforms.Compose: This function combines several transformations into a single transformation pipeline. It takes a list of transformations as input and applies them sequentially to the input data.

transforms.ToTensor(): This transformation converts input data (such as images) into PyTorch tensors. It's commonly used because neural networks typically operate on tensors rather than raw image data.

transforms.Normalize((0.5,), (0.5,)): This transformation normalizes the tensor by subtracting the mean and dividing by the standard deviation. In this case, it subtracts 0.5 from each pixel value and then divides by 0.5. This effectively scales the pixel values to be between -1 and 1, which is a common practice in deep learning to make training more stable.

So, in summary, the transforms.Compose function defines a transformation pipeline that converts input data into tensors and then normalizes those tensors. This pipeline is often used when preprocessing images or other data for input into neural networks.

In [162]:
# Download and Load the dataset
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [163]:
# Optimizers and loss function
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

loss_fn = nn.BCELoss()

In [164]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")

In the context of Generative Adversarial Networks (GANs) training on the MNIST dataset, these lines create two SummaryWriter objects, one for logging generated (fake) images and another for logging real images. The SummaryWriter is a utility provided by PyTorch's TensorBoard integration for logging various metrics and visualizations during model training.

Let's break down what each line does:

writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake"): This line creates a SummaryWriter object named writer_fake. It specifies a directory path runs/GAN_MNIST/fake where the logs for fake images will be stored. Typically, during GAN training, the generator network produces fake images, and you might want to visualize these images over time to see how the generator improves.

writer_real = SummaryWriter(f"runs/GAN_MNIST/real"): Similarly, this line creates another SummaryWriter object named writer_real, but this time it's for logging real images. In the context of GAN training, real images are the ones sampled from the MNIST dataset that serve as the ground truth for the discriminator network. Logging real images can help monitor how well the discriminator distinguishes between real and fake images.

By using these SummaryWriter objects, you can log various information such as images, scalar values (e.g., loss), histograms, and more during the training process. This information can then be visualized in TensorBoard to gain insights into the training progress and the performance of the GAN.

In [165]:
step = 0

for epoch in tqdm(range(num_epochs), desc="Training"):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        loss_disc_real = loss_fn(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        loss_disc_fake = loss_fn(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).view(-1)
        loss_gen = loss_fn(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=step)
                writer_real.add_image("Mnist Real Images", img_grid_real, global_step=step)

                step += 1

# Save the model
torch.save(gen.state_dict(), "gen.pth")
torch.save(disc.state_dict(), "disc.pth")

# Close the tensorboard writer
writer_fake.close()
writer_real.close()

Training:   0%|          | 0/25 [00:00<?, ?it/s]

Epoch [0/25] Batch 0/3750                   Loss D: 0.6468, loss G: 0.6745
Epoch [0/25] Batch 100/3750                   Loss D: 0.2958, loss G: 1.8331
Epoch [0/25] Batch 200/3750                   Loss D: 0.0519, loss G: 3.3469
Epoch [0/25] Batch 300/3750                   Loss D: 0.1241, loss G: 1.9476
Epoch [0/25] Batch 400/3750                   Loss D: 0.4436, loss G: 0.9435
Epoch [0/25] Batch 500/3750                   Loss D: 0.2625, loss G: 2.0015
Epoch [0/25] Batch 600/3750                   Loss D: 0.3151, loss G: 1.7031
Epoch [0/25] Batch 700/3750                   Loss D: 0.1762, loss G: 2.1765
Epoch [0/25] Batch 800/3750                   Loss D: 0.3037, loss G: 1.5143
Epoch [0/25] Batch 900/3750                   Loss D: 0.2701, loss G: 1.5341
Epoch [0/25] Batch 1000/3750                   Loss D: 0.7370, loss G: 0.8800
Epoch [0/25] Batch 1100/3750                   Loss D: 0.5510, loss G: 1.0745
Epoch [0/25] Batch 1200/3750                   Loss D: 0.2983, loss G: 1.806

Training:   4%|▍         | 1/25 [00:10<04:03, 10.13s/it]

Epoch [0/25] Batch 3700/3750                   Loss D: 0.5463, loss G: 1.3157
Epoch [1/25] Batch 0/3750                   Loss D: 0.6691, loss G: 0.8155
Epoch [1/25] Batch 100/3750                   Loss D: 0.4979, loss G: 1.3029
Epoch [1/25] Batch 200/3750                   Loss D: 0.4386, loss G: 1.2609
Epoch [1/25] Batch 300/3750                   Loss D: 0.3659, loss G: 1.6706
Epoch [1/25] Batch 400/3750                   Loss D: 0.2952, loss G: 1.5839
Epoch [1/25] Batch 500/3750                   Loss D: 0.8863, loss G: 0.6742
Epoch [1/25] Batch 600/3750                   Loss D: 0.5706, loss G: 1.3934
Epoch [1/25] Batch 700/3750                   Loss D: 0.4638, loss G: 1.2288
Epoch [1/25] Batch 800/3750                   Loss D: 0.9025, loss G: 1.2371
Epoch [1/25] Batch 900/3750                   Loss D: 0.2491, loss G: 2.0880
Epoch [1/25] Batch 1000/3750                   Loss D: 0.8176, loss G: 1.3991
Epoch [1/25] Batch 1100/3750                   Loss D: 0.4790, loss G: 1.228

Training:   8%|▊         | 2/25 [00:20<03:59, 10.40s/it]

Epoch [1/25] Batch 3700/3750                   Loss D: 0.9619, loss G: 1.5446
Epoch [2/25] Batch 0/3750                   Loss D: 0.6724, loss G: 1.1137
Epoch [2/25] Batch 100/3750                   Loss D: 0.5742, loss G: 1.1210
Epoch [2/25] Batch 200/3750                   Loss D: 0.3529, loss G: 1.7735
Epoch [2/25] Batch 300/3750                   Loss D: 0.3062, loss G: 1.5966
Epoch [2/25] Batch 400/3750                   Loss D: 0.6324, loss G: 1.3508
Epoch [2/25] Batch 500/3750                   Loss D: 0.4993, loss G: 1.0952
Epoch [2/25] Batch 600/3750                   Loss D: 0.9527, loss G: 0.8554
Epoch [2/25] Batch 700/3750                   Loss D: 0.6399, loss G: 1.1161
Epoch [2/25] Batch 800/3750                   Loss D: 1.0542, loss G: 0.7926
Epoch [2/25] Batch 900/3750                   Loss D: 0.4855, loss G: 1.4128
Epoch [2/25] Batch 1000/3750                   Loss D: 0.5370, loss G: 1.0762
Epoch [2/25] Batch 1100/3750                   Loss D: 0.7781, loss G: 0.978

Training:  12%|█▏        | 3/25 [00:30<03:40, 10.04s/it]

Epoch [2/25] Batch 3700/3750                   Loss D: 0.8413, loss G: 0.8695
Epoch [3/25] Batch 0/3750                   Loss D: 0.6480, loss G: 1.2173
Epoch [3/25] Batch 100/3750                   Loss D: 0.6052, loss G: 1.2127
Epoch [3/25] Batch 200/3750                   Loss D: 0.5153, loss G: 1.1759
Epoch [3/25] Batch 300/3750                   Loss D: 0.8061, loss G: 0.8360
Epoch [3/25] Batch 400/3750                   Loss D: 0.5117, loss G: 1.2916
Epoch [3/25] Batch 500/3750                   Loss D: 0.5570, loss G: 1.6051
Epoch [3/25] Batch 600/3750                   Loss D: 0.7126, loss G: 0.7527
Epoch [3/25] Batch 700/3750                   Loss D: 0.5319, loss G: 1.4827
Epoch [3/25] Batch 800/3750                   Loss D: 0.6524, loss G: 1.3179
Epoch [3/25] Batch 900/3750                   Loss D: 0.6546, loss G: 1.4641
Epoch [3/25] Batch 1000/3750                   Loss D: 0.4969, loss G: 1.1597
Epoch [3/25] Batch 1100/3750                   Loss D: 0.4416, loss G: 1.591

Training:  16%|█▌        | 4/25 [00:39<03:27,  9.87s/it]

Epoch [3/25] Batch 3700/3750                   Loss D: 0.7234, loss G: 0.8513
Epoch [4/25] Batch 0/3750                   Loss D: 0.3655, loss G: 1.4970
Epoch [4/25] Batch 100/3750                   Loss D: 0.4928, loss G: 1.8031
Epoch [4/25] Batch 200/3750                   Loss D: 0.5679, loss G: 1.4200
Epoch [4/25] Batch 300/3750                   Loss D: 0.6315, loss G: 0.9452
Epoch [4/25] Batch 400/3750                   Loss D: 0.5368, loss G: 1.6045
Epoch [4/25] Batch 500/3750                   Loss D: 0.6904, loss G: 0.9592
Epoch [4/25] Batch 600/3750                   Loss D: 1.1982, loss G: 2.0825
Epoch [4/25] Batch 700/3750                   Loss D: 0.5181, loss G: 1.2081
Epoch [4/25] Batch 800/3750                   Loss D: 0.7980, loss G: 0.9139
Epoch [4/25] Batch 900/3750                   Loss D: 0.6579, loss G: 1.2857
Epoch [4/25] Batch 1000/3750                   Loss D: 0.6724, loss G: 1.1946
Epoch [4/25] Batch 1100/3750                   Loss D: 0.4045, loss G: 1.842

Training:  20%|██        | 5/25 [00:49<03:15,  9.76s/it]

Epoch [4/25] Batch 3700/3750                   Loss D: 0.9600, loss G: 0.8701
Epoch [5/25] Batch 0/3750                   Loss D: 0.9952, loss G: 0.6802
Epoch [5/25] Batch 100/3750                   Loss D: 0.4752, loss G: 1.6775
Epoch [5/25] Batch 200/3750                   Loss D: 0.6334, loss G: 1.5713
Epoch [5/25] Batch 300/3750                   Loss D: 0.5650, loss G: 1.3555
Epoch [5/25] Batch 400/3750                   Loss D: 0.7662, loss G: 1.0766
Epoch [5/25] Batch 500/3750                   Loss D: 0.6018, loss G: 0.9384
Epoch [5/25] Batch 600/3750                   Loss D: 0.5819, loss G: 1.0105
Epoch [5/25] Batch 700/3750                   Loss D: 0.5338, loss G: 1.4564
Epoch [5/25] Batch 800/3750                   Loss D: 0.8504, loss G: 1.3952
Epoch [5/25] Batch 900/3750                   Loss D: 0.6296, loss G: 1.0240
Epoch [5/25] Batch 1000/3750                   Loss D: 0.6776, loss G: 1.5358
Epoch [5/25] Batch 1100/3750                   Loss D: 0.5715, loss G: 1.160

Training:  24%|██▍       | 6/25 [00:59<03:03,  9.68s/it]

Epoch [5/25] Batch 3700/3750                   Loss D: 0.5408, loss G: 1.2917
Epoch [6/25] Batch 0/3750                   Loss D: 0.4347, loss G: 1.5121
Epoch [6/25] Batch 100/3750                   Loss D: 0.5062, loss G: 1.1855
Epoch [6/25] Batch 200/3750                   Loss D: 0.4220, loss G: 1.4630
Epoch [6/25] Batch 300/3750                   Loss D: 0.9079, loss G: 0.7184
Epoch [6/25] Batch 400/3750                   Loss D: 0.6710, loss G: 0.8537
Epoch [6/25] Batch 500/3750                   Loss D: 0.5077, loss G: 1.0230
Epoch [6/25] Batch 600/3750                   Loss D: 0.7428, loss G: 0.7120
Epoch [6/25] Batch 700/3750                   Loss D: 0.4076, loss G: 1.5134
Epoch [6/25] Batch 800/3750                   Loss D: 0.5379, loss G: 1.0972
Epoch [6/25] Batch 900/3750                   Loss D: 0.5727, loss G: 1.6362
Epoch [6/25] Batch 1000/3750                   Loss D: 0.8180, loss G: 1.7822
Epoch [6/25] Batch 1100/3750                   Loss D: 0.4958, loss G: 1.454

Training:  28%|██▊       | 7/25 [01:08<02:52,  9.57s/it]

Epoch [6/25] Batch 3700/3750                   Loss D: 0.5400, loss G: 1.0036
Epoch [7/25] Batch 0/3750                   Loss D: 0.6756, loss G: 1.5966
Epoch [7/25] Batch 100/3750                   Loss D: 0.6118, loss G: 1.0994
Epoch [7/25] Batch 200/3750                   Loss D: 0.6756, loss G: 1.1520
Epoch [7/25] Batch 300/3750                   Loss D: 0.6401, loss G: 1.1314
Epoch [7/25] Batch 400/3750                   Loss D: 0.6751, loss G: 1.0866
Epoch [7/25] Batch 500/3750                   Loss D: 0.5333, loss G: 0.9458
Epoch [7/25] Batch 600/3750                   Loss D: 0.6194, loss G: 1.4365
Epoch [7/25] Batch 700/3750                   Loss D: 0.6426, loss G: 1.4427
Epoch [7/25] Batch 800/3750                   Loss D: 0.6427, loss G: 1.1533
Epoch [7/25] Batch 900/3750                   Loss D: 0.5828, loss G: 0.9993
Epoch [7/25] Batch 1000/3750                   Loss D: 0.7066, loss G: 0.9012
Epoch [7/25] Batch 1100/3750                   Loss D: 0.6869, loss G: 1.483

Training:  32%|███▏      | 8/25 [01:17<02:42,  9.54s/it]

Epoch [7/25] Batch 3700/3750                   Loss D: 0.7670, loss G: 0.8943
Epoch [8/25] Batch 0/3750                   Loss D: 0.6616, loss G: 1.1052
Epoch [8/25] Batch 100/3750                   Loss D: 0.6719, loss G: 0.9351
Epoch [8/25] Batch 200/3750                   Loss D: 0.7665, loss G: 0.8554
Epoch [8/25] Batch 300/3750                   Loss D: 0.4910, loss G: 1.3197
Epoch [8/25] Batch 400/3750                   Loss D: 0.6261, loss G: 1.2127
Epoch [8/25] Batch 500/3750                   Loss D: 0.6294, loss G: 1.3016
Epoch [8/25] Batch 600/3750                   Loss D: 0.6799, loss G: 0.9111
Epoch [8/25] Batch 700/3750                   Loss D: 0.5648, loss G: 1.1434
Epoch [8/25] Batch 800/3750                   Loss D: 0.6043, loss G: 1.1359
Epoch [8/25] Batch 900/3750                   Loss D: 0.5366, loss G: 1.0819
Epoch [8/25] Batch 1000/3750                   Loss D: 0.6421, loss G: 1.0512
Epoch [8/25] Batch 1100/3750                   Loss D: 0.6850, loss G: 0.882

Training:  36%|███▌      | 9/25 [01:26<02:30,  9.41s/it]

Epoch [8/25] Batch 3700/3750                   Loss D: 0.5777, loss G: 1.0062
Epoch [9/25] Batch 0/3750                   Loss D: 0.5418, loss G: 0.9807
Epoch [9/25] Batch 100/3750                   Loss D: 0.6868, loss G: 0.9299
Epoch [9/25] Batch 200/3750                   Loss D: 0.4843, loss G: 1.6231
Epoch [9/25] Batch 300/3750                   Loss D: 0.6059, loss G: 0.7045
Epoch [9/25] Batch 400/3750                   Loss D: 0.6648, loss G: 1.1562
Epoch [9/25] Batch 500/3750                   Loss D: 0.6217, loss G: 0.9998
Epoch [9/25] Batch 600/3750                   Loss D: 0.5498, loss G: 1.4431
Epoch [9/25] Batch 700/3750                   Loss D: 0.6243, loss G: 1.0757
Epoch [9/25] Batch 800/3750                   Loss D: 0.5622, loss G: 1.2558
Epoch [9/25] Batch 900/3750                   Loss D: 0.5196, loss G: 1.4987
Epoch [9/25] Batch 1000/3750                   Loss D: 0.7088, loss G: 1.4638
Epoch [9/25] Batch 1100/3750                   Loss D: 0.6788, loss G: 1.644

Training:  40%|████      | 10/25 [01:36<02:20,  9.37s/it]

Epoch [9/25] Batch 3700/3750                   Loss D: 0.5145, loss G: 1.0146
Epoch [10/25] Batch 0/3750                   Loss D: 0.5897, loss G: 1.1859
Epoch [10/25] Batch 100/3750                   Loss D: 0.6235, loss G: 1.3769
Epoch [10/25] Batch 200/3750                   Loss D: 0.7759, loss G: 0.9731
Epoch [10/25] Batch 300/3750                   Loss D: 0.5845, loss G: 0.9945
Epoch [10/25] Batch 400/3750                   Loss D: 0.6219, loss G: 1.1863
Epoch [10/25] Batch 500/3750                   Loss D: 0.4286, loss G: 1.4201
Epoch [10/25] Batch 600/3750                   Loss D: 0.6057, loss G: 1.0814
Epoch [10/25] Batch 700/3750                   Loss D: 0.4122, loss G: 1.6156
Epoch [10/25] Batch 800/3750                   Loss D: 0.7174, loss G: 0.7481
Epoch [10/25] Batch 900/3750                   Loss D: 0.8176, loss G: 0.8635
Epoch [10/25] Batch 1000/3750                   Loss D: 0.5455, loss G: 1.2872
Epoch [10/25] Batch 1100/3750                   Loss D: 0.4814, l

Training:  44%|████▍     | 11/25 [01:45<02:09,  9.27s/it]

Epoch [10/25] Batch 3700/3750                   Loss D: 0.6126, loss G: 1.1253
Epoch [11/25] Batch 0/3750                   Loss D: 0.6169, loss G: 0.9918
Epoch [11/25] Batch 100/3750                   Loss D: 0.6279, loss G: 0.9882
Epoch [11/25] Batch 200/3750                   Loss D: 0.5855, loss G: 1.1575
Epoch [11/25] Batch 300/3750                   Loss D: 0.5232, loss G: 1.3997
Epoch [11/25] Batch 400/3750                   Loss D: 0.5148, loss G: 1.3260
Epoch [11/25] Batch 500/3750                   Loss D: 0.7348, loss G: 0.7210
Epoch [11/25] Batch 600/3750                   Loss D: 0.4689, loss G: 1.1605
Epoch [11/25] Batch 700/3750                   Loss D: 0.5595, loss G: 1.2696
Epoch [11/25] Batch 800/3750                   Loss D: 0.6129, loss G: 0.9134
Epoch [11/25] Batch 900/3750                   Loss D: 0.6053, loss G: 1.0263
Epoch [11/25] Batch 1000/3750                   Loss D: 0.5710, loss G: 0.7995
Epoch [11/25] Batch 1100/3750                   Loss D: 0.7246, 

Training:  48%|████▊     | 12/25 [01:54<02:00,  9.28s/it]

Epoch [11/25] Batch 3700/3750                   Loss D: 0.5897, loss G: 1.0875
Epoch [12/25] Batch 0/3750                   Loss D: 0.6085, loss G: 1.0869
Epoch [12/25] Batch 100/3750                   Loss D: 0.6027, loss G: 0.9794
Epoch [12/25] Batch 200/3750                   Loss D: 0.6855, loss G: 0.9724
Epoch [12/25] Batch 300/3750                   Loss D: 0.7178, loss G: 0.7701
Epoch [12/25] Batch 400/3750                   Loss D: 0.6848, loss G: 0.8936
Epoch [12/25] Batch 500/3750                   Loss D: 0.5484, loss G: 1.0939
Epoch [12/25] Batch 600/3750                   Loss D: 0.6348, loss G: 1.3700
Epoch [12/25] Batch 700/3750                   Loss D: 0.6633, loss G: 1.0861
Epoch [12/25] Batch 800/3750                   Loss D: 0.6864, loss G: 1.0796
Epoch [12/25] Batch 900/3750                   Loss D: 0.6112, loss G: 1.2199
Epoch [12/25] Batch 1000/3750                   Loss D: 0.8535, loss G: 1.1362
Epoch [12/25] Batch 1100/3750                   Loss D: 0.5398, 

Training:  52%|█████▏    | 13/25 [02:03<01:51,  9.28s/it]

Epoch [12/25] Batch 3700/3750                   Loss D: 0.6367, loss G: 0.8038
Epoch [13/25] Batch 0/3750                   Loss D: 0.5895, loss G: 0.9661
Epoch [13/25] Batch 100/3750                   Loss D: 0.4915, loss G: 1.2131
Epoch [13/25] Batch 200/3750                   Loss D: 0.5815, loss G: 1.0800
Epoch [13/25] Batch 300/3750                   Loss D: 0.6864, loss G: 1.0537
Epoch [13/25] Batch 400/3750                   Loss D: 0.4923, loss G: 1.1321
Epoch [13/25] Batch 500/3750                   Loss D: 0.6054, loss G: 0.7865
Epoch [13/25] Batch 600/3750                   Loss D: 0.5698, loss G: 1.0046
Epoch [13/25] Batch 700/3750                   Loss D: 0.5041, loss G: 0.9409
Epoch [13/25] Batch 800/3750                   Loss D: 0.6578, loss G: 1.2841
Epoch [13/25] Batch 900/3750                   Loss D: 0.6264, loss G: 0.7948
Epoch [13/25] Batch 1000/3750                   Loss D: 0.4862, loss G: 1.2410
Epoch [13/25] Batch 1100/3750                   Loss D: 0.6161, 

Training:  56%|█████▌    | 14/25 [02:13<01:42,  9.33s/it]

Epoch [13/25] Batch 3700/3750                   Loss D: 0.4948, loss G: 1.3964
Epoch [14/25] Batch 0/3750                   Loss D: 0.5371, loss G: 0.9874
Epoch [14/25] Batch 100/3750                   Loss D: 0.5785, loss G: 0.9214
Epoch [14/25] Batch 200/3750                   Loss D: 0.6144, loss G: 0.9380
Epoch [14/25] Batch 300/3750                   Loss D: 0.6392, loss G: 1.1171
Epoch [14/25] Batch 400/3750                   Loss D: 0.7440, loss G: 0.8538
Epoch [14/25] Batch 500/3750                   Loss D: 0.5294, loss G: 0.9201
Epoch [14/25] Batch 600/3750                   Loss D: 0.5891, loss G: 1.2449
Epoch [14/25] Batch 700/3750                   Loss D: 0.5997, loss G: 0.9521
Epoch [14/25] Batch 800/3750                   Loss D: 0.6508, loss G: 0.9622
Epoch [14/25] Batch 900/3750                   Loss D: 0.5977, loss G: 1.4541
Epoch [14/25] Batch 1000/3750                   Loss D: 0.7008, loss G: 1.1042
Epoch [14/25] Batch 1100/3750                   Loss D: 0.5804, 

Training:  60%|██████    | 15/25 [02:22<01:33,  9.31s/it]

Epoch [14/25] Batch 3700/3750                   Loss D: 0.5985, loss G: 1.0172
Epoch [15/25] Batch 0/3750                   Loss D: 0.5240, loss G: 1.0012
Epoch [15/25] Batch 100/3750                   Loss D: 0.6318, loss G: 0.7791
Epoch [15/25] Batch 200/3750                   Loss D: 0.6422, loss G: 0.9100
Epoch [15/25] Batch 300/3750                   Loss D: 0.6815, loss G: 1.1187
Epoch [15/25] Batch 400/3750                   Loss D: 0.6766, loss G: 1.1085
Epoch [15/25] Batch 500/3750                   Loss D: 0.5482, loss G: 1.0182
Epoch [15/25] Batch 600/3750                   Loss D: 0.6104, loss G: 0.7224
Epoch [15/25] Batch 700/3750                   Loss D: 0.5699, loss G: 1.0137
Epoch [15/25] Batch 800/3750                   Loss D: 0.5668, loss G: 1.1353
Epoch [15/25] Batch 900/3750                   Loss D: 0.5332, loss G: 0.9316
Epoch [15/25] Batch 1000/3750                   Loss D: 0.5881, loss G: 1.2782
Epoch [15/25] Batch 1100/3750                   Loss D: 0.5469, 

Training:  64%|██████▍   | 16/25 [02:31<01:23,  9.32s/it]

Epoch [15/25] Batch 3700/3750                   Loss D: 0.5731, loss G: 1.2290
Epoch [16/25] Batch 0/3750                   Loss D: 0.4815, loss G: 1.0831
Epoch [16/25] Batch 100/3750                   Loss D: 0.6443, loss G: 1.0066
Epoch [16/25] Batch 200/3750                   Loss D: 0.6939, loss G: 0.8966
Epoch [16/25] Batch 300/3750                   Loss D: 0.6295, loss G: 1.0950
Epoch [16/25] Batch 400/3750                   Loss D: 0.6422, loss G: 0.6363
Epoch [16/25] Batch 500/3750                   Loss D: 0.6099, loss G: 1.2413
Epoch [16/25] Batch 600/3750                   Loss D: 0.6058, loss G: 0.9200
Epoch [16/25] Batch 700/3750                   Loss D: 0.7047, loss G: 1.3110
Epoch [16/25] Batch 800/3750                   Loss D: 0.6154, loss G: 0.7832
Epoch [16/25] Batch 900/3750                   Loss D: 0.6587, loss G: 0.8612
Epoch [16/25] Batch 1000/3750                   Loss D: 0.6180, loss G: 0.8585
Epoch [16/25] Batch 1100/3750                   Loss D: 0.8078, 

Training:  68%|██████▊   | 17/25 [02:41<01:14,  9.34s/it]

Epoch [16/25] Batch 3700/3750                   Loss D: 0.6163, loss G: 0.9950
Epoch [17/25] Batch 0/3750                   Loss D: 0.6865, loss G: 1.0259
Epoch [17/25] Batch 100/3750                   Loss D: 0.5924, loss G: 1.0335
Epoch [17/25] Batch 200/3750                   Loss D: 0.7535, loss G: 0.7114
Epoch [17/25] Batch 300/3750                   Loss D: 0.6282, loss G: 1.3393
Epoch [17/25] Batch 400/3750                   Loss D: 0.5804, loss G: 1.3853
Epoch [17/25] Batch 500/3750                   Loss D: 0.6072, loss G: 1.0733
Epoch [17/25] Batch 600/3750                   Loss D: 0.6788, loss G: 0.8376
Epoch [17/25] Batch 700/3750                   Loss D: 0.5692, loss G: 1.0356
Epoch [17/25] Batch 800/3750                   Loss D: 0.7471, loss G: 1.0959
Epoch [17/25] Batch 900/3750                   Loss D: 0.5569, loss G: 1.0393
Epoch [17/25] Batch 1000/3750                   Loss D: 0.5436, loss G: 1.0431
Epoch [17/25] Batch 1100/3750                   Loss D: 0.6419, 

Training:  72%|███████▏  | 18/25 [02:50<01:05,  9.36s/it]

Epoch [17/25] Batch 3700/3750                   Loss D: 0.6091, loss G: 0.9783
Epoch [18/25] Batch 0/3750                   Loss D: 0.6965, loss G: 0.8427
Epoch [18/25] Batch 100/3750                   Loss D: 0.6643, loss G: 1.1268
Epoch [18/25] Batch 200/3750                   Loss D: 0.6653, loss G: 0.9363
Epoch [18/25] Batch 300/3750                   Loss D: 0.6335, loss G: 0.9822
Epoch [18/25] Batch 400/3750                   Loss D: 0.7738, loss G: 0.7702
Epoch [18/25] Batch 500/3750                   Loss D: 0.4808, loss G: 1.1019
Epoch [18/25] Batch 600/3750                   Loss D: 0.5017, loss G: 1.0583
Epoch [18/25] Batch 700/3750                   Loss D: 0.5173, loss G: 0.9509
Epoch [18/25] Batch 800/3750                   Loss D: 0.6433, loss G: 0.9762
Epoch [18/25] Batch 900/3750                   Loss D: 0.7093, loss G: 0.9843
Epoch [18/25] Batch 1000/3750                   Loss D: 0.5258, loss G: 1.1192
Epoch [18/25] Batch 1100/3750                   Loss D: 0.6606, 

Training:  76%|███████▌  | 19/25 [03:00<00:56,  9.40s/it]

Epoch [18/25] Batch 3700/3750                   Loss D: 0.6559, loss G: 0.8894
Epoch [19/25] Batch 0/3750                   Loss D: 0.7555, loss G: 1.0920
Epoch [19/25] Batch 100/3750                   Loss D: 0.5651, loss G: 1.0705
Epoch [19/25] Batch 200/3750                   Loss D: 0.5580, loss G: 0.9448
Epoch [19/25] Batch 300/3750                   Loss D: 0.6064, loss G: 1.0934
Epoch [19/25] Batch 400/3750                   Loss D: 0.7137, loss G: 0.7108
Epoch [19/25] Batch 500/3750                   Loss D: 0.5059, loss G: 1.0633
Epoch [19/25] Batch 600/3750                   Loss D: 0.6726, loss G: 0.9384
Epoch [19/25] Batch 700/3750                   Loss D: 0.6504, loss G: 0.8004
Epoch [19/25] Batch 800/3750                   Loss D: 0.6519, loss G: 1.3815
Epoch [19/25] Batch 900/3750                   Loss D: 0.5245, loss G: 0.9862
Epoch [19/25] Batch 1000/3750                   Loss D: 0.5676, loss G: 1.0437
Epoch [19/25] Batch 1100/3750                   Loss D: 0.6126, 

Training:  80%|████████  | 20/25 [03:09<00:47,  9.42s/it]

Epoch [19/25] Batch 3700/3750                   Loss D: 0.7248, loss G: 0.7940
Epoch [20/25] Batch 0/3750                   Loss D: 0.5788, loss G: 0.8988
Epoch [20/25] Batch 100/3750                   Loss D: 0.6057, loss G: 0.7141
Epoch [20/25] Batch 200/3750                   Loss D: 0.7146, loss G: 0.9886
Epoch [20/25] Batch 300/3750                   Loss D: 0.5897, loss G: 1.0066
Epoch [20/25] Batch 400/3750                   Loss D: 0.7064, loss G: 0.8351
Epoch [20/25] Batch 500/3750                   Loss D: 0.7389, loss G: 0.9002
Epoch [20/25] Batch 600/3750                   Loss D: 0.6300, loss G: 1.1518
Epoch [20/25] Batch 700/3750                   Loss D: 0.4565, loss G: 1.3424
Epoch [20/25] Batch 800/3750                   Loss D: 0.5949, loss G: 0.9915
Epoch [20/25] Batch 900/3750                   Loss D: 0.6688, loss G: 0.9455
Epoch [20/25] Batch 1000/3750                   Loss D: 0.5865, loss G: 1.0337
Epoch [20/25] Batch 1100/3750                   Loss D: 0.6762, 

Training:  84%|████████▍ | 21/25 [03:19<00:37,  9.41s/it]

Epoch [20/25] Batch 3700/3750                   Loss D: 0.5989, loss G: 1.0006
Epoch [21/25] Batch 0/3750                   Loss D: 0.6627, loss G: 0.7979
Epoch [21/25] Batch 100/3750                   Loss D: 0.6807, loss G: 0.7549
Epoch [21/25] Batch 200/3750                   Loss D: 0.5836, loss G: 1.0523
Epoch [21/25] Batch 300/3750                   Loss D: 0.8149, loss G: 0.9764
Epoch [21/25] Batch 400/3750                   Loss D: 0.6121, loss G: 0.7088
Epoch [21/25] Batch 500/3750                   Loss D: 0.5024, loss G: 0.9153
Epoch [21/25] Batch 600/3750                   Loss D: 0.5255, loss G: 1.2007
Epoch [21/25] Batch 700/3750                   Loss D: 0.7127, loss G: 1.2965
Epoch [21/25] Batch 800/3750                   Loss D: 0.5196, loss G: 0.9901
Epoch [21/25] Batch 900/3750                   Loss D: 0.6363, loss G: 1.1141
Epoch [21/25] Batch 1000/3750                   Loss D: 0.7292, loss G: 1.0389
Epoch [21/25] Batch 1100/3750                   Loss D: 0.5151, 

Training:  88%|████████▊ | 22/25 [03:28<00:28,  9.40s/it]

Epoch [21/25] Batch 3700/3750                   Loss D: 0.5899, loss G: 1.2478
Epoch [22/25] Batch 0/3750                   Loss D: 0.6734, loss G: 1.0553
Epoch [22/25] Batch 100/3750                   Loss D: 0.5727, loss G: 1.0505
Epoch [22/25] Batch 200/3750                   Loss D: 0.6035, loss G: 0.8890
Epoch [22/25] Batch 300/3750                   Loss D: 0.5602, loss G: 1.0957
Epoch [22/25] Batch 400/3750                   Loss D: 0.7702, loss G: 0.9612
Epoch [22/25] Batch 500/3750                   Loss D: 0.5790, loss G: 1.1169
Epoch [22/25] Batch 600/3750                   Loss D: 0.6654, loss G: 1.1630
Epoch [22/25] Batch 700/3750                   Loss D: 0.5769, loss G: 1.1353
Epoch [22/25] Batch 800/3750                   Loss D: 0.5702, loss G: 1.3810
Epoch [22/25] Batch 900/3750                   Loss D: 0.5242, loss G: 1.0639
Epoch [22/25] Batch 1000/3750                   Loss D: 0.7377, loss G: 1.0022
Epoch [22/25] Batch 1100/3750                   Loss D: 0.6781, 

Training:  92%|█████████▏| 23/25 [03:37<00:18,  9.41s/it]

Epoch [22/25] Batch 3700/3750                   Loss D: 0.6446, loss G: 0.8073
Epoch [23/25] Batch 0/3750                   Loss D: 0.8034, loss G: 1.0869
Epoch [23/25] Batch 100/3750                   Loss D: 0.5340, loss G: 1.1153
Epoch [23/25] Batch 200/3750                   Loss D: 0.7574, loss G: 0.5899
Epoch [23/25] Batch 300/3750                   Loss D: 0.5897, loss G: 1.0983
Epoch [23/25] Batch 400/3750                   Loss D: 0.6635, loss G: 1.1795
Epoch [23/25] Batch 500/3750                   Loss D: 0.5694, loss G: 0.8838
Epoch [23/25] Batch 600/3750                   Loss D: 0.5371, loss G: 1.1507
Epoch [23/25] Batch 700/3750                   Loss D: 0.5938, loss G: 0.9892
Epoch [23/25] Batch 800/3750                   Loss D: 0.6167, loss G: 0.9651
Epoch [23/25] Batch 900/3750                   Loss D: 0.6413, loss G: 1.0285
Epoch [23/25] Batch 1000/3750                   Loss D: 0.4880, loss G: 1.2923
Epoch [23/25] Batch 1100/3750                   Loss D: 0.5595, 

Training:  96%|█████████▌| 24/25 [03:47<00:09,  9.44s/it]

Epoch [23/25] Batch 3700/3750                   Loss D: 0.6908, loss G: 0.9810
Epoch [24/25] Batch 0/3750                   Loss D: 0.5925, loss G: 1.2191
Epoch [24/25] Batch 100/3750                   Loss D: 0.5675, loss G: 0.8029
Epoch [24/25] Batch 200/3750                   Loss D: 0.6189, loss G: 1.1092
Epoch [24/25] Batch 300/3750                   Loss D: 0.6507, loss G: 1.0058
Epoch [24/25] Batch 400/3750                   Loss D: 0.5396, loss G: 1.2730
Epoch [24/25] Batch 500/3750                   Loss D: 0.4628, loss G: 1.5038
Epoch [24/25] Batch 600/3750                   Loss D: 0.5984, loss G: 0.9617
Epoch [24/25] Batch 700/3750                   Loss D: 0.4828, loss G: 1.2575
Epoch [24/25] Batch 800/3750                   Loss D: 0.5365, loss G: 1.6192
Epoch [24/25] Batch 900/3750                   Loss D: 0.4754, loss G: 1.4256
Epoch [24/25] Batch 1000/3750                   Loss D: 0.6167, loss G: 0.9050
Epoch [24/25] Batch 1100/3750                   Loss D: 0.6254, 

Training: 100%|██████████| 25/25 [03:57<00:00,  9.48s/it]

Epoch [24/25] Batch 3700/3750                   Loss D: 0.6496, loss G: 1.0464





Results on changing hyperparameters:
1. lr = 5e-4, z_dim = 64, img_dim = 28 * 28 * 1, batch_size = 16, num_epochs = 25
   1. Results: D Loss = 0.5637 | G Loss = 0.9129
2. lr = 7e-4, z_dim = 32, img_dim = 28 * 28 * 1, batch_size = 16, num_epochs = 25
   1. Results: D Loss = 0.4870 | G Loss = 1.0671
3. lr = 3e-4, z_dim = 64, img_dim = 28 * 28 * 1, batch_size = 32, num_epochs = 25
   1. Results: D Loss = 0.5484 | G Loss = 0.9143
4. lr = 3e-4, z_dim = 64, img_dim = 28 * 28 * 1, batch_size = 16, num_epochs = 30
   1. Results: D Loss = 0.5484 | G Loss = 0.9143

Results: 2 > 1 > 3




How to improve the GAN?
1. Use different architectures for the generator and discriminator (e.g. DCGAN, WGAN, etc.)
2. Use different hyperparameters, one thing to note is that GANs are very sensitive to hyperparameters
3. Use different loss functions
4. Train for longer
5. Add regularization techniques (e.g. weight clipping in WGAN)
6. Use different types of normalization
7. Use different types of noise
8. Use different types of optimizers

What is DCGAN?<br>
DCGAN stands for Deep Convolutional Generative Adversarial Network. It is a type of GAN that uses convolutional layers in both the generator and discriminator.
This allows the model to learn more complex patterns in the data and generate higher quality images. DCGANs are commonly used for image generation tasks and have been shown to produce realistic images in a variety of domains.

What is WGAN?<br>
WGAN stands for Wasserstein Generative Adversarial Network. It is a type of GAN that uses the Wasserstein distance as the loss function instead of the traditional binary cross-entropy loss. And this is used to stabilize the training of GANs and produce higher quality images. WGANs have been shown to be more stable and produce better results than traditional GANs in many cases.

What is weight clipping?<br>
Weight clipping is a regularization technique used in WGANs to enforce a Lipschitz constraint on the discriminator. This involves clipping the weights of the discriminator to a small range after each training step. This helps to prevent the discriminator from becoming too powerful and dominating the training process, which can lead to mode collapse and other issues. Weight clipping has been shown to improve the stability and performance of WGANs in practice.

How is weight clipping different from learning rate?<br>
Weight clipping is a regularization technique that is applied to the weights of the discriminator in a GAN. It involves clipping the weights to a small range after each training step. This helps to prevent the discriminator from becoming too powerful and dominating the training process. Learning rate, on the other hand, is a hyperparameter that controls how much the weights of the model are updated during training. It determines the size of the steps taken in the direction of the gradient during optimization. Both weight clipping and learning rate are important hyperparameters that can affect the performance of a GAN, but they serve different purposes and are applied in different ways.

In [1]:
# Generate 9 random images in a 3 x 3 grid
with torch.no_grad():
    noise = torch.randn(9, z_dim).to(device)
    img = gen(noise).view(-1, 1, 28, 28)
    img_grid = torchvision.utils.make_grid(img, nrow=3, normalize=True)
    plt.imshow(img_grid.permute(1, 2, 0).cpu().numpy())
    plt.show()


NameError: name 'torch' is not defined