In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


In [10]:
class Discriminator(nn.Module):
  def __init__(self, dim_img):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(dim_img, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.disc(x)


class Generator(nn.Module):
  def __init__(self, z_dim, dim_img):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, dim_img),
        nn.Tanh(),
    )

  def forward(self, x):
    return self.gen(x)


# Hyperparameters

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4
z_dim = 64
image_dim = 28*28*1
batch_size = 32
num_epochs = 50
step = 0

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

optim_disc = optim.Adam(disc.parameters(), lr=lr)
optim_gen = optim.Adam(gen.parameters(), lr=lr)

loss_fn = nn.BCELoss()

fixed_noise = torch.randn((batch_size, z_dim)).to(device)
writer_fake = SummaryWriter(f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')

# Dataset

In [12]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root='dataset/', transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Traning

In [13]:
for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784).to(device)
    batch_size = real.shape[0]

    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)

    disc_real = disc(real).view(-1)
    loss_disc_real = loss_fn(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake).view(-1)
    loss_disc_fake = loss_fn(disc_fake, torch.zeros_like(disc_fake))
    loss_disc = (loss_disc_real + loss_disc_fake) / 2

    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    optim_disc.step()

    output = disc(fake).view(-1)
    loss_gen = loss_fn(output, torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    optim_gen.step()

    if batch_idx == 0:
      print(
          f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
      )

      with torch.no_grad():
          fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
          data = real.reshape(-1, 1, 28, 28)
          img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
          img_grid_real = torchvision.utils.make_grid(data, normalize=True)

          writer_fake.add_image(
              "Mnist Fake Images", img_grid_fake, global_step=step
          )
          writer_real.add_image(
              "Mnist Real Images", img_grid_real, global_step=step
          )
          step += 1

Epoch [0/50] Batch 0/1875                 Loss D: 0.6792, loss G: 0.7016
Epoch [1/50] Batch 0/1875                 Loss D: 0.6110, loss G: 0.9259
Epoch [2/50] Batch 0/1875                 Loss D: 0.8985, loss G: 0.8038
Epoch [3/50] Batch 0/1875                 Loss D: 0.6823, loss G: 0.8772
Epoch [4/50] Batch 0/1875                 Loss D: 0.5659, loss G: 0.9544
Epoch [5/50] Batch 0/1875                 Loss D: 0.7036, loss G: 0.8397
Epoch [6/50] Batch 0/1875                 Loss D: 0.6986, loss G: 1.1053
Epoch [7/50] Batch 0/1875                 Loss D: 0.4132, loss G: 1.4153
Epoch [8/50] Batch 0/1875                 Loss D: 0.3886, loss G: 1.3847
Epoch [9/50] Batch 0/1875                 Loss D: 0.4367, loss G: 1.7541
Epoch [10/50] Batch 0/1875                 Loss D: 0.7652, loss G: 1.1256
Epoch [11/50] Batch 0/1875                 Loss D: 0.9415, loss G: 0.7985
Epoch [12/50] Batch 0/1875                 Loss D: 0.5282, loss G: 1.0730
Epoch [13/50] Batch 0/1875                 Loss 