In [None]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
import torchvision
from torchvision import datasets
from torchvision.transforms import ToTensor
# from tensorflow.examples.tutorials.mnist import input_data


train_data = datasets.MNIST(root = 'data', train = True, transform = ToTensor(), download = True)
train_loader = torch.utils.data.DataLoader(train_data,batch_size=60,shuffle=True,num_workers=4)
    
test_data = datasets.MNIST(root = 'data', train = False,transform = ToTensor())
test_loader = torch.utils.data.DataLoader(test_data,batch_size=60,shuffle=True,num_workers=4)

mb_size = 60
Z_dim = 100
X_dim = 784
y_dim = 10
h_dim = 128
c = 0
lr = 1e-3


def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)


""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def G(z):
    h = nn.relu(z @ Wzh + bzh.repeat(z.size(0), 1))
    X = torch.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X


""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)


def D(X):
    h = nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = torch.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y


G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params


""" ===================== TRAINING ======================== """


def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

ones_label = Variable(torch.ones(mb_size, 1))
zeros_label = Variable(torch.zeros(mb_size, 1))

num_epochs = 300000;

iterations = 0

# max_iterations = 5
max_iterations = 20000
# max_iterations = 100000

for epoch in range(num_epochs):
    if iterations == max_iterations:
      break
    for i, (images, labels) in enumerate(train_loader): 
      iterations += 1
      if iterations == max_iterations:
        break

      # Sample data
      z = Variable(torch.randn(mb_size, Z_dim))
      X = torch.reshape(images, (60,-1)).numpy()
      X = Variable(torch.from_numpy(X))

      # Dicriminator forward-loss-backward-update
      G_sample = G(z)
      D_real = D(X)
      D_fake = D(G_sample)

      #D_loss = -torch.mean(torch.log(D_real) + torch.log(1. - D_fake)) # original loss

      D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
      D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
      D_loss = D_loss_real + D_loss_fake                                # logistic loss

      D_loss.backward()
      D_solver.step()

      # Housekeeping - reset gradient
      reset_grad()

      # Generator forward-loss-backward-update
      z = Variable(torch.randn(mb_size, Z_dim))
      G_sample = G(z)
      D_fake = D(G_sample)


      #G_loss = -torch.mean(torch.log(D_fake))                # original loss
      G_loss = nn.binary_cross_entropy(D_fake, ones_label)    # logistic loss

      G_loss.backward()
      G_solver.step()

      # Housekeeping - reset gradient
      reset_grad()

      # Print and plot every now and then
      if i % 600 == 0:
          print('Iter-{}; D_loss: {}; G_loss: {}'.format(epoch, D_loss.data.numpy(), G_loss.data.numpy()))

          samples = G(z).data.numpy()[:16]

          fig = plt.figure(figsize=(4, 4))
          gs = gridspec.GridSpec(4, 4)
          gs.update(wspace=0.05, hspace=0.05)

          for j, sample in enumerate(samples):
              ax = plt.subplot(gs[j])
              plt.axis('off')
              ax.set_xticklabels([])
              ax.set_yticklabels([])
              ax.set_aspect('equal')
              plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

          if not os.path.exists('out/'):
              os.makedirs('out/')

          plt.savefig('out/{}.png'.format(str(c).zfill(3)), bbox_inches='tight')
          c += 1
          plt.close(fig)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



  cpuset_checked))


Iter-0; D_loss: 1.4182612895965576; G_loss: 2.3820767402648926
Iter-0; D_loss: 0.005114791914820671; G_loss: 6.779626369476318
Iter-1; D_loss: 0.02093389257788658; G_loss: 6.862981796264648
Iter-1; D_loss: 0.01241007074713707; G_loss: 7.49234676361084
Iter-2; D_loss: 0.04508471488952637; G_loss: 5.199045181274414
Iter-2; D_loss: 0.07008163630962372; G_loss: 5.072390079498291
Iter-3; D_loss: 0.060113683342933655; G_loss: 5.187243461608887
Iter-3; D_loss: 0.18031595647335052; G_loss: 5.2485456466674805
Iter-4; D_loss: 0.11321131885051727; G_loss: 4.099161148071289
Iter-4; D_loss: 0.29186373949050903; G_loss: 4.003844738006592
Iter-5; D_loss: 0.2055610716342926; G_loss: 3.9275574684143066
Iter-5; D_loss: 0.24596476554870605; G_loss: 5.049686908721924
Iter-6; D_loss: 0.4847845435142517; G_loss: 3.660613536834717
Iter-6; D_loss: 0.3791835606098175; G_loss: 3.9514334201812744
Iter-7; D_loss: 0.6734856963157654; G_loss: 3.1086151599884033
Iter-7; D_loss: 0.5885113477706909; G_loss: 2.48404741

In [None]:
!zip -r /content/file.zip /content/out

  adding: content/out/ (stored 0%)
  adding: content/out/015.png (deflated 6%)
  adding: content/out/034.png (deflated 6%)
  adding: content/out/003.png (deflated 9%)
  adding: content/out/000.png (deflated 8%)
  adding: content/out/021.png (deflated 7%)
  adding: content/out/013.png (deflated 6%)
  adding: content/out/035.png (deflated 6%)
  adding: content/out/016.png (deflated 6%)
  adding: content/out/023.png (deflated 6%)
  adding: content/out/006.png (deflated 6%)
  adding: content/out/010.png (deflated 6%)
  adding: content/out/019.png (deflated 6%)
  adding: content/out/001.png (deflated 8%)
  adding: content/out/030.png (deflated 6%)
  adding: content/out/002.png (deflated 9%)
  adding: content/out/020.png (deflated 6%)
  adding: content/out/026.png (deflated 6%)
  adding: content/out/033.png (deflated 7%)
  adding: content/out/029.png (deflated 6%)
  adding: content/out/017.png (deflated 6%)
  adding: content/out/036.png (deflated 6%)
  adding: content/out/025.png (deflated 6

In [None]:
!rm -rf out