In [None]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'gen_samples'

# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [None]:
# Image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5,),   # 3 for RGB channels
                                     std=(0.5,))])

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|█████████▉| 9871360/9912422 [00:46<00:00, 214244.23it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw



0it [00:00, ?it/s][A

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz



  0%|          | 0/28881 [00:00<?, ?it/s][A
 57%|█████▋    | 16384/28881 [00:00<00:00, 87407.45it/s][A
32768it [00:00, 40563.69it/s]                           [A
0it [00:00, ?it/s][A

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



  0%|          | 0/1648877 [00:00<?, ?it/s][A
  1%|          | 16384/1648877 [00:00<00:21, 75816.01it/s][A
  2%|▏         | 40960/1648877 [00:00<00:18, 86998.99it/s][A
  4%|▍         | 73728/1648877 [00:00<00:15, 102642.90it/s][A
  7%|▋         | 114688/1648877 [00:01<00:12, 122286.48it/s][A
  9%|▉         | 147456/1648877 [00:01<00:11, 134648.84it/s][A
 11%|█▏        | 188416/1648877 [00:01<00:09, 152477.60it/s][A
 14%|█▍        | 229376/1648877 [00:01<00:08, 168045.47it/s][A
 16%|█▋        | 270336/1648877 [00:01<00:07, 180935.82it/s][A
 19%|█▉        | 319488/1648877 [00:02<00:06, 199892.54it/s][A
 22%|██▏       | 360448/1648877 [00:02<00:06, 205665.79it/s][A
 25%|██▍       | 409600/1648877 [00:02<00:05, 220547.49it/s][A
 27%|██▋       | 442368/1648877 [00:02<00:05, 204036.73it/s][A
 30%|███       | 499712/1648877 [00:02<00:05, 228537.42it/s][A
 32%|███▏      | 532480/1648877 [00:03<00:05, 210355.83it/s][A
 35%|███▍      | 573440/1648877 [00:03<00:05, 213605.97it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz




  0%|          | 0/4542 [00:00<?, ?it/s][A[A

8192it [00:00, 21678.38it/s]            [A[A

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


In [None]:
# Discriminator
# The output of D is no longer a probability, we do not apply sigmoid at the output of D.
wD = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1)
    )

# Generator 
wG = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

# Device setting
wD = wD.to(device)
wG = wG.to(device)

# Wasserstein distance loss and WGAN values from paper
weight_cliping_limit = 0.01
critic_iter = 5

# WGAN with gradient clipping uses RMSprop instead of ADAM
wg_optimizer = torch.optim.RMSprop(wG.parameters(), lr=0.00005)
wd_optimizer = torch.optim.RMSprop(wD.parameters(), lr=0.00005)

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [None]:
def get_infinite_batches(data_loader):
    while True:
        for i, (images, _) in enumerate(data_loader):
            yield images

In [None]:
# Start training
total_step = 600*200
data = get_infinite_batches(data_loader)
for currstep in range(total_step):        
    
    # ================================================================== #
    #                      Train the discriminator                       #
    # ================================================================== #
    
    # Requires grad
    for p in wD.parameters():
        p.requires_grad = True
    
    for d_iter in range(critic_iter):    
        wd_optimizer.zero_grad()   
            
        images = data.__next__().to(device)
        images = images.reshape(batch_size, -1).to(device)
    
        # Train with real images
        outputs = wD(images)
        wd_loss_real = outputs.mean(0).view(1)
        
        # Train with fake images        
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = wG(z)
        outputs = wD(fake_images)
        wd_loss_fake = outputs.mean(0).view(1)
        
        # For displaying the trends
        wd_loss = wd_loss_fake - wd_loss_real
        wd_loss.backward()
        Wasserstein_D = wd_loss_real - wd_loss_fake
        
        # Backprop and optimize             
        wd_optimizer.step()

        # Clamp parameters to a range [-c, c], c=weight_cliping_limit
        for p in wD.parameters():
            p.data.clamp_(-weight_cliping_limit, weight_cliping_limit)
            
    # ================================================================== #
    #                        Train the generator                         #
    # ================================================================== #

    for p in wD.parameters():
        p.requires_grad = False  # to avoid computation

    # Compute loss with fake images
    z = torch.randn(batch_size, latent_size).to(device)
    fake_images = wG(z)
    outputs = wD(fake_images)
    wg_loss = -outputs.mean().mean(0).view(1)        
    wg_optimizer.zero_grad()    
    wg_loss.backward()
    
    # Backprop and optimize            
    wg_optimizer.step()

    if (currstep+1) % 200 == 0:
        print('Curr_step [{}/{}], wd_loss: {:.4f}, wg_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
              .format(currstep+1, total_step, wd_loss.item(), -wg_loss.item(), 
                      wd_loss_real.item(), wd_loss_fake.item()))
    if (currstep+1) % 600 == 0:
        epoch = (currstep+1) // 600
        # Save sampled images
        fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
        save_image(denorm(fake_images), os.path.join(sample_dir, 'wgan_fake_images-{}.png'.format(epoch)))
        print('Saved generated images..!')

    # Save real images
    if (currstep+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
        print('Saved real images..!')
    

# Save the model checkpoints 
torch.save(wG.state_dict(), 'wG.ckpt')
torch.save(wD.state_dict(), 'wD.ckpt')

Saved real images..!



1654784it [00:22, 249646.42it/s]                             [A

Curr_step [200/120000], wd_loss: -0.0877, wg_loss: 6.2192, D(x): 6.32, D(G(z)): 6.23
Curr_step [400/120000], wd_loss: 0.0014, wg_loss: 4.2816, D(x): 4.29, D(G(z)): 4.29
Curr_step [600/120000], wd_loss: -0.1408, wg_loss: 1.3731, D(x): 1.49, D(G(z)): 1.35
Saved generated images..!
Curr_step [800/120000], wd_loss: -0.2073, wg_loss: 2.1858, D(x): 2.39, D(G(z)): 2.18
Curr_step [1000/120000], wd_loss: -0.2183, wg_loss: 1.9515, D(x): 2.19, D(G(z)): 1.97
Curr_step [1200/120000], wd_loss: -0.1442, wg_loss: 1.6466, D(x): 1.79, D(G(z)): 1.65
Saved generated images..!
Curr_step [1400/120000], wd_loss: -0.1834, wg_loss: 1.0261, D(x): 1.22, D(G(z)): 1.03
Curr_step [1600/120000], wd_loss: -0.1572, wg_loss: 0.7849, D(x): 0.95, D(G(z)): 0.80
Curr_step [1800/120000], wd_loss: -0.2074, wg_loss: 0.8336, D(x): 1.03, D(G(z)): 0.83
Saved generated images..!
Curr_step [2000/120000], wd_loss: -0.1764, wg_loss: 0.8617, D(x): 1.04, D(G(z)): 0.87
Curr_step [2200/120000], wd_loss: -0.1797, wg_loss: 0.7067, D(x): 0