In [1]:
from utils import Logger

import torch
from torch import nn, optim
from torch.autograd.variable import Variable
from torchvision import transforms, datasets
from tqdm import tqdm_notebook

In [2]:
DATA_FOLDER = './torch_data/VGAN/MNIST'

## Load Data

In [3]:
def mnist_data():
    compose = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((.5, .5, .5), (.5, .5, .5))
        ])
    out_dir = '{}/dataset'.format(DATA_FOLDER)
    return datasets.MNIST(root=out_dir, train=True, transform=compose, download=True)

In [4]:
# Load data
data = mnist_data()
# Create loader with data, so that we can iterate over it
batch_size = 100
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
# Num batches
num_batches = len(data_loader)

## Networks

In [5]:
from msgan import GeneratorNet, DiscriminatorNet
discriminator = DiscriminatorNet(hidden_sizes=[256, 128, 64])
generator = GeneratorNet(hidden_sizes=[64, 128, 256], n_input=28*28)
if torch.cuda.is_available():
    discriminator.cuda()
    generator.cuda()

In [6]:
print(f'Cuda Avilable? : {torch.cuda.is_available()}')

Cuda Avilable? : False


## Optimization

In [7]:
# Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)

# Loss function
loss = nn.BCELoss()

## Training

In [8]:
from msgan import GAN

def noise(size):
    n = Variable(torch.randn(size, generator.n_input))
    if torch.cuda.is_available():
        return n.cuda 
    return n

gan = GAN(generator, discriminator, g_optimizer, d_optimizer)

num_epochs = 5

gan.fit(data_generator=data_loader, noise_generator=noise)

Epoch: [0/100],
Generator Loss: 7.8064
D(x): 0.0095, D(G(z)): 0.0146
Epoch: [1/100],
Generator Loss: 10.0668
D(x): 0.0498, D(G(z)): 0.0342
Epoch: [2/100],
Generator Loss: 5.4425
D(x): 0.2841, D(G(z)): 0.0552
Epoch: [3/100],
Generator Loss: 4.1292
D(x): 0.0973, D(G(z)): 0.0411
Epoch: [4/100],
Generator Loss: 5.1368
D(x): 0.1681, D(G(z)): 0.0112
Epoch: [5/100],
Generator Loss: 7.1353
D(x): 0.0023, D(G(z)): 0.0727
Epoch: [6/100],
Generator Loss: 5.5056
D(x): 0.1361, D(G(z)): 0.0110
Epoch: [7/100],
Generator Loss: 5.0492
D(x): 0.0362, D(G(z)): 0.0738
Epoch: [8/100],
Generator Loss: 4.6698
D(x): 0.0917, D(G(z)): 0.0249
Epoch: [9/100],
Generator Loss: 4.8691
D(x): 0.0858, D(G(z)): 0.0472
Epoch: [10/100],
Generator Loss: 4.3268
D(x): 0.1559, D(G(z)): 0.0567
Epoch: [11/100],
Generator Loss: 3.9806
D(x): 0.0304, D(G(z)): 0.0530
Epoch: [12/100],
Generator Loss: 6.6292
D(x): 0.0896, D(G(z)): 0.0070
Epoch: [13/100],
Generator Loss: 4.8687
D(x): 0.1534, D(G(z)): 0.0252
Epoch: [14/100],
Generator Lo

KeyboardInterrupt: 