In [2]:
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

In [3]:
epochs, batch_size  = 100, 64
lr, b1, b2 = 2e-4, 0.5, 0.999
latent_dim = 100
img_size = 28
channels = 1
n_critic = 5
lambda_gp = 10
img_shape = (channels, img_size, img_size)
if torch.cuda.is_available(): 
    print("Train on GPU \nCUDA is available")
    cuda = True 
else:
    print("Train on the CPU \nCUDA is not available")
    cuda = False

Train on GPU 
CUDA is available


In [4]:
os.makedirs("data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Resize(img_size), 
            transforms.ToTensor(), 
            transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        flat_img = img.view(img.shape[0], -1)
        pred = self.model(flat_img)
        return pred

In [7]:
G = Generator()
D = Discriminator()

if cuda:
    G.cuda()
    D.cuda()

In [8]:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

def gradient_penalty(D, real_img, fake_img):
    alpha = Tensor(np.random.random((real_img.size(0), 1, 1, 1)))
    interpolates = (alpha * real_img + ((1 - alpha) * 
                                        fake_img)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_img.shape[0], 1).fill_(1.0)
                    , requires_grad=False)
    gradients = autograd.grad(outputs=d_interpolates,
                              inputs=interpolates,
                              grad_outputs=fake,
                              create_graph=True,
                              retain_graph=True,
                              only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    GP = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return GP

In [12]:
optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(b1, b2))
current_iters = 0
os.makedirs("WGAN-GP_results", exist_ok=True)

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = Variable(imgs.type(Tensor))
        
        ## Train Discriminator ##
        optimizer_D.zero_grad()
        z = Variable(Tensor(np.random.normal(0, 1,
                                             (imgs.shape[0], latent_dim))))
        fake_imgs = G(z)
        real_pred = D(real_imgs)
        fake_pred = D(fake_imgs)
        GP = gradient_penalty(D, real_imgs.data, fake_imgs.data)
        d_loss = -torch.mean(real_pred) + torch.mean(fake_pred) + lambda_gp * GP
        loss_d.append(d_loss.item())
        d_loss.backward()
        optimizer_D.step()
        optimizer_G.zero_grad()
        
        if i % n_critic == 0:
            ## Train Generator ##
            fake_imgs = G(z)
            fake_pred = D(fake_imgs)
            g_loss = -torch.mean(fake_pred)
            loss_g.append(g_loss.item())
            g_loss.backward()
            optimizer_G.step()
            current_iters += n_critic
    print("[Epoch %d/%d]  [D loss: %f] [G loss: %f]"
          % (epoch, epochs, d_loss.item(), g_loss.item()))
    save_image(fake_imgs.data[:25], "WGAN-GP_results/epoch_%d.png" 
               % epoch, nrow=5, normalize=True)
    
torch.save(G.state_dict(), './Generator.pth')
torch.save(D.state_dict(), './Discriminator.pth')

KeyboardInterrupt: 