<a href="https://colab.research.google.com/github/kaiju8/GANs-Implemented/blob/main/GANs_Simple.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid(),
    )
  def forward(self, x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim),
        nn.Tanh(), 
    )
  def forward(self, x):
    return self.gen(x)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

GANs are very sensitive to hyperparameters

In [None]:
lr = 3e-4 #vary 
z_dim = 64 #vary
img_dim = 28 * 28 * 1
batch_size = 32
num_epochs = 50

In [None]:
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)

In [None]:
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
train_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))]
)
dataset = datasets.MNIST(root = "dataset", transform = train_transforms, download = True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle = True)

In [None]:
opt_disc = optim.Adam(disc.parameters(), lr = lr)
opt_gen = optim.Adam(gen.parameters(), lr = lr)

In [None]:
criterion = nn.BCELoss()

In [None]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

In [None]:
for epoch in range(num_epochs):
  for batch_idx, (real, _) in enumerate(loader):
    real = real.view(-1, 784).to(device)
    batch_Size = real.shape[0]

    #Discriminator max log(D(real)) + log(1 - D(G(z)))

    noise = torch.randn((batch_size, z_dim)).to(device)
    fake = gen(noise)

    disc_real = disc(real).view(-1)
    loss_real = criterion(disc_real, torch.ones_like(disc_real))

    disc_fake = disc(fake).view(-1)# detach for generator stuff or a
    loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

    loss_D = (loss_real + loss_fake)/2

    disc.zero_grad()
    loss_D.backward(retain_graph = True)# a
    opt_disc.step()

    #Discriminator min log(1 - D(G(z))) but better to max log(D(G(z)))

    output = disc(fake).view(-1)
    loss_G = criterion(output, torch.ones_like(output))

    gen.zero_grad()
    loss_G.backward(retain_graph = True)# a
    opt_gen.step()


###################################################
    if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {loss_D:.4f}, loss G: {loss_G:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1
####################################################

Epoch [0/50] Batch 0/1875                       Loss D: 0.6723, loss G: 0.6845
Epoch [1/50] Batch 0/1875                       Loss D: 0.1768, loss G: 1.9352
Epoch [2/50] Batch 0/1875                       Loss D: 0.4047, loss G: 1.2208
Epoch [3/50] Batch 0/1875                       Loss D: 0.4969, loss G: 1.0770
Epoch [4/50] Batch 0/1875                       Loss D: 0.8932, loss G: 0.6541
Epoch [5/50] Batch 0/1875                       Loss D: 0.5001, loss G: 1.0841
Epoch [6/50] Batch 0/1875                       Loss D: 0.7088, loss G: 1.0133
Epoch [7/50] Batch 0/1875                       Loss D: 0.8211, loss G: 0.7327
Epoch [8/50] Batch 0/1875                       Loss D: 0.5968, loss G: 0.9356
Epoch [9/50] Batch 0/1875                       Loss D: 0.5716, loss G: 0.9195
Epoch [10/50] Batch 0/1875                       Loss D: 0.6963, loss G: 0.9466
Epoch [11/50] Batch 0/1875                       Loss D: 0.5493, loss G: 1.2704
Epoch [12/50] Batch 0/1875                       L