In [1]:
import torch
print(torch.__version__)

0.4.1


In [3]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image




In [4]:
# create directory

sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [5]:
# loading and preprocessing MNIST

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='../../data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=32, 
                                          shuffle=True)

In [6]:
#MNIST is 28x28 = 784 pixels

#Discriminator
D = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Linear(256, 256),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Linear(256, 1),
        nn.Sigmoid())
# Generator
G = nn.Sequential(
        nn.Linear(64, 256),
        nn.ReLU(inplace=True),
        nn.Linear(256, 256),
        nn.ReLU(inplace=True),
        nn.Linear(256, 784),
        nn.Tanh())



In [13]:
# define loss and optimizers
criterion = nn.BCELoss()
d_opt = torch.optim.Adam(D.parameters(), lr = 0.0002)
g_opt = torch.optim.Adam(G.parameters(), lr = 0.0002)

In [17]:
# 
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
    

In [23]:
# training
n_epochs = 20
total_step = len(data_loader)
for epoch in range(n_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.view(images.size(0), -1) # flatten images for MLP
        
        #create labels (1 for real img, 0 for fake img)
        real_labels = torch.ones(32, 1) # batchsize = 32
        fake_labels = torch.zeros(32, 1)

        
        
        ########################
        # training discrimiator
        ########################
        
        # zeroing gradients
        d_opt.zero_grad()
        
        # loss for real images
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # loss for fake images
        z = torch.randn(32, 64)
        fake_img = G(z)
        outputs = D(fake_img)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # total loss
        total_loss = d_loss_real + d_loss_fake

        total_loss.backward()
        d_opt.step()
        
        ########################
        # training generator
        ########################
        # zeroing gradients
        g_opt.zero_grad()
        
        # generating fake imgs 
        # batch size = 32, input size is 64
        z = torch.randn(32, 64)
        fake_img = G(z)
        outputs = D(fake_img)
        
        # loss for generator
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_opt.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch+1, n_epochs, i+1, total_step, total_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))
            
#     #real images
#     if (epoch+1) == 1:
#         images = images.reshape(images.size(0), 1, 28, 28)
#         save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
#     # generated images
#     fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
#     save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
    

# # Save the model checkpoints 
# #torch.save(G.state_dict(), 'G.ckpt')
# #torch.save(D.state_dict(), 'D.ckpt')


Epoch [1/20], Step [200/1875], d_loss: 1.1758, g_loss: 1.6279, D(x): 0.72, D(G(z)): 0.43
Epoch [1/20], Step [400/1875], d_loss: 0.7457, g_loss: 1.8307, D(x): 0.81, D(G(z)): 0.28
Epoch [1/20], Step [600/1875], d_loss: 0.6718, g_loss: 2.4484, D(x): 0.75, D(G(z)): 0.19
Epoch [1/20], Step [800/1875], d_loss: 1.2043, g_loss: 1.5941, D(x): 0.56, D(G(z)): 0.31
Epoch [1/20], Step [1000/1875], d_loss: 0.9384, g_loss: 1.6775, D(x): 0.62, D(G(z)): 0.22
Epoch [1/20], Step [1200/1875], d_loss: 0.9176, g_loss: 1.8270, D(x): 0.67, D(G(z)): 0.26
Epoch [1/20], Step [1400/1875], d_loss: 0.9829, g_loss: 1.8166, D(x): 0.58, D(G(z)): 0.21
Epoch [1/20], Step [1600/1875], d_loss: 0.8997, g_loss: 1.5854, D(x): 0.61, D(G(z)): 0.23
Epoch [1/20], Step [1800/1875], d_loss: 0.7610, g_loss: 1.1318, D(x): 0.80, D(G(z)): 0.28
Epoch [2/20], Step [200/1875], d_loss: 1.0621, g_loss: 1.3745, D(x): 0.68, D(G(z)): 0.33
Epoch [2/20], Step [400/1875], d_loss: 1.0164, g_loss: 1.5736, D(x): 0.64, D(G(z)): 0.26
Epoch [2/20], St