## Wasserstein GANs

### Importing the necessary libraries   

In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Critic, Generator, initialize_weights


### Hyperparameters setting (following the recommendation in the paper)

In [None]:
LEARNING_RATE = 5e-5  
BATCH_SIZE = 64
IMAGE_SIZE = 64
IMG_CHANNELS = 1
Z_DIM = 128
NUM_EPOCHS = 5
CRIT_FEAT = 64
GEN_FEAT = 64
NUM_CRITIC_ITERS = 5
WEIGHT_CLIP = 0.01

### For reproducibility...

In [None]:
MANUAL_SEED = 42
torch.manual_seed(MANUAL_SEED)
random.seed(MANUAL_SEED)
torch.use_deterministic_algorithms(mode=True)

### Loading the MNIST dataset and preprocessing

In [None]:
transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize( (0.5,), (0.5,)),
    ]
)

In [None]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dev = "cuda" if torch.cuda.is_available() else "cpu"

### Initializing the Critic and the Generator

In [None]:
G = Generator(Z_DIM, IMG_CHANNELS, GEN_FEAT).to(dev)
G.apply(initialize_weights)
C = Critic(IMG_CHANNELS, CRIT_FEAT).to(dev)
C.apply(initialize_weights)

### Optimizers

In [None]:
opt_G = optim.RMSprop(G.parameters(), lr=LEARNING_RATE)
opt_C = optim.RMSprop(C.parameters(), lr=LEARNING_RATE)
fixed_z = torch.randn(32, Z_DIM, 1, 1).to(dev)

### Config for TensorBoard

In [None]:
writer_data = SummaryWriter(f"logs/MNIST")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

### Training the Wasserstein GAN

In [None]:
G.train()
C.train()

for epoch in range(NUM_EPOCHS):
    for batch_id, (data, _) in enumerate(tqdm(dataloader)):
        data = data.to(dev)
        curr_bs = data.shape[0]

        # Train the Critic for specified number of iterations
        # Objective: max E[C(real)] - E[C(fake)]
        for _ in range(NUM_CRITIC_ITERS):
            C.zero_grad()
            z = torch.randn(curr_bs, Z_DIM, 1, 1).to(dev)
            fake = G(z)
            C_real = C(data).reshape(-1)
            C_fake = C(fake.detach()).reshape(-1)
            loss_C = - (torch.mean(C_real) - torch.mean(C_fake)) 
            loss_C.backward()
            opt_C.step()

            # Enforcing Lipschitz constraint by weight clipping
            for param in C.parameters():
                param.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)
        
        # Train the Generator
        # Objective: max E[C(fake)] <--> min - E[C(fake)]
        G.zero_grad()
        G_fake = C(fake).reshape(-1)
        loss_G = - torch.mean(G_fake)
        loss_G.backward()
        opt_G.step()

        # Printing losses and logging to TensorBoard...
        if batch_id % 100 == 0 and batch_id > 0:
            # Enter evaluation mode
            G.eval()
            C.eval()
            print(f"Epoch [{epoch+1} / {NUM_EPOCHS}] Batch [{batch_id}/ {len(dataloader)}] Loss C: {loss_C.item():.4f} Loss G: {loss_G.item():.4f}")

            with torch.no_grad():
                fake = G(fixed_z)
                im_grid_real = vutils.make_grid(data[:32], normalize=True)
                im_grid_fake = vutils.make_grid(fake[:32], normalize=True)

                writer_data.add_image("MNIST", im_grid_real, global_step=step)
                writer_fake.add_image("Generated", im_grid_fake, global_step=step)
            step += 1
            # Back to training mode
            G.train()
            C.train()

