In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from scipy.misc import imsave
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data.sampler import Sampler, BatchSampler
from torch.nn.modules.loss import MSELoss

In [None]:
from model import Generator, Discriminator

In [None]:
input_size = 784
num_classes = 10
batch_size = 64
test_batch_size = 64

train_dataset = dsets.MNIST(root='./MNIST/', 
                            train=True, 
                            transform=transforms.Compose([transforms.ToTensor(),
                                                         transforms.Lambda(lambda x: (x - 0.5) * 2)]),
                            download=True)

test_dataset = dsets.MNIST(root='./MNIST/', 
                           train=False, 
                           transform= transforms.Compose([transforms.ToTensor(),
                                                         transforms.Lambda(lambda x: (x - 0.5) * 2)]),
                          )


train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=False, drop_last=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=test_batch_size,         
                                          shuffle=False, drop_last=True)

In [None]:
import torch
from torch import autograd
from norms import sobolev_norm, lp_norm

l2_norm = 2
dual_norm = 1 / (1 - 1 / l2_norm) if l2_norm != 1 else np.inf

c_for_sobolev = 5.0
s_for_sobolev = 0

def calculate_gradient_penalty(discriminator, images, gen_images):
        epsilon = torch.FloatTensor(images.size(0), 1, 1, 1).uniform_(0., 1.).cuda()
        epsilon = epsilon.expand(images.size(0), images.size(1), images.size(2), images.size(3))
        
        x_hat = epsilon * images + ((1 - epsilon) * gen_images)
        x_hat = Variable(x_hat, requires_grad=True)

        prob_x_hat = discriminator(x_hat)
        gradients = autograd.grad(outputs=prob_x_hat, inputs=x_hat,
                                  grad_outputs=torch.ones_like(prob_x_hat).cuda(),
                                  create_graph=True, retain_graph=True)[0]
        
        dual_sobolev_gradients = sobolev_norm(gradients, s=-s_for_sobolev, c=c_for_sobolev)
        gradients_norm = lp_norm(dual_sobolev_gradients, p=dual_norm)
        
        lambda_ = lp_norm(sobolev_norm(images,  s=s_for_sobolev, c=c_for_sobolev),
                              p=l2_norm).mean()
        gamma_ = lp_norm(sobolev_norm(images, s=-s_for_sobolev, c=c_for_sobolev),
                              p=dual_norm).mean()
        
        prob_images = discriminator(images)
        
        grad_penalty = ((gradients_norm.float().cuda() / gamma_.float().cuda()  - 1) ** 2).mean() * lambda_.float().cuda()  +\
                       1e-5 * (prob_images.float().cuda()  ** 2).mean()            
        return grad_penalty, gamma_

In [None]:
def test(generator, discriminator, test_batch_generator):
    generator.eval()
    batch_num, images = test_batch_generator.__next__()
    z = torch.randn((batch_size, noise_size)).cuda()
    images, z = Variable(images.cuda()), Variable(z)
    
    gen_images = generator(z)
    fake_loss = discriminator(gen_images).mean() 
    real_loss = discriminator(images).mean()

    gradient_penalty, gamma = calculate_gradient_penalty(discriminator, images.data, gen_images.data)
        
    wasserstein_loss = (fake_loss - real_loss) / gamma

    g_loss = (gen_images).mean() / gamma
    d_loss = - wasserstein_loss + gradient_penalty
    
    return g_loss.item(), d_loss.item(), gradient_penalty.item()

  
def plot_history(train_history, val_history, title='loss'):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    plt.title('{}'.format(title))
    plt.plot(train_history['g_loss'], label='train_generator_loss', zorder=1)
    plt.plot(train_history['d_loss'], label='train_discrimintor_loss', zorder=1)
    points_g = np.array(val_history['g_loss'])
    points_d = np.array(val_history['d_loss'])
    plt.scatter(points_g[:, 0], points_g[:, 1], marker='+', s=180, c='orange', label='val_generator', zorder=2)
    plt.scatter(points_d[:, 0], points_d[:, 1], marker='+', s=180, c='red', label='val_discriminator', zorder=2)
    plt.xlabel('train steps')
    plt.legend(loc='best')
    plt.grid()
    
    plt.subplot(1, 2, 2)
    plt.title('Gradient Penalty')
    plt.plot(train_history['grad_pen'], label='gradient_penalty', zorder=1)
    points_d = np.array(val_history['grad_pen'])
    plt.scatter(points_g[:, 0], points_g[:, 1], marker='+', s=180, c='orange', label='val_gradient_penalty', zorder=2)
    plt.xlabel('train steps')
    plt.legend(loc='best')
    plt.grid()
    
    plt.show()

In [None]:
from tqdm import trange
from torch.autograd import Variable

max_iters = 100000
val_freq = 10
num_disc_iters = 5
noise_size = 128
channels = 1
img_size = 28

def generate_batches(train_loader):
    while True:
        for batch_num, (x_batch_base, _) in zip(trange(len(train_loader)), train_loader):
            yield batch_num, x_batch_base.float()
            
#Models
generator = Generator(noise_size, channels, img_size).cuda()
discriminator = Discriminator(noise_size, channels).cuda()

lr = 1e-4
beta1 = 0.
beta2 = 0.9
    
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

# Data-Generator
data = generate_batches(train_loader)
test_data = generate_batches(test_loader)

train_log = {'g_loss': [], 'd_loss' : [], 'grad_pen' : []}
val_log = {'g_loss': [], 'd_loss' : [], 'grad_pen' : []}
for global_iter in range(max_iters):
    generator.eval()
    for d_iter in range(num_disc_iters):
        batch_num, images = data.__next__()
        z = torch.randn((batch_size, noise_size)).cuda()
        images, z = Variable(images.cuda()), Variable(z)
        
        discriminator.zero_grad()
        gen_images = generator(z)
        fake_loss = discriminator(gen_images).mean()
        fake_loss.backward()
        
        real_loss = discriminator(images).mean()
        real_loss.backward(torch.FloatTensor([-1]).cuda())

        gradient_penalty, gamma = calculate_gradient_penalty(discriminator, images.data, gen_images.data)
        gradient_penalty.backward()
        
        wasserstein_loss = (fake_loss - real_loss) / gamma

        g_loss = (gen_images).mean() / gamma
        d_loss = - wasserstein_loss + gradient_penalty
        
        optimizer_D.step()
        
        
        train_log['g_loss'] += [g_loss.item()]
        train_log['d_loss'] += [d_loss.item()]
        train_log['grad_pen'] += [gradient_penalty.item()]
            #clear_output(True)
            #print ("\n[Global iter %d/%d] [Discr iter %d/%d] [D loss: %f] [G adv: %f] [Gradient Penalty: %f]" 
            #       % (global_iter, max_iters, d_iter, num_disc_iters, 
            #          d_loss.item(), g_loss.item(), gradient_penalty.item()))
    
    generator.train()
    generator.zero_grad()
            
    z = Variable(torch.randn(batch_size, noise_size)).cuda()
    gen_images = generator(z)
    g_loss = discriminator(gen_images).mean()
    g_loss.backward(torch.FloatTensor([-1]).cuda())
            
    optimizer_G.step()
   
    if global_iter % val_freq == val_freq - 1:
        g_loss_log, d_loss_log, gradient_penalty_log = test(generator, discriminator, test_data)
        val_log['g_loss'] += [(global_iter * (num_disc_iters), np.mean(g_loss_log))]
        val_log['d_loss'] += [(global_iter * (num_disc_iters), np.mean(d_loss_log))]
        val_log['grad_pen'] += [(global_iter * (num_disc_iters), np.mean(gradient_penalty_log))]
        clear_output()
        plot_history(train_log, val_log)
        
        numpy_images = gen_images.cpu().detach().numpy()
        plt.figure(figsize=(15, 3))
        
        for i in range(5):
            plt.subplot(1, 5, i + 1)
            plt.imshow(numpy_images[i].reshape((28, 28)), cmap="gray")
        plt.show()

In [None]:
def get_big_image(images_64):
    res = []
    for i in range(8):
        res.append([x for x in images[i * 8:(i + 1) * 8]])
    return np.block(res)

In [None]:
numpy_gen_images = gen_images.cpu().detach().numpy()
grid_images = get_big_image(numpy_gen_images[:64]).squeeze()
imsave("fake.png", grid_images)
plt.imshow(grid_images, cmap="gray")

In [None]:
numpy_images = images.cpu().detach().numpy()
grid_images = get_big_image(numpy_images[:64]).squeeze()
imsave("real.png", grid_images)
plt.imshow(grid_images, cmap="gray")

In [None]:
torch.save(generator.state_dict(), "generator.weights")
torch.save(discriminator.state_dict(), "discriminator.weights")

In [None]:
# To resore weights
# generator = Generator(noise_size, channels, img_size).cuda()
# generator.load_state_dict(torch.load("generator.weights"))
# generator.eval()
# discriminator = Discriminator(noise_size, channels).cuda()
# discriminator.load_state_dict(torch.load("discriminator.weights"))