# GAN example with PyTorch
This is intended to be a concise example of Generative Adversarial Networks, therefore no deep explanation is given.

In GAN architecture, there are two Neural Networks competing against each other. The first is called generator (we'll give it the $G$ label) and it simply generate new data. The second is a discriminative network (we'll label it as $D$) that is a common classificator.

The generator $G$ goal is to fool the discriminator $D$, trying to make it think the generated data are real. The discriminator will try to classify incoming data as real or fake. This is how the $G$ and $D$ compete. 

Code taken from [this article](https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f)

In [28]:
# importing libraries
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

**1.** First we create the function to generate our dataset.

In [29]:
def get_distribution_sampler(mu, sigma):
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))

**2.** Next we prepare function that creates the input that goes to the generator.

In [30]:
def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)

**3.** Next, we create our generator (which is a standard FeedFoward network).

Here, `nn.Linear` is a linear layer from PyTorch that applies a linear transformation in the data: $y = xA^T + b$.

In [31]:
class Generator(nn.Module):
    
    """
    Defines the generator network, G.
    """
    
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f
        
    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return self.f(self.map3(x))

**4.** Now we create our discriminator (also, a FeedFoward network). Note that the code is the same.

In [32]:
class Discriminator(nn.Module):
    
    """
    Defines the discriminator network, D.
    """
    
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f
        
    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return self.f(self.map3(x))

**5.** Finally, we create the training loop.

To start the loop, we need to set some things up:

In [33]:
def get_moments(d):
    """
    Return the first 4 moments of the data provided
    For more information about moments check the link:
    https://www.thoughtco.com/what-are-moments-in-statistics-3126234
    """
    
    mean = torch.mean(d)
    diffs = d - mean
    var = torch.mean(torch.pow(diffs, 2.0))
    std = torch.pow(var, 0.5)
    zscores = diffs / std
    skews = torch.mean(torch.pow(zscores, 3.0))
    kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0  # excess kurtosis, should be 0 for Gaussian
    final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))
    return final

def stats(d):
    """
    Returns mean and stdev of given data vector 
    """
    
    return [np.mean(d), np.std(d)]

def extract(v):
    """
    """
    
    return v.data.storage().tolist()

# Data params
data_mean = 4
data_stdev = 1.25

g_input_size = 1                # Random noise dimension coming into generator, per output vector
g_hidden_size = 5               # Generator hidden layers size
g_output_size = 1               # Size of generated output vector

d_input_size = 500              # Input size coming into discriminator
d_hidden_size = 10              # Discriminator hidden layers size
d_output_size = 1               # Single dimension for 'real' vs. 'fake' classification

epochs = 5000                   # Training epochs
d_steps = 20                    # Train steps for D
g_steps = 20                    # Train steps for G

minibatch_size = d_input_size   #

# Get samplers
g_sampler = get_generator_input_sampler()
d_sampler = get_distribution_sampler(data_mean, data_stdev)

# Instantiates generator
G = Generator(input_size=g_input_size,
                  hidden_size=g_hidden_size,
                  output_size=g_output_size,
                  f=torch.tanh)

# Instantiates discriminator
D = Discriminator(input_size=d_input_size,
                  hidden_size=d_hidden_size,
                  output_size=d_output_size,
                  f=torch.sigmoid)

d_learning_rate = 1e-3          # D learning rate
g_learning_rate = 1e-3          # G learning rate
sgd_momentum = 0.9              #
    
# Define optimizers for D and G
d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)

# Define the cost function as the
# binary cross-entropy function.
cost = nn.BCELoss()

dfe, dre, ge = 0, 0, 0
d_real_data, d_fake_data, g_fake_data = None, None, None

print_interval = 100            # defines a regular interval to print info and stats

Then, we perform the actually training:

In [36]:
for epoch in range(epochs):
    
    # Discriminator loop
    for d_index in range(d_steps):
        D.zero_grad()
        
        # Train D on real data
        d_real_data = Variable(d_sampler(d_input_size))
        d_real_decision = D(d_real_data)
        d_real_error = cost(d_real_decision, Variable(torch.ones([1, 1])))
        d_real_error.backward()
        
        # Train D on fake data
        d_gen_input = Variable(g_sampler(minibatch_size, g_input_size))
        d_fake_data = G(d_gen_input).detach()  # detach to avoid G training with these data
        d_fake_decision = D(d_fake_data.t())
        d_fake_error = cost(d_fake_decision, Variable(torch.ones([1, 1])))
        d_fake_error.backward()
        d_optimizer.step() # chages D's parameters based on backpropagation results
        dre = extract(d_real_error)[0]
        dfe = extract(d_fake_error)[0]
        
    # Generator loop
    for g_index in range(g_steps):
        G.zero_grad()
        
        gen_input = Variable(g_sampler(minibatch_size, g_input_size))
        g_fake_data = G(gen_input)
        dg_fake_decision = D(g_fake_data.t())
        g_error = cost(dg_fake_decision, Variable(torch.ones([1, 1])))
        
        g_error.backward()
        g_optimizer.step()
        ge = extract(g_error.data)[0]
        
    if epoch % print_interval == 0:
        print("Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %
             (epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))
        

Epoch 0: D (0.20092250406742096 real_err, 0.2159106582403183 fake_err) G (0.21251614391803741 err); Real Dist ([4.00015239238739, 1.2839091444746191]),  Fake Dist ([0.1546692524254322, 0.058713997862042185]) 
Epoch 100: D (0.0032095503993332386 real_err, 0.004487195517867804 fake_err) G (0.004470251966267824 err); Real Dist ([4.075904984116554, 1.2635852427961882]),  Fake Dist ([0.3809716682434082, 0.04731612615002367]) 
Epoch 200: D (0.0015005397144705057 real_err, 0.0020566259045153856 fake_err) G (0.002046412555500865 err); Real Dist ([3.996270431637764, 1.1868401026613342]),  Fake Dist ([0.44569210189580916, 0.04370137079643329]) 
Epoch 300: D (0.000965942395851016 real_err, 0.0013009423855692148 fake_err) G (0.0013062540674582124 err); Real Dist ([3.967964541912079, 1.2266390775962484]),  Fake Dist ([0.47950993180274964, 0.0407742171730421]) 
Epoch 400: D (0.0007097855559550226 real_err, 0.0009428533958271146 fake_err) G (0.0009446432231925428 err); Real Dist ([4.01204094479233, 1

Epoch 3800: D (6.687864515697584e-05 real_err, 8.023106784094125e-05 fake_err) G (8.023106784094125e-05 err); Real Dist ([4.004149026274681, 1.3157129248810608]),  Fake Dist ([0.6236022375822067, 0.03168649866891332]) 
Epoch 3900: D (6.520960596390069e-05 real_err, 7.820434984751046e-05 fake_err) G (7.820434984751046e-05 err); Real Dist ([4.076516931891441, 1.251320375266635]),  Fake Dist ([0.6250120311975479, 0.03193548149316131]) 
Epoch 4000: D (6.342135020531714e-05 real_err, 7.605842256452888e-05 fake_err) G (7.593919872306287e-05 err); Real Dist ([4.008521069377661, 1.2833013158426176]),  Fake Dist ([0.6237312965393066, 0.03171582538349023]) 
Epoch 4100: D (6.1871534853708e-05 real_err, 7.39124880055897e-05 fake_err) G (7.41509284125641e-05 err); Real Dist ([4.022280661225319, 1.2325043479738893]),  Fake Dist ([0.6261633501052857, 0.031325350940422]) 
Epoch 4200: D (6.032171950209886e-05 real_err, 7.212421769509092e-05 fake_err) G (7.224344153655693e-05 err); Real Dist ([4.0165290