In [1]:
import os, sys
sys.path.append(os.getcwd())

import time

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import sklearn.datasets

import tflib as lib
import tflib.save_images
import tflib.mnist
import tflib.plot

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
if use_cuda:
    gpu = 0

DIM = 64 # Model dimensionality
BATCH_SIZE = 50 # Batch size
CRITIC_ITERS = 5 # For WGAN and WGAN-GP, number of critic iters per gen iter
LAMBDA = 10 # Gradient penalty lambda hyperparameter
ITERS = 200000 # How many generator iterations to train for
OUTPUT_DIM = 784 # Number of pixels in MNIST (28*28)

lib.print_model_settings(locals().copy())

Uppercase local vars:
	BATCH_SIZE: 50
	CRITIC_ITERS: 5
	DIM: 64
	F: <module 'torch.nn.functional' from '/home/kadarsh22/miniconda3/lib/python3.6/site-packages/torch/nn/functional.py'>
	ITERS: 200000
	LAMBDA: 10
	OUTPUT_DIM: 784


In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        preprocess = nn.Sequential(
            nn.Linear(128, 4*4*4*DIM),nn.BatchNorm1d(4*4*4*DIM),
            nn.ReLU(True),
        )
        block1 = nn.Sequential(
            nn.ConvTranspose2d(4*DIM, 2*DIM, 5),nn.BatchNorm2d(2*DIM),
            nn.ReLU(True),
        )
        block2 = nn.Sequential(
            nn.ConvTranspose2d(2*DIM, DIM, 5),nn.BatchNorm2d(DIM),
            nn.ReLU(True),
        )
        deconv_out = nn.ConvTranspose2d(DIM, 1, 8, stride=2)
        

        self.block1 = block1
        self.block2 = block2
        self.deconv_out = deconv_out
        self.preprocess = preprocess
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
    def forward(self, input):
        output = self.preprocess(input)
        output = output.view(-1, 4*DIM, 4, 4)
        #print output.size()
        output = self.block1(output)
        #print output.size()
        output = output[:, :, :7, :7]
        #print output.size()
        output = self.block2(output)
        #print output.size()
        output = self.deconv_out(output)
        output = self.tanh(output)
        #print output.size()
        return output.view(-1, OUTPUT_DIM)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        main = nn.Sequential(
            nn.Conv2d(1, DIM, 5, stride=2, padding=2),
            # nn.Linear(OUTPUT_DIM, 4*4*4*DIM),
            nn.LeakyReLU(True),
            nn.Conv2d(DIM, 2*DIM, 5, stride=2, padding=2),
            nn.BatchNorm2d(2*DIM),
            # nn.Linear(4*4*4*DIM, 4*4*4*DIM),
            nn.LeakyReLU(True),
            nn.Conv2d(2*DIM, 4*DIM, 5, stride=2, padding=2),
            # nn.Linear(4*4*4*DIM, 4*4*4*DIM),
            nn.BatchNorm2d(4*DIM),
            # nn.Linear(4*4*4*DIM, 4*4*4*DIM),
            nn.LeakyReLU(True),
            # nn.Linear(4*4*4*DIM, 4*4*4*DIM),
            # nn.LeakyReLU(True),
            # nn.Linear(4*4*4*DIM, 4*4*4*DIM),
            # nn.LeakyReLU(True),
        )
        self.main = main
        self.output = nn.Linear(4*4*4*DIM, 1)

    def forward(self, input):
        input = input.view(-1, 1, 28, 28)
        out = self.main(input)
        out = out.view(-1, 4*4*4*DIM)
        out = self.output(out)
        return out.view(-1)

In [4]:
def generate_image(frame, netG):
    noise = torch.randn(BATCH_SIZE, 128)
    if use_cuda:
        noise = noise.cuda(gpu)
    noisev = autograd.Variable(noise, volatile=True)
    samples = netG(noisev)
    samples = samples.view(BATCH_SIZE, 28, 28)
    # print samples.size()

    samples = samples.cpu().data.numpy()

    lib.save_images.save_images(
        samples,
        'results/mnist/samples/samples_{}.png'.format(frame)
    )


In [5]:
# Dataset iterator
train_gen, dev_gen, test_gen = lib.mnist.load(BATCH_SIZE, BATCH_SIZE)
def inf_train_gen():
    while True:
        for images,targets in train_gen():
            yield images

In [6]:
def calc_gradient_penalty(netD, real_data, fake_data):
    #print real_data.size()
    alpha = torch.rand(BATCH_SIZE, 1)
    alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda(gpu) if use_cuda else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    if use_cuda:
        interpolates = interpolates.cuda(gpu)
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    disc_interpolates = netD(interpolates)

    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda(gpu) if use_cuda else torch.ones(
                                  disc_interpolates.size()),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
    return gradient_penalty

In [7]:
netG = Generator()
netD = Discriminator()
print(netG)

Generator(
  (block1): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (block2): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (deconv_out): ConvTranspose2d(64, 1, kernel_size=(8, 8), stride=(2, 2))
  (preprocess): Sequential(
    (0): Linear(in_features=128, out_features=4096, bias=True)
    (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (sigmoid): Sigmoid()
  (tanh): Tanh()
)


In [8]:
if use_cuda:
    netD = netD.cuda(gpu)
    netG = netG.cuda(gpu)

optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9))

one = torch.ones([])
mone = one * -1
if use_cuda:
    one = one.cuda(gpu)
    mone = mone.cuda(gpu)

data = inf_train_gen()

In [9]:
for iteration in range(ITERS):
    start_time = time.time()
    ############################
    # (1) Update D network
    ###########################
    for p in netD.parameters():  # reset requires_grad
        p.requires_grad = True  # they are set to False below in netG update

    for iter_d in range(CRITIC_ITERS):
        _data = data.__next__()
        real_data = torch.Tensor(_data)
        if use_cuda:
            real_data = real_data.cuda(gpu)
        real_data_v = autograd.Variable(real_data)

        netD.zero_grad()

        # train with real
        D_real = netD(real_data_v)
        D_real = D_real.mean()
        # print D_real
        D_real.backward(mone)

        # train with fake
        noise = torch.randn(BATCH_SIZE, 128)
        if use_cuda:
            noise = noise.cuda(gpu)
        noisev = autograd.Variable(noise, volatile=True)  # totally freeze netG
        fake = autograd.Variable(netG(noisev).data)
        inputv = fake
        D_fake = netD(inputv)
        D_fake = D_fake.mean()
        D_fake.backward(one)

        # train with gradient penalty
        gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data)
        gradient_penalty.backward()

        D_cost = D_fake - D_real + gradient_penalty
        Wasserstein_D = D_real - D_fake
        optimizerD.step()
    ############################
    # (2) Update G network
    ###########################
    for p in netD.parameters():
        p.requires_grad = False  # to avoid computation
    netG.zero_grad()

    noise = torch.randn(BATCH_SIZE, 128)
    if use_cuda:
        noise = noise.cuda(gpu)
    noisev = autograd.Variable(noise)
    fake = netG(noisev)
    G = netD(fake)
    G = G.mean()
    G.backward(mone)
    G_cost = -G
    optimizerG.step()
#   print('Iteration: {}.  D_Loss: {}. G_Loss: {}'.format(iteration, D_cost, G_cost))
    lib.plot.plot('results/mnist/plots/time_iter', time.time() - start_time)
    lib.plot.plot('results/mnist/plots/train disc cost', D_cost.cpu().data.numpy())
    lib.plot.plot('results/mnist/plots/train gen cost', G_cost.cpu().data.numpy())
    lib.plot.plot('results/mnist/plots/wasserstein distance', Wasserstein_D.cpu().data.numpy())

#   Write logs and save samples
 
#   Calculate dev loss and generate samples every 1000 iters
    if iteration % 1000 == 0:
        dev_disc_costs = []
        for images,_ in dev_gen():
            imgs = torch.Tensor(images)
            if use_cuda:
                imgs = imgs.cuda(gpu)
            imgs_v = autograd.Variable(imgs, volatile=True)

            D = netD(imgs_v)
            _dev_disc_cost = -D.mean().cpu().data.numpy()
            dev_disc_costs.append(_dev_disc_cost)
        dev_loss = sum(dev_disc_costs)/len(dev_disc_costs)
        print('Iteration: {}. Validation_Loss: {}'.format(iteration, dev_loss))        
        lib.plot.plot('results/mnist/plots/dev disc cost', np.mean(dev_disc_costs))
        generate_image(iteration, netG)
        lib.plot.flush()
        
    lib.plot.tick()
print("Training finish!... save training results")



Iteration: 0. Validation_Loss: -4.0468001568317415


  """


Iteration: 1000. Validation_Loss: -2014.830036010742
Iteration: 2000. Validation_Loss: -6074.14125
Iteration: 3000. Validation_Loss: -12162.165717773438
Iteration: 4000. Validation_Loss: -20286.821484375
Iteration: 5000. Validation_Loss: -30441.624033203127
Iteration: 6000. Validation_Loss: -42636.28169921875
Iteration: 7000. Validation_Loss: -56870.359453125
Iteration: 8000. Validation_Loss: -73121.011171875
Iteration: 9000. Validation_Loss: -91423.325625
Iteration: 10000. Validation_Loss: -111748.00078125
Iteration: 11000. Validation_Loss: -134101.76609375
Iteration: 12000. Validation_Loss: -158514.27453125
Iteration: 13000. Validation_Loss: -184960.16828125
Iteration: 14000. Validation_Loss: -213414.247734375
Iteration: 15000. Validation_Loss: -243929.59046875
Iteration: 16000. Validation_Loss: -276469.38203125
Iteration: 17000. Validation_Loss: -311078.89828125
Iteration: 18000. Validation_Loss: -347613.7075
Iteration: 19000. Validation_Loss: -386383.50484375
Iteration: 20000. Vali

Iteration: 168000. Validation_Loss: -28465399.3
Iteration: 169000. Validation_Loss: -28733850.13
Iteration: 170000. Validation_Loss: -28998013.92
Iteration: 171000. Validation_Loss: -29251608.22
Iteration: 172000. Validation_Loss: -29454380.53
Iteration: 173000. Validation_Loss: -29500669.83
Iteration: 174000. Validation_Loss: -29780641.24
Iteration: 175000. Validation_Loss: -30043379.84
Iteration: 176000. Validation_Loss: -30157988.7
Iteration: 177000. Validation_Loss: -30514754.83
Iteration: 178000. Validation_Loss: -30742734.32
Iteration: 179000. Validation_Loss: -30959887.44
Iteration: 180000. Validation_Loss: -31169848.72
Iteration: 181000. Validation_Loss: -31476902.72
Iteration: 182000. Validation_Loss: -31703561.36
Iteration: 183000. Validation_Loss: -31961219.96
Iteration: 184000. Validation_Loss: -32264136.82
Iteration: 185000. Validation_Loss: -32610422.07
Iteration: 186000. Validation_Loss: -32955719.31
Iteration: 187000. Validation_Loss: -33308402.32
Iteration: 188000. Val

In [10]:
torch.save(netG.state_dict(), "results/mnist/model/generator_param.pkl")
torch.save(netD.state_dict(), "results/mnist/model/discriminator_param.pkl")
