In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets 
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable

In [2]:
def to_var(x):
    return Variable(x)

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [3]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), 
                                     std=(0.5, 0.5, 0.5))])
# MNIST dataset
mnist = datasets.MNIST(root='./data/',
                       train=True,
                       transform=transform,
                       download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=100, 
                                          shuffle=True)

In [4]:
D = nn.Sequential(
    nn.Linear(784, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 1),
    nn.Sigmoid())

# Generator 
G = nn.Sequential(
    nn.Linear(64, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 256),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 784),
    nn.Tanh())

In [8]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)

In [9]:
for epoch in range(200):
    for i, (images, _) in enumerate(data_loader):
        # Build mini-batch dataset
        batch_size = images.size(0)
        images = to_var(images.view(batch_size, -1))
       
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = to_var(torch.ones(batch_size))
        fake_labels = to_var(torch.zeros(batch_size))

        #============= Train the discriminator =============#
        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = to_var(torch.randn(batch_size, 64))
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop + Optimize
        d_loss = d_loss_real + d_loss_fake
        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        #=============== Train the generator ===============#
        # Compute loss with fake images
        z = to_var(torch.randn(batch_size, 64))
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
        # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
        g_loss = criterion(outputs, real_labels)
        
        # Backprop + Optimize
        D.zero_grad()
        G.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 300 == 0:
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, '
                  'g_loss: %.4f, D(x): %.2f, D(G(z)): %.2f' 
                  %(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
                    real_score.data.mean(), fake_score.data.mean()))
    
    # Save real images
    if (epoch+1) == 1:
        images = images.view(images.size(0), 1, 28, 28)
        save_image(denorm(images.data), './data/real_images.png')
    
    # Save sampled images
    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images.data), './data/fake_images-%d.png' %(epoch+1))

# Save the trained parameters 
torch.save(G.state_dict(), './generator.pkl')
torch.save(D.state_dict(), './discriminator.pkl')

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [0/200], Step[300/600], d_loss: 0.2789, g_loss: 5.0400, D(x): 0.95, D(G(z)): 0.17
Epoch [0/200], Step[600/600], d_loss: 0.4416, g_loss: 2.9956, D(x): 0.87, D(G(z)): 0.16
Epoch [1/200], Step[300/600], d_loss: 0.2265, g_loss: 3.8500, D(x): 0.96, D(G(z)): 0.15
Epoch [1/200], Step[600/600], d_loss: 0.6512, g_loss: 2.7178, D(x): 0.83, D(G(z)): 0.28
Epoch [2/200], Step[300/600], d_loss: 0.1247, g_loss: 3.5515, D(x): 0.96, D(G(z)): 0.07
Epoch [2/200], Step[600/600], d_loss: 0.5509, g_loss: 2.7527, D(x): 0.82, D(G(z)): 0.21
Epoch [3/200], Step[300/600], d_loss: 2.2102, g_loss: 0.5968, D(x): 0.49, D(G(z)): 0.68
Epoch [3/200], Step[600/600], d_loss: 0.7991, g_loss: 3.6314, D(x): 0.68, D(G(z)): 0.17
Epoch [4/200], Step[300/600], d_loss: 3.6753, g_loss: 0.4076, D(x): 0.21, D(G(z)): 0.70
Epoch [4/200], Step[600/600], d_loss: 1.2039, g_loss: 2.1944, D(x): 0.66, D(G(z)): 0.40
Epoch [5/200], Step[300/600], d_loss: 1.0808, g_loss: 1.5716, D(x): 0.70, D(G(z)): 0.37
Epoch [5/200], Step[600/600], d_

Epoch [46/200], Step[600/600], d_loss: 0.7099, g_loss: 2.3821, D(x): 0.80, D(G(z)): 0.25
Epoch [47/200], Step[300/600], d_loss: 0.8218, g_loss: 1.4865, D(x): 0.71, D(G(z)): 0.26
Epoch [47/200], Step[600/600], d_loss: 0.9060, g_loss: 1.8449, D(x): 0.68, D(G(z)): 0.24
Epoch [48/200], Step[300/600], d_loss: 0.8488, g_loss: 1.9408, D(x): 0.70, D(G(z)): 0.26
Epoch [48/200], Step[600/600], d_loss: 1.1781, g_loss: 1.7131, D(x): 0.57, D(G(z)): 0.28
Epoch [49/200], Step[300/600], d_loss: 0.7507, g_loss: 1.7769, D(x): 0.79, D(G(z)): 0.31
Epoch [49/200], Step[600/600], d_loss: 0.7293, g_loss: 2.2002, D(x): 0.82, D(G(z)): 0.28
Epoch [50/200], Step[300/600], d_loss: 0.8951, g_loss: 1.5988, D(x): 0.67, D(G(z)): 0.25
Epoch [50/200], Step[600/600], d_loss: 0.7707, g_loss: 1.5193, D(x): 0.78, D(G(z)): 0.30
Epoch [51/200], Step[300/600], d_loss: 1.0064, g_loss: 1.6417, D(x): 0.67, D(G(z)): 0.30
Epoch [51/200], Step[600/600], d_loss: 0.7817, g_loss: 2.0206, D(x): 0.74, D(G(z)): 0.27
Epoch [52/200], Step[

Epoch [93/200], Step[300/600], d_loss: 0.7732, g_loss: 1.7514, D(x): 0.75, D(G(z)): 0.28
Epoch [93/200], Step[600/600], d_loss: 0.8493, g_loss: 1.4584, D(x): 0.73, D(G(z)): 0.29
Epoch [94/200], Step[300/600], d_loss: 0.8068, g_loss: 1.4537, D(x): 0.76, D(G(z)): 0.32
Epoch [94/200], Step[600/600], d_loss: 0.7325, g_loss: 2.1571, D(x): 0.72, D(G(z)): 0.23
Epoch [95/200], Step[300/600], d_loss: 0.7916, g_loss: 1.7870, D(x): 0.68, D(G(z)): 0.20
Epoch [95/200], Step[600/600], d_loss: 0.8044, g_loss: 1.6687, D(x): 0.73, D(G(z)): 0.25
Epoch [96/200], Step[300/600], d_loss: 0.9523, g_loss: 1.7640, D(x): 0.69, D(G(z)): 0.32
Epoch [96/200], Step[600/600], d_loss: 1.0614, g_loss: 1.5829, D(x): 0.74, D(G(z)): 0.40
Epoch [97/200], Step[300/600], d_loss: 0.9284, g_loss: 1.5215, D(x): 0.69, D(G(z)): 0.28
Epoch [97/200], Step[600/600], d_loss: 0.7755, g_loss: 1.4854, D(x): 0.78, D(G(z)): 0.32
Epoch [98/200], Step[300/600], d_loss: 0.9370, g_loss: 1.8482, D(x): 0.62, D(G(z)): 0.21
Epoch [98/200], Step[

Epoch [139/200], Step[300/600], d_loss: 0.7342, g_loss: 1.6314, D(x): 0.75, D(G(z)): 0.25
Epoch [139/200], Step[600/600], d_loss: 0.7890, g_loss: 1.7200, D(x): 0.68, D(G(z)): 0.23
Epoch [140/200], Step[300/600], d_loss: 0.8105, g_loss: 1.9028, D(x): 0.70, D(G(z)): 0.24
Epoch [140/200], Step[600/600], d_loss: 0.9061, g_loss: 1.6050, D(x): 0.68, D(G(z)): 0.26
Epoch [141/200], Step[300/600], d_loss: 0.9346, g_loss: 1.9920, D(x): 0.72, D(G(z)): 0.30
Epoch [141/200], Step[600/600], d_loss: 0.8147, g_loss: 1.4573, D(x): 0.80, D(G(z)): 0.33
Epoch [142/200], Step[300/600], d_loss: 0.8426, g_loss: 2.1531, D(x): 0.75, D(G(z)): 0.28
Epoch [142/200], Step[600/600], d_loss: 0.8465, g_loss: 1.4514, D(x): 0.75, D(G(z)): 0.30
Epoch [143/200], Step[300/600], d_loss: 0.7279, g_loss: 1.8481, D(x): 0.76, D(G(z)): 0.27
Epoch [143/200], Step[600/600], d_loss: 0.7361, g_loss: 2.0338, D(x): 0.72, D(G(z)): 0.23
Epoch [144/200], Step[300/600], d_loss: 0.7057, g_loss: 1.6406, D(x): 0.77, D(G(z)): 0.25
Epoch [144

Epoch [185/200], Step[300/600], d_loss: 0.7400, g_loss: 1.5224, D(x): 0.77, D(G(z)): 0.28
Epoch [185/200], Step[600/600], d_loss: 0.8155, g_loss: 1.9356, D(x): 0.72, D(G(z)): 0.26
Epoch [186/200], Step[300/600], d_loss: 1.0106, g_loss: 1.8558, D(x): 0.82, D(G(z)): 0.42
Epoch [186/200], Step[600/600], d_loss: 0.6964, g_loss: 2.0779, D(x): 0.74, D(G(z)): 0.23
Epoch [187/200], Step[300/600], d_loss: 0.8673, g_loss: 1.9550, D(x): 0.67, D(G(z)): 0.24
Epoch [187/200], Step[600/600], d_loss: 0.9257, g_loss: 1.7528, D(x): 0.63, D(G(z)): 0.22
Epoch [188/200], Step[300/600], d_loss: 0.7990, g_loss: 1.3966, D(x): 0.80, D(G(z)): 0.31
Epoch [188/200], Step[600/600], d_loss: 0.7851, g_loss: 2.2225, D(x): 0.71, D(G(z)): 0.20
Epoch [189/200], Step[300/600], d_loss: 0.8461, g_loss: 2.1860, D(x): 0.66, D(G(z)): 0.20
Epoch [189/200], Step[600/600], d_loss: 0.7641, g_loss: 1.7276, D(x): 0.76, D(G(z)): 0.27
Epoch [190/200], Step[300/600], d_loss: 0.7211, g_loss: 1.5631, D(x): 0.73, D(G(z)): 0.22
Epoch [190