In [1]:
#!/usr/bin/env python

# Generative Adversarial Networks (GAN) example in PyTorch. Tested with PyTorch 0.4.1, Python 3.6.7 (Nov 2018)
# See related blog post at https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9

from py_functions import fivenum

import sys
import numpy as np
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

t = time.time()


# seeds
seed = 123
torch.manual_seed(seed)
np.random.seed(seed=seed)

# Data params
data_mean = 4
data_stddev = 1.25

# ### Uncomment only one of these to define what data is actually sent to the Discriminator
# (name, preprocess, d_input_func) = ("Raw data", lambda data: data, lambda x: x)
# (name, preprocess, d_input_func) = ("Data and variances", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)
# (name, preprocess, d_input_func) = ("Data and diffs", lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2)
(name, preprocess, d_input_func) = ("Only 4 moments", lambda data: get_moments(data), lambda x: 4)

print("Using data [%s]" % (name))

# ##### DATA: Target data and generator input data


def get_distribution_sampler(mu, sigma):
    # Bell curve
    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian


def get_generator_input_sampler():
    return lambda m, n: torch.rand(m, n)   # Uniform-dist data into generator, _NOT_ Gaussian


# ##### MODELS: Generator model and discriminator model
class Generator(nn.Module):
    # Two hidden layers
    # Three linear maps
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.map1(x)
        x = self.f(x)
        x = self.map2(x)
        x = self.f(x)
        x = self.map3(x)
        return x


class Discriminator(nn.Module):
    # also two hidden layer and three linear maps
    def __init__(self, input_size, hidden_size, output_size, f):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.f = f

    def forward(self, x):
        x = self.f(self.map1(x))
        x = self.f(self.map2(x))
        return self.f(self.map3(x))


def extract(v):
    return v.data.storage().tolist()


def stats(d):
    return [np.mean(d), np.std(d)]


def get_moments(d):
    # Return the first 4 moments of the data provided
    mean = torch.mean(d)
    diffs = d - mean
    var = torch.mean(torch.pow(diffs, 2.0))
    std = torch.pow(var, 0.5)
    zscores = diffs / std
    skews = torch.mean(torch.pow(zscores, 3.0))
    kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0  # excess kurtosis, should be 0 for Gaussian
    final = torch.cat((mean.reshape(1,), std.reshape(1,), skews.reshape(1,), kurtoses.reshape(1,)))
    return final


# def decorate_with_diffs(data, exponent, remove_raw_data=False):
#     mean = torch.mean(data.data, 1, keepdim=True)
#     mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
#     diffs = torch.pow(data - Variable(mean_broadcast), exponent)
#     if remove_raw_data:
#         return torch.cat([diffs], 1)
#     else:
#         return torch.cat([data, diffs], 1)


def train():
    # Model parameters
    g_input_size = 1      # Random noise dimension coming into generator, per output vector
    g_hidden_size = 5     # Generator complexity
    g_output_size = 1     # Size of generated output vector
    d_input_size = 500    # Minibatch size - cardinality of distributions
    d_hidden_size = 10    # Discriminator complexity
    d_output_size = 1     # Single dimension for 'real' vs. 'fake' classification
    minibatch_size = d_input_size

    d_learning_rate = 1e-3
    g_learning_rate = 1e-3
    sgd_momentum = 0.9

    num_epochs = 5000
    print_interval = 100
    d_steps = 10
    g_steps = 10

    dfe, dre, ge = 0, 0, 0
    d_real_data, d_fake_data, g_fake_data = None, None, None

    # Activation functions
    discriminator_activation_function = torch.sigmoid
    generator_activation_function = torch.tanh

    d_sampler = get_distribution_sampler(data_mean, data_stddev)
    gi_sampler = get_generator_input_sampler()
    G = Generator(input_size=g_input_size,
                  hidden_size=g_hidden_size,
                  output_size=g_output_size,
                  f=generator_activation_function)
    D = Discriminator(input_size=d_input_func(d_input_size),
                      hidden_size=d_hidden_size,
                      output_size=d_output_size,
                      f=discriminator_activation_function)
    criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
    d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
    g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)

    # Now, training loop alternates between Generator and Discriminator modes

    for epoch in range(num_epochs):
        for d_index in range(d_steps):
            # 1. Train D on real+fake
            D.zero_grad()

            #  1A: Train D on real data
            d_real_data = Variable(d_sampler(d_input_size))
            d_real_decision = D(preprocess(d_real_data))
            d_real_error = criterion(d_real_decision, Variable(torch.ones([1,1])))  # ones = true
            d_real_error.backward()  # compute/store gradients, but don't change params

            #  1B: Train D on fake data
            d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            d_fake_decision = D(preprocess(d_fake_data.t()))
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros([1,1])))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()

            dre, dfe = extract(d_real_error)[0], extract(d_fake_error)[0]

        for g_index in range(g_steps):
            # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()

            gen_input = Variable(gi_sampler(minibatch_size, g_input_size))
            g_fake_data = G(gen_input)
            dg_fake_decision = D(preprocess(g_fake_data.t()))
            g_error = criterion(dg_fake_decision, Variable(torch.ones([1,1])))  # Train G to pretend it's genuine

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters
            ge = extract(g_error)[0]

        if epoch % print_interval == 0:
            print("\t Epoch %s: D (%s real_err, %s fake_err) G (%s err); Real Dist (%s),  Fake Dist (%s) " %
                  (epoch, dre, dfe, ge, stats(extract(d_real_data)), stats(extract(d_fake_data))))
            sys.stdout.flush()

        values = extract(g_fake_data)        

    return values


Using data [Only 4 moments]


In [2]:
for i in range(10):
    print("Run: ", i)
    ret_values = train()
    print("Seed: %d" % seed)
    # print("Seed: %d; epochs: %d" % (seed, epochs))
    print(fivenum(ret_values))
    
elapsed = time.time() - t
print(elapsed)    

Run:  0
	 Epoch 0: D (0.6278240084648132 real_err, 0.7675003409385681 fake_err) G (0.6247562170028687 err); Real Dist ([4.06202992349863, 1.2533160032819028]),  Fake Dist ([-0.09698092967271804, 0.04072233797370387]) 


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


	 Epoch 100: D (0.6799901127815247 real_err, 0.6880636811256409 fake_err) G (0.6978486776351929 err); Real Dist ([4.041919274926186, 1.2134489407354332]),  Fake Dist ([-0.07967908541113138, 0.04207925442362191]) 
	 Epoch 200: D (0.5843085646629333 real_err, 0.5923900008201599 fake_err) G (0.8001345992088318 err); Real Dist ([4.003920792222023, 1.2362263918415264]),  Fake Dist ([0.3841201916337013, 0.05538094748317778]) 
	 Epoch 300: D (0.5718815922737122 real_err, 0.49410480260849 fake_err) G (0.9614760875701904 err); Real Dist ([4.043231249570846, 1.1933573060573708]),  Fake Dist ([4.563140733718872, 0.01781378383368324]) 
	 Epoch 400: D (0.6067327260971069 real_err, 0.8268445134162903 fake_err) G (0.54496830701828 err); Real Dist ([3.909938768684864, 1.2393620604879318]),  Fake Dist ([2.995265905857086, 1.1687877519010184]) 
	 Epoch 500: D (0.6814371943473816 real_err, 0.6658629179000854 fake_err) G (0.736302375793457 err); Real Dist ([3.9193006608188155, 1.187892911602872]),  Fake D

	 Epoch 4100: D (0.6930255889892578 real_err, 0.6949262022972107 fake_err) G (0.6915705800056458 err); Real Dist ([4.023552442312241, 1.2845319029552291]),  Fake Dist ([3.9706951248645783, 1.2483697452819331]) 
	 Epoch 4200: D (0.6917162537574768 real_err, 0.6940514445304871 fake_err) G (0.6912511587142944 err); Real Dist ([4.020993260502816, 1.2478146340817033]),  Fake Dist ([3.8398239325284957, 1.2462603475651888]) 
	 Epoch 4300: D (0.6925825476646423 real_err, 0.6935195326805115 fake_err) G (0.6919646263122559 err); Real Dist ([4.0036144387722015, 1.3054182478785394]),  Fake Dist ([4.028837070584297, 1.2519464495257528]) 
	 Epoch 4400: D (0.6906870007514954 real_err, 0.6932815313339233 fake_err) G (0.691779375076294 err); Real Dist ([3.973287959434092, 1.2332569519790504]),  Fake Dist ([4.057715004444122, 1.1578513302867122]) 
	 Epoch 4500: D (0.6934181451797485 real_err, 0.6936647891998291 fake_err) G (0.6916343569755554 err); Real Dist ([4.164440550327301, 1.2685123823966626]),  F

	 Epoch 3000: D (0.6946260333061218 real_err, 0.694707989692688 fake_err) G (0.6915898323059082 err); Real Dist ([4.028840819597244, 1.1999598462011702]),  Fake Dist ([9.02365535736084, 0.7909852949218943]) 
	 Epoch 3100: D (0.694471001625061 real_err, 0.694623589515686 fake_err) G (0.6916694641113281 err); Real Dist ([3.965825160300359, 1.2209692478065863]),  Fake Dist ([9.09528422164917, 0.7377494467081701]) 
	 Epoch 3200: D (0.6944129467010498 real_err, 0.6945481300354004 fake_err) G (0.6917474269866943 err); Real Dist ([3.962600977540016, 1.2418446255345057]),  Fake Dist ([9.091712211608886, 0.7542622748339288]) 
	 Epoch 3300: D (0.6939448118209839 real_err, 0.6945003867149353 fake_err) G (0.6917994618415833 err); Real Dist ([4.035449884727597, 1.2756499996161013]),  Fake Dist ([9.032387175559997, 0.7836786022129196]) 
	 Epoch 3400: D (0.6944675445556641 real_err, 0.6943895220756531 fake_err) G (0.6918995976448059 err); Real Dist ([3.978204193174839, 1.2235240078017642]),  Fake Dis

	 Epoch 1900: D (0.7021476030349731 real_err, 0.7028763890266418 fake_err) G (0.6839435696601868 err); Real Dist ([3.923489537477493, 1.273962653384794]),  Fake Dist ([6.020181489467621, 1.8395003295747017]) 
	 Epoch 2000: D (0.6982994675636292 real_err, 0.6983721852302551 fake_err) G (0.6880348920822144 err); Real Dist ([3.96634181483835, 1.210174102394083]),  Fake Dist ([6.138939633846283, 1.919123376031786]) 
	 Epoch 2100: D (0.6957194805145264 real_err, 0.6965949535369873 fake_err) G (0.6897014379501343 err); Real Dist ([4.019937011986971, 1.3633513931078984]),  Fake Dist ([6.154907716751099, 1.8285977296966833]) 
	 Epoch 2200: D (0.6949143409729004 real_err, 0.6953930854797363 fake_err) G (0.6909393668174744 err); Real Dist ([4.112532938957214, 1.2036654517522605]),  Fake Dist ([6.172581834793091, 1.7334890925204736]) 
	 Epoch 2300: D (0.6938026547431946 real_err, 0.6945407390594482 fake_err) G (0.6917802095413208 err); Real Dist ([3.9849299409538506, 1.2720454820304543]),  Fake D

	 Epoch 800: D (0.6976097226142883 real_err, 0.5545181632041931 fake_err) G (0.7769040465354919 err); Real Dist ([3.933922771722078, 1.2473138039767286]),  Fake Dist ([4.558685360908508, 1.1329272976003417]) 
	 Epoch 900: D (0.679829478263855 real_err, 0.6097428798675537 fake_err) G (0.7770048975944519 err); Real Dist ([3.9464895789027215, 1.2341922663635396]),  Fake Dist ([5.475476087331772, 2.893727992733288]) 
	 Epoch 1000: D (0.28365638852119446 real_err, 0.3087335228919983 fake_err) G (1.261656403541565 err); Real Dist ([3.917099401049316, 1.2545220966362551]),  Fake Dist ([3.8143048396110535, 1.4373496015216798]) 
	 Epoch 1100: D (0.19899168610572815 real_err, 0.30393561720848083 fake_err) G (1.269384503364563 err); Real Dist ([4.0367746322155, 1.2397491389937538]),  Fake Dist ([1.9333967065811157, 2.193560655765356]) 
	 Epoch 1200: D (0.6500893831253052 real_err, 0.6691963076591492 fake_err) G (0.7444525361061096 err); Real Dist ([3.982315641820431, 1.3289965112370525]),  Fake D

	 Epoch 4800: D (0.6920680999755859 real_err, 0.6937121748924255 fake_err) G (0.6935848593711853 err); Real Dist ([4.029956232249737, 1.226184227471558]),  Fake Dist ([3.9879204605817793, 1.26044324446339]) 
	 Epoch 4900: D (0.6922128796577454 real_err, 0.6942154169082642 fake_err) G (0.6922561526298523 err); Real Dist ([4.031440401434899, 1.2286579805316646]),  Fake Dist ([3.9910743582248687, 1.2019292670811843]) 
Seed: 123
[0.90231991, 3.58445787, 3.68690157, 4.90297127, 6.57914162]
Run:  4
	 Epoch 0: D (1.016071081161499 real_err, 0.4583829939365387 fake_err) G (0.9961672425270081 err); Real Dist ([3.9587636728324, 1.2572639269708668]),  Fake Dist ([0.16958069199323655, 0.012855211445082391]) 
	 Epoch 100: D (0.6380869150161743 real_err, 0.6260032057762146 fake_err) G (0.7653270363807678 err); Real Dist ([4.057902472019196, 1.2097736724334973]),  Fake Dist ([0.42303221333026886, 0.0031274352618359375]) 
	 Epoch 200: D (0.6299542784690857 real_err, 0.5838536620140076 fake_err) G (0.8

	 Epoch 3700: D (0.6952980160713196 real_err, 0.6888859868049622 fake_err) G (0.6895148158073425 err); Real Dist ([4.029732334554195, 1.2645398952644764]),  Fake Dist ([4.127769768238068, 1.2216038600648234]) 
	 Epoch 3800: D (0.6954100728034973 real_err, 0.6918572187423706 fake_err) G (0.6949708461761475 err); Real Dist ([3.9876725640892983, 1.2234417054931985]),  Fake Dist ([3.918433317422867, 1.2959556883478212]) 
	 Epoch 3900: D (0.6885040402412415 real_err, 0.6973931193351746 fake_err) G (0.6913131475448608 err); Real Dist ([4.010326121807099, 1.2330916223688388]),  Fake Dist ([4.028446849822998, 1.3093224790752263]) 
	 Epoch 4000: D (0.6946171522140503 real_err, 0.6897502541542053 fake_err) G (0.6939640641212463 err); Real Dist ([3.8888979904353618, 1.244343609091028]),  Fake Dist ([3.931962909221649, 1.1487342682362724]) 
	 Epoch 4100: D (0.6915404796600342 real_err, 0.6921494007110596 fake_err) G (0.6916167736053467 err); Real Dist ([4.003088918436319, 1.2426896956324232]),  Fa

	 Epoch 2600: D (0.6845526099205017 real_err, 0.7164925336837769 fake_err) G (0.672747790813446 err); Real Dist ([4.0382154150009155, 1.240430511867864]),  Fake Dist ([4.123039381027222, 1.2737095871137927]) 
	 Epoch 2700: D (0.6758033037185669 real_err, 0.6841816902160645 fake_err) G (0.6754249334335327 err); Real Dist ([4.019112104490399, 1.2578236650956405]),  Fake Dist ([4.081520566940307, 1.2131479761498436]) 
	 Epoch 2800: D (0.6812053322792053 real_err, 0.6682953834533691 fake_err) G (0.6996521949768066 err); Real Dist ([4.083778915762902, 1.2400181231934642]),  Fake Dist ([4.089929563522339, 1.25284538419533]) 
	 Epoch 2900: D (0.6536538600921631 real_err, 0.6446532011032104 fake_err) G (0.6738632917404175 err); Real Dist ([4.000757842123509, 1.2679723476436533]),  Fake Dist ([4.1396783576011655, 1.2911815159239521]) 
	 Epoch 3000: D (0.7066645622253418 real_err, 0.6953697800636292 fake_err) G (0.7565325498580933 err); Real Dist ([4.024324470628053, 1.291274450455293]),  Fake D

	 Epoch 1500: D (0.6950238943099976 real_err, 0.6927435994148254 fake_err) G (0.6938021779060364 err); Real Dist ([4.010182991981506, 1.1966373742464336]),  Fake Dist ([4.9288956592082975, 2.203411741600803]) 
	 Epoch 1600: D (0.6936472654342651 real_err, 0.6913600564002991 fake_err) G (0.6945537328720093 err); Real Dist ([4.007645472660661, 1.2416146220270823]),  Fake Dist ([4.838798904895783, 2.295597427416836]) 
	 Epoch 1700: D (0.6897515654563904 real_err, 0.6876867413520813 fake_err) G (0.6980944871902466 err); Real Dist ([3.938443461060524, 1.270804560326215]),  Fake Dist ([4.680428267002106, 2.5346716634744832]) 
	 Epoch 1800: D (0.6760773658752441 real_err, 0.6744865775108337 fake_err) G (0.7116739153862 err); Real Dist ([3.9680220012366774, 1.3182491093815463]),  Fake Dist ([4.90774642086029, 1.868289155752233]) 
	 Epoch 1900: D (0.6880233883857727 real_err, 0.6963095664978027 fake_err) G (0.6874836683273315 err); Real Dist ([4.057291237711906, 1.242691866783385]),  Fake Dist 

	 Epoch 400: D (0.7698032259941101 real_err, 0.6841868758201599 fake_err) G (0.7243028283119202 err); Real Dist ([3.9287285058498385, 1.2948941742055473]),  Fake Dist ([6.446651108264923, 2.1244453875854554]) 
	 Epoch 500: D (0.691801130771637 real_err, 0.5531315207481384 fake_err) G (0.8298667669296265 err); Real Dist ([3.93959041929245, 1.1996342178745114]),  Fake Dist ([5.71083510017395, 1.8686797691327266]) 
	 Epoch 600: D (0.4900931417942047 real_err, 0.49902206659317017 fake_err) G (0.8801273107528687 err); Real Dist ([4.0463398540057245, 1.2254614831959922]),  Fake Dist ([6.964518931388855, 1.5943835238824282]) 
	 Epoch 700: D (0.1854420304298401 real_err, 0.19532136619091034 fake_err) G (1.6721917390823364 err); Real Dist ([4.0167645382881165, 1.2299797832444426]),  Fake Dist ([5.9860596817731855, 2.927929824318766]) 
	 Epoch 800: D (0.998485803604126 real_err, 0.31597939133644104 fake_err) G (0.04849136248230934 err); Real Dist ([3.9739978635311126, 1.2665023999940321]),  Fake

	 Epoch 4400: D (0.6922075152397156 real_err, 0.6949298977851868 fake_err) G (0.6927960515022278 err); Real Dist ([3.959922316133976, 1.2237003462167448]),  Fake Dist ([3.9421669483184814, 1.1790253152629209]) 
	 Epoch 4500: D (0.6892807483673096 real_err, 0.6949189901351929 fake_err) G (0.6939820051193237 err); Real Dist ([4.006684739887715, 1.2659907343824122]),  Fake Dist ([4.095604979753494, 1.1943023690892605]) 
	 Epoch 4600: D (0.6934569478034973 real_err, 0.6927343010902405 fake_err) G (0.6955490708351135 err); Real Dist ([3.981746079444885, 1.2579805746469177]),  Fake Dist ([3.960471501350403, 1.2030680564915497]) 
	 Epoch 4700: D (0.6917465925216675 real_err, 0.6890989542007446 fake_err) G (0.6926774978637695 err); Real Dist ([3.9828487466014924, 1.2682904452048942]),  Fake Dist ([3.9418619830608366, 1.3129163866912625]) 
	 Epoch 4800: D (0.6920742988586426 real_err, 0.6944957375526428 fake_err) G (0.69220370054245 err); Real Dist ([4.0332703897915785, 1.2732048554375488]),  F

	 Epoch 3300: D (0.6915822625160217 real_err, 0.6894323825836182 fake_err) G (0.6992677450180054 err); Real Dist ([4.080684035539627, 1.1955428084508481]),  Fake Dist ([4.263055449008942, 0.9537457857346681]) 
	 Epoch 3400: D (0.6987778544425964 real_err, 0.697763204574585 fake_err) G (0.6883143186569214 err); Real Dist ([4.037419785380363, 1.2442353348118464]),  Fake Dist ([3.6350712950229647, 0.931485385390099]) 
	 Epoch 3500: D (0.6863767504692078 real_err, 0.6917669773101807 fake_err) G (0.6911109089851379 err); Real Dist ([4.005043379247189, 1.2787503023782636]),  Fake Dist ([3.3822673666477203, 1.0168474497016906]) 
	 Epoch 3600: D (0.6917065978050232 real_err, 0.6954320669174194 fake_err) G (0.690969705581665 err); Real Dist ([4.088095718264579, 1.2912463069279314]),  Fake Dist ([3.7860892572402953, 1.2918027057227806]) 
	 Epoch 3700: D (0.6946706771850586 real_err, 0.6948555707931519 fake_err) G (0.6915220022201538 err); Real Dist ([3.9341873797774314, 1.2354558314216717]),  Fa

	 Epoch 2200: D (0.7087257504463196 real_err, 0.6875905394554138 fake_err) G (0.7030124068260193 err); Real Dist ([4.049053099393845, 1.2113979196210334]),  Fake Dist ([4.357091289997101, 1.2919922909946708]) 
	 Epoch 2300: D (0.7060800194740295 real_err, 0.6850689053535461 fake_err) G (0.7023137211799622 err); Real Dist ([3.9929314313083886, 1.2044248427612814]),  Fake Dist ([4.176114661693573, 1.1733370735584407]) 
	 Epoch 2400: D (0.6870611310005188 real_err, 0.6922172904014587 fake_err) G (0.6966923475265503 err); Real Dist ([3.9157642422914507, 1.220494436106193]),  Fake Dist ([4.141368027210236, 1.3032221746046029]) 
	 Epoch 2500: D (0.6883088946342468 real_err, 0.7049814462661743 fake_err) G (0.6873268485069275 err); Real Dist ([4.064892911419272, 1.2885313105896694]),  Fake Dist ([3.838355712413788, 1.3190613121558814]) 
	 Epoch 2600: D (0.6855680346488953 real_err, 0.702788770198822 fake_err) G (0.6927947402000427 err); Real Dist ([3.9123600185364484, 1.2454034246576027]),  Fa