In [2]:
# Generative Adversarial Networks (GAN) example in PyTorch.
# See related blog post at https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size

d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

# ### Uncomment only one of these
#(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
(name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)

print("Using data [%s]" % (name))

# ##### DATA: Target data and generator input data

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

# ##### MODELS: Generator model and discriminator model

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))

Using data [Data and variances]
0: D: 0.7591767311096191/0.6492283940315247 G: 0.7404654622077942 (Real: [3.7744895476102829, 1.2127879487501048], Fake: [-0.080349415764212603, 0.0040373059949561415]) 


  "Please ensure they have the same size.".format(target.size(), input.size()))


200: D: 0.000754163833335042/0.6369110345840454 G: 0.7691926956176758 (Real: [4.0486575376987455, 1.2539147953408929], Fake: [-0.39132367789745331, 0.0049814184192758003]) 
400: D: 5.614915062324144e-05/0.5187681913375854 G: 0.9356982707977295 (Real: [3.8863593947887423, 1.3437744502831639], Fake: [-0.38369377285242079, 0.0085378762138881292]) 
600: D: 6.675741587969242e-06/0.35793426632881165 G: 1.2138944864273071 (Real: [3.9529964721202848, 1.3199351706128957], Fake: [-0.38372327983379362, 0.056591286684082628]) 
800: D: 6.437321189878276e-06/0.19419550895690918 G: 1.9797745943069458 (Real: [3.9778831040859224, 1.0985574696608611], Fake: [-0.38926272854208949, 0.30557722775850882]) 
1000: D: 0.04351293295621872/0.015439733862876892 G: 3.9179484844207764 (Real: [3.9195833837985994, 1.1447051885462327], Fake: [1.303229666352272, 0.50094379356635221]) 
1200: D: 0.12069693207740784/0.19963468611240387 G: 2.3182003498077393 (Real: [4.237438430786133, 1.3047203523687139], Fake: [3.06352782

10000: D: 0.5729203224182129/0.49057695269584656 G: 1.0445129871368408 (Real: [4.145910779237747, 1.3416634060166293], Fake: [4.1204122048616405, 1.1044755617378577]) 
10200: D: 1.1420321464538574/0.7699006795883179 G: 0.7112780809402466 (Real: [4.0053596019744875, 1.1305459219223297], Fake: [3.9794211101531984, 1.3490466391922549]) 
10400: D: 0.7340941429138184/0.2624286115169525 G: 1.2078440189361572 (Real: [3.8517177972197532, 1.4292914224937217], Fake: [4.0255674493312839, 1.3381280381062433]) 
10600: D: 0.4878266155719757/0.4451640248298645 G: 0.9934609532356262 (Real: [4.1101989614963532, 1.3230360417616684], Fake: [4.0249018996953962, 1.2920780006085295]) 
10800: D: 0.8899770975112915/0.5163384675979614 G: 1.0826796293258667 (Real: [3.9148611101508139, 1.2682766245917632], Fake: [3.9188432061672209, 1.2643077708311175]) 
11000: D: 0.5683136582374573/0.6140703558921814 G: 0.9674574136734009 (Real: [4.0380168223381041, 1.1813273217305131], Fake: [4.0974157047271724, 1.169389793849

19800: D: 0.19312016665935516/0.4549640715122223 G: 1.6338828802108765 (Real: [4.0336783683300022, 1.0857848212896839], Fake: [3.9983774733543398, 1.1859897767005405]) 
20000: D: 0.33223336935043335/0.24320967495441437 G: 1.179128885269165 (Real: [3.9047317101061343, 1.2826183678407914], Fake: [4.0308008933067319, 1.0580101947826166]) 
20200: D: 0.2904634475708008/0.9970558285713196 G: 1.829633116722107 (Real: [3.9558619594573976, 1.2575629152182586], Fake: [4.111473888754845, 1.188649399466666]) 
20400: D: 0.2684444189071655/0.5881476402282715 G: 1.1585514545440674 (Real: [4.0228100979328154, 1.1084062458011388], Fake: [4.3453722763061524, 1.2790166034832122]) 
20600: D: 0.12829643487930298/0.7798513770103455 G: 0.41847172379493713 (Real: [3.9155438077449798, 1.3614451838848478], Fake: [4.4139680230617522, 1.3405579231725515]) 
20800: D: 0.07240618020296097/0.5143404603004456 G: 0.07385839521884918 (Real: [3.7377527683973311, 1.2162414621025703], Fake: [4.175955717563629, 1.1753697089

KeyboardInterrupt: 