In [1]:
## Code referenced from https://github.com/yunjey/pytorch-tutorial

In [14]:
import os
import torch
import torchvision
import torch.nn as nn
import sys
from torchvision import transforms
from torchvision.utils import save_image

ModuleNotFoundError: No module named 'tensorflow'

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
logger = Logger('./logs')

In [4]:
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [5]:
# Image processing
# Data --> torch Tensor
# For normalization, use transform.Compse
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
                                     std=(0.5, 0.5, 0.5))])


# MNIST dataset
mnist = torchvision.datasets.MNIST(root='../../data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

In [6]:
# Discriminator
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

In [7]:
# Generator 
G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

In [8]:
# Device setting
D = D.to(device)
G = G.to(device)

In [9]:
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

In [10]:
## Why?
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [11]:
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images,_) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)
        
        ## Discriminator loss
        reset_grad()
        
        # 1> Create the labels 
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        
        # 2> Loss with real image
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels) # (y, yhat)
        real_score = outputs

        # 3> Loss with fake image
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # 4> Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        d_optimizer.step()
        
        

        ## Generator loss
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        g_loss = criterion(outputs, real_labels) ## reference : ppt 
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        # See Loss
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' 
                  .format(epoch, num_epochs, 
                          i+1, total_step, 
                          d_loss.item(), 
                          g_loss.item(), 
                          real_score.mean().item(), 
                          fake_score.mean().item()))

    # Save real images
    images = images.reshape(images.size(0), 1, 28, 28)
    save_image(denorm(images),
    os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images),
               os.path.join(sample_dir, 
                            'fake_images-{}.png'.format(epoch+1)))

Epoch [0/200], Step [200/600], d_loss: 0.0288, g_loss: 4.4131, D(x): 1.00, D(G(z)): 0.03
Epoch [0/200], Step [400/600], d_loss: 0.0909, g_loss: 4.9666, D(x): 0.98, D(G(z)): 0.07
Epoch [0/200], Step [600/600], d_loss: 0.0366, g_loss: 5.2619, D(x): 0.99, D(G(z)): 0.02
Epoch [1/200], Step [200/600], d_loss: 0.1198, g_loss: 5.2746, D(x): 0.95, D(G(z)): 0.03
Epoch [1/200], Step [400/600], d_loss: 0.1451, g_loss: 5.2458, D(x): 0.96, D(G(z)): 0.08
Epoch [1/200], Step [600/600], d_loss: 0.3847, g_loss: 4.1914, D(x): 0.93, D(G(z)): 0.20
Epoch [2/200], Step [200/600], d_loss: 0.2585, g_loss: 3.5728, D(x): 0.88, D(G(z)): 0.09
Epoch [2/200], Step [400/600], d_loss: 0.4122, g_loss: 4.8896, D(x): 0.83, D(G(z)): 0.08
Epoch [2/200], Step [600/600], d_loss: 0.2568, g_loss: 3.6162, D(x): 0.90, D(G(z)): 0.09
Epoch [3/200], Step [200/600], d_loss: 0.8881, g_loss: 1.9446, D(x): 0.82, D(G(z)): 0.36
Epoch [3/200], Step [400/600], d_loss: 0.6094, g_loss: 3.1839, D(x): 0.78, D(G(z)): 0.05
Epoch [3/200], Step [

Epoch [30/200], Step [600/600], d_loss: 0.5373, g_loss: 4.7193, D(x): 0.81, D(G(z)): 0.07
Epoch [31/200], Step [200/600], d_loss: 0.2947, g_loss: 3.3523, D(x): 0.93, D(G(z)): 0.11
Epoch [31/200], Step [400/600], d_loss: 0.3133, g_loss: 3.1470, D(x): 0.88, D(G(z)): 0.08
Epoch [31/200], Step [600/600], d_loss: 0.3217, g_loss: 3.6339, D(x): 0.89, D(G(z)): 0.05
Epoch [32/200], Step [200/600], d_loss: 0.4175, g_loss: 2.2607, D(x): 0.91, D(G(z)): 0.15
Epoch [32/200], Step [400/600], d_loss: 0.5014, g_loss: 3.7074, D(x): 0.84, D(G(z)): 0.13
Epoch [32/200], Step [600/600], d_loss: 0.3079, g_loss: 3.9317, D(x): 0.90, D(G(z)): 0.09
Epoch [33/200], Step [200/600], d_loss: 0.2807, g_loss: 4.5807, D(x): 0.89, D(G(z)): 0.10
Epoch [33/200], Step [400/600], d_loss: 0.3976, g_loss: 3.7185, D(x): 0.85, D(G(z)): 0.08
Epoch [33/200], Step [600/600], d_loss: 0.4252, g_loss: 2.9929, D(x): 0.88, D(G(z)): 0.12
Epoch [34/200], Step [200/600], d_loss: 0.4765, g_loss: 3.0790, D(x): 0.81, D(G(z)): 0.09
Epoch [34/

Epoch [61/200], Step [400/600], d_loss: 0.6174, g_loss: 2.2682, D(x): 0.83, D(G(z)): 0.23
Epoch [61/200], Step [600/600], d_loss: 0.6306, g_loss: 2.7271, D(x): 0.78, D(G(z)): 0.17
Epoch [62/200], Step [200/600], d_loss: 0.6077, g_loss: 3.3684, D(x): 0.77, D(G(z)): 0.15
Epoch [62/200], Step [400/600], d_loss: 0.6048, g_loss: 2.4838, D(x): 0.87, D(G(z)): 0.28
Epoch [62/200], Step [600/600], d_loss: 0.6128, g_loss: 1.7513, D(x): 0.83, D(G(z)): 0.27
Epoch [63/200], Step [200/600], d_loss: 0.5178, g_loss: 2.2123, D(x): 0.81, D(G(z)): 0.18
Epoch [63/200], Step [400/600], d_loss: 0.7546, g_loss: 1.8113, D(x): 0.83, D(G(z)): 0.34
Epoch [63/200], Step [600/600], d_loss: 0.6793, g_loss: 2.3018, D(x): 0.80, D(G(z)): 0.26
Epoch [64/200], Step [200/600], d_loss: 0.6268, g_loss: 2.3983, D(x): 0.79, D(G(z)): 0.21
Epoch [64/200], Step [400/600], d_loss: 0.6935, g_loss: 2.2382, D(x): 0.75, D(G(z)): 0.20
Epoch [64/200], Step [600/600], d_loss: 0.8910, g_loss: 2.4979, D(x): 0.73, D(G(z)): 0.25
Epoch [65/

Epoch [92/200], Step [200/600], d_loss: 0.7642, g_loss: 2.2667, D(x): 0.80, D(G(z)): 0.28
Epoch [92/200], Step [400/600], d_loss: 0.6146, g_loss: 2.6838, D(x): 0.78, D(G(z)): 0.21
Epoch [92/200], Step [600/600], d_loss: 0.7533, g_loss: 1.9899, D(x): 0.69, D(G(z)): 0.20
Epoch [93/200], Step [200/600], d_loss: 0.8772, g_loss: 1.5552, D(x): 0.64, D(G(z)): 0.24
Epoch [93/200], Step [400/600], d_loss: 0.6123, g_loss: 2.7702, D(x): 0.82, D(G(z)): 0.24
Epoch [93/200], Step [600/600], d_loss: 0.7333, g_loss: 2.1066, D(x): 0.76, D(G(z)): 0.24
Epoch [94/200], Step [200/600], d_loss: 0.5941, g_loss: 2.6966, D(x): 0.76, D(G(z)): 0.17
Epoch [94/200], Step [400/600], d_loss: 0.8837, g_loss: 1.9123, D(x): 0.70, D(G(z)): 0.26
Epoch [94/200], Step [600/600], d_loss: 0.9359, g_loss: 1.7021, D(x): 0.75, D(G(z)): 0.34
Epoch [95/200], Step [200/600], d_loss: 0.7828, g_loss: 1.6714, D(x): 0.74, D(G(z)): 0.27
Epoch [95/200], Step [400/600], d_loss: 0.7792, g_loss: 1.6691, D(x): 0.80, D(G(z)): 0.32
Epoch [95/

Epoch [122/200], Step [400/600], d_loss: 0.9501, g_loss: 1.4248, D(x): 0.63, D(G(z)): 0.25
Epoch [122/200], Step [600/600], d_loss: 0.7884, g_loss: 1.4355, D(x): 0.80, D(G(z)): 0.33
Epoch [123/200], Step [200/600], d_loss: 0.8611, g_loss: 1.4936, D(x): 0.67, D(G(z)): 0.25
Epoch [123/200], Step [400/600], d_loss: 0.9639, g_loss: 1.5361, D(x): 0.65, D(G(z)): 0.26
Epoch [123/200], Step [600/600], d_loss: 0.8196, g_loss: 1.4241, D(x): 0.75, D(G(z)): 0.30
Epoch [124/200], Step [200/600], d_loss: 0.7848, g_loss: 1.4920, D(x): 0.76, D(G(z)): 0.30
Epoch [124/200], Step [400/600], d_loss: 0.8743, g_loss: 1.8394, D(x): 0.79, D(G(z)): 0.36
Epoch [124/200], Step [600/600], d_loss: 0.7807, g_loss: 1.8249, D(x): 0.73, D(G(z)): 0.26
Epoch [125/200], Step [200/600], d_loss: 0.8938, g_loss: 1.5263, D(x): 0.65, D(G(z)): 0.23
Epoch [125/200], Step [400/600], d_loss: 0.8837, g_loss: 2.0101, D(x): 0.70, D(G(z)): 0.28
Epoch [125/200], Step [600/600], d_loss: 0.8999, g_loss: 1.9190, D(x): 0.61, D(G(z)): 0.17

Epoch [152/200], Step [600/600], d_loss: 0.9493, g_loss: 1.6350, D(x): 0.71, D(G(z)): 0.36
Epoch [153/200], Step [200/600], d_loss: 0.9698, g_loss: 1.4344, D(x): 0.66, D(G(z)): 0.29
Epoch [153/200], Step [400/600], d_loss: 1.0129, g_loss: 1.5356, D(x): 0.75, D(G(z)): 0.40
Epoch [153/200], Step [600/600], d_loss: 0.8403, g_loss: 1.4316, D(x): 0.74, D(G(z)): 0.32
Epoch [154/200], Step [200/600], d_loss: 0.8685, g_loss: 1.5248, D(x): 0.69, D(G(z)): 0.29
Epoch [154/200], Step [400/600], d_loss: 1.2033, g_loss: 1.4398, D(x): 0.58, D(G(z)): 0.30
Epoch [154/200], Step [600/600], d_loss: 1.3156, g_loss: 1.4230, D(x): 0.75, D(G(z)): 0.47
Epoch [155/200], Step [200/600], d_loss: 1.0085, g_loss: 1.1850, D(x): 0.75, D(G(z)): 0.41
Epoch [155/200], Step [400/600], d_loss: 0.7806, g_loss: 1.9571, D(x): 0.75, D(G(z)): 0.29
Epoch [155/200], Step [600/600], d_loss: 0.9663, g_loss: 1.0692, D(x): 0.72, D(G(z)): 0.36
Epoch [156/200], Step [200/600], d_loss: 1.0292, g_loss: 1.5207, D(x): 0.66, D(G(z)): 0.33

Epoch [183/200], Step [200/600], d_loss: 1.0409, g_loss: 1.3429, D(x): 0.61, D(G(z)): 0.29
Epoch [183/200], Step [400/600], d_loss: 1.0064, g_loss: 1.1911, D(x): 0.70, D(G(z)): 0.37
Epoch [183/200], Step [600/600], d_loss: 0.8484, g_loss: 1.5477, D(x): 0.70, D(G(z)): 0.28
Epoch [184/200], Step [200/600], d_loss: 0.8819, g_loss: 1.2382, D(x): 0.71, D(G(z)): 0.33
Epoch [184/200], Step [400/600], d_loss: 0.9753, g_loss: 1.5610, D(x): 0.66, D(G(z)): 0.32
Epoch [184/200], Step [600/600], d_loss: 1.1531, g_loss: 1.7962, D(x): 0.57, D(G(z)): 0.26
Epoch [185/200], Step [200/600], d_loss: 1.0448, g_loss: 1.3861, D(x): 0.62, D(G(z)): 0.29
Epoch [185/200], Step [400/600], d_loss: 1.0297, g_loss: 1.3407, D(x): 0.63, D(G(z)): 0.29
Epoch [185/200], Step [600/600], d_loss: 1.0676, g_loss: 1.2347, D(x): 0.73, D(G(z)): 0.42
Epoch [186/200], Step [200/600], d_loss: 1.0481, g_loss: 1.5795, D(x): 0.55, D(G(z)): 0.19
Epoch [186/200], Step [400/600], d_loss: 0.7654, g_loss: 1.2834, D(x): 0.74, D(G(z)): 0.29

In [12]:
# Save the model checkpoints 
torch.save(G.state_dict(), 'checkpoint/G.ckpt')
torch.save(D.state_dict(), 'checkpoint/D.ckpt')