In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torchvision import transforms

from context import rf_pool

**Load Data**

In [None]:
# get MNIST training data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,),(0.5,))])
trainset = torchvision.datasets.MNIST(root='../data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='../data', train=False, download=True, transform=transform)

In [None]:
# create trainloader
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=True, num_workers=2)

**Build Model**

In [None]:
# initialize the GAN model (as well as generator and discriminator)
model = rf_pool.models.GAN()
generator = rf_pool.models.FeedForwardNetwork()
discriminator = rf_pool.models.FeedForwardNetwork()

In [None]:
# append layers to generator
generator.append('0', rf_pool.modules.FeedForward(random_sampler=lambda x: torch.randn(x.shape[0], 100)))
generator.append('1', rf_pool.modules.FeedForward(linear=torch.nn.Linear(100, 256),
                                                  activation=torch.nn.LeakyReLU(0.2)))
generator.append('2', rf_pool.modules.FeedForward(linear=torch.nn.Linear(256, 512), 
                                                  activation=torch.nn.LeakyReLU(0.2)))
generator.append('3', rf_pool.modules.FeedForward(linear=torch.nn.Linear(512, 1024), 
                                                  activation=torch.nn.LeakyReLU(0.2)))
generator.append('4', rf_pool.modules.FeedForward(linear=torch.nn.Linear(1024, 28*28),
                                                  activation=torch.nn.Tanh()))
generator.append('5', rf_pool.modules.FeedForward(input_shape=(-1, 1, 28, 28)))

In [None]:
# append layers to discriminator
discriminator.append('0', rf_pool.modules.FeedForward(input_shape=(-1, 28*28),
                                                      linear=torch.nn.Linear(28*28, 1024),
                                                      activation=torch.nn.LeakyReLU(0.2),
                                                      dropout=torch.nn.Dropout(0.3)))
discriminator.append('1', rf_pool.modules.FeedForward(linear=torch.nn.Linear(1024, 512),
                                                      activation=torch.nn.LeakyReLU(0.2),
                                                      dropout=torch.nn.Dropout(0.3)))
discriminator.append('2', rf_pool.modules.FeedForward(linear=torch.nn.Linear(512, 256),
                                                      activation=torch.nn.LeakyReLU(0.2),
                                                      dropout=torch.nn.Dropout(0.3)))
discriminator.append('3', rf_pool.modules.FeedForward(linear=torch.nn.Linear(256, 1), 
                                                      activation=torch.nn.Sigmoid()))

In [None]:
# add generator and discriminator to GAN
model.add_generator(generator)
model.add_discriminator(discriminator)

In [None]:
# set optimizer
optim = torch.optim.Adam(model.parameters(), lr=2e-4)

**Set Metrics for Monitoring**

In [None]:
class Metrics(object):
    def show_samples(self, n_samples=1):
        return rf_pool.utils.visualize.show_images(generator(torch.zeros(n_samples,)),
                                                   cmap='gray')
    
    def discriminator_prob(self, dataloader):
        x = iter(dataloader).next()[0]
        return torch.mean(discriminator(x)).item()

**Train GAN**

In [None]:
# train GAN, monitor discriminator
loss_history = model.train_model(2, trainloader, optimizer=optim, monitor=100,
                                 metrics=Metrics(), show_samples={'n_samples': 10},
                                 discriminator_prob={'dataloader': testloader})