In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data.sampler import Sampler, BatchSampler
from torch.nn.modules.loss import MSELoss

In [4]:
from models import Generator, Discriminator

ModuleNotFoundError: No module named 'models'

In [2]:
input_size = 784
num_classes = 10
batch_size = 256

train_dataset = dsets.MNIST(root='./MNIST/', 
                            train=True, 
                            transform=transforms.ToTensor(),
                            download=True)

test_dataset = dsets.MNIST(root='./MNIST/', 
                           train=False, 
                           transform=transforms.ToTensor())


train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=False)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size,         
                                          shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [None]:
def calculate_gradient_penalty(discriminator, images, gen_images):
        epsilon = torch.FloatTensor(batch_size, 1, 1, 1).uniform_(0., 1.)
        epsilon = epsilon.expand(batch_size, images.size(1), images.size(2), images.size(3))
        
        x_hat = epsilon * images + ((1 - epsilon) * gen_images)
        x_hat = Variable(x_hat, requires_grad=True)

        prob_x_hat = discriminator(x_hat)
        gradients = autograd.grad(outputs=prob_x_hat, inputs=x_hat,
                                  grad_outputs=torch.ones_like(prob_x_hat),
                                  create_graph=True, retain_graph=True)[0]
        
        dual_sobolev_gradients = sobolev_filter(gradients, c=SOBOLEV_C, s=-SOBOLEV_S)
        gradients_stable_norm = stable_norm(dual_sobolev_gradients, ord=DUAL_EXPONENT)
        
        lambda_ = stable_norm(sobolev_filter(images, c=SOBOLEV_C, s=SOBOLEV_S),
                              ord=EXPONENT).mean()
        gamma_ = stable_norm(sobolev_filter(images, c=SOBOLEV_C, s=-SOBOLEV_S),
                              ord=DUAL_EXPONENT).mean()
        
        prob_images = discriminator(images)
        
        grad_penalty = ((gradients_stable_norm / gamma_ - 1) ** 2).mean() * lambda_ +\
                       1e-5 * (prob_images ** 2).mean()
                        
        return grad_penalty

In [5]:
from tqdm import trange

max_iters = 100000
num_disc_iters = 5
noise_size = 128

def generate_batches(train_loader):
    while True:
        for batch_num, (x_batch_base, _) in zip(trange(len(train_loader)), train_loader):
            yield batch_num, x_batch_base.float()
            
#Models
generator = Generator(noise_size)
discriminator = Discriminator()

lr = 1e-4
beta1 = 0.5
beta2 = 0.999

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

# Data-Generator
data = generate_batches(train_loader)

for global_iter in range(max_iters):
    
    for d_iter in range(num_disc_iters):
        batch_num, images = data.__next__()
        z = torch.randn((batch_size, noise_size, 1, 1))
        images, z = Variable(images), Variable(z)
        
        discriminator.zero_grad()
        gen_images = generator(z)
        fake_loss = discriminator(gen_images).mean()
        fake_loss.backward()
        
        real_loss = discriminator(images).mean()
        real_loss.backward(mone)

        gradient_penalty = calculate_gradient_penalty(discriminator, images.data, gen_images.data)
        gradient_penalty.backward()
        
        wasserstein_loss = (fake_loss - real_loss) / gamma

        g_loss = (gen_images).mean() / gamma
        d_loss = - wasserstein_loss + gradient_penalty
        
        optimizer_D.step()
    
    generator.zero_grad()
            
    z = Variable(torch.randn(batch_size, noise_size, 1, 1))
    gen_images = generator(z)
    g_loss = discriminator(gen_images).mean()
    g_loss.backward(mone)
            
    optimizer_G.step()

NameError: name 'Generator' is not defined