In [43]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from utils.visdom_utils import VisFunc

In [44]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [45]:
mb_size = 100

X_dim = 28 * 28
h_dim = 128
z_dim = 10

cnt = 0
lr = 1e-4

# MNIST Data

In [46]:
transform = transforms.ToTensor()
mnist = datasets.MNIST(root='./dataset2',
                       train=True,
                       transform=transform,
                       download=True)
mnist_loader = DataLoader(mnist, batch_size=mb_size)

In [47]:
len(mnist)

60000

In [48]:
len(mnist_loader)

600

# Define Model

In [49]:
# Generator
G = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)

# Critic 
D = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
)

In [50]:
def reset_grad():
    G.zero_grad()
    D.zero_grad()

In [51]:
# Optimizer for Generator
G_solver = optim.RMSprop(G.parameters(), lr=lr)

# Optimizer for Critic
D_solver = optim.RMSprop(D.parameters(), lr=lr)

# Visdom
env_name = 'WGAN'
vf = VisFunc(enval=env_name)

In [None]:
for epoch in range(200):
    mnist_iter = iter(mnist_loader)
    it = 0
    while it < len(mnist_loader):

        ################
        # Train Critic #
        ################
        for _ in range(5):
            # Sample data
            z = Variable(torch.randn(mb_size, z_dim))
            X, _ = next(mnist_iter)
            X = X.view(mb_size, -1)
            X = Variable(X)

            # forward-loss-backward-update
            G_sample = G(z)
            D_real = D(X)
            D_fake = D(G_sample)

            D_loss = -(torch.mean(D_real) - torch.mean(D_fake))

            D_loss.backward()
            D_solver.step()

            # Weight clipping
            for p in D.parameters():
                p.data.clamp_(-0.01, 0.01)

            # Reset gradient
            reset_grad()
            
            it += 1

        ###################
        # Train Generator #
        ###################

        # forward-loss-backward-update
        z = Variable(torch.randn(mb_size, z_dim))

        G_sample = G(z)
        D_fake = D(G_sample)

        G_loss = -torch.mean(D_fake)

        G_loss.backward()
        G_solver.step()

        # Reset gradient
        reset_grad()

        # Print and plot
        if it % 50 == 0:
            print('Epoch {} | [{}/{}] | D_loss: {:.2f} | G_loss: {:.2f}'
                  .format(epoch, it, len(mnist_loader),
                          float(D_loss.data), float(G_loss.data)))
            
            samples = G(z).data.numpy()[:100]

            fig = plt.figure(figsize=(10, 10))
            gs = gridspec.GridSpec(10, 10)
            gs.update(wspace=0.05, hspace=0.05)

            for i, sample in enumerate(samples):
                ax = plt.subplot(gs[i])
                plt.axis('off')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_aspect('equal')
                plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

            if not os.path.exists('out/'):
                os.makedirs('out/')

            plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
           # vf.imshow_multi(sample.data, nrow=10, title='', factor=1)            

            cnt += 1
            plt.close(fig)

Epoch 0 | [50/600] | D_loss: -0.96 | G_loss: 0.79
Epoch 0 | [100/600] | D_loss: -0.94 | G_loss: 0.84
Epoch 0 | [150/600] | D_loss: -0.88 | G_loss: 0.86
Epoch 0 | [200/600] | D_loss: -0.84 | G_loss: 0.89
Epoch 0 | [250/600] | D_loss: -0.74 | G_loss: 0.88
Epoch 0 | [300/600] | D_loss: -0.68 | G_loss: 0.89
Epoch 0 | [350/600] | D_loss: -0.65 | G_loss: 0.90
Epoch 0 | [400/600] | D_loss: -0.62 | G_loss: 0.90
Epoch 0 | [450/600] | D_loss: -0.61 | G_loss: 0.88
Epoch 0 | [500/600] | D_loss: -0.58 | G_loss: 0.87
Epoch 0 | [550/600] | D_loss: -0.52 | G_loss: 0.87
Epoch 0 | [600/600] | D_loss: -0.53 | G_loss: 0.87
Epoch 1 | [50/600] | D_loss: -0.48 | G_loss: 0.84
Epoch 1 | [100/600] | D_loss: -0.51 | G_loss: 0.82
Epoch 1 | [150/600] | D_loss: -0.46 | G_loss: 0.80
Epoch 1 | [200/600] | D_loss: -0.49 | G_loss: 0.80
Epoch 1 | [250/600] | D_loss: -0.42 | G_loss: 0.80
Epoch 1 | [300/600] | D_loss: -0.35 | G_loss: 0.79
