In [3]:
import torch
print(torch.__version__)

0.4.1


In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image




In [2]:
# create directory

sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [4]:
# loading and preprocessing MNIST

transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='../../data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=32, 
                                          shuffle=True)

In [5]:
#MNIST is 28x28 = 784 pixels

#Discriminator
D = nn.Sequential(
        nn.Linear(784, 256),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Linear(256, 256),
        nn.LeakyReLU(negative_slope=0.2, inplace=True),
        nn.Linear(256, 1),
        nn.Sigmoid())
# Generator
G = nn.Sequential(
        nn.Linear(64, 256),
        nn.ReLU(inplace=True),
        nn.Linear(256, 256),
        nn.ReLU(inplace=True),
        nn.Linear(256, 784),
        nn.Tanh())



In [6]:
# define loss and optimizers
criterion = nn.BCELoss()
d_opt = torch.optim.Adam(D.parameters(), lr = 0.0002)
g_opt = torch.optim.Adam(G.parameters(), lr = 0.0002)

In [7]:
# denorm for plotting
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)


In [17]:
# training
n_epochs = 20
total_step = len(data_loader)
for epoch in range(n_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.view(images.size(0), -1) # flatten images for MLP
        
        #create labels (1 for real img, 0 for fake img)
        real_labels = torch.ones(32, 1) # batchsize = 32
        fake_labels = torch.zeros(32, 1)

        
        
        ########################
        # training discrimiator
        ########################
        
        # zeroing gradients
        d_opt.zero_grad()
        
        # loss for real images
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # loss for fake images
        z = torch.randn(32, 64)
        fake_img = G(z)
        outputs = D(fake_img)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # total loss
        total_loss = d_loss_real + d_loss_fake

        total_loss.backward()
        d_opt.step()
        
        ########################
        # training generator
        ########################
        # zeroing gradients
        g_opt.zero_grad()
        
        # generating fake imgs 
        # batch size = 32, input size is 64
        z = torch.randn(32, 64)
        fake_img = G(z)
        outputs = D(fake_img)
        
        # loss for generator
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_opt.step()
        
        if (i+1) % 1000 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch+1, n_epochs, i+1, total_step, total_loss.item(), g_loss.item(), 
                          real_score.mean().item(), fake_score.mean().item()))


# # Save the model checkpoints 
# #torch.save(G.state_dict(), 'G.ckpt')
# #torch.save(D.state_dict(), 'D.ckpt')


Epoch [1/20], Step [1000/1875], d_loss: 0.9252, g_loss: 1.5333, D(x): 0.69, D(G(z)): 0.36
Epoch [2/20], Step [1000/1875], d_loss: 1.0656, g_loss: 1.1845, D(x): 0.61, D(G(z)): 0.37
Epoch [3/20], Step [1000/1875], d_loss: 1.2036, g_loss: 1.6312, D(x): 0.62, D(G(z)): 0.39
Epoch [4/20], Step [1000/1875], d_loss: 0.7774, g_loss: 1.4891, D(x): 0.71, D(G(z)): 0.28
Epoch [5/20], Step [1000/1875], d_loss: 0.9766, g_loss: 1.2439, D(x): 0.65, D(G(z)): 0.31
Epoch [6/20], Step [1000/1875], d_loss: 1.0137, g_loss: 1.5641, D(x): 0.68, D(G(z)): 0.37
Epoch [7/20], Step [1000/1875], d_loss: 1.2967, g_loss: 1.3159, D(x): 0.66, D(G(z)): 0.44
Epoch [8/20], Step [1000/1875], d_loss: 0.9878, g_loss: 1.2165, D(x): 0.60, D(G(z)): 0.29
Epoch [9/20], Step [1000/1875], d_loss: 0.8312, g_loss: 1.1466, D(x): 0.75, D(G(z)): 0.35
Epoch [10/20], Step [1000/1875], d_loss: 0.8462, g_loss: 1.3450, D(x): 0.65, D(G(z)): 0.24
Epoch [11/20], Step [1000/1875], d_loss: 1.0895, g_loss: 1.4595, D(x): 0.66, D(G(z)): 0.36
Epoch [1

In [18]:
# plotting imgs 
fake_img = fake_img.reshape(fake_img.size(0), 1, 28, 28)
save_image(denorm(fake_img), os.path.join(sample_dir, 'fake_images-100.png'.format(epoch+1))) # change epoch number
