In [12]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
import tensorflow as tf
import torch.nn.functional as F
import torchvision 
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
import copy
import numpy as np
import matplotlib.pyplot as plt

In [15]:
train_dataset_mnist = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset_mnist = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

train_dataset_mnist, val_dataset_mnist = random_split(train_dataset_mnist, [55000, 5000])

batch_size = 64

train_loader_mnist = DataLoader(train_dataset_mnist,
                        batch_size=batch_size,
                        shuffle=True)

val_loader_mnist = DataLoader(val_dataset_mnist,
                        batch_size=batch_size,
                        shuffle=True)

test_loader_mnist = DataLoader(test_dataset_mnist,
                        batch_size=batch_size,
                        shuffle=False)

X_temp, y_temp = next(iter(train_loader_mnist))

print(X_temp[2], y_temp.shape)

mb_size = 64
Z_dim = 100
X_dim = 28*28
h_dim = 128
c = 0
lr = 1e-3


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)


""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def G(z):
    h = F.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = torch.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X


""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)


def D(X):
    h = F.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = torch.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y


G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params


""" ===================== TRAINING ======================== """


def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

ones_label = Variable(torch.ones(mb_size, 1))
zeros_label = Variable(torch.zeros(mb_size, 1))


for it in range(100000): 
    # Sample data
    z = Variable(torch.randn(mb_size, Z_dim))
    X, _ = next(iter(train_loader_mnist))
    X = X.view(-1, 784)
    X = Variable(X)

    # Dicriminator forward-loss-backward-update
    G_sample = G(z)
    D_real = D(X)
    D_fake = D(G_sample)

    #D_loss_real = F.binary_cross_entropy(D_real, ones_label)
    #D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
    #D_loss = D_loss_real + D_loss_fake

    D_loss = -torch.mean(torch.log(D_real) + torch.log(1. - D_fake))
    
    D_loss.backward()
    D_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Generator forward-loss-backward-update
    z = Variable(torch.randn(mb_size, Z_dim))
    G_sample = G(z)
    D_fake = D(G_sample)

    #G_loss = F.binary_cross_entropy(D_fake, ones_label)
    G_loss = -torch.mean(torch.log(D_fake))

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    # Print and plot every now and then
    if it % 1000 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), G_loss.data.numpy()))

        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('original100k/'):
            os.makedirs('original100k/')

        plt.savefig('original100k/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
        c += 1
        plt.close(fig)

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,

Iter-12000; D_loss: 0.711113452911377; G_loss: 2.401885509490967
Iter-13000; D_loss: 0.7005552649497986; G_loss: 2.4120309352874756
Iter-14000; D_loss: 0.6816340088844299; G_loss: 2.4768097400665283
Iter-15000; D_loss: 0.8199543356895447; G_loss: 2.2225887775421143
Iter-16000; D_loss: 0.6136703491210938; G_loss: 2.3213813304901123
Iter-17000; D_loss: 0.5786684155464172; G_loss: 2.3334426879882812
Iter-18000; D_loss: 0.5918607711791992; G_loss: 1.89707350730896
Iter-19000; D_loss: 0.6433102488517761; G_loss: 2.105242967605591
Iter-20000; D_loss: 0.699147641658783; G_loss: 2.294600486755371
Iter-21000; D_loss: 0.5932392477989197; G_loss: 2.135253667831421
Iter-22000; D_loss: 0.6063427925109863; G_loss: 1.9805829524993896
Iter-23000; D_loss: 0.686589241027832; G_loss: 2.307128667831421
Iter-24000; D_loss: 0.6429980993270874; G_loss: 1.9668774604797363
Iter-25000; D_loss: 0.7193344831466675; G_loss: 2.2326083183288574
Iter-26000; D_loss: 0.7233313918113708; G_loss: 2.3596489429473877
Iter-