In [1]:
import os
import torch
import torchvision
from torch import optim, nn
from torchvision import transforms, datasets
from torchvision.utils import save_image
from torch.utils.data import DataLoader

In [2]:
# Device Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = './GAN_Code/samples'

In [4]:
# Create a directory if not exists
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
# Image processing
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5],std = [0.5])
                                ])
# MNIST dataset
mnist = datasets.MNIST(root='./GAN_Code/data/', train = True, transform = transform, download= True)

# Data Loader
data_loader = DataLoader(dataset = mnist, batch_size=batch_size, shuffle=True)

# Discriminator

D = nn.Sequential(nn.Linear(image_size, hidden_size),
                 nn.LeakyReLU(0.2),
                 nn.Linear(hidden_size, hidden_size),
                 nn.LeakyReLU(0.2),
                 nn.Linear(hidden_size,1),
                 nn.Sigmoid())

G = nn.Sequential(nn.Linear(latent_size, hidden_size),
                 nn.ReLU(),
                 nn.Linear(hidden_size, hidden_size),
                 nn.ReLU(),
                 nn.Linear(hidden_size, image_size),
                 nn.Tanh())
# Device Setting
D = D.to(device)
G = G.to(device)

# Binary cross entrpy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr = 0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr = 0.0002)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./GAN_Code/data/MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ./GAN_Code/data/MNIST/raw/train-images-idx3-ubyte.gz to ./GAN_Code/data/MNIST/raw


28.4%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./GAN_Code/data/MNIST/raw/train-labels-idx1-ubyte.gz


0.5%5%

Extracting ./GAN_Code/data/MNIST/raw/train-labels-idx1-ubyte.gz to ./GAN_Code/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./GAN_Code/data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ./GAN_Code/data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./GAN_Code/data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./GAN_Code/data/MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./GAN_Code/data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./GAN_Code/data/MNIST/raw
Processing...
Done!


In [5]:
def denorm(x):
    out = ( x + 1 ) / 2
    return out.clamp(0,1)

def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [6]:
# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.view(batch_size, -1).to(device)
        
        # Create the labels which are later used as input for the BCE loss
        real_labels = torch.ones(batch_size,1).to(device)
        fake_labels = torch.zeros(batch_size,1).to(device)
        
        # =====================================================  #
        #                 Train the discriminator 
        # =====================================================  #
        
        # Compute BCE_Loss using real images where BCE_Loss(x,y): - y * log(x) - (1 - y)*log(1-x)
        
        # the second term of the loss is always zero since real_labels == 1
        
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs
        
        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs
        
        # Backprop and optimize
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # =====================================================  #
        #                 Train the generator 
        # =====================================================  #
        
        # Compute loss with fake images
        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        
        # We train G to maximize log(D(G(z))) instead of minimizing log(1-D(G(z))) 
        # For the reason, see the last paragraph of section 3. https: 
        
        g_loss = criterion(outputs, real_labels)
        
        # Backprop and optimize
        reset_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss:{:.2f}, D(G(z)): {:.2f}'.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))
                  
        # Save sampled images
        fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
        save_image(denorm(fake_images), os.path.join(sample_dir,'fake_images-{}.png'.format(epoch+1)))

# Save the model checkpoints

torch.save(G.state_dict(),'G.ckpt')
torch.save(D.state_dict(),'D.ckpt')
        
        

Epoch [0/200], Step [200/600], d_loss: 0.0338, g_loss:4.21, D(G(z)): 1.00
Epoch [0/200], Step [400/600], d_loss: 0.0442, g_loss:5.51, D(G(z)): 1.00
Epoch [0/200], Step [600/600], d_loss: 0.0604, g_loss:4.45, D(G(z)): 0.98
Epoch [1/200], Step [200/600], d_loss: 0.0322, g_loss:6.15, D(G(z)): 0.98
Epoch [1/200], Step [400/600], d_loss: 0.4956, g_loss:3.21, D(G(z)): 0.77
Epoch [1/200], Step [600/600], d_loss: 0.2517, g_loss:5.48, D(G(z)): 0.90
Epoch [2/200], Step [200/600], d_loss: 0.2637, g_loss:3.40, D(G(z)): 0.91
Epoch [2/200], Step [400/600], d_loss: 0.2064, g_loss:5.03, D(G(z)): 0.89
Epoch [2/200], Step [600/600], d_loss: 0.9163, g_loss:2.53, D(G(z)): 0.73
Epoch [3/200], Step [200/600], d_loss: 0.4480, g_loss:2.82, D(G(z)): 0.84
Epoch [3/200], Step [400/600], d_loss: 0.4560, g_loss:3.83, D(G(z)): 0.83
Epoch [3/200], Step [600/600], d_loss: 0.3966, g_loss:2.44, D(G(z)): 0.84
Epoch [4/200], Step [200/600], d_loss: 1.4533, g_loss:1.79, D(G(z)): 0.75
Epoch [4/200], Step [400/600], d_loss:

Epoch [36/200], Step [600/600], d_loss: 0.4499, g_loss:3.15, D(G(z)): 0.84
Epoch [37/200], Step [200/600], d_loss: 0.4505, g_loss:2.61, D(G(z)): 0.84
Epoch [37/200], Step [400/600], d_loss: 0.3335, g_loss:4.26, D(G(z)): 0.90
Epoch [37/200], Step [600/600], d_loss: 0.4578, g_loss:2.45, D(G(z)): 0.90
Epoch [38/200], Step [200/600], d_loss: 0.4055, g_loss:2.86, D(G(z)): 0.91
Epoch [38/200], Step [400/600], d_loss: 0.3329, g_loss:2.66, D(G(z)): 0.87
Epoch [38/200], Step [600/600], d_loss: 0.2708, g_loss:4.35, D(G(z)): 0.90
Epoch [39/200], Step [200/600], d_loss: 0.3967, g_loss:3.11, D(G(z)): 0.89
Epoch [39/200], Step [400/600], d_loss: 0.4825, g_loss:2.97, D(G(z)): 0.84
Epoch [39/200], Step [600/600], d_loss: 0.4475, g_loss:3.08, D(G(z)): 0.90
Epoch [40/200], Step [200/600], d_loss: 0.4182, g_loss:2.38, D(G(z)): 0.87
Epoch [40/200], Step [400/600], d_loss: 0.6991, g_loss:2.29, D(G(z)): 0.81
Epoch [40/200], Step [600/600], d_loss: 0.4688, g_loss:2.56, D(G(z)): 0.81
Epoch [41/200], Step [200

Epoch [73/200], Step [400/600], d_loss: 0.7177, g_loss:2.55, D(G(z)): 0.76
Epoch [73/200], Step [600/600], d_loss: 0.7350, g_loss:2.28, D(G(z)): 0.72
Epoch [74/200], Step [200/600], d_loss: 0.7482, g_loss:2.04, D(G(z)): 0.71
Epoch [74/200], Step [400/600], d_loss: 1.1011, g_loss:1.46, D(G(z)): 0.64
Epoch [74/200], Step [600/600], d_loss: 0.7256, g_loss:2.34, D(G(z)): 0.83
Epoch [75/200], Step [200/600], d_loss: 0.8003, g_loss:2.15, D(G(z)): 0.66
Epoch [75/200], Step [400/600], d_loss: 0.8224, g_loss:1.82, D(G(z)): 0.72
Epoch [75/200], Step [600/600], d_loss: 0.9792, g_loss:1.89, D(G(z)): 0.60
Epoch [76/200], Step [200/600], d_loss: 0.6739, g_loss:2.19, D(G(z)): 0.81
Epoch [76/200], Step [400/600], d_loss: 0.6782, g_loss:2.56, D(G(z)): 0.84
Epoch [76/200], Step [600/600], d_loss: 0.9239, g_loss:1.87, D(G(z)): 0.67
Epoch [77/200], Step [200/600], d_loss: 0.8250, g_loss:2.40, D(G(z)): 0.78
Epoch [77/200], Step [400/600], d_loss: 0.9877, g_loss:1.67, D(G(z)): 0.78
Epoch [77/200], Step [600

Epoch [109/200], Step [600/600], d_loss: 0.8924, g_loss:1.54, D(G(z)): 0.72
Epoch [110/200], Step [200/600], d_loss: 0.8140, g_loss:1.60, D(G(z)): 0.74
Epoch [110/200], Step [400/600], d_loss: 1.0020, g_loss:2.28, D(G(z)): 0.68
Epoch [110/200], Step [600/600], d_loss: 0.7258, g_loss:2.22, D(G(z)): 0.74
Epoch [111/200], Step [200/600], d_loss: 0.7637, g_loss:1.99, D(G(z)): 0.80
Epoch [111/200], Step [400/600], d_loss: 0.7991, g_loss:1.81, D(G(z)): 0.76
Epoch [111/200], Step [600/600], d_loss: 0.9547, g_loss:1.46, D(G(z)): 0.62
Epoch [112/200], Step [200/600], d_loss: 1.0916, g_loss:1.66, D(G(z)): 0.73
Epoch [112/200], Step [400/600], d_loss: 0.7835, g_loss:1.80, D(G(z)): 0.71
Epoch [112/200], Step [600/600], d_loss: 0.8290, g_loss:1.45, D(G(z)): 0.71
Epoch [113/200], Step [200/600], d_loss: 0.8298, g_loss:2.34, D(G(z)): 0.75
Epoch [113/200], Step [400/600], d_loss: 0.6681, g_loss:2.17, D(G(z)): 0.75
Epoch [113/200], Step [600/600], d_loss: 0.8242, g_loss:1.55, D(G(z)): 0.70
Epoch [114/2

Epoch [145/200], Step [600/600], d_loss: 0.9539, g_loss:1.77, D(G(z)): 0.60
Epoch [146/200], Step [200/600], d_loss: 0.9226, g_loss:1.50, D(G(z)): 0.68
Epoch [146/200], Step [400/600], d_loss: 0.7605, g_loss:1.79, D(G(z)): 0.77
Epoch [146/200], Step [600/600], d_loss: 0.9280, g_loss:1.69, D(G(z)): 0.64
Epoch [147/200], Step [200/600], d_loss: 1.0747, g_loss:1.52, D(G(z)): 0.64
Epoch [147/200], Step [400/600], d_loss: 0.9666, g_loss:1.30, D(G(z)): 0.74
Epoch [147/200], Step [600/600], d_loss: 0.9792, g_loss:1.30, D(G(z)): 0.69
Epoch [148/200], Step [200/600], d_loss: 0.8734, g_loss:1.66, D(G(z)): 0.73
Epoch [148/200], Step [400/600], d_loss: 1.1097, g_loss:1.74, D(G(z)): 0.59
Epoch [148/200], Step [600/600], d_loss: 0.8209, g_loss:1.76, D(G(z)): 0.73
Epoch [149/200], Step [200/600], d_loss: 0.9033, g_loss:1.48, D(G(z)): 0.70
Epoch [149/200], Step [400/600], d_loss: 0.8500, g_loss:1.57, D(G(z)): 0.72
Epoch [149/200], Step [600/600], d_loss: 0.9283, g_loss:1.27, D(G(z)): 0.73
Epoch [150/2

Epoch [181/200], Step [600/600], d_loss: 0.7911, g_loss:1.37, D(G(z)): 0.73
Epoch [182/200], Step [200/600], d_loss: 0.9173, g_loss:1.75, D(G(z)): 0.64
Epoch [182/200], Step [400/600], d_loss: 0.8422, g_loss:1.59, D(G(z)): 0.72
Epoch [182/200], Step [600/600], d_loss: 0.9487, g_loss:1.50, D(G(z)): 0.73
Epoch [183/200], Step [200/600], d_loss: 0.8257, g_loss:1.84, D(G(z)): 0.75
Epoch [183/200], Step [400/600], d_loss: 0.9584, g_loss:1.39, D(G(z)): 0.76
Epoch [183/200], Step [600/600], d_loss: 0.9278, g_loss:1.17, D(G(z)): 0.71
Epoch [184/200], Step [200/600], d_loss: 0.9548, g_loss:1.38, D(G(z)): 0.67
Epoch [184/200], Step [400/600], d_loss: 1.0977, g_loss:1.38, D(G(z)): 0.57
Epoch [184/200], Step [600/600], d_loss: 0.8558, g_loss:1.47, D(G(z)): 0.69
Epoch [185/200], Step [200/600], d_loss: 0.9788, g_loss:1.72, D(G(z)): 0.65
Epoch [185/200], Step [400/600], d_loss: 0.9386, g_loss:1.70, D(G(z)): 0.63
Epoch [185/200], Step [600/600], d_loss: 0.9502, g_loss:1.42, D(G(z)): 0.74
Epoch [186/2