In [1]:
%matplotlib inline
from IPython import display

import itertools
import math
import time
import matplotlib.pyplot as plt

import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

train_dataset = datasets.MNIST(root="./data/",
                               train=True,
                               download=True,
                               transform=transform)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        h = self.model(x.view(x.size(0), 784))
        h = h.view(h.size(0), -1)
        return h

In [7]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
        
    def forward(self, x):
        h = x.view(x.size(0), 100)
        h = self.model(h)
        return h

In [8]:
dis = Discriminator()
gen = Generator()

In [9]:
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(dis.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=lr)

In [10]:
def train_discriminator(dis, x_real, y_real, x_fake, y_fake):
    dis.zero_grad()
    outputs = dis(x_real)
    real_loss = criterion(outputs, y_real)
    real_score = outputs
    
    outputs = dis(x_fake)
    fake_loss = criterion(outputs, y_fake)
    fake_score = outputs
    
    d_loss = real_loss + fake_loss
    d_loss.backward()
    d_optimizer.step()
    
    return d_loss, real_score, fake_score

In [None]:
def train_generator(gen, dis_outputs, y_real):
    gen.zero_grad()
    g_loss = criterion(dis_outputs, y_real)
    g_loss.backward()
    g_optimizer