# Generative Adversarial Networks (GAN) example in PyTorch.

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

## Hyperparameters

In [2]:
# Data params
data_mean = 4
data_stddev = 1.25

# Model params
g_input_size = 1     # Random noise dimension coming into generator, per output vector
g_hidden_size = 50   # Generator complexity
g_output_size = 1    # size of generated output vector
d_input_size = 100   # Minibatch size - cardinality of distributions
d_hidden_size = 50   # Discriminator complexity
d_output_size = 1    # Single dimension for 'real' vs. 'fake'
minibatch_size = d_input_size

d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)
num_epochs = 30000
print_interval = 200
d_steps = 1  # k >= 1
g_steps = 1

In [None]:
## Hyperparameters

In [3]:
preprocess, d_input_func = lambda data: data, lambda x: x)
preprocess, d_input_func = lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)

def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian

def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian

Using data [Data and variances]


In [4]:
# ##### MODELS: Generator model and discriminator model

class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.sigmoid(self.map2(x))
        return self.map3(x)

class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.elu(self.map1(x))
        x = F.elu(self.map2(x))
        return F.sigmoid(self.map3(x))

def extract(v):
    return v.data.storage().tolist()

def stats(d):
    return [np.mean(d), np.std(d)]

def decorate_with_diffs(data, exponent):
    mean = torch.mean(data.data, 1)
    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
    diffs = torch.pow(data - Variable(mean_broadcast), exponent)
    return torch.cat([data, diffs], 1)

In [5]:
d_sampler = get_distribution_sampler(data_mean, data_stddev)
gi_sampler = get_generator_input_sampler()
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)
D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)
criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)
g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)

In [6]:
for epoch in range(num_epochs):
    for d_index in range(d_steps):
        # 1. Train D on real+fake
        D.zero_grad()

        #  1A: Train D on real
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(preprocess(d_real_data))
        d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true
        d_real_error.backward() # compute/store gradients, but don't change params

        #  1B: Train D on fake
        d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
        d_fake_decision = D(preprocess(d_fake_data.t()))
        d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake
        d_fake_error.backward()
        d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

    for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
        G.zero_grad()

        gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(preprocess(g_fake_data.t()))
        g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine

        g_error.backward()
        g_optimizer.step()  # Only optimizes G's parameters

    if epoch % print_interval == 0:
        print("%s: D: %s/%s G: %s (Real: %s, Fake: %s) " % (epoch,
                                                            extract(d_real_error)[0],
                                                            extract(d_fake_error)[0],
                                                            extract(g_error)[0],
                                                            stats(extract(d_real_data)),
                                                            stats(extract(d_fake_data))))

0: D: 0.6981781125068665/0.755477786064148 G: 0.6416189074516296 (Real: [4.0547098523378375, 1.2314252959340828], Fake: [-0.28511309772729876, 0.0099985404669034675]) 
200: D: 0.0009525781497359276/0.451774924993515 G: 1.0012849569320679 (Real: [3.6725147613883018, 1.3844103867078239], Fake: [0.14209336966276168, 0.063836321614812092]) 
400: D: 0.006096851080656052/0.33512821793556213 G: 1.2866592407226562 (Real: [4.0221354687213902, 1.2193185105708324], Fake: [0.080465802438557152, 0.11968505218068264]) 
600: D: 0.0010952985612675548/0.3185677230358124 G: 1.268117070198059 (Real: [4.0592756074666978, 1.307170498415837], Fake: [-0.045618432685732839, 0.26553063354745349]) 
800: D: 0.04345858842134476/0.030353741720318794 G: 3.019766092300415 (Real: [4.1049610841274262, 1.0628556763952564], Fake: [0.87872730329632764, 0.41226892261950321]) 
1000: D: 0.013587034307420254/0.14907453954219818 G: 2.6867165565490723 (Real: [4.2136442089080814, 1.1570395098039838], Fake: [2.0033200972899796, 

10000: D: 0.5003601312637329/1.172639012336731 G: 0.37165704369544983 (Real: [3.9852352273464202, 1.1590077681765674], Fake: [4.0687204885482791, 1.229554758161314]) 
10200: D: 0.9755654335021973/0.7196236252784729 G: 0.8616246581077576 (Real: [4.0175358653068542, 1.1447951729589665], Fake: [3.9511641836166382, 1.3466961627958391]) 
10400: D: 0.37218981981277466/0.7939143180847168 G: 0.7537536025047302 (Real: [3.8593139111995698, 1.2168616469643259], Fake: [3.9173891496658326, 1.1199848083709121]) 
10600: D: 0.8577403426170349/0.6049784421920776 G: 0.618677020072937 (Real: [4.0312767395377156, 1.2784476278844978], Fake: [4.3731682640314098, 1.2742238975508162]) 
10800: D: 0.6766371130943298/0.6346055865287781 G: 0.7929967641830444 (Real: [3.855501775741577, 1.3379375899211765], Fake: [4.2419640529155735, 0.97599873996567366]) 
11000: D: 0.8232480883598328/0.685746967792511 G: 0.8170490264892578 (Real: [3.9466886144876479, 1.222878085185122], Fake: [3.9818111205101014, 1.149805283807306

19800: D: 0.5665174722671509/0.25150051712989807 G: 1.6446659564971924 (Real: [4.0396297061443329, 1.2025047058152705], Fake: [3.8883231937885285, 1.3786591640294428]) 
20000: D: 0.12129618227481842/0.1923273652791977 G: 2.1840627193450928 (Real: [4.0129013103246685, 1.3621970283850597], Fake: [4.1223703467845914, 1.3159434614145862]) 
20200: D: 0.055870864540338516/0.11196769028902054 G: 2.5396525859832764 (Real: [3.9606452202796936, 1.1675076246315217], Fake: [3.8837533509731292, 1.3205521437912644]) 
20400: D: 0.7079916000366211/0.058159612119197845 G: 1.886824607849121 (Real: [3.9738567399978639, 1.1051662760210235], Fake: [4.0049975240230564, 1.2540673828517888]) 
20600: D: 0.03515109792351723/0.05685963109135628 G: 1.3269304037094116 (Real: [4.0202750995755192, 1.3124802611963766], Fake: [4.098667097091675, 1.2341350539348426]) 
20800: D: 0.0010486376704648137/0.575573742389679 G: 1.452439785003662 (Real: [3.905985063314438, 1.2034626258637202], Fake: [4.2762358474731448, 1.10084

29600: D: 0.0010859303874894977/0.0009414429077878594 G: 10.873520851135254 (Real: [3.7867905431985855, 1.3014709660666599], Fake: [11.122827396392822, 1.5762867256918673]) 
29800: D: 4.202215495752171e-05/9.479149554181276e-08 G: 15.285253524780273 (Real: [4.0204957526922227, 1.3140525536356058], Fake: [13.194394207000732, 1.881389211628145]) 
