# Wasserstein GAN - WGAN
- Build Wasserstein GAN with Gradient Penalty: WGAN - GP
- It solves othe stability issues with the GANs
- The special kind of loss function W-loss and Gradient Penalties prevent mode collapse

*Fun Fact: Wasserstein is named after a mathematician at Penn State, Leonid Vaseršteĭn. You'll see it abbreviated to W (e.g. WGAN, W-loss, W-distance).*

## Generator and Critic


In [None]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
def make_grad_hook():
    '''
    Function to keep track of gradients for visalization purposes,
    which fills the grads list when using model.apply
    '''
    grads = []
    def grad_hook(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            grads.append(m.weight.grad)
        return grds, grad_hook

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )
        
    def make_gen_block(
        self, input_channels, output_channels, 
        kernel_size=3, stride=2, final_layer=False
    ):
        return nn.Sequential(
            nn.ConvTranspose2d(
                input_channels,
                output_channels,
                kernel_size=(kernel_size, kernel_size),
                stride=stride
            ),
            nn.BatchNorm2d(output_channels), nn.ReLU() if not final_layer else nn.Tanh()
        )
    
    def unsqueeze_noise(self, noise):
        '''
        Function for completing a forward pass of the genertator: Given a noise tensor
        returns a copy of that noise with width and height = 1 and channels = z_dim
        
        Parameters:
            noise: a noise tensor with dimensions (n_samples, z_dim)
        '''
        return noise.view(len(noise), self.z_dim, 1, 1)
    
    def forward(self, noise):
        x = self.unsqueeze_noise(noise)
        return self.gen(x)
    
    
def get_noise(n_samples, z_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dim (n_samples, z_dim)
    creates a tensor of theat shape filled with random numbers form the normal distribution
    '''
    return torch.randn(n_samples, z_dim, device=device)

In [None]:
class Critic(nn.Module):
    def __init__(self, im_chan=1, hidden_dim=16):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            self.make_critic_block(im_chan, hidden_dim),
            self.make_critic_block(hidden_dim, hidden_dim * 2),
            self.make_critic_block(hidden_dim * 2, 1, final_layer=True),
        )
        
    def make_critic_block(
        self, input_channels, output_channels,
        kernel_size=4, stride=2, final_layer=False
    ):
        
        if not final_layer: 
            return nn.Sequential(
                nn.Conv2d(
                    input_channels,
                    output_channels,
                    kernel_size=(kernel_size, kernel_size),
                    stride=stride
                ),
                nn.BatchNorm2d(output_channels), nn.LeakyReLU(0.2)
            )
        else:
            return nn.Sequential(
                nn.Conv2d(
                    input_channels,
                    output_channels,
                    kernel_size=(kernel_size, kernel_size),
                    stride=stride
                )
            )
    
    def forward(self, image):
        critic_pred = self.critic(image)
        return critic_pred.view(len(critic_pred), -1)

## Training Initializations
Start by setting the parameters:
  *   n_epochs: the number of times you iterate through the entire dataset when training
  *   z_dim: the dimension of the noise vector
  *   display_step: how often to display/visualize the images
  *   batch_size: the number of images per forward/backward pass
  *   lr: the learning rate
  *   beta_1, beta_2: the momentum terms
  *   c_lambda: weight of the gradient penalty
  *   crit_repeats: number of times to update the critic per generator update - there are more details about this in the *Putting It All Together* section
  *   device: the device type

You will also load and transform the MNIST dataset to tensors.

In [None]:
n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True)

In [None]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
crit = Critic().to(device)
crit_opt = torch.optim.Adam(critic.parameters(), lr=lr)

# Initialize the weights to normal distribution
# with mean 0 and standard deviation 0.02

def weights_init(m):
    if(isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if(isinstance(m, nn.BatchNorm2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
        
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

## Gradient Penalty
Calculating GP in 2 steps.  
 * Compute the gradient wrt the images
 * Compute the GP given the gradient
 
a. The gradient is created by first creating a mixed image  
b. This is done by weighing the fake and real image using epsilon and then adding together  
c. Once the interermediate image is available, critic's output of the images is obtained.  
d. Compute the gradient of the critic score's on the mixed images(output) wrt the pixels of the mixed images  


In [None]:
def get_gradient(crit, real, fake, epsilon):
    # Mix the images together
    mixed_images = real * epsilon + fake * (1 - epsilon)
    
    # Calculate the critic's scores on the mixed images
    mixed_scores = crit(mixed_images)
    
    # Take the gradient of the scores wrt the images
    gradient = torch.autograd.grad(
        inputs=mixed_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True
    )[0]
    return gradient

In [None]:
def test_get_gradient(image_shape):
    real = torch.randn(*image_shape, device=device) + 1
    fake = torch.randn(*image_shape, device=device) + 1
    epsilon_shape = [1 for _ in image_shape]
    epsilon_shape[0] = image_shape[0]
    epsilon = torch.randn(epsilon_shape, device=device).requires_grad_()
    gradient = get_gradient(crit, real, fake, epsilon)
    assert tuple(gradient.shape) == image_shape
    assert gradient.max() > 0
    assert gradient.min() < 0
    return gradient

gradient = test_get_gradient((256, 1, 28, 28))
print('Success')

**Gradient Penalty**.  
- Calculate the magnitude of each image's gradient
- The magnitude of a gradient is also called the norm
- Calculate the penalty by squaring the distance between each magnitude and the ideal norm of 1 and taking the mean of all the squared distances


In [None]:
def gradient_penalty(gradient):
    # Flatten the gradients so that each row captures one image
    gradient = gradient.view(len(gradient), -1)
    # Calculate the magnitude of every row
    gradient_norm = gradient.norm(2, dim=1)
    # Penalize the mean squared distance of the gradient norms from 1
    penalty = ((1 - gradient_norm).mean())**2
    
    return penalty

In [None]:
def test_gradient_penalty(image_shape):
    bad_gradient = torch.zeros(*image_shape)
    bad_gradient_penalty = gradient_penalty(bad_gradient)
    assert torch.isclose(bad_gradient_penalty, torch.tensor(1.0))
    
    image_size = torch.prod(torch.Tensor(image_shape[1:]))
    good_gradient = torch.ones(*image_shape) / torch.sqrt(image_size)
    good_gradient_penalty = gradient_penalty(good_gradient)
    assert torch.isclose(good_gradient_penalty, torch.tensor(0.))
    
    random_gradient = test_get_gradient(image_shape)
    random_gradient_penalty = gradient_penalty(random_gradient)
    
    assert torch.abs(random_gradient_penalty - 1) < 0.1
    
test_gradient_penalty((256, 1, 28, 29))
print('Success')

### Losses
- Calculate the loss for the generator and critic
- For generator, the loss is calculated by maximizing the critic's prediction on the generator's fake image
- The argument has the score for all fake images in the batch, but you will use the mean of the them

In [None]:
def get_gen_loss(crit_fake_pred):
    gen_loss = -crit_fake_pred.mean()
    return gen_loss

In [None]:
assert torch.isclose(
    get_gen_loss(torch.tensor(1.0)), torch.tensor(-1.0)
)

assert torch.isclose(
    get_gen_loss(torch.rand(10000)), torch.tensor(-0.5), 0.05
)
print('Success')

In [None]:
def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    crit_loss = gp * c_lambda + (-crit_real_pred + crit_fake_pred).mean()
    return crit_loss

In [None]:
assert torch.isclose(
    get_crit_loss(torch.tensor(1.), torch.tensor(2.), torch.tensor(3.), 0.1),
    torch.tensor(-0.7)
)
assert torch.isclose(
    get_crit_loss(torch.tensor(20.), torch.tensor(-20.), torch.tensor(2.), 10),
    torch.tensor(60.)
)

print("Success!")

In [None]:
import matplotlib.pyplot as plt

cur_step = 0
generator_losses = []
critic_losses = []

for epoch in range(n_epochs):
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)
        
        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            crit_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            crit_fake_pred = crit(fake.detach())
            crit_real_pred = crit(real)
            
            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient(crit, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)
            
            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            
            # Update Grdients
            crit_loss.backward(retain_graph=True)
            
            # Update Optimizer
            crit_opt.step()
        critic_losses += [mean_iteration_critic_loss]
        
        ## Update Generator
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        crit_fake_pred = crit(fake_2)
        
        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()
        
        gen_opt.step()
        
        generator_losses += [gen_loss.item()]
        
        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1