In [29]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [30]:
class Discriminator (nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        return self.disc(x)

In [31]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh()
        )
    
    def forward(self, x):
        return self.gen(x)

In [32]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
img_dim = 28*28*1
batch_size = 32
num_epochs = 50

In [33]:

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)
fixed_noise = torch.randn(batch_size, z_dim).to(device)
transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(
    (0.5,),
    (0.5,)
  )
])
dataset = datasets.MNIST(
  root="./dataset/", 
  transform=transform,
  download=True
)
loader= DataLoader(dataset, batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen= optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"./runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"./runs/GAN_MNIST/real")
step = 0

In [34]:
for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]

    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)

    disc_real = disc(real).view(-1)
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))

    # don't want opt_disc.step() update fake, so create a detached version fake at this point
    # also the gradient that is used to update disc has nothing to do with gen
    # as the graph of fake involve gen, we have to detach fake to avoid affecting gen itself
    
    disc_fake = disc(fake.detach()).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

    lossD = (lossD_real + lossD_fake)/2
    
    disc.zero_grad()
    lossD.backward()
    opt_disc.step()

    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()

    if batch_idx == 0:
        print(
            f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                    Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
        )

        with torch.no_grad():
            fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
            data = real.reshape(-1, 1, 28, 28)
            img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
            img_grid_real = torchvision.utils.make_grid(data, normalize=True)

            writer_fake.add_image(
                "Mnist Fake Images", img_grid_fake, global_step=step
            )
            writer_real.add_image(
                "Mnist Real Images", img_grid_real, global_step=step
            )
            step += 1


Epoch [0/50] Batch 0/1875                     Loss D: 0.6798, loss G: 0.6657
Epoch [1/50] Batch 0/1875                     Loss D: 0.6073, loss G: 0.8925
Epoch [2/50] Batch 0/1875                     Loss D: 0.3536, loss G: 1.3072
Epoch [3/50] Batch 0/1875                     Loss D: 0.3749, loss G: 1.4472
Epoch [4/50] Batch 0/1875                     Loss D: 0.6495, loss G: 0.9522
Epoch [5/50] Batch 0/1875                     Loss D: 0.5699, loss G: 1.0106
Epoch [6/50] Batch 0/1875                     Loss D: 0.6386, loss G: 1.1112
Epoch [7/50] Batch 0/1875                     Loss D: 0.5836, loss G: 0.9521
Epoch [8/50] Batch 0/1875                     Loss D: 0.3807, loss G: 1.5441
Epoch [9/50] Batch 0/1875                     Loss D: 0.5309, loss G: 0.9624
Epoch [10/50] Batch 0/1875                     Loss D: 0.5184, loss G: 1.0213
Epoch [11/50] Batch 0/1875                     Loss D: 0.4975, loss G: 1.1149
Epoch [12/50] Batch 0/1875                     Loss D: 0.8079, loss G: 0.8