In [1]:
# Load libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import make_grid, save_image
from tqdm import tqdm
import pdb
import numpy as np
import imageio
from google.colab import drive
import os

In [9]:
drive.mount('/content/gdrive',force_remount=True)

os.chdir('gdrive/MyDrive/Colab Notebooks/masterclass')
print(os.getcwd())

Mounted at /content/gdrive


In [14]:
# Load MNIST training data
bs = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), 
                         std=(0.5))
    ])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/train-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./mnist_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./mnist_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./mnist_data/MNIST/raw



In [15]:
# Train Wasserstein GAN

class TrainWGAN():
    def __init__(self, generator, critic, generator_optimizer, critic_optimizer,
                 latent_dim=100, n_critic=5, gamma=10, device='cpu'):

        self.to_pil_image = transforms.ToPILImage()

        self.device = device
        self.generator = generator
        self.critic = critic
        self.G_opt = generator_optimizer
        self.C_opt = critic_optimizer

        self.generator.train()
        self.critic.train()

        self.latent_dim = latent_dim
        self.n_critic = n_critic
        self.gamma = gamma
        self.steps = 0
        self.fixed_z = torch.randn(64, self.latent_dim).to(self.device)

    def train(self, data_loader, n_epochs):
        """Train generator and critic"""

        images = []
        generator_loss = []
        critic_loss = []
        gradient_penalty = []
        for epoch in range(1, n_epochs + 1):

            # Train one step
            g_loss, c_loss, grad_penalty = self.train_epoch(data_loader)

            print(f'Epoch: {epoch}, g_loss: {g_loss:.3f},', end=' ')
            print(f'c_loss: {c_loss:.3f}, grad_penalty: {grad_penalty:.3f}')

            # Save loss
            generator_loss.append(g_loss)
            critic_loss.append(c_loss)
            gradient_penalty.append(grad_penalty)

            # Save generated images
            generated_img = self.generator(self.fixed_z)
            generated_img = make_grid(generated_img)
            images.append(generated_img)
            self.save_generator_image(generated_img,f"outputs_WGAN/gen_img{epoch}.png")

        
        # save the generated images as GIF file
        imgs = [np.array(self.to_pil_image(img)) for img in images]
        imageio.mimsave('outputs_WGAN/generator_images.gif', imgs)

        # Save generator and critic weights
        torch.save(self.generator.state_dict(), 'model_weights/generator')
        torch.save(self.critic.state_dict(), 'model_weights/critic')
        

        self.generator.train(mode=False)
        self.critic.train(mode=False)

        return generator_loss, critic_loss, gradient_penalty

    def train_epoch(self, data_loader):
        """Train generator and critic for one epoch"""

        for bidx, (real_data, _) in  tqdm(enumerate(data_loader),
                 total=int(len(data_loader.dataset)/data_loader.batch_size)):

            current_batch_size = len(real_data)

            real_data = real_data.to(self.device)

            c_loss, grad_penalty = self.critic_train_step(real_data)

            if bidx % self.n_critic == 0:
              g_loss = self.generator_train_step(current_batch_size)

        return g_loss, c_loss, grad_penalty

    def critic_train_step(self, data):
        """Train critic one step"""

        batch_size = data.size(0)
        generated_data = self.sample(batch_size)
        grad_penalty = self.gradient_penalty(data, generated_data)
        c_loss = self.critic(generated_data).mean() - self.critic(data).mean() + grad_penalty
        self.C_opt.zero_grad()
        c_loss.backward()
        self.C_opt.step()
        return c_loss.item(), grad_penalty.item()

    def generator_train_step(self, batch_size):
        """Train generator one step"""

        self.G_opt.zero_grad()
        generated_data = self.sample(batch_size)
        g_loss = -self.critic(generated_data).mean()
        g_loss.backward()
        self.G_opt.step()

        return g_loss.item()

    def gradient_penalty(self, data, generated_data, gamma=10):
        """Compute gradient penalty"""

        batch_size = data.size(0)
        epsilon = torch.rand(batch_size, 1, 1, 1)
        epsilon = epsilon.expand_as(data).to(self.device)

        interpolation = epsilon * data.data + (1 - epsilon) * generated_data
        interpolation = torch.autograd.Variable(interpolation, requires_grad=True)
        interpolation = interpolation.to(self.device)

        interpolation_critic_score = self.critic(interpolation)

        grad_outputs = torch.ones(interpolation_critic_score.size())
        grad_outputs = grad_outputs.to(self.device)

        gradients = torch.autograd.grad(outputs=interpolation_critic_score,
                                        inputs=interpolation,
                                        grad_outputs=grad_outputs,
                                        create_graph=True,
                                        retain_graph=None)[0]

        gradients = gradients.view(batch_size, -1)
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
        return self.gamma * ((gradients_norm - 1) ** 2).mean()

    def sample(self, n_samples):
        """Generate n_samples fake samples"""

        z = torch.randn(n_samples, self.latent_dim).to(self.device)
        return self.generator(z)

    def save_generator_image(self, image, path):
        """Save image"""

        save_image(image, path)

In [16]:
# Define generator
class Generator(nn.Module):
  def __init__(self, latent_dim, num_out_channels=1):
    super(Generator, self).__init__()
    self.activation = nn.LeakyReLU()
    self.latent_dim = latent_dim

    hidden_channels = [256, 128, 64, 32]
    self.hidden_channels = hidden_channels

    self.fc = nn.Linear(latent_dim, hidden_channels[0]*7*7)
    self.trans_conv1 = nn.ConvTranspose2d(hidden_channels[0], hidden_channels[1], kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
    self.trans_conv1_bn = nn.BatchNorm2d(hidden_channels[1])
    self.trans_conv2 = nn.ConvTranspose2d(hidden_channels[1], hidden_channels[2], kernel_size = 3, stride = 1, padding = 1)
    self.trans_conv2_bn = nn.BatchNorm2d(hidden_channels[2])
    self.trans_conv3 = nn.ConvTranspose2d(hidden_channels[2], hidden_channels[3], kernel_size = 3, stride = 1, padding = 1)
    self.trans_conv3_bn = nn.BatchNorm2d(hidden_channels[3])
    self.trans_conv4 = nn.ConvTranspose2d(hidden_channels[3], num_out_channels, kernel_size = 3, stride = 2, padding = 1, output_padding = 1, bias=False)

  # forward method
  def forward(self, x): 
    x = self.fc(x)
    x = self.activation(x)
    x = x.view(-1, self.hidden_channels[0], 7, 7)
    x = self.trans_conv1(x)
    x = self.activation(x)
    x = self.trans_conv1_bn(x)
    x = self.trans_conv2(x)
    x = self.activation(x)
    x = self.trans_conv2_bn(x)
    x = self.trans_conv3(x)
    x = self.activation(x)
    x = self.trans_conv3_bn(x)
    x = self.trans_conv4(x)
    return x

# Define critic
class Discriminator(nn.Module):
  def __init__(self, num_in_channels=1):
    super(Discriminator, self).__init__()
    self.activation = nn.LeakyReLU()

    hidden_channels = [32, 64, 128, 256]

    self.conv0 = nn.Conv2d(num_in_channels, hidden_channels[0], kernel_size = 3, stride = 2, padding = 1)
    self.conv1 = nn.Conv2d(hidden_channels[0], hidden_channels[1], kernel_size = 3, stride = 1, padding = 1)
    self.conv2 = nn.Conv2d(hidden_channels[1], hidden_channels[2], kernel_size = 3, stride = 1, padding = 1)
    self.conv3 = nn.Conv2d(hidden_channels[2], hidden_channels[3], kernel_size = 3, stride = 2, padding = 1)
    self.fc = nn.Linear(12544, 1, bias=False)
    
    # forward method
  def forward(self, x):
    x = self.conv0(x)
    x = self.activation(x)
    x = self.conv1(x)
    x = self.activation(x)
    x = self.conv2(x)
    x = self.activation(x)
    x = self.conv3(x)
    x = self.activation(x)

    x = x.view(-1, 12544)
    x = self.fc(x)
    return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# build network
latent_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

generator = Generator(latent_dim = latent_dim, num_out_channels = 1).to(device)
critic = Discriminator(num_in_channels=1).to(device)

# optimizer
lr = 0.0002 
generator_optimizer = optim.Adam(generator.parameters(), lr = lr)
critic_optimizer = optim.Adam(critic.parameters(), lr = lr)

# Trainer
training_params = {
    'generator': generator,
    'critic': critic,
    'generator_optimizer': generator_optimizer, 
    'critic_optimizer': critic_optimizer,
    'latent_dim': latent_dim, 
    'n_critic': 5, 
    'gamma': 10, 
    'device': device
}

trainer = TrainWGAN(**training_params)

trainer.train(
  data_loader=train_loader,
  n_epochs=10
  )

100%|██████████| 600/600 [00:43<00:00, 13.87it/s]


Epoch: 1, g_loss: -19.700, c_loss: -25.115, grad_penalty: 7.797


100%|██████████| 600/600 [00:45<00:00, 13.30it/s]


Epoch: 2, g_loss: -20.013, c_loss: -17.704, grad_penalty: 4.341


100%|██████████| 600/600 [00:43<00:00, 13.80it/s]


Epoch: 3, g_loss: -21.581, c_loss: -13.552, grad_penalty: 2.768


100%|██████████| 600/600 [00:43<00:00, 13.81it/s]


Epoch: 4, g_loss: -19.970, c_loss: -10.569, grad_penalty: 1.725


100%|██████████| 600/600 [00:43<00:00, 13.85it/s]


Epoch: 5, g_loss: -7.644, c_loss: -9.430, grad_penalty: 1.486


100%|██████████| 600/600 [00:43<00:00, 13.82it/s]


Epoch: 6, g_loss: -0.949, c_loss: -8.365, grad_penalty: 1.208


100%|██████████| 600/600 [00:43<00:00, 13.83it/s]


Epoch: 7, g_loss: 3.129, c_loss: -7.943, grad_penalty: 1.246


 11%|█▏        | 68/600 [00:04<00:38, 13.95it/s]