In [1]:
#!/usr/bin/env python

# Generative Adversarial Networks (GAN) example in PyTorch.
# See related blog post at https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  
os.environ["CUDA_VISIBLE_DEVICES"]="6" # change 0  with whatever card is available

In [3]:
# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output state vector

m_input_size = 1   # State dimension coming into model, per output vector
m_hidden_size = 50   # Model complexity
m_output_size = 1    # size of generated output vector

minibatch_size = 200

g_learning_rate = 2e-4  # 2e-4
m_learning_rate = 2e-4  # 2e-4

optim_betas = (0.9, 0.999)
num_epochs = 50000
print_interval = 200
m_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

# # ### Uncomment only one of these
# #(name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
# (name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)

# print("Using data [%s]" % (name))

In [4]:
def l1(x, y):
    loss = (x-y).pow(2).sum()
    return loss

def nl1(x, y):
    loss = -((x-y).pow(2).sum())
    return loss

In [None]:
# ##### DATA: Target data and generator input data
def exact_model(mu, sigma):
    return lambda x: torch.exp(-((x-mu)**2)/sigma**2)/np.sqrt(6.28)/sigma  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

# ##### MODELS: Generator model and discriminator model

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.map1(x)
        x = F.elu(self.map2(x))
        return self.map3(x)

class Model(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Model, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        
        return self.map3(x)

# class Discriminator(nn.Module):
#     def __init__(self, input_size, hidden_size, output_size):
#         super(Discriminator, self).__init__()
#         self.map1 = nn.Linear(input_size, hidden_size)
#         self.map2 = nn.Linear(hidden_size, hidden_size)
#         self.map3 = nn.Linear(hidden_size, output_size)

#     def forward(self, x):
#         x = F.elu(self.map1(x))
#         x = F.elu(self.map2(x))
#         return F.sigmoid(self.map3(x))

def extract(v):
    return v.data.storage().tolist()


def stats(d):
    return [np.mean(d), np.std(d)]

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1, keepdim=True)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

In [None]:
exact_sampler = exact_model(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()

G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
M = Model(input_size=m_input_size, hidden_size=m_hidden_size, output_size=m_output_size)

g_criterion = nl1  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
consta = Variable(torch.Tensor([-1]).double())
m_criterion = l1  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss

g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)
m_optimizer = optim.Adam(M.parameters(), lr=m_learning_rate, betas=optim_betas)

for epoch in range(num_epochs):

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()
        M.zero_grad()
        
        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        
        states = G(gen_input)
        fake_data = M(states)
        true_data = exact_sampler(states)
        
        g_vars = G.parameters()
        g_error = g_criterion(fake_data, true_data)  # we want to fool, so pretend it's all genuine
        g_error.backward(retain_graph=True)
        g_optimizer.step()  # Only optimizes G's parameters

        m_vars = M.parameters()
        m_error = m_criterion(fake_data, true_data)  # we want to fool, so pretend it's all genuine
        m_error.backward()
        m_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s L1: G:%s M:%s \n%s \n mu, std (Real: %s, Fake: %s) " % (epoch,
                                                            g_error,
                                                            m_error,
                                                            states,
                                                            stats(extract(true_data)),
                                                            stats(extract(fake_data))))

0 L1: G:Variable containing:
-0.5067
[torch.FloatTensor of size 1]
 M:Variable containing:
 0.5067
[torch.FloatTensor of size 1]
 
Variable containing:
 3.2136e-02
 3.4872e-02
 8.1771e-02
 6.4766e-03
 2.8789e-02
 2.8675e-02
 9.8288e-02
-8.3785e-03
 6.0016e-03
 1.6382e-02
-2.3044e-03
 8.6832e-02
-2.1289e-03
 7.0945e-02
-5.7542e-03
-3.4859e-04
 4.5356e-02
 8.4326e-02
 1.0606e-02
-3.9366e-03
 5.0774e-02
 3.4312e-02
 4.9893e-02
 1.0243e-01
-4.0676e-03
 1.0016e-01
 3.9299e-02
 4.7989e-02
 8.6637e-02
 8.8312e-02
 1.9611e-02
 5.6707e-02
 1.6018e-03
 4.6439e-02
 2.4045e-02
 7.5301e-02
 1.6362e-02
 9.0811e-03
 3.7990e-02
 1.9693e-02
-6.4602e-03
 4.1787e-02
 5.4031e-02
 7.4609e-02
 7.9838e-02
-4.2650e-03
 5.3019e-02
 2.7853e-02
 2.1519e-02
 1.1428e-02
-1.0780e-03
 5.3239e-02
 8.0709e-02
 2.6485e-02
 2.6820e-02
 7.8316e-02
 3.0489e-02
 4.7395e-02
 2.3149e-02
 9.5153e-02
 4.5277e-02
 8.6480e-02
 2.3381e-02
 3.7431e-02
-6.7275e-03
 4.7602e-03
 9.2864e-02
-5.7677e-03
 6.5141e-02
 3.2125e-02
-1.8179e

800 L1: G:Variable containing:
-9.4138e+10
[torch.FloatTensor of size 1]
 M:Variable containing:
 9.4138e+10
[torch.FloatTensor of size 1]
 
Variable containing:
1.00000e+05 *
 -2.0411
 -2.1077
 -1.4068
 -1.9833
 -1.3233
 -1.4714
 -1.9495
 -1.9078
 -1.3715
 -2.1699
 -1.8244
 -1.6546
 -1.7960
 -1.5254
 -2.1345
 -2.1576
 -2.1247
 -1.9277
 -2.0676
 -1.5711
 -1.9934
 -2.1516
 -2.0431
 -2.0349
 -1.1380
 -1.2359
 -1.7919
 -1.9470
 -2.0496
 -1.6674
 -1.1244
 -1.3446
 -1.2297
 -2.1261
 -1.7902
 -1.9844
 -2.1802
 -1.3345
 -2.0702
 -1.9779
 -1.3827
 -2.0065
 -1.6533
 -1.3882
 -1.4580
 -2.1789
 -1.2747
 -1.3620
 -2.1036
 -1.5773
 -1.2143
 -2.0437
 -1.7857
 -1.3518
 -1.6010
 -1.9098
 -1.8655
 -1.4963
 -1.3217
 -1.2702
 -1.9190
 -1.8112
 -1.3193
 -1.6637
 -1.9139
 -1.3479
 -1.4621
 -1.6083
 -1.3814
 -2.0784
 -1.2702
 -1.6491
 -1.6593
 -1.5887
 -1.6525
 -1.6197
 -1.7240
 -1.6077
 -1.8567
 -1.9744
 -2.0444
 -1.4992
 -1.6504
 -2.0772
 -1.9902
 -2.0002
 -1.9325
 -1.9743
 -1.7966
 -1.7645
 -1.1829
 -1.8

1600 L1: G:Variable containing:
-4.8214e+12
[torch.FloatTensor of size 1]
 M:Variable containing:
 4.8214e+12
[torch.FloatTensor of size 1]
 
Variable containing:
1.00000e+06 *
 -1.4358
 -1.2467
 -1.5715
 -1.2286
 -1.0606
 -1.3005
 -1.1837
 -0.8223
 -0.9894
 -1.3156
 -1.5535
 -1.1641
 -0.9413
 -1.5635
 -0.9916
 -1.0026
 -1.3409
 -0.8130
 -1.5647
 -0.8975
 -1.4550
 -0.9710
 -1.1281
 -1.5275
 -1.4763
 -1.5698
 -1.1033
 -1.3851
 -1.0718
 -1.5819
 -1.3922
 -1.5134
 -1.4612
 -1.5783
 -0.8437
 -1.0167
 -1.4394
 -1.2346
 -0.8665
 -1.0067
 -0.8659
 -0.9103
 -0.9194
 -1.3324
 -1.5288
 -1.0671
 -1.1260
 -1.5906
 -1.1770
 -1.0244
 -1.4655
 -1.5420
 -0.9299
 -1.1839
 -1.3414
 -0.9065
 -1.4973
 -1.0336
 -1.2095
 -0.9050
 -1.5280
 -1.0451
 -1.5217
 -0.9621
 -1.1796
 -0.8927
 -1.1898
 -0.8207
 -1.4983
 -1.5312
 -1.1484
 -1.6012
 -1.5454
 -0.9877
 -0.8668
 -0.9957
 -0.9879
 -1.1279
 -1.1531
 -0.8438
 -1.1082
 -1.2743
 -1.1016
 -1.2849
 -1.0066
 -1.0370
 -1.4428
 -1.3839
 -1.1098
 -1.5978
 -1.2868
 -0.

2400 L1: G:Variable containing:
-3.9879e+13
[torch.FloatTensor of size 1]
 M:Variable containing:
 3.9879e+13
[torch.FloatTensor of size 1]
 
Variable containing:
1.00000e+06 *
 -4.4171
 -3.9589
 -4.3732
 -2.8904
 -2.3382
 -4.5675
 -3.6888
 -4.2594
 -3.2786
 -3.7460
 -2.6871
 -3.6091
 -2.7878
 -3.8046
 -4.2516
 -2.7675
 -2.6098
 -4.3603
 -4.2507
 -2.7544
 -2.5701
 -4.6084
 -2.6247
 -3.9484
 -3.8011
 -3.1214
 -2.9959
 -3.7121
 -2.5220
 -3.3506
 -2.5286
 -3.3555
 -3.2577
 -3.1453
 -3.4511
 -2.7718
 -3.8294
 -2.7342
 -4.1029
 -4.0799
 -4.5810
 -4.4249
 -2.8681
 -3.0946
 -4.4434
 -4.4272
 -3.0816
 -2.9079
 -3.2002
 -4.0232
 -4.3605
 -2.8410
 -3.3017
 -2.9750
 -2.8164
 -3.1785
 -3.3715
 -3.0615
 -4.6197
 -4.4058
 -4.1391
 -3.7012
 -2.6272
 -2.8374
 -4.5097
 -3.6285
 -3.2314
 -2.4806
 -3.4954
 -3.4135
 -3.2549
 -4.0561
 -2.3418
 -3.0599
 -4.3606
 -2.6539
 -2.9594
 -2.5776
 -2.3897
 -4.0651
 -2.6727
 -3.0532
 -2.8701
 -4.4987
 -4.0043
 -3.0527
 -3.3474
 -2.8460
 -3.1386
 -3.6287
 -4.5408
 -3.

In [None]:
import matplotlib.pyplot as plt
x = G(Variable(torch.FloatTensor(np.array(np.linspace(-30, 30, num=500))).view(-1,1)))
y = G(x)
plt.plot(y.data.numpy())

In [None]:
import matplotlib.pyplot as plt
plt.plot(M(Variable(torch.Tensor(np.linspace(-3,10,1000)).view(-1,1).float())).data.numpy())
plt.plot(exact_sampler(Variable(torch.Tensor(np.linspace(-3,10,1000)).view(-1,1).float())).data.numpy())
plt.show()