In [13]:
import os
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [14]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

num_epochs = 200
batch_size = 100
latent_size = 64
hidden_size = 256
image_size = 784

sample_dir = './samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)


# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.5, 0.5, 0.5),
#                          std=(0.5, 0.5, 0.5))
# ])
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])

train_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=transform,
                                           download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

print(len(train_loader))

600


In [15]:
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

D = D.to(device)
G = G.to(device)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=2e-4)
g_optimizer = torch.optim.Adam(G.parameters(), lr=2e-4)

def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
#         print(images.size())
        images = images.reshape(batch_size, -1).to(device)
        
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        ##
        # train the discriminator
        ##
        outputs = D(images)
        # BCELoss(x, y): -y * log(D(x)) - (1 - y) * log(1 - D(x))
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        d_loss = d_loss_real + d_loss_fake
        
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        ##
        # train the generator
        ##
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        d_optimizer.zero_grad()
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
    
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

Epoch [0/200], Step [200/600], d_loss: 0.0456, g_loss: 4.1334, D(x): 0.99, D(G(z)): 0.04
Epoch [0/200], Step [400/600], d_loss: 0.0341, g_loss: 5.4638, D(x): 1.00, D(G(z)): 0.03
Epoch [0/200], Step [600/600], d_loss: 0.0507, g_loss: 5.2047, D(x): 0.98, D(G(z)): 0.02
Epoch [1/200], Step [200/600], d_loss: 0.1295, g_loss: 5.6997, D(x): 0.96, D(G(z)): 0.08
Epoch [1/200], Step [400/600], d_loss: 0.0866, g_loss: 5.5478, D(x): 0.97, D(G(z)): 0.05
Epoch [1/200], Step [600/600], d_loss: 0.3095, g_loss: 4.5735, D(x): 0.90, D(G(z)): 0.07
Epoch [2/200], Step [200/600], d_loss: 0.1137, g_loss: 4.1576, D(x): 0.94, D(G(z)): 0.04
Epoch [2/200], Step [400/600], d_loss: 0.7999, g_loss: 2.7373, D(x): 0.69, D(G(z)): 0.17
Epoch [2/200], Step [600/600], d_loss: 0.5965, g_loss: 4.6036, D(x): 0.80, D(G(z)): 0.11
Epoch [3/200], Step [200/600], d_loss: 0.9501, g_loss: 1.8834, D(x): 0.72, D(G(z)): 0.33
Epoch [3/200], Step [400/600], d_loss: 0.6029, g_loss: 3.7393, D(x): 0.80, D(G(z)): 0.18
Epoch [3/200], Step [

Epoch [30/200], Step [600/600], d_loss: 0.3587, g_loss: 4.3738, D(x): 0.86, D(G(z)): 0.07
Epoch [31/200], Step [200/600], d_loss: 0.6168, g_loss: 2.4980, D(x): 0.76, D(G(z)): 0.08
Epoch [31/200], Step [400/600], d_loss: 0.3601, g_loss: 2.9472, D(x): 0.92, D(G(z)): 0.19
Epoch [31/200], Step [600/600], d_loss: 0.3821, g_loss: 3.6261, D(x): 0.89, D(G(z)): 0.09
Epoch [32/200], Step [200/600], d_loss: 0.3852, g_loss: 4.5451, D(x): 0.86, D(G(z)): 0.09
Epoch [32/200], Step [400/600], d_loss: 0.4801, g_loss: 4.4227, D(x): 0.86, D(G(z)): 0.11
Epoch [32/200], Step [600/600], d_loss: 0.3675, g_loss: 2.8905, D(x): 0.88, D(G(z)): 0.14
Epoch [33/200], Step [200/600], d_loss: 0.2645, g_loss: 3.1049, D(x): 0.95, D(G(z)): 0.14
Epoch [33/200], Step [400/600], d_loss: 0.4927, g_loss: 3.5728, D(x): 0.92, D(G(z)): 0.22
Epoch [33/200], Step [600/600], d_loss: 0.2850, g_loss: 5.2783, D(x): 0.90, D(G(z)): 0.07
Epoch [34/200], Step [200/600], d_loss: 0.5819, g_loss: 2.9279, D(x): 0.84, D(G(z)): 0.16
Epoch [34/

Epoch [61/200], Step [400/600], d_loss: 0.5619, g_loss: 2.3150, D(x): 0.90, D(G(z)): 0.28
Epoch [61/200], Step [600/600], d_loss: 0.3090, g_loss: 2.8950, D(x): 0.87, D(G(z)): 0.09
Epoch [62/200], Step [200/600], d_loss: 0.5908, g_loss: 2.3681, D(x): 0.83, D(G(z)): 0.22
Epoch [62/200], Step [400/600], d_loss: 0.5786, g_loss: 3.1380, D(x): 0.80, D(G(z)): 0.14
Epoch [62/200], Step [600/600], d_loss: 0.5373, g_loss: 2.6153, D(x): 0.80, D(G(z)): 0.13
Epoch [63/200], Step [200/600], d_loss: 0.5519, g_loss: 2.5578, D(x): 0.80, D(G(z)): 0.17
Epoch [63/200], Step [400/600], d_loss: 0.5159, g_loss: 3.0044, D(x): 0.83, D(G(z)): 0.16
Epoch [63/200], Step [600/600], d_loss: 0.5887, g_loss: 3.2207, D(x): 0.77, D(G(z)): 0.11
Epoch [64/200], Step [200/600], d_loss: 0.5523, g_loss: 2.1948, D(x): 0.85, D(G(z)): 0.22
Epoch [64/200], Step [400/600], d_loss: 0.6964, g_loss: 3.0254, D(x): 0.86, D(G(z)): 0.32
Epoch [64/200], Step [600/600], d_loss: 0.4147, g_loss: 2.5507, D(x): 0.85, D(G(z)): 0.14
Epoch [65/

Epoch [92/200], Step [200/600], d_loss: 0.8039, g_loss: 1.8048, D(x): 0.79, D(G(z)): 0.30
Epoch [92/200], Step [400/600], d_loss: 0.6304, g_loss: 1.8145, D(x): 0.79, D(G(z)): 0.23
Epoch [92/200], Step [600/600], d_loss: 0.8682, g_loss: 2.2134, D(x): 0.71, D(G(z)): 0.25
Epoch [93/200], Step [200/600], d_loss: 0.9773, g_loss: 1.5822, D(x): 0.69, D(G(z)): 0.29
Epoch [93/200], Step [400/600], d_loss: 0.6673, g_loss: 1.9182, D(x): 0.73, D(G(z)): 0.16
Epoch [93/200], Step [600/600], d_loss: 0.6006, g_loss: 2.0082, D(x): 0.79, D(G(z)): 0.20
Epoch [94/200], Step [200/600], d_loss: 0.6493, g_loss: 1.9713, D(x): 0.81, D(G(z)): 0.27
Epoch [94/200], Step [400/600], d_loss: 0.5696, g_loss: 2.2121, D(x): 0.79, D(G(z)): 0.21
Epoch [94/200], Step [600/600], d_loss: 0.7944, g_loss: 1.3370, D(x): 0.79, D(G(z)): 0.33
Epoch [95/200], Step [200/600], d_loss: 0.8868, g_loss: 1.7004, D(x): 0.66, D(G(z)): 0.24
Epoch [95/200], Step [400/600], d_loss: 0.9226, g_loss: 1.8715, D(x): 0.73, D(G(z)): 0.31
Epoch [95/

Epoch [122/200], Step [400/600], d_loss: 0.9068, g_loss: 1.7738, D(x): 0.68, D(G(z)): 0.21
Epoch [122/200], Step [600/600], d_loss: 0.9505, g_loss: 1.6063, D(x): 0.70, D(G(z)): 0.29
Epoch [123/200], Step [200/600], d_loss: 0.7971, g_loss: 2.3491, D(x): 0.76, D(G(z)): 0.29
Epoch [123/200], Step [400/600], d_loss: 0.7396, g_loss: 1.8173, D(x): 0.79, D(G(z)): 0.28
Epoch [123/200], Step [600/600], d_loss: 1.0366, g_loss: 1.1611, D(x): 0.70, D(G(z)): 0.33
Epoch [124/200], Step [200/600], d_loss: 1.1354, g_loss: 1.4475, D(x): 0.64, D(G(z)): 0.31
Epoch [124/200], Step [400/600], d_loss: 0.7999, g_loss: 1.8420, D(x): 0.72, D(G(z)): 0.26
Epoch [124/200], Step [600/600], d_loss: 0.8870, g_loss: 1.8781, D(x): 0.67, D(G(z)): 0.25
Epoch [125/200], Step [200/600], d_loss: 0.8566, g_loss: 1.7060, D(x): 0.75, D(G(z)): 0.32
Epoch [125/200], Step [400/600], d_loss: 0.7711, g_loss: 1.4947, D(x): 0.81, D(G(z)): 0.36
Epoch [125/200], Step [600/600], d_loss: 0.8155, g_loss: 1.6181, D(x): 0.76, D(G(z)): 0.29

Epoch [152/200], Step [600/600], d_loss: 0.9489, g_loss: 1.3232, D(x): 0.73, D(G(z)): 0.37
Epoch [153/200], Step [200/600], d_loss: 0.8805, g_loss: 1.9609, D(x): 0.72, D(G(z)): 0.29
Epoch [153/200], Step [400/600], d_loss: 0.8632, g_loss: 1.5145, D(x): 0.68, D(G(z)): 0.24
Epoch [153/200], Step [600/600], d_loss: 0.7259, g_loss: 1.7041, D(x): 0.76, D(G(z)): 0.26
Epoch [154/200], Step [200/600], d_loss: 0.8249, g_loss: 2.0340, D(x): 0.68, D(G(z)): 0.22
Epoch [154/200], Step [400/600], d_loss: 0.9603, g_loss: 1.3238, D(x): 0.73, D(G(z)): 0.39
Epoch [154/200], Step [600/600], d_loss: 0.8933, g_loss: 1.4106, D(x): 0.70, D(G(z)): 0.30
Epoch [155/200], Step [200/600], d_loss: 0.6985, g_loss: 1.5735, D(x): 0.74, D(G(z)): 0.23
Epoch [155/200], Step [400/600], d_loss: 0.8361, g_loss: 1.3675, D(x): 0.72, D(G(z)): 0.30
Epoch [155/200], Step [600/600], d_loss: 1.1947, g_loss: 1.2038, D(x): 0.73, D(G(z)): 0.44
Epoch [156/200], Step [200/600], d_loss: 0.9323, g_loss: 1.5401, D(x): 0.66, D(G(z)): 0.27

Epoch [183/200], Step [200/600], d_loss: 0.9355, g_loss: 1.4513, D(x): 0.74, D(G(z)): 0.33
Epoch [183/200], Step [400/600], d_loss: 0.9016, g_loss: 1.3154, D(x): 0.69, D(G(z)): 0.30
Epoch [183/200], Step [600/600], d_loss: 1.0332, g_loss: 1.4579, D(x): 0.67, D(G(z)): 0.35
Epoch [184/200], Step [200/600], d_loss: 0.8075, g_loss: 1.7763, D(x): 0.71, D(G(z)): 0.26
Epoch [184/200], Step [400/600], d_loss: 0.9105, g_loss: 1.6573, D(x): 0.66, D(G(z)): 0.23
Epoch [184/200], Step [600/600], d_loss: 1.0164, g_loss: 1.7425, D(x): 0.61, D(G(z)): 0.27
Epoch [185/200], Step [200/600], d_loss: 0.9923, g_loss: 1.5860, D(x): 0.69, D(G(z)): 0.34
Epoch [185/200], Step [400/600], d_loss: 1.0041, g_loss: 1.5478, D(x): 0.76, D(G(z)): 0.39
Epoch [185/200], Step [600/600], d_loss: 1.0282, g_loss: 1.4608, D(x): 0.76, D(G(z)): 0.39
Epoch [186/200], Step [200/600], d_loss: 0.8312, g_loss: 1.6103, D(x): 0.70, D(G(z)): 0.28
Epoch [186/200], Step [400/600], d_loss: 0.7807, g_loss: 2.0369, D(x): 0.70, D(G(z)): 0.23