In [0]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

from torchsummary import torchsummary
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [101]:
print('Using \'{}\' in this session'.format(device))

Using 'cuda' in this session


In [0]:
bs = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor()])
#     transforms.Normalize(mean=0.5, 
#                          std=0.5)])

train_dataset = datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [0]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
        
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
        

class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
        
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [104]:
z_dim = 100
mnist_dim = train_dataset.train_data.size(1)*train_dataset.train_data.size(2)


G = Generator(g_input_dim=z_dim, g_output_dim=mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)



In [105]:
torchsummary.summary(G, (1, z_dim))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1               [-1, 1, 256]          25,856
            Linear-2               [-1, 1, 512]         131,584
            Linear-3              [-1, 1, 1024]         525,312
            Linear-4               [-1, 1, 784]         803,600
Total params: 1,486,352
Trainable params: 1,486,352
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 5.67
Estimated Total Size (MB): 5.69
----------------------------------------------------------------


In [106]:
D

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [0]:
# Loss
criterion = nn.BCELoss()

# optimizer
lr = 0.0005
G_optimizer = optim.Adam(G.parameters(), lr=0.0001)
D_optimizer = optim.Adam(D.parameters(), lr=0.00005)

In [0]:
def D_train(x):
    #------------------------------------
    # Train the Discriminator
    #------------------------------------
    D.zero_grad()
    
    # Train on real image
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))
    
    
    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output
    
    # Train on fake image
    z = Variable(torch.randn(bs, z_dim).to(device))
    x_fake, y_fake = G(z), Variable(torch.zeros(bs, 1).to(device))
    
    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output
    
    # BackProb & Optimize only D paramters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
    
    return D_loss.data.item()

In [0]:
def G_train(x):
    #------------------------------------
    # Train the Generator
    #------------------------------------
    
    G.zero_grad()
    z = Variable(torch.randn(bs, z_dim).to(device))
    y = Variable(torch.ones(bs, 1).to(device))
    
    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)
    
    # backprob & optimize only G parameters
    G_loss.backward()
    G_optimizer.step()
    
    return G_loss.data.item()

In [0]:
n_epochs = 200


for epoch in range(1, n_epochs + 1):
    D_losses, G_losses = [], []
    for batch_idx, (x, y) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))
        G_losses.append(G_train(x))
        

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
        (epoch), n_epochs, 
        torch.mean(torch.FloatTensor(D_losses)), 
        torch.mean(torch.FloatTensor(G_losses))))

[1/200]: loss_d: 1.286, loss_g: 0.983
[2/200]: loss_d: 1.341, loss_g: 0.837
[3/200]: loss_d: 1.392, loss_g: 0.786
[4/200]: loss_d: 1.400, loss_g: 0.769
[5/200]: loss_d: 1.381, loss_g: 0.727
[6/200]: loss_d: 1.413, loss_g: 0.726
[7/200]: loss_d: 1.407, loss_g: 0.714
[8/200]: loss_d: 1.405, loss_g: 0.711
[9/200]: loss_d: 1.399, loss_g: 0.697
[10/200]: loss_d: 1.401, loss_g: 0.699
[11/200]: loss_d: 1.394, loss_g: 0.703
[12/200]: loss_d: 1.398, loss_g: 0.707
[13/200]: loss_d: 1.395, loss_g: 0.696
[14/200]: loss_d: 1.397, loss_g: 0.702
[15/200]: loss_d: 1.398, loss_g: 0.695
[16/200]: loss_d: 1.390, loss_g: 0.699
[17/200]: loss_d: 1.398, loss_g: 0.696
[18/200]: loss_d: 1.391, loss_g: 0.696
[19/200]: loss_d: 1.393, loss_g: 0.697
[20/200]: loss_d: 1.394, loss_g: 0.697
[21/200]: loss_d: 1.392, loss_g: 0.702
[22/200]: loss_d: 1.392, loss_g: 0.695
[23/200]: loss_d: 1.392, loss_g: 0.698
[24/200]: loss_d: 1.391, loss_g: 0.700
[25/200]: loss_d: 1.392, loss_g: 0.695
[26/200]: loss_d: 1.396, loss_g: 0

In [0]:
with torch.no_grad():
    test_z = Variable(torch.randn(bs, z_dim).to(device))
    generated = G(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), './samples/sample_' + '.png')