This source code is a modified version of a source code from https://github.com/wiseodd/generative-models.git
directory of the source code: wiseodd/generative-models/GAN/WGAN/wgan_pytorch.py
What's modified:
- to put in a .ipynb file.
- to download the MNIST dataset by using torchvision.datasets because tensorflow_datasets was not working for some reason. Accordingly, the way how tloading the dataset, a
- the algorithm remains the same. 

In [44]:
#!pip3 install tensorflow
!pip3 install torchvision



In [68]:
import torch
import torch.nn
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
#from tensorflow.examples.tutorials.mnist import input_data

mb_size = 32
z_dim = 10

train_data = datasets.MNIST(
    root = 'data',
    train = True,                         
    transform = ToTensor(), 
    download = True,            
)
test_data = datasets.MNIST(
    root = 'data', 
    train = False, 
    transform = ToTensor()
)

loaders = {
    'train' : torch.utils.data.DataLoader(train_data, 
                                          batch_size=mb_size, 
                                          shuffle=True, 
                                          num_workers=1),
    
    #'test'  : torch.utils.data.DataLoader(test_data, 
    #                                      batch_size=mb_size, 
    #                                      shuffle=True, 
    #                                      num_workers=1),
}

for _, (image, label) in enumerate(loaders['train']):
    X_dim = image.shape[2] **2
    break

h_dim = 128
cnt = 0
lr = 1e-4

In [69]:
# Neural Networks

# Generator
G = torch.nn.Sequential(
    torch.nn.Linear(z_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, X_dim),
    torch.nn.Sigmoid()
)

# Discriminator
D = torch.nn.Sequential(
    torch.nn.Linear(X_dim, h_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(h_dim, 1),
)


def reset_grad():
    G.zero_grad()
    D.zero_grad()

# Optimizer
G_solver = optim.RMSprop(G.parameters(), lr=lr)
D_solver = optim.RMSprop(D.parameters(), lr=lr)

In [74]:
# Training 
epoch = 50
for it in range(epoch):
    for i, (images, labels) in enumerate(loaders['train']):
        
        # Sample data
        z = Variable(torch.randn(mb_size, z_dim))
        X = Variable(images.reshape(-1,X_dim))
        
        if i % 5 != 4:
            
            # Dicriminator forward-loss-backward-update
            G_sample = G(z)
            D_real = D(X)
            D_fake = D(G_sample)

            D_loss = -(torch.mean(D_real) - torch.mean(D_fake))

            D_loss.backward()
            D_solver.step()

            # Weight clipping
            for p in D.parameters():
                p.data.clamp_(-0.01, 0.01)

            # Housekeeping - reset gradient
            reset_grad()

        else:

            G_sample = G(z)
            D_fake = D(G_sample)

            G_loss = -torch.mean(D_fake)

            G_loss.backward()
            G_solver.step()

            # Housekeeping - reset gradient
            reset_grad()

            # Print and plot every now and then
            if i  % 1000 == 4:
                print('Iter-{}; D_loss: {}; G_loss: {}'
                      .format(i + it*(1875)+ 1, D_loss.data.numpy(), G_loss.data.numpy()))

                samples = G(z).data.numpy()[:16]

                fig = plt.figure(figsize=(4, 4))
                gs = gridspec.GridSpec(4, 4)
                gs.update(wspace=0.05, hspace=0.05)

                for i, sample in enumerate(samples):
                    ax = plt.subplot(gs[i])
                    plt.axis('off')
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.set_aspect('equal')
                    plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

                if not os.path.exists('out/'):
                    os.makedirs('out/')

                plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
                cnt += 1
                plt.close(fig)


Iter-5; D_loss: -0.02831856533885002; G_loss: -0.027151141315698624
Iter-1005; D_loss: -0.026682481169700623; G_loss: -0.0255219005048275
Iter-1880; D_loss: -0.025646761059761047; G_loss: -0.026162322610616684
Iter-2880; D_loss: -0.023585669696331024; G_loss: -0.024375615641474724
Iter-3755; D_loss: -0.021438511088490486; G_loss: -0.02396930567920208
Iter-4755; D_loss: -0.026750894263386726; G_loss: -0.02500130981206894
Iter-5630; D_loss: -0.026280775666236877; G_loss: -0.023443173617124557
Iter-6630; D_loss: -0.024505916982889175; G_loss: -0.02435639314353466
Iter-7505; D_loss: -0.025206351652741432; G_loss: -0.02089271880686283
Iter-8505; D_loss: -0.026752745732665062; G_loss: -0.020642220973968506
Iter-9380; D_loss: -0.026810502633452415; G_loss: -0.022096745669841766
Iter-10380; D_loss: -0.02128356322646141; G_loss: -0.02303585410118103
Iter-11255; D_loss: -0.0246658306568861; G_loss: -0.022723644971847534
Iter-12255; D_loss: -0.019124532118439674; G_loss: -0.022257039323449135
Ite