In [1]:
# Generative Adversarial Networks (GAN) example in PyTorch.
# See related blog post at 
# https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [2]:
# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size

d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)

num_epochs = 3000
#num_epochs = 30000
print_interval = 200

d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

In [3]:
# ##### DATA: Target data and generator input data

dtype = torch.cuda.FloatTensor
#dtype = torch.FloatTensor


def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian


In [4]:
# ##### MODELS: Generator model and discriminator model
#%time

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast).cuda(), exponent)
    return torch.cat([data, diffs], 1)


In [5]:
# ### Uncomment only one of these
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)

print("Using data [%s]" % (name))

Using data [Data and variances]


In [6]:
d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss().cuda()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

Run the networks on GPU

In [7]:
G = G.cuda()
D = D.cuda()
criterion = criterion.cuda()

In [None]:
%%time

def train_GAN(num_epochs):
    '''
    '''
    for epoch in range(num_epochs):
        for d_index in range(d_steps):
            # 1. Train D on real+fake
            D.zero_grad()

            #  1A: Train D on real
            d_real_data = Variable(d_sampler(d_input_size)).cuda()
        
            d_real_decision = D(preprocess(d_real_data))
            d_real_error = criterion(d_real_decision, Variable(torch.ones(1).cuda()))  # ones = true
            d_real_error.backward() # compute/store gradients, but don't change params

            #  1B: Train D on fake
            d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size)).cuda()
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            d_fake_decision = D(preprocess(d_fake_data.t()))
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)).cuda()).cuda()  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

        for g_index in range(g_steps):
            # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()

            gen_input = Variable(gi_sampler(minibatch_size, g_input_size)).cuda()
            g_fake_data = G(gen_input)
            dg_fake_decision = D(preprocess(g_fake_data.t()))
            g_error = criterion(dg_fake_decision, Variable(torch.ones(1)).cuda())  # we want to fool, so pretend it's all genuine

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))

  "Please ensure they have the same size.".format(target.size(), input.size()))


0: D: 0.8314700126647949/0.7408788800239563 G: 0.6517988443374634 (Real: [3.5847136580944063, 1.3028425320120185], Fake: [-0.19148351490497589, 0.0026327345021920287]) 
200: D: 0.005149307660758495/0.6713951826095581 G: 0.7198851704597473 (Real: [4.0853431904315949, 1.2638330799754618], Fake: [-0.13201068341732025, 0.012496495198185742]) 
400: D: 4.4466054532676935e-05/0.4138678014278412 G: 1.1039412021636963 (Real: [3.7992736539244651, 1.45441843807528], Fake: [-0.062396438270807264, 0.042444088267637958]) 
600: D: 7.271793037944008e-06/0.4078928232192993 G: 0.9893813133239746 (Real: [3.7780758392810823, 1.30437147225858], Fake: [-0.22115120703354477, 0.17351046334565301]) 
800: D: 0.001397214480675757/0.15718115866184235 G: 2.0131983757019043 (Real: [3.9881101787090301, 1.1478168989297544], Fake: [0.4706003931351006, 0.56432951749273874]) 
1000: D: 0.07969240099191666/0.11074899137020111 G: 2.802726984024048 (Real: [4.1164034307003021, 1.3332585979281273], Fake: [2.7175288861989975, 

10000: D: 0.4605080187320709/0.6545906066894531 G: 0.7222508788108826 (Real: [3.9682945823669433, 1.2388837634512304], Fake: [4.125325989127159, 1.3054842383967098]) 
10200: D: 0.6009887456893921/0.5211098194122314 G: 0.8292682766914368 (Real: [4.1526535423845052, 1.1840650140518765], Fake: [3.8940234857797624, 1.3016225625412801]) 
10400: D: 0.623485803604126/0.7103880047798157 G: 0.5747432708740234 (Real: [3.8938691937923431, 1.2831633272956378], Fake: [3.994795436859131, 1.3455326163296097]) 
10600: D: 0.7250586748123169/0.7088103890419006 G: 0.9190853238105774 (Real: [3.6585712096095087, 1.2236852592118925], Fake: [4.0267258596420286, 1.3601611757301317]) 
