In [0]:
!pip install tensorboardX
!pip install -q tf-nightly-2.0-preview
%load_ext tensorboard

Collecting tensorboardX
[?25l  Downloading https://files.pythonhosted.org/packages/5c/76/89dd44458eb976347e5a6e75eb79fecf8facd46c1ce259bad54e0044ea35/tensorboardX-1.6-py2.py3-none-any.whl (129kB)
[K     |████████████████████████████████| 133kB 41.8MB/s 
Installing collected packages: tensorboardX
Successfully installed tensorboardX-1.6
[K     |████████████████████████████████| 87.5MB 507kB/s 
[K     |████████████████████████████████| 430kB 50.4MB/s 
[K     |████████████████████████████████| 61kB 26.6MB/s 
[K     |████████████████████████████████| 3.1MB 49.1MB/s 


In [0]:
# Heavily inspired by: https://github.com/pytorch/examples/blob/master/mnist/main.py
# And https://medium.com/ai-society/gans-from-scratch-1-a-deep-introduction-with-code-in-pytorch-and-tensorflow-cb03cdcdba0f
import os
import torch
from torch import nn
from torch.autograd import Variable
import torchvision.datasets as dset
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
from torchvision.utils import make_grid

In [0]:
# Cuda stuff
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Device is " + str(device) + ".")

In [0]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.hidden1 = nn.Sequential(
            nn.Linear(28*28, 1024),
            nn.LeakyReLu(0.2),
            nn.Dropout(0.3),
        )
        
        self.hidden2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLu(0.2),
            nn.Dropout(0.3),       
        )
        
        self.hidden3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLu(0.2),
            nn.Dropout(0.3),     
        )
        
        self.out = nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        x = self.hidden1(img_flat)
        x = self.hidden2(x)
        x = self.hidden3(x)
        out= self.out(x)
        return out

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.hidden1 = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLu(0.2),
            nn.Dropout(0.3),
        )
        
        self.hidden2 = nn.Sequential(
            nn.Linear(256, 512),
            nn.LeakyReLu(0.2),
            nn.Dropout(0.3),
        )
        
        self.hidden3 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LeakyReLu(0.2),
            nn.Dropout(0.3),
        )
        
        self.out = nn.Sequential(
            nn.Linear(1024, 28*28),
            nn.Tanh()
        )
        
    def forward(self, z):
        x = self.hidden1(z)
        x = self.hidden2(x)
        x = self.hidden3(x)
        out = self.out(x)
        return out

In [0]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss(reduction='mean')

In [0]:
def train_discriminator(optimizer, real_data, fake_data):
    optimizer.zero_grad()
    
    # Real images
    prediction_real = discriminator(real_data)
    loss_real = loss(predictions_real, torch.ones_like(predictions_real))
    loss.backward()
    
    # Fake images
    predictions_fake = discriminator(fake_data)
    loss_fake = loss(predictions_fake, torch.zeros_like(predictions_fake))
    loss.backward()
    
    optimizer.step()
    return (loss_real + loss_fake)/2.0

In [0]:
def train_generator(optimizer, fake_data):
    optimizer.zero_grad()
    
    predictions_fake = discriminator(fake_data)
    loss_fake = loss(predictions_fake, torch.ones_like(predictions_fake))
    loss.backward()
    
    optimizer.step()
    return loss_fake

In [0]:
def images_to_vectors(images):
    img.view(images.size(0), 28*28)
    
def vectors_to_images(vectors):
    vectors.view(vectors.size(0), 1, 28, 28)
    
def noise(size):
    z = Variable(torch.randn(size, 100)).to(device)

In [0]:
def load_mnist(batch_size)
    # Load data
    root = './data'
    if not os.path.exists(root):
        os.mkdir(root)

    # normalizes values to interval [-1, 1]
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_set = dset.MNIST(root=root, train=True, transform=trans, download=True)
    test_set = dset.MNIST(root=root, train=False, transform=trans, download=True)

    train_loader = torch.utils.data.DataLoader(
                     dataset=train_set,
                     batch_size=batch_size,
                     shuffle=True)
    test_loader = torch.utils.data.DataLoader(
                    dataset=test_set,
                    batch_size=batch_size,
                    shuffle=False)
    return train_loader, test_loader

In [0]:
num_epochs = 100
batch_size = 32

test_noise = noise(16)

train_loader, test_loader = load_mnist(batch_size)

writer = SummaryWriter('logs/1')
for epoch in range(epochs):
    loss_d = 0
    loss_g = 0
    iterations = 0
    iterations_gen = 0
    for n_batch, batch in enumerate(train_loader):
        fake_data = generator(noise(batch_size)).detach()
        real_data = Variable(images_to_vectors(batch)).to(device)
        loss_d += train_discriminator(d_optimizer, real_data, fake_data)
        
        fake_data = generator(noise(batch_size))
        loss_g += train_generator(g_optimizer, fake_data)

        if n_batch%100 == 0:
          test_gen = vectors_to_images(generator(test_noise))
          writer.add_image('test_gen', make_grid(test_gen), iterations_gen)
          iterations_gen += 1
          
        iterations += 1     
    writer.add_scalar('loss_d', loss_d/iterations, epoch)
    writer.add_scalar('loss_g', loss_g/iterations, epoch)