<a href="https://colab.research.google.com/github/boppana-tejkiran/Genarative-Adversarial-Networks-Practice/blob/main/GAN_Model_basic_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip uninstall -q tensorboard tb-nightly
!pip install tb-nightly

Proceed (y/n)? y
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tb-nightly
  Using cached tb_nightly-2.12.0a20230101-py3-none-any.whl (5.7 MB)
Installing collected packages: tb-nightly
Successfully installed tb-nightly-2.12.0a20230101


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
%load_ext tensorboard
from torch.utils.tensorboard import SummaryWriter 

In [4]:
import tempfile
log_dir = tempfile.mkdtemp() 
%tensorboard --logdir {log_dir} --reload_interval 1

<IPython.core.display.Javascript object>

The Generator and Discriminator

In [5]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim,128),
        nn.LeakyReLU(0.1),
        nn.Linear(128,1),
        nn.Sigmoid(),
    )

  def forward(self,x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self,z_dim, img_dim):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim,256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim),
        nn.Tanh(),
    )

  def forward(self,z):
    return self.gen(z)

GANs are sensitive to hyperparameters

In [6]:
#Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64 # 128, 256
image_dim = 28 * 28 * 1 # 784
batch_size = 32
num_epochs = 15

In [7]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5),(0.5))] # transforms.Normalize takes mean, std
)

Load the dataset

In [8]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download = True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Initialise optimizers for generator & descriminator
and loss function

In [9]:
opt_disc = optim.Adam(disc.parameters(), lr = lr)
opt_gen = optim.Adam(gen.parameters(), lr = lr)
criterian = nn.BCELoss()
# WRITER_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
# WRITER_real = SummaryWriter(f"runs/GAN_MNIST/real")
WRITER_fake = SummaryWriter(log_dir + '/fake', flush_secs=1)
WRITER_real = SummaryWriter(log_dir + '/real', flush_secs=1)
step = 0

Train loop

In [None]:
for epoch in range(num_epochs):
  for batch_idx, (real,_) in enumerate(loader):
    real = real.view(-1,784).to(device)
    batch_size = real.shape[0]

    ### Train Discriminator: max log(D(real)) + log(1-D(G(z)))
    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1)
    lossD_real = criterian(disc_real, torch.ones_like(disc_real))
    disc_fake = disc(fake.detach()).view(-1)
    lossD_fake = criterian(disc_fake, torch.zeros_like(disc_fake))
    lossD = (lossD_real + lossD_fake)/2
    lossD.backward() # or lossD.backward(retain_graph = True)
    opt_disc.step()

    ### train Generator min log(1-D(G(z))) ==> max log(D(G(z)))
    output = disc(fake).view(-1)
    lossG = criterian(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()

    ### End of training setup ###
    ### Code for TensorBoard ###
    if batch_idx == 0:
      print(
          f"Epoch [{epoch}/{num_epochs}] "\
          f"Loss D: {lossD: .4f}, Loss G: {lossG: .4f}"
      )

      with torch.no_grad():
        fake = gen(fixed_noise).reshape(-1,1,28,28)
        data = real.reshape(-1,1,28,28)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize = True)
        img_grid_real = torchvision.utils.make_grid(data, normailze = True)

        WRITER_fake.add_image(
            "Mnist Fake Images", img_grid_fake, global_step = step
        )
        WRITER_real.add_image(
            "Mnist real images", img_grid_real, global_step = step
        )
        step += 1

Epoch [0/15] Loss D:  0.7244, Loss G:  0.6658
Epoch [1/15] Loss D:  100.0000, Loss G:  0.0000
Epoch [2/15] Loss D:  100.0000, Loss G:  0.0000
Epoch [3/15] Loss D:  100.0000, Loss G:  0.0000
Epoch [4/15] Loss D:  100.0000, Loss G:  0.0000
Epoch [5/15] Loss D:  100.0000, Loss G:  0.0000
Epoch [6/15] Loss D:  100.0000, Loss G:  0.0000
Epoch [7/15] Loss D:  100.0000, Loss G:  0.0000
Epoch [8/15] Loss D:  100.0000, Loss G:  0.0000
Epoch [9/15] Loss D:  100.0000, Loss G:  0.0000
