In [1]:
import torch
import numpy as np
from torch import nn
from torch import optim
from torch.autograd import Variable
import torch.nn.functional as F

In [2]:
# [gan_pytorch.py](https://github.com/devnag/pytorch-generative-adversarial-networks/blob/master/gan_pytorch.py)

In [3]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.sigmoid(self.fc3(x)) # sigmoid判断数据real还是fake
        return x

In [5]:
g_input_size = 1
g_hidden_size = 50
g_output_size = 1

In [6]:
d_input_size = 1
d_hidden_size = 50
d_output_size = 1

In [7]:
G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)

In [8]:
D = Discriminator(input_size=d_input_size, hidden_size=d_hidden_size, output_size=d_output_size)

In [9]:
criterion = nn.BCELoss() # binary cross entropy

In [10]:
g_optimizer = optim.Adam(G.parameters(), lr=0.0001)

In [11]:
d_optimizer = optim.Adam(D.parameters(), lr=0.0001)

In [12]:
def d_sampler_f(n):
    mu = 4
    sigma = 1.25
    return torch.Tensor(np.random.normal(mu, sigma, (1, n)))

In [13]:
d_sampler_f(5)


 3.9810  2.9058  3.1202  4.0703  4.2688
[torch.FloatTensor of size 1x5]

In [14]:
def g_sampler_f(n):
    return torch.randn(1, n)

In [15]:
g_sampler_f(5)


-0.3719  0.1446  0.8508 -0.0931 -0.3209
[torch.FloatTensor of size 1x5]

In [16]:
for epoch in range(1):
    # 训练D
    d_optimizer.zero_grad()
    
    # real data训练
    d_real_sampler = Variable(d_sampler_f(d_input_size))
    d_real_pred = D(d_real_sampler)
    d_real_error = criterion(d_real_pred, Variable(torch.ones(d_real_pred.size())))
    
    d_real_error.backward()
    
    # fake data训练
    g_input = Variable(g_sampler_f(g_input_size))
    d_fake_sampler = G(g_input)
    d_fake_pred = D(d_fake_sampler)
    d_fake_error = criterion(d_fake_pred, Variable(torch.zeros(d_fake_pred.size())))
    d_fake_error.backward()
    ####
    
    d_optimizer.step()
    
    # 训练G
    g_input = Variable(g_sampler_f(g_input_size))
    d_sampler = G(g_input)
    d_pred = D(d_sampler)
    g_error = criterion(d_pred, Variable(torch.ones(d_pred.size())))
    g_error.backward()
    
    ####
    
    g_optimizer.step()
    
    if epoch%500==0:
        print('epoch:', epoch, ',d_fake_error:', d_fake_error.data.numpy()[0], ',d_real_error:', d_real_error.data.numpy()[0], ',g_error:', g_error.data.numpy()[0])
    

('epoch:', 0, ',d_fake_error:', 0.69772047, ',d_real_error:', 0.68049753, ',g_error:', 0.6862129)
