In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.discriminator = nn.Sequential(
        nn.Linear(in_features=img_dim, out_features=128),
        nn.LeakyReLU(negative_slope=0.1),
        nn.Linear(in_features=128, out_features=1),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.discriminator(x)

class Generator(nn.Module):
  def __init__(self, noise_dim, img_dim):
    super().__init__()
    self.generator = nn.Sequential(
        nn.Linear(in_features=z_dim, out_features=256),
        nn.LeakyReLU(negative_slope=0.1),
        nn.Linear(in_features=256, out_features=img_dim),
        nn.Tanh(),
    )
  
  def forward(self, x):
    return self.generator(x)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr_rate = 3e-04
z_dim = 64    
img_dim = 28*28*1
batch_size = 32
epochs = 50

discriminator = Discriminator(img_dim).to(device)
generator = Generator(z_dim, img_dim).to(device)

noise_fixed = torch.randn((batch_size, z_dim)).to(device)
transformations = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
)

datasets = datasets.MNIST(root="dataset/", transform=transformations, download=True)
loader = DataLoader(datasets, batch_size=batch_size, shuffle=True)

opt_discriminator = optim.Adam(discriminator.parameters(), lr = lr_rate)
opt_generator = optim.Adam(generator.parameters(), lr = lr_rate)
criterion = nn.BCELoss()

writer_fake = SummaryWriter(f"fake_GAN/")
writer_real = SummaryWriter(f"real_GAN/")

step = 0
for epoch in range(epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]
    #Discriminator training
    noise = torch.randn((batch_size, z_dim)).to(device)
    fake = generator(noise)
    disc_real = discriminator(real).view(-1)
    loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    disc_fake = discriminator(fake.detach()).view(-1)
    loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

    loss_disc = (loss_disc_real + loss_disc_fake)/2
    discriminator.zero_grad()
    loss_disc.backward(retain_graph=True)
    opt_discriminator.step()

    #Generator training
    output = discriminator(fake).view(-1)
    loss_generator = criterion(output, torch.ones_like(output))
    generator.zero_grad()
    loss_generator.backward()
    opt_generator.step()

    if batch_idx == 0:
      print(
          f"Epoch [{epoch}/{epochs}] \ "
          f"Loss Discriminator: {loss_disc: .4f}, Loss Generator: {loss_generator: .4f}"
      )
      with torch.no_grad():
        fake = generator(noise_fixed).reshape(-1, 1, 28, 28)
        data = real.reshape(-1, 1, 28, 28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        img_grid_real = torchvision.utils.make_grid(data, normalize=True)

        writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
        writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
        step += 1

Epoch [0/50] \ Loss Discriminator:  0.6779, Loss Generator:  0.6651
Epoch [1/50] \ Loss Discriminator:  0.5781, Loss Generator:  0.9516
Epoch [2/50] \ Loss Discriminator:  0.3800, Loss Generator:  1.4256
Epoch [3/50] \ Loss Discriminator:  0.9031, Loss Generator:  0.6232
Epoch [4/50] \ Loss Discriminator:  0.9864, Loss Generator:  0.6459
Epoch [5/50] \ Loss Discriminator:  0.4381, Loss Generator:  1.3398
Epoch [6/50] \ Loss Discriminator:  0.3738, Loss Generator:  1.3335
Epoch [7/50] \ Loss Discriminator:  0.7732, Loss Generator:  0.8697
Epoch [8/50] \ Loss Discriminator:  0.4661, Loss Generator:  1.8147
Epoch [9/50] \ Loss Discriminator:  0.6080, Loss Generator:  0.9118
Epoch [10/50] \ Loss Discriminator:  0.8070, Loss Generator:  0.7885
Epoch [11/50] \ Loss Discriminator:  0.6315, Loss Generator:  1.1614
Epoch [12/50] \ Loss Discriminator:  0.7285, Loss Generator:  1.2860
Epoch [13/50] \ Loss Discriminator:  0.5687, Loss Generator:  1.1223
Epoch [14/50] \ Loss Discriminator:  0.5841,