-
Notifications
You must be signed in to change notification settings - Fork 862
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
387 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import gzip | ||
import os | ||
|
||
import numpy as np | ||
import six | ||
from six.moves.urllib import request | ||
|
||
parent = 'http://yann.lecun.com/exdb/mnist' | ||
train_images = 'train-images-idx3-ubyte.gz' | ||
train_labels = 'train-labels-idx1-ubyte.gz' | ||
test_images = 't10k-images-idx3-ubyte.gz' | ||
test_labels = 't10k-labels-idx1-ubyte.gz' | ||
num_train = 60000 | ||
num_test = 10000 | ||
dim = 784 | ||
|
||
|
||
def load_mnist(images, labels, num): | ||
data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim)) | ||
target = np.zeros(num, dtype=np.uint8).reshape((num, )) | ||
|
||
with gzip.open(images, 'rb') as f_images,\ | ||
gzip.open(labels, 'rb') as f_labels: | ||
f_images.read(16) | ||
f_labels.read(8) | ||
for i in six.moves.range(num): | ||
target[i] = ord(f_labels.read(1)) | ||
for j in six.moves.range(dim): | ||
data[i, j] = ord(f_images.read(1)) | ||
|
||
return data, target | ||
|
||
|
||
def download_mnist_data(): | ||
print('Downloading {:s}...'.format(train_images)) | ||
request.urlretrieve('{:s}/{:s}'.format(parent, train_images), train_images) | ||
print('Done') | ||
print('Downloading {:s}...'.format(train_labels)) | ||
request.urlretrieve('{:s}/{:s}'.format(parent, train_labels), train_labels) | ||
print('Done') | ||
print('Downloading {:s}...'.format(test_images)) | ||
request.urlretrieve('{:s}/{:s}'.format(parent, test_images), test_images) | ||
print('Done') | ||
print('Downloading {:s}...'.format(test_labels)) | ||
request.urlretrieve('{:s}/{:s}'.format(parent, test_labels), test_labels) | ||
print('Done') | ||
|
||
print('Converting training data...') | ||
data_train, target_train = load_mnist(train_images, train_labels, | ||
num_train) | ||
print('Done') | ||
print('Converting test data...') | ||
data_test, target_test = load_mnist(test_images, test_labels, num_test) | ||
mnist = {'data': np.append(data_train, data_test, axis=0), | ||
'target': np.append(target_train, target_test, axis=0)} | ||
print('Done') | ||
print('Save output...') | ||
with open('mnist.pkl', 'wb') as output: | ||
six.moves.cPickle.dump(mnist, output, -1) | ||
print('Done') | ||
print('Convert completed') | ||
|
||
|
||
def load_mnist_data(): | ||
if not os.path.exists('mnist.pkl'): | ||
download_mnist_data() | ||
with open('mnist.pkl', 'rb') as mnist_pickle: | ||
mnist = six.moves.cPickle.load(mnist_pickle) | ||
return mnist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import six | ||
|
||
import chainer | ||
import chainer.functions as F | ||
from chainer.functions.loss.vae import gaussian_kl_divergence | ||
import chainer.links as L | ||
|
||
|
||
class VAE(chainer.Chain): | ||
"""Variational AutoEncoder""" | ||
|
||
def __init__(self, n_in, n_latent, n_h): | ||
super(VAE, self).__init__() | ||
with self.init_scope(): | ||
# encoder | ||
self.le1 = L.Linear(n_in, n_h) | ||
self.le2_mu = L.Linear(n_h, n_latent) | ||
self.le2_ln_var = L.Linear(n_h, n_latent) | ||
# decoder | ||
self.ld1 = L.Linear(n_latent, n_h) | ||
self.ld2 = L.Linear(n_h, n_in) | ||
|
||
def __call__(self, x, sigmoid=True): | ||
"""AutoEncoder""" | ||
return self.decode(self.encode(x)[0], sigmoid) | ||
|
||
def encode(self, x): | ||
h1 = F.tanh(self.le1(x)) | ||
mu = self.le2_mu(h1) | ||
ln_var = self.le2_ln_var(h1) # log(sigma**2) | ||
return mu, ln_var | ||
|
||
def decode(self, z, sigmoid=True): | ||
h1 = F.tanh(self.ld1(z)) | ||
h2 = self.ld2(h1) | ||
if sigmoid: | ||
return F.sigmoid(h2) | ||
else: | ||
return h2 | ||
|
||
def get_loss_func(self, C=1.0, k=1): | ||
"""Get loss function of VAE. | ||
The loss value is equal to ELBO (Evidence Lower Bound) | ||
multiplied by -1. | ||
Args: | ||
C (int): Usually this is 1.0. Can be changed to control the | ||
second term of ELBO bound, which works as regularization. | ||
k (int): Number of Monte Carlo samples used in encoded vector. | ||
""" | ||
def lf(x): | ||
mu, ln_var = self.encode(x) | ||
batchsize = len(mu.data) | ||
# reconstruction loss | ||
rec_loss = 0 | ||
for l in six.moves.range(k): | ||
z = F.gaussian(mu, ln_var) | ||
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) \ | ||
/ (k * batchsize) | ||
self.rec_loss = rec_loss | ||
self.loss = self.rec_loss + \ | ||
C * gaussian_kl_divergence(mu, ln_var) / batchsize | ||
return self.loss | ||
return lf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
#!/usr/bin/env python | ||
"""Chainer example: train a VAE on MNIST | ||
""" | ||
from __future__ import print_function | ||
import argparse | ||
|
||
import matplotlib | ||
# Disable interactive backend | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import six | ||
|
||
import chainer | ||
from chainer import computational_graph | ||
from chainer import cuda | ||
from chainer import optimizers | ||
from chainer import serializers | ||
from tensorboardX import SummaryWriter | ||
import data | ||
import net | ||
|
||
writer = SummaryWriter() | ||
|
||
parser = argparse.ArgumentParser(description='Chainer example: MNIST') | ||
parser.add_argument('--initmodel', '-m', default='', | ||
help='Initialize the model from given file') | ||
parser.add_argument('--resume', '-r', default='', | ||
help='Resume the optimization from snapshot') | ||
parser.add_argument('--gpu', '-g', default=-1, type=int, | ||
help='GPU ID (negative value indicates CPU)') | ||
parser.add_argument('--epoch', '-e', default=100, type=int, | ||
help='number of epochs to learn') | ||
parser.add_argument('--dimz', '-z', default=20, type=int, | ||
help='dimention of encoded vector') | ||
parser.add_argument('--batchsize', '-b', type=int, default=100, | ||
help='learning minibatch size') | ||
parser.add_argument('--test', action='store_true', | ||
help='Use tiny datasets for quick tests') | ||
args = parser.parse_args() | ||
|
||
batchsize = args.batchsize | ||
n_epoch = args.epoch | ||
n_latent = args.dimz | ||
|
||
writer.add_text('config', str(args)) | ||
|
||
print('GPU: {}'.format(args.gpu)) | ||
print('# dim z: {}'.format(args.dimz)) | ||
print('# Minibatch-size: {}'.format(args.batchsize)) | ||
print('# epoch: {}'.format(args.epoch)) | ||
print('') | ||
|
||
# Prepare dataset | ||
print('load MNIST dataset') | ||
mnist = data.load_mnist_data() | ||
mnist['data'] = mnist['data'].astype(np.float32) | ||
mnist['data'] /= 255 | ||
mnist['target'] = mnist['target'].astype(np.int32) | ||
|
||
if args.test: | ||
mnist['data'] = mnist['data'][0:100] | ||
mnist['target'] = mnist['target'][0:100] | ||
N = 30 | ||
else: | ||
N = 60000 | ||
|
||
x_train, x_test = np.split(mnist['data'], [N]) | ||
y_train, y_test = np.split(mnist['target'], [N]) | ||
N_test = y_test.size | ||
|
||
# Prepare VAE model, defined in net.py | ||
model = net.VAE(784, n_latent, 500) | ||
if args.gpu >= 0: | ||
cuda.get_device_from_id(args.gpu).use() | ||
model.to_gpu() | ||
xp = np if args.gpu < 0 else cuda.cupy | ||
|
||
# Setup optimizer | ||
optimizer = optimizers.Adam() | ||
optimizer.setup(model) | ||
|
||
# Init/Resume | ||
if args.initmodel: | ||
print('Load model from', args.initmodel) | ||
serializers.load_npz(args.initmodel, model) | ||
if args.resume: | ||
print('Load optimizer state from', args.resume) | ||
serializers.load_npz(args.resume, optimizer) | ||
|
||
# Learning loop | ||
for epoch in six.moves.range(1, n_epoch + 1): | ||
print('epoch', epoch) | ||
|
||
# training | ||
perm = np.random.permutation(N) | ||
sum_loss = 0 # total loss | ||
sum_rec_loss = 0 # reconstruction loss | ||
for i in six.moves.range(0, N, batchsize): | ||
x = chainer.Variable(xp.asarray(x_train[perm[i:i + batchsize]])) | ||
optimizer.update(model.get_loss_func(), x) | ||
if epoch == 1 and i == 0: | ||
with open('graph.dot', 'w') as o: | ||
g = computational_graph.build_computational_graph( | ||
(model.loss, )) | ||
o.write(g.dump()) | ||
print('graph generated') | ||
writer.add_scalar('train/loss', model.loss, epoch*N+i) | ||
writer.add_scalar('train/rec_loss', model.rec_loss, epoch*N+i) | ||
sum_loss += float(model.loss.data) * len(x.data) | ||
sum_rec_loss += float(model.rec_loss.data) * len(x.data) | ||
|
||
print('train mean loss={}, mean reconstruction loss={}' | ||
.format(sum_loss / N, sum_rec_loss / N)) | ||
|
||
# evaluation | ||
sum_loss = 0 | ||
sum_rec_loss = 0 | ||
with chainer.no_backprop_mode(): | ||
for i in six.moves.range(0, N_test, batchsize): | ||
x = chainer.Variable(xp.asarray(x_test[i:i + batchsize])) | ||
loss_func = model.get_loss_func(k=10) | ||
loss_func(x) | ||
sum_loss += float(model.loss.data) * len(x.data) | ||
sum_rec_loss += float(model.rec_loss.data) * len(x.data) | ||
writer.add_scalar('test/loss', model.loss, epoch*N_test+i) | ||
writer.add_scalar('test/rec_loss', model.rec_loss, epoch*N_test+i) | ||
writer.add_image('reconstructed', model(x).reshape(-1,1,28,28), epoch*N_test+i) | ||
writer.add_image('input', x.reshape(-1,1,28,28), epoch*N_test+i) | ||
del model.loss | ||
print('test mean loss={}, mean reconstruction loss={}' | ||
.format(sum_loss / N_test, sum_rec_loss / N_test)) | ||
|
||
|
||
# Save the model and the optimizer | ||
print('save the model') | ||
serializers.save_npz('mlp.model', model) | ||
print('save the optimizer') | ||
serializers.save_npz('mlp.state', optimizer) | ||
|
||
model.to_cpu() | ||
|
||
|
||
# original images and reconstructed images | ||
def save_images(x, filename): | ||
fig, ax = plt.subplots(3, 3, figsize=(9, 9), dpi=100) | ||
for ai, xi in zip(ax.flatten(), x): | ||
ai.imshow(xi.reshape(28, 28)) | ||
fig.savefig(filename) | ||
|
||
|
||
train_ind = [1, 3, 5, 10, 2, 0, 13, 15, 17] | ||
x = chainer.Variable(np.asarray(x_train[train_ind])) | ||
with chainer.no_backprop_mode(): | ||
x1 = model(x) | ||
save_images(x.data, 'train') | ||
save_images(x1.data, 'train_reconstructed') | ||
|
||
test_ind = [3, 2, 1, 18, 4, 8, 11, 17, 61] | ||
x = chainer.Variable(np.asarray(x_test[test_ind])) | ||
with chainer.no_backprop_mode(): | ||
x1 = model(x) | ||
save_images(x.data, 'test') | ||
save_images(x1.data, 'test_reconstructed') | ||
|
||
|
||
# draw images from randomly sampled z | ||
z = chainer.Variable(np.random.normal(0, 1, (9, n_latent)).astype(np.float32)) | ||
x = model.decode(z) | ||
save_images(x.data, 'sampled') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.