## Vanilla GAN Applied on MNIST

In [None]:
# Import stuff
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, utils
from torch.utils.tensorboard import SummaryWriter

# Set random seed
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Set directory for tensorboard logs
log_dir = './logs/vanilla_gan_03'
writer = SummaryWriter(log_dir)

# Use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# Define transformation to be applied to the data
compose = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,))])

# Load the MNIST data
DATA_PATH = './data/MNIST'
data = datasets.MNIST(DATA_PATH, train=True, transform=compose, download=True)

# Create data loader
batch_size = 100
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, num_workers=1)
num_batches = len(data_loader)

In [None]:
# Utility function to display an image
def imshow(image):
    image = image * 0.5 + 0.5  # un-normalize
    np_image = image.numpy()
    plt.imshow(np_image, cmap='gray')
    plt.show()
    
# Utility function to display random images from the dataset
def sampleshow(n_samples=10, samples_per_row=10):
    loader = torch.utils.data.DataLoader(data, batch_size=n_samples, shuffle=True, num_workers=1)
    images, labels = iter(loader).next()
    images_grid = utils.make_grid(images, nrow=samples_per_row).permute(1, 2, 0)
    imshow(images_grid)
    print(labels)
    
sampleshow()

In [None]:
# Define the discriminator's network
# Here, we will use a vanilla neural net
class DiscriminatorNet(nn.Module):
    def __init__(self):
        super().__init__()
        n_features = 28 * 28  # size of MNIST data
        n_out = 1  # a single number, the probability of the data being real
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 1024), 
            nn.LeakyReLU(0.2), 
            nn.Dropout(0.3)
        )
        
        self.hidden1 = nn.Sequential(
            nn.Linear(1024, 512), 
            nn.LeakyReLU(0.2), 
            nn.Dropout(0.3)
        )
        
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 256), 
            nn.LeakyReLU(0.2), 
            nn.Dropout(0.3)
        )
        
        self.out = nn.Sequential(
            nn.Linear(256, n_out), 
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

In [None]:
# Define the generator's network
# We will also use a vanilla neural network for this one
class GeneratorNet(nn.Module):
    def __init__(self):
        super().__init__()
        n_features = 100
        n_out = 28 * 28
        
        self.hidden0 = nn.Sequential(
            nn.Linear(n_features, 256), 
            nn.LeakyReLU(0.2)
        )
        
        self.hidden1 = nn.Sequential(
            nn.Linear(256, 512), 
            nn.LeakyReLU(0.2)
        )
        
        self.hidden2 = nn.Sequential(
            nn.Linear(512, 1024), 
            nn.LeakyReLU(0.2)
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, n_out), 
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

In [None]:
# Create networks
discriminator = DiscriminatorNet().to(device)
generator = GeneratorNet().to(device)

# Create optimizers
learning_rate = 0.0002
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)

# Loss function
loss = nn.BCELoss()

In [None]:
# Various utility functions
def images_to_vectors(images):
    return images.view(images.size(0), 784)

def vectors_to_images(vectors):
    return vectors.view(vectors.size(0), 1, 28, 28)

def generate_noise(size):
    return torch.randn(size, 100, device=device)

def real_target(size):
    return torch.ones(size, 1, device=device)

def fake_target(size):
    return torch.zeros(size, 1, device=device)

In [None]:
# Function to train the discriminator once
def train_discriminator(optimizer, real_data, fake_data):
    # Reset gradients
    optimizer.zero_grad()
    
    # Train on real data
    prediction_real = discriminator(real_data)
    
    # Calculate error and propagate backward
    error_real = loss(prediction_real, real_target(real_data.size(0)))
    error_real.backward(retain_graph=True)
    
    # Train on fake data
    prediction_fake = discriminator(fake_data)
    
    # Calculate error and propagate backward
    error_fake = loss(prediction_fake, fake_target(fake_data.size(0)))
    error_fake.backward()
    
    # Take optimization step
    optimizer.step()
    
    # Calculate accuracy
    accuracy = ((prediction_real >= 0.5).sum() + (prediction_fake < 0.5).sum()).double() / (len(prediction_real) + len(prediction_fake))
    
    # Return error and predictions
    return error_real + error_fake, prediction_real, prediction_fake, accuracy

# Function to train the generator once
def train_generator(optimizer, fake_data):
    # Reset gradients
    optimizer.zero_grad()
    
    # Get discriminator's prediction
    prediction = discriminator(fake_data)
    
    # Calculate error and propagate backward
    error = loss(prediction, real_target(prediction.size(0)))
    error.backward()
    
    # Take optimization step
    optimizer.step()
    
    # Return error
    return error

In [None]:
# Training parameters
n_epochs = 1000

# Autosave settings
d_save_path = 'models/vgan_discriminator_02.pth'
g_save_path = 'models/vgan_generator_02.pth'
load_parameters_before_training = True

if load_parameters_before_training:
    if os.path.exists(d_save_path):
        discriminator.load_state_dict(torch.load(d_save_path))
        print("Discriminator loaded successfully")
    if os.path.exists(g_save_path):
        generator.load_state_dict(torch.load(g_save_path))
        print("Generator loaded successfully")

# Training loop
for epoch in range(n_epochs):
    for batch_id, (real_batch,_) in enumerate(data_loader):
        # Generate real data
        real_data = images_to_vectors(real_batch).to(device)
        
        # Generate fake data
        noise = generate_noise(real_batch.size(0))
        fake_data = generator(noise).detach()
        
        # Train discriminator first
        d_error, prediction_real, prediction_fake, acc = train_discriminator(d_optimizer, real_data, fake_data)
        
        # Generate a new batch of fake data
        noise = generate_noise(real_batch.size(0))
        fake_data = generator(noise)
        
        # Train generator
        g_error = train_generator(g_optimizer, fake_data)
        
    # Log data
    print("Epoch", epoch, "\tAccuracy", acc.item())
    writer.add_scalar("discriminator_accuracy", acc, epoch)
    writer.add_scalar("discriminator_error", d_error, epoch)
    writer.add_scalar("generator_error", g_error, epoch)
        
    # Occasionally see how the generator is doing
    if epoch % 10 == 9:
        with torch.no_grad():
            noise = generate_noise(10)
            fake_data = generator(noise)
            fake_images = vectors_to_images(fake_data).cpu()
            grid = utils.make_grid(fake_images, nrow=10).permute(1, 2, 0)
            imshow(grid)
            
        # Save models' parameters
        torch.save(discriminator.state_dict(), d_save_path)
        torch.save(generator.state_dict(), g_save_path)