Import libraries

In [3]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter 

Build Discriminator

In [22]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
      
      return self.disc(x)


Build generator

In [23]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [24]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50

In [29]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = torchvision.transforms.Compose(
    [
         torchvision.transforms.ToTensor(),
         torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)

Loading MNIST data

In [30]:
dataset= datasets.MNIST(root='datasets/', transform= transforms, download= True)
loader= DataLoader(dataset, batch_size= batch_size, shuffle= True)
opt_disc= optim.Adam(disc.parameters(), lr=lr) 
opt_gen= optim.Adam(gen.parameters(), lr=lr)
criterion= nn.BCELoss()
writer_fake= SummaryWriter(f'logs/fake')
writer_real= SummaryWriter(f'logs/real')
step=0

In [31]:
for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]
    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    lossD = (lossD_real + lossD_fake) / 2
    disc.zero_grad()
    lossD.backward(retain_graph=True)
    opt_disc.step()
    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()
    
    if batch_idx == 0:
      print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )
      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
        data = real.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)

        writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
        writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
        step += 1



Epoch [0/50] Batch 0/1875                       Loss D: 0.6628, loss G: 0.7239
Epoch [1/50] Batch 0/1875                       Loss D: 0.7037, loss G: 0.6988
Epoch [2/50] Batch 0/1875                       Loss D: 0.4396, loss G: 0.9919
Epoch [3/50] Batch 0/1875                       Loss D: 0.2622, loss G: 1.4652
Epoch [4/50] Batch 0/1875                       Loss D: 0.9185, loss G: 0.8080
Epoch [5/50] Batch 0/1875                       Loss D: 0.3769, loss G: 1.6426
Epoch [6/50] Batch 0/1875                       Loss D: 0.7855, loss G: 0.7495
Epoch [7/50] Batch 0/1875                       Loss D: 0.5725, loss G: 1.2666
Epoch [8/50] Batch 0/1875                       Loss D: 0.4778, loss G: 1.3271
Epoch [9/50] Batch 0/1875                       Loss D: 0.7277, loss G: 0.8378
Epoch [10/50] Batch 0/1875                       Loss D: 0.7234, loss G: 0.8935
Epoch [11/50] Batch 0/1875                       Loss D: 0.6834, loss G: 1.1836
Epoch [12/50] Batch 0/1875                       L