# Generative Adversarial Networks (GAN) example in PyTorch.

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

## Hyperparameters

In [6]:
# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size

d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1  # k >= 1
g_steps = 1

## Data loader

In [11]:
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

# preprocess, d_input_func = lambda data: data, lambda x: x)
preprocess, d_input_func = lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2

def normal(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n))) # gaussian noise

def entropy():
    return lambda m, n: torch.rand(m, n) # uniform noise

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

In [12]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

In [13]:
d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

In [15]:
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        print(d_real_decision.data)
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))


 0.9989
[torch.FloatTensor of size 1x1]

0: D: 0.0011032943148165941/0.031091857701539993 G: 3.8190414905548096 (Real: [4.0715127182006832, 1.2866877137199169], Fake: [0.41409335136413572, 0.17117764882776262]) 

 0.9993
[torch.FloatTensor of size 1x1]


 0.9953
[torch.FloatTensor of size 1x1]


 0.9938
[torch.FloatTensor of size 1x1]


 0.9998
[torch.FloatTensor of size 1x1]


 0.9985
[torch.FloatTensor of size 1x1]


 0.9992
[torch.FloatTensor of size 1x1]


 0.9996
[torch.FloatTensor of size 1x1]


 0.9991
[torch.FloatTensor of size 1x1]


 0.9975
[torch.FloatTensor of size 1x1]


 0.9989
[torch.FloatTensor of size 1x1]


 1.0000
[torch.FloatTensor of size 1x1]


 0.9988
[torch.FloatTensor of size 1x1]


 1.0000
[torch.FloatTensor of size 1x1]


 0.9958
[torch.FloatTensor of size 1x1]


 0.9984
[torch.FloatTensor of size 1x1]


 0.9986
[torch.FloatTensor of size 1x1]


 0.9685
[torch.FloatTensor of size 1x1]


 0.9994
[torch.FloatTensor of size 1x1]


 0.9974
[torch.FloatTensor of 


 0.9687
[torch.FloatTensor of size 1x1]


 0.9985
[torch.FloatTensor of size 1x1]


 0.9999
[torch.FloatTensor of size 1x1]


 1.0000
[torch.FloatTensor of size 1x1]


 0.9992
[torch.FloatTensor of size 1x1]


 1.0000
[torch.FloatTensor of size 1x1]


 0.9987
[torch.FloatTensor of size 1x1]


 0.9999
[torch.FloatTensor of size 1x1]


 0.9994
[torch.FloatTensor of size 1x1]


 0.9999
[torch.FloatTensor of size 1x1]


 0.9999
[torch.FloatTensor of size 1x1]


 1.0000
[torch.FloatTensor of size 1x1]


 0.9996
[torch.FloatTensor of size 1x1]


 0.9999
[torch.FloatTensor of size 1x1]


 0.9999
[torch.FloatTensor of size 1x1]


 0.9999
[torch.FloatTensor of size 1x1]


 0.9821
[torch.FloatTensor of size 1x1]


 0.9999
[torch.FloatTensor of size 1x1]


 1.0000
[torch.FloatTensor of size 1x1]


 0.9994
[torch.FloatTensor of size 1x1]


 0.9998
[torch.FloatTensor of size 1x1]


 0.9972
[torch.FloatTensor of size 1x1]


 0.9796
[torch.FloatTensor of size 1x1]


 0.9954
[torch.FloatTensor of siz


 0.9736
[torch.FloatTensor of size 1x1]


 0.9972
[torch.FloatTensor of size 1x1]


 0.9950
[torch.FloatTensor of size 1x1]


 0.9992
[torch.FloatTensor of size 1x1]


 0.9748
[torch.FloatTensor of size 1x1]


 0.9741
[torch.FloatTensor of size 1x1]


 0.9519
[torch.FloatTensor of size 1x1]


 0.2297
[torch.FloatTensor of size 1x1]


 0.9318
[torch.FloatTensor of size 1x1]


 0.6868
[torch.FloatTensor of size 1x1]


 0.9629
[torch.FloatTensor of size 1x1]


 0.9417
[torch.FloatTensor of size 1x1]


 0.9652
[torch.FloatTensor of size 1x1]


 0.9692
[torch.FloatTensor of size 1x1]


 0.9608
[torch.FloatTensor of size 1x1]


 0.9883
[torch.FloatTensor of size 1x1]


 0.9396
[torch.FloatTensor of size 1x1]


 0.9838
[torch.FloatTensor of size 1x1]


 0.9506
[torch.FloatTensor of size 1x1]


 0.9451
[torch.FloatTensor of size 1x1]


 0.8913
[torch.FloatTensor of size 1x1]


 0.9563
[torch.FloatTensor of size 1x1]


 0.9128
[torch.FloatTensor of size 1x1]


 0.8647
[torch.FloatTensor of siz

KeyboardInterrupt: 