<a href="https://colab.research.google.com/github/guswns3396/ICME-2023/blob/main/GANs_Exercise_3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Link for the dataset used: https://drive.google.com/file/d/1ByPqKC5f9F8ZiJHR5uPMLuCoELdWhwKz/view?usp=sharing

In [None]:
import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

import torchvision
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

import numpy as np
import matplotlib.pyplot as plt

import zipfile
from pathlib import Path
import os
import requests

In [None]:
def get_data(file_name):
  data_path = Path('.')
  image_path = data_path/'Data'
  if image_path.is_dir():
    print("Data Directory Exists. Skipping Download.")
  else:
    print("Data Directory being created. Downloading.")
    image_path.mkdir(parents=True, exist_ok=True)

  with zipfile.ZipFile(data_path / file_name, "r") as zip_ref:
    print(data_path / file_name)
    print("Unzipping")
    zip_ref.extractall(image_path)

In [None]:
def get_device():
  if torch.cuda.is_available():
    return torch.device('cuda')
  else:
    return torch.device('cpu')

In [None]:
gen_dir = 'GenerationHistory'
os.makedirs(gen_dir, exist_ok=True)

In [None]:
get_data("AnimeFacesDatasetKaggle.zip")

Data Directory being created. Downloading.


FileNotFoundError: ignored

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
device = get_device()
print(device)

In [None]:
IMAGE_SIZE = 64
BATCH_SIZE = 64
SCALER = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
NUM_WORKERS = os.cpu_count()
LATENT_DIM = 128

In [None]:
transform=transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(*SCALER)])

In [None]:
path = "./Data"
dataset = datasets.ImageFolder(path, transform=transform)

In [None]:
dl = DataLoader(dataset, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

In [None]:
def unscale(scaled_images):
    return scaled_images * SCALER[1][0] + SCALER[0][0]

In [None]:
def show_images(images, n_max=64):
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.set_xticks([])
    ax.set_yticks([])
    normal_images = unscale(images.detach()[:n_max])
    ax.imshow( make_grid(normal_images, nrow=8).permute(1, 2, 0) )

def show_batch(dl, n_max=64):
  images, _ = next(iter(dl))
  show_images(images, n_max)

In [None]:
show_batch(dl)

In [None]:
discriminator = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),   #(3,64,64) -> (64,32,32)
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False), #(64,32,32) -> (128, 16, 16)
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),#(128, 16, 16) -> (256, 8,8)
    nn.BatchNorm2d(256),
    nn.ReLU(),
    nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),#(256,8,8) -> (512, 4,4)
    nn.BatchNorm2d(512),
    nn.ReLU(),
    nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),  #(512,4,4) -> (1,1,1)
    nn.Flatten(),
    nn.Sigmoid()
)

In [None]:
generator = nn.Sequential(
    nn.ConvTranspose2d(LATENT_DIM, 512, kernel_size=4, stride=1, padding=0, bias=False), #(128,1,1) -> (512,4,4)
    nn.BatchNorm2d(512),
    nn.ReLU(),
    nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),        #(512, 4,4 ) -> (256, 8,8)
    nn.BatchNorm2d(256),
    nn.ReLU(),
    nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),        #(256,8,8) -> (128, 16,16)
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),         #(128, 16, 16) -> (64, 32,32)
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),           #(64,32,32) -> (3,64,64)
    nn.Tanh()
)

In [None]:
discriminator = discriminator.to(device)
generator = generator.to(device)

In [None]:
latent_vector_batch = torch.randn((64,LATENT_DIM,1,1)).to(device)
fake_images = generator(latent_vector_batch)
print(fake_images.shape)
fake_preds = discriminator(fake_images)
print(fake_preds.shape)
show_images(fake_images.cpu())

In [None]:
def train_discriminator(real_images, opt_d):
  batch_size = real_images.shape[0]
  real_targets = torch.ones(batch_size, 1).to(device)

  latent_vectors = torch.randn(batch_size, LATENT_DIM, 1, 1).to(device)
  fake_images = generator(latent_vectors)
  fake_targets = torch.zeros(batch_size, 1).to(device)

  opt_d.zero_grad()
  fake_preds = discriminator(fake_images)
  real_preds = discriminator(real_images)

  fake_loss = nn.functional.binary_cross_entropy(fake_preds, fake_targets)
  real_loss = nn.functional.binary_cross_entropy(real_preds, real_targets)

  loss = real_loss + fake_loss
  loss.backward()
  opt_d.step()

  return loss.detach().cpu().item()

In [None]:
def train_generator(opt_g):
  latent_vectors = torch.randn(BATCH_SIZE, LATENT_DIM, 1, 1).to(device)
  fake_images = generator(latent_vectors)
  fake_targets = torch.ones(BATCH_SIZE, 1).to(device)

  opt_g.zero_grad()
  outputs = discriminator(fake_images)
  loss = nn.functional.binary_cross_entropy(outputs, fake_targets)
  loss.backward()
  opt_g.step()

  return loss.detach().cpu().item()

In [None]:
def save_images(idx, latent_vectors):
  fake_images = generator(latent_vectors)
  fake_name = F"FakeImages_{idx}.png"
  save_image(unscale(fake_images), os.path.join(gen_dir, fake_name), nrow=8)

In [None]:
fixed_latent = torch.randn(64, LATENT_DIM, 1, 1, device=device)

In [None]:
save_images(0, fixed_latent)

In [None]:
def train(num_epochs, lr):
  torch.cuda.empty_cache()

  gen_losses = []
  disc_losses = []

  opt_d = Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
  opt_g = Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

  for epoch in range(num_epochs):
    gen_losses_batch = []
    disc_losses_batch = []
    for real_images, _ in dl:
      real_images = real_images.to(device)
      disc_loss = train_discriminator(real_images, opt_d)
      gen_loss = train_generator(opt_g)
      gen_losses_batch.append(gen_loss)
      disc_losses_batch.append(disc_loss)

    gen_loss_epoch = torch.tensor(gen_losses_batch).mean()
    disc_loss_epoch = torch.tensor(disc_losses_batch).mean()

    print(f"Epoch: {epoch+1} \t GenLoss: {gen_loss_epoch.item()} \t DiscLoss: {disc_loss_epoch.item()}")

    gen_losses.append(gen_loss_epoch.item())
    disc_losses.append(disc_loss_epoch.item())

    save_images(epoch+1, fixed_latent)

  return gen_losses, disc_losses

In [None]:
lr = 0.0002
num_epochs = 10

In [None]:
gen_losses, disc_losses = train(num_epochs, lr)

In [None]:
real_images, _ = next(iter(dl))
real_images = real_images.to(device)
batch_size = real_images.shape[0]
real_targets = torch.ones(batch_size, 1).to(device)

latent_vectors = torch.randn(batch_size, LATENT_DIM, 1, 1).to(device)
fake_images = generator(latent_vectors)
fake_targets = torch.zeros(batch_size, 1).to(device)

all_images = torch.cat([real_images, fake_images], dim=0)
all_targets = torch.cat([real_targets, fake_targets],  dim=0)
all_images.shape, all_targets.shape

##Exercise

We have heard so much about how sensitive GANs are while training. Let's see a simpler example of this in real life.

In the discriminator training step, I'm passing all my fake examples and fake targets, and gettting the loss over this set...then I'm doing the same for the real images and real targets, and getting the loss.

Any self-respecting pytonista would be scandalized by this profligacy. There's a python (and programming in general) rule, DRY...Dont Repeat Yourself.

In keeping with this, I should concatenate the real and fake images and the real and fake targets and pass this big batch through the discriminator...thus, simplifying the operation to just one single pass.

Please try this. Look at what happens in terms of the generator output and the losses...and explain them.