In [12]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch import Tensor

from tqdm import tqdm_notebook

import matplotlib.pyplot as plt
%matplotlib tk
plt.style.use('ggplot')

In [190]:
class DataDistribution(object):
    def __init__(self):
        self.mu = 4
        self.sigma = 0.5

    def sample(self, N):
        samples = np.random.normal(self.mu, self.sigma, N)
        samples.sort()
        return samples

class GeneratorDistribution(object):
    def __init__(self, range):
        self.range = range

    def sample(self, N):
        return np.linspace(-self.range, self.range, N) + \
            np.random.random(N) * 0.01
    
data = DataDistribution()
gen = GeneratorDistribution(8)

In [207]:
batch_size = 100

g_in = 1
g_hid = 4
d_in = batch_size
d_hid = 4

G = nn.Sequential(
    nn.Linear(g_in, g_hid),
    nn.Softplus(),
    nn.Linear(g_hid, 1)
)

D = nn.Sequential(
    nn.Linear(d_in, d_hid * 2),
    nn.Tanh(),
    nn.Linear(d_hid * 2, d_hid * 2),
    nn.Tanh(),
    nn.Linear(d_hid * 2, d_hid * 2),
    nn.Tanh(),
    nn.Linear(d_hid * 2, 1),
    #nn.Sigmoid()
)

G_opt = optim.RMSprop(G.parameters(), lr=5e-5)
D_opt = optim.RMSprop(D.parameters(), lr=5e-5)

#loss_d = lambda r, f: torch.mean(-torch.log(r) - torch.log(1 - f))
#loss_g = lambda y: torch.mean(-torch.log(y))
loss_d = lambda r, f: -(torch.mean(r) - torch.mean(f))
loss_g = lambda y: -torch.mean(y)

In [208]:
def train():
    def train_D():
        x = Variable(Tensor(data.sample(batch_size)))
        z = Variable(Tensor(gen.sample(batch_size)))
        D1 = D(x.unsqueeze(0))
        D2 = G(z.unsqueeze(1))
        D2.detach()
        D2 = D(D2.t())
        D_err = loss_d(D1, D2)
        D_err.backward()
        D_opt.step()
        for p in D.parameters():
            p.data.clamp_(-0.01, 0.01)
        return D_err.data[0]
    
    def train_G():
        z = Variable(Tensor(gen.sample(batch_size))) 
        D3 = D(G(z.unsqueeze(1)).t())
        G_err = loss_g(D3)
        G_err.backward()
        G_opt.step()
        return G_err.data[0]
    
    D_losses = []
    G_losses = []
    e_bar = tqdm_notebook(range(5000))
    for e in e_bar:
        for _ in range(5):
            D_loss = train_D()
            D_losses.append(D_loss)
        G_loss = train_G()
        G_losses.append(G_loss)
        e_bar.set_postfix(
            D_loss=D_loss,
            G_loss=G_loss
        )
    return D_losses, G_losses

In [209]:
D_losses, G_losses = train()

In [212]:
plt.plot(np.array(D_losses[10:])*-1, label='Discriminator')
#plt.plot(G_losses, label='Generator')
plt.show()

In [217]:
z = Variable(Tensor(gen.sample(batch_size)))
g = G(z.unsqueeze(1)).data.numpy().flatten()
plt.hist(g)
plt.show()
