In [None]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt
import numpy as np

EPOCHS = 200

FOLDER = "./Data/"

# Choose one
#LOSS = "BCE" 
LOSS = "Wasserstein"

G_DIR = FOLDER + "Generator_200epochs" + LOSS
D_DIR = FOLDER + "Discriminator_200epochs" + LOSS
DIGIT_SAVE_DIR = FOLDER + 'sample_digits_GAN_' + LOSS + '.png'
z_dim = 4
RANDOM_VECTOR = FOLDER + "random_vector.txt"

TRAIN_G_LOSS = FOLDER + "TRAIN_G_LOSS_GAN_" + LOSS + ".txt.npy"
TRAIN_D_LOSS = FOLDER + "TRAIN_D_LOSS_GAN_" + LOSS + ".txt.npy"

LOSS_PLOT = FOLDER + "LOSS_PLOT_GAN" + LOSS + ".png"
KLD_PLOT = FOLDER + "KLD_PLOT_GAN" + LOSS + ".png"
WCL = 0.01
if not os.path.exists(RANDOM_VECTOR):
    z = torch.randn(100, z_dim)
    torch.save(z, RANDOM_VECTOR)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
bs = 100

# MNIST Dataset

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=True)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [None]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, g_output_dim)
    
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        return torch.tanh(self.fc4(x))

    def G_train(self, x, D):
        self.zero_grad()

        z = Variable(torch.randn(bs, z_dim).to(device))
        y = Variable(torch.ones(bs, 1).to(device))

        G_output = self.forward(z)
        D_output = D(G_output)

        G_loss = criterion(D_output, y)

        # gradient backprop & optimize ONLY G's parameters
        G_loss.backward()
        G_optimizer.step()
            
        return G_loss.data.item()
    
    
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.fc4 = nn.Linear(256, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.2)
        if LOSS == "BCE":
            return torch.sigmoid(self.fc4(x))
        elif LOSS == "Wasserstein":
            return (self.fc4(x))
    
    def D_train(self, x):
        

        x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
        x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))

        D_output = self.forward(x_real)
        
        D_real_score = D_output

        z = Variable(torch.randn(bs, z_dim).to(device))
        x_fake, y_fake = G(z), Variable(((LOSS == "BCE") - 1) * torch.ones(bs, 1).to(device))

        D_output = self.forward(x_fake)
        
        D_fake_score = D_output

        D_real_loss = criterion(D_real_score, y_real)
        D_fake_loss = criterion(D_fake_score, y_fake)
        D_loss = D_real_loss + D_fake_loss

        self.zero_grad()
        D_loss.backward()
        D_optimizer.step()
        if LOSS == "Wasserstein":
            for p in self.parameters():
                p.data.clamp_(-WCL, WCL)
            
        return  D_loss.data.item()

In [None]:
# build network
z_dim = 4
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

In [None]:
def wasserLoss(yp, yt):
    return torch.mean(yp) * torch.mean(-yt)
if LOSS == "BCE":
    criterion = nn.BCELoss() 
elif LOSS == "Wasserstein":
    criterion = wasserLoss

# optimizer
lr = 0.0002 

if LOSS == "BCE":
    G_optimizer = optim.Adam(G.parameters(), lr = 0.0002)
    D_optimizer = optim.Adam(D.parameters(), lr = 0.0002)
elif LOSS == "Wasserstein":
    G_optimizer = optim.RMSprop(G.parameters(), lr = 5e-5)
    D_optimizer = optim.RMSprop(D.parameters(), lr = 5e-5)

In [None]:
D_train_loss_history = []
G_train_loss_history = []
G.train()
D.train()
for epoch in range(1, EPOCHS+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D.D_train(x))
        G_losses.append(G.G_train(x, D))
    D_train_loss_history.append(torch.mean(torch.FloatTensor(D_losses)))
    G_train_loss_history.append(torch.mean(torch.FloatTensor(G_losses)))
    print('[%d/%d]: Discriminator loss: %.3f, Generator loss: %.3f' % (
            (epoch), EPOCHS, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

In [None]:
np.save(TRAIN_G_LOSS, np.array(G_train_loss_history))
np.save(TRAIN_D_LOSS, np.array(D_train_loss_history))

In [None]:
torch.save(G.state_dict(), G_DIR)
torch.save(D.state_dict(), D_DIR)

In [None]:
G_train_loss_history = np.load(TRAIN_G_LOSS)
D_train_loss_history = np.load(TRAIN_D_LOSS)

In [None]:
plt.figure(1)
plt.plot(range(1,len(G_train_loss_history)+1), G_train_loss_history)
plt.plot(range(1,len(D_train_loss_history)+1), D_train_loss_history)
plt.xlabel("EPOCHS")
plt.ylabel("Average loss")
plt.legend(["G Train", "D Train"])
plt.savefig(LOSS_PLOT)

In [None]:
z_dim = 4
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)
G.load_state_dict(torch.load(G_DIR, map_location=device))
D.load_state_dict(torch.load(D_DIR, map_location=device))
with torch.no_grad():
    test_z = torch.load(RANDOM_VECTOR).to(device)
    generated = G(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), DIGIT_SAVE_DIR)