In [15]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets 
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import numpy as np

In [2]:
#Hyperparams
batch_size = 128
learning_rate = 0.0003
epochs =  200
latent_space_size = 64

#### Define dataset transform and load MNIST dataset

In [3]:
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=(0.5, 0.5, 0.5), 
                                     std=(0.5, 0.5, 0.5))])

In [4]:
mnist = datasets.MNIST(root='./data/',
                       train=True,
                       transform=transform,
                       download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

#### Create Discriminator class

In [5]:
class Discriminator(nn.Module):
    
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.l1 = nn.Linear(784, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 1)
        
    def forward(self, X):
        out = F.leaky_relu(self.l1(X))
        out = F.leaky_relu(self.l2(out))
        out = F.sigmoid(self.l3(out))
        return out

#### Create Generator class

In [6]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        
        self.l1 = nn.Linear(64, 256)
        self.l2 = nn.Linear(256, 256)
        self.l3 = nn.Linear(256, 784)
        
    def forward(self, X):
        out = F.leaky_relu(self.l1(X))
        out = F.leaky_relu(self.l2(out))
        out = F.tanh(self.l3(out))
        return out

In [7]:
D = Discriminator()

In [8]:
D.cuda()

Discriminator(
  (l1): Linear(in_features=784, out_features=256)
  (l2): Linear(in_features=256, out_features=256)
  (l3): Linear(in_features=256, out_features=1)
)

In [9]:
G = Generator()

In [10]:
G.cuda()

Generator(
  (l1): Linear(in_features=64, out_features=256)
  (l2): Linear(in_features=256, out_features=256)
  (l3): Linear(in_features=256, out_features=784)
)

#### Define loss funciton

In [11]:
criterion = nn.BCELoss()

#### Define optimizers for generator and discriminator

In [12]:
g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate)
d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate)

#### Create a training loop

In [28]:
for epoch in range(epochs):

    g_epoch_loss = []
    d_epoch_loss = []
    for images, _ in data_loader:
        
        real_images = Variable(images.view(batch_size, -1)).cuda()
        fake_images = G(Variable(torch.randn(batch_size, latent_space_size).cuda()))
        
        if real_images.cpu().data.numpy().shape[1] == 28*28:
            real_images_dis_outputs = D(real_images)
            fake_images_dis_outputs = D(fake_images)

            D_real_loss = criterion(real_images_dis_outputs, Variable(torch.ones(batch_size)).cuda())
            D_fake_loss = criterion(fake_images_dis_outputs, Variable(torch.zeros(batch_size)).cuda())

            #Calculate losses
            D_loss = D_real_loss + D_fake_loss
            G_loss = criterion(fake_images_dis_outputs, Variable(torch.ones(batch_size)).cuda())
            
            #Log
            g_epoch_loss.append(G_loss.cpu().data.numpy())
            d_epoch_loss.append(D_loss.cpu().data.numpy())
            
            #Update Discriminator
            d_optimizer.zero_grad()
            D_loss.backward(retain_graph=True)
            d_optimizer.step()

            #Update Generator
            
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()
            G_loss.backward()
            g_optimizer.step()


    print("G loss: {}".format(np.mean(g_epoch_loss)), " | D loss: {}".format(np.mean(d_epoch_loss))) 

  "Please ensure they have the same size.".format(target.size(), input.size()))


G loss: 1.8818649053573608  | D loss: 0.7400017380714417
G loss: 1.9479326009750366  | D loss: 0.6165869832038879
G loss: 2.1690335273742676  | D loss: 0.6458476781845093
G loss: 2.345785617828369  | D loss: 0.5358943939208984
G loss: 2.250821352005005  | D loss: 0.6544641256332397
G loss: 2.5245144367218018  | D loss: 0.7299686670303345
G loss: 2.5776891708374023  | D loss: 0.7506179809570312
G loss: 2.0646724700927734  | D loss: 0.864018976688385
G loss: 2.286578893661499  | D loss: 0.6973211765289307
G loss: 2.89401912689209  | D loss: 0.7400039434432983
G loss: 2.9077513217926025  | D loss: 0.6944419741630554
G loss: 2.5862698554992676  | D loss: 0.6557679176330566
G loss: 2.6391549110412598  | D loss: 0.4914310574531555
G loss: 2.6979963779449463  | D loss: 0.5106393098831177
G loss: 2.719245672225952  | D loss: 0.37330761551856995
G loss: 3.0988714694976807  | D loss: 0.4231935143470764
G loss: 3.421299695968628  | D loss: 0.4121059775352478
G loss: 2.8859806060791016  | D loss: 

G loss: 1.7154077291488647  | D loss: 0.7943461537361145
G loss: 1.6909401416778564  | D loss: 0.8006151914596558
G loss: 1.7036904096603394  | D loss: 0.8036582469940186
G loss: 1.7027660608291626  | D loss: 0.7974878549575806
G loss: 1.7097532749176025  | D loss: 0.7967267632484436
G loss: 1.6965079307556152  | D loss: 0.7972111105918884
G loss: 1.7019339799880981  | D loss: 0.804136335849762
G loss: 1.683194875717163  | D loss: 0.8030058741569519
G loss: 1.7300807237625122  | D loss: 0.7951854467391968
G loss: 1.7232540845870972  | D loss: 0.7982699275016785
G loss: 1.7247685194015503  | D loss: 0.8010149598121643
G loss: 1.7099969387054443  | D loss: 0.7988106608390808
G loss: 1.698622226715088  | D loss: 0.7973784804344177
G loss: 1.730819821357727  | D loss: 0.792112410068512
G loss: 1.7010807991027832  | D loss: 0.793914794921875
G loss: 1.7061855792999268  | D loss: 0.7904404997825623
G loss: 1.712510108947754  | D loss: 0.7987266182899475
G loss: 1.709913730621338  | D loss: 0