Skip to content

Commit

Permalink
supports chainer
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Sep 4, 2017
1 parent 5f7f03f commit 92e2750
Show file tree
Hide file tree
Showing 8 changed files with 387 additions and 26 deletions.
8 changes: 4 additions & 4 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ install:
- conda info -a

# Replace dep1 dep2 ... with your dependencies
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytorch torchvision -c soumith
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION pytorch torchvision -c soumith chainer
- source activate test-environment
- which python
- conda list
Expand All @@ -32,6 +32,6 @@ install:
script:
- pytest
# Your test script goes here
- pip uninstall -y tensorboardX
- pip install tensorboardX
- pytest
# - pip uninstall -y tensorboardX
# - pip install tensorboardX
# - pytest
69 changes: 69 additions & 0 deletions examples/chainer/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import gzip
import os

import numpy as np
import six
from six.moves.urllib import request

parent = 'http://yann.lecun.com/exdb/mnist'
train_images = 'train-images-idx3-ubyte.gz'
train_labels = 'train-labels-idx1-ubyte.gz'
test_images = 't10k-images-idx3-ubyte.gz'
test_labels = 't10k-labels-idx1-ubyte.gz'
num_train = 60000
num_test = 10000
dim = 784


def load_mnist(images, labels, num):
data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim))
target = np.zeros(num, dtype=np.uint8).reshape((num, ))

with gzip.open(images, 'rb') as f_images,\
gzip.open(labels, 'rb') as f_labels:
f_images.read(16)
f_labels.read(8)
for i in six.moves.range(num):
target[i] = ord(f_labels.read(1))
for j in six.moves.range(dim):
data[i, j] = ord(f_images.read(1))

return data, target


def download_mnist_data():
print('Downloading {:s}...'.format(train_images))
request.urlretrieve('{:s}/{:s}'.format(parent, train_images), train_images)
print('Done')
print('Downloading {:s}...'.format(train_labels))
request.urlretrieve('{:s}/{:s}'.format(parent, train_labels), train_labels)
print('Done')
print('Downloading {:s}...'.format(test_images))
request.urlretrieve('{:s}/{:s}'.format(parent, test_images), test_images)
print('Done')
print('Downloading {:s}...'.format(test_labels))
request.urlretrieve('{:s}/{:s}'.format(parent, test_labels), test_labels)
print('Done')

print('Converting training data...')
data_train, target_train = load_mnist(train_images, train_labels,
num_train)
print('Done')
print('Converting test data...')
data_test, target_test = load_mnist(test_images, test_labels, num_test)
mnist = {'data': np.append(data_train, data_test, axis=0),
'target': np.append(target_train, target_test, axis=0)}
print('Done')
print('Save output...')
with open('mnist.pkl', 'wb') as output:
six.moves.cPickle.dump(mnist, output, -1)
print('Done')
print('Convert completed')


def load_mnist_data():
if not os.path.exists('mnist.pkl'):
download_mnist_data()
with open('mnist.pkl', 'rb') as mnist_pickle:
mnist = six.moves.cPickle.load(mnist_pickle)
return mnist
65 changes: 65 additions & 0 deletions examples/chainer/net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import six

import chainer
import chainer.functions as F
from chainer.functions.loss.vae import gaussian_kl_divergence
import chainer.links as L


class VAE(chainer.Chain):
"""Variational AutoEncoder"""

def __init__(self, n_in, n_latent, n_h):
super(VAE, self).__init__()
with self.init_scope():
# encoder
self.le1 = L.Linear(n_in, n_h)
self.le2_mu = L.Linear(n_h, n_latent)
self.le2_ln_var = L.Linear(n_h, n_latent)
# decoder
self.ld1 = L.Linear(n_latent, n_h)
self.ld2 = L.Linear(n_h, n_in)

def __call__(self, x, sigmoid=True):
"""AutoEncoder"""
return self.decode(self.encode(x)[0], sigmoid)

def encode(self, x):
h1 = F.tanh(self.le1(x))
mu = self.le2_mu(h1)
ln_var = self.le2_ln_var(h1) # log(sigma**2)
return mu, ln_var

def decode(self, z, sigmoid=True):
h1 = F.tanh(self.ld1(z))
h2 = self.ld2(h1)
if sigmoid:
return F.sigmoid(h2)
else:
return h2

def get_loss_func(self, C=1.0, k=1):
"""Get loss function of VAE.
The loss value is equal to ELBO (Evidence Lower Bound)
multiplied by -1.
Args:
C (int): Usually this is 1.0. Can be changed to control the
second term of ELBO bound, which works as regularization.
k (int): Number of Monte Carlo samples used in encoded vector.
"""
def lf(x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction loss
rec_loss = 0
for l in six.moves.range(k):
z = F.gaussian(mu, ln_var)
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) \
/ (k * batchsize)
self.rec_loss = rec_loss
self.loss = self.rec_loss + \
C * gaussian_kl_divergence(mu, ln_var) / batchsize
return self.loss
return lf
170 changes: 170 additions & 0 deletions examples/chainer/train_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python
"""Chainer example: train a VAE on MNIST
"""
from __future__ import print_function
import argparse

import matplotlib
# Disable interactive backend
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import six

import chainer
from chainer import computational_graph
from chainer import cuda
from chainer import optimizers
from chainer import serializers
from tensorboardX import SummaryWriter
import data
import net

writer = SummaryWriter()

parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--initmodel', '-m', default='',
help='Initialize the model from given file')
parser.add_argument('--resume', '-r', default='',
help='Resume the optimization from snapshot')
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--epoch', '-e', default=100, type=int,
help='number of epochs to learn')
parser.add_argument('--dimz', '-z', default=20, type=int,
help='dimention of encoded vector')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='learning minibatch size')
parser.add_argument('--test', action='store_true',
help='Use tiny datasets for quick tests')
args = parser.parse_args()

batchsize = args.batchsize
n_epoch = args.epoch
n_latent = args.dimz

writer.add_text('config', str(args))

print('GPU: {}'.format(args.gpu))
print('# dim z: {}'.format(args.dimz))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
print('')

# Prepare dataset
print('load MNIST dataset')
mnist = data.load_mnist_data()
mnist['data'] = mnist['data'].astype(np.float32)
mnist['data'] /= 255
mnist['target'] = mnist['target'].astype(np.int32)

if args.test:
mnist['data'] = mnist['data'][0:100]
mnist['target'] = mnist['target'][0:100]
N = 30
else:
N = 60000

x_train, x_test = np.split(mnist['data'], [N])
y_train, y_test = np.split(mnist['target'], [N])
N_test = y_test.size

# Prepare VAE model, defined in net.py
model = net.VAE(784, n_latent, 500)
if args.gpu >= 0:
cuda.get_device_from_id(args.gpu).use()
model.to_gpu()
xp = np if args.gpu < 0 else cuda.cupy

# Setup optimizer
optimizer = optimizers.Adam()
optimizer.setup(model)

# Init/Resume
if args.initmodel:
print('Load model from', args.initmodel)
serializers.load_npz(args.initmodel, model)
if args.resume:
print('Load optimizer state from', args.resume)
serializers.load_npz(args.resume, optimizer)

# Learning loop
for epoch in six.moves.range(1, n_epoch + 1):
print('epoch', epoch)

# training
perm = np.random.permutation(N)
sum_loss = 0 # total loss
sum_rec_loss = 0 # reconstruction loss
for i in six.moves.range(0, N, batchsize):
x = chainer.Variable(xp.asarray(x_train[perm[i:i + batchsize]]))
optimizer.update(model.get_loss_func(), x)
if epoch == 1 and i == 0:
with open('graph.dot', 'w') as o:
g = computational_graph.build_computational_graph(
(model.loss, ))
o.write(g.dump())
print('graph generated')
writer.add_scalar('train/loss', model.loss, epoch*N+i)
writer.add_scalar('train/rec_loss', model.rec_loss, epoch*N+i)
sum_loss += float(model.loss.data) * len(x.data)
sum_rec_loss += float(model.rec_loss.data) * len(x.data)

print('train mean loss={}, mean reconstruction loss={}'
.format(sum_loss / N, sum_rec_loss / N))

# evaluation
sum_loss = 0
sum_rec_loss = 0
with chainer.no_backprop_mode():
for i in six.moves.range(0, N_test, batchsize):
x = chainer.Variable(xp.asarray(x_test[i:i + batchsize]))
loss_func = model.get_loss_func(k=10)
loss_func(x)
sum_loss += float(model.loss.data) * len(x.data)
sum_rec_loss += float(model.rec_loss.data) * len(x.data)
writer.add_scalar('test/loss', model.loss, epoch*N_test+i)
writer.add_scalar('test/rec_loss', model.rec_loss, epoch*N_test+i)
writer.add_image('reconstructed', model(x).reshape(-1,1,28,28), epoch*N_test+i)
writer.add_image('input', x.reshape(-1,1,28,28), epoch*N_test+i)
del model.loss
print('test mean loss={}, mean reconstruction loss={}'
.format(sum_loss / N_test, sum_rec_loss / N_test))


# Save the model and the optimizer
print('save the model')
serializers.save_npz('mlp.model', model)
print('save the optimizer')
serializers.save_npz('mlp.state', optimizer)

model.to_cpu()


# original images and reconstructed images
def save_images(x, filename):
fig, ax = plt.subplots(3, 3, figsize=(9, 9), dpi=100)
for ai, xi in zip(ax.flatten(), x):
ai.imshow(xi.reshape(28, 28))
fig.savefig(filename)


train_ind = [1, 3, 5, 10, 2, 0, 13, 15, 17]
x = chainer.Variable(np.asarray(x_train[train_ind]))
with chainer.no_backprop_mode():
x1 = model(x)
save_images(x.data, 'train')
save_images(x1.data, 'train_reconstructed')

test_ind = [3, 2, 1, 18, 4, 8, 11, 17, 61]
x = chainer.Variable(np.asarray(x_test[test_ind]))
with chainer.no_backprop_mode():
x1 = model(x)
save_images(x.data, 'test')
save_images(x1.data, 'test_reconstructed')


# draw images from randomly sampled z
z = chainer.Variable(np.random.normal(0, 1, (9, n_latent)).astype(np.float32))
x = model.decode(z)
save_images(x.data, 'sampled')
11 changes: 2 additions & 9 deletions tensorboardX/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,9 @@ def image(tag, tensor):
A scalar `Tensor` of type `string`. The serialized `Summary` protocol
buffer.
"""
import torch
tag = _clean_tag(tag)
assert isinstance(tensor, np.ndarray) or isinstance(tensor, torch.cuda.FloatTensor) or isinstance(tensor, torch.FloatTensor), 'input tensor should be one of numpy.ndarray, torch.cuda.FloatTensor, torch.FloatTensor'
if not isinstance(tensor, np.ndarray):
assert tensor.dim()<4 and tensor.dim()>1, 'input tensor should be 3 dimensional.'
if tensor.dim()==2:
tensor = tensor.unsqueeze(0)
tensor = tensor.cpu().permute(1,2,0).numpy()
else:
tensor = tensor.astype(np.float32)
assert isinstance(tensor, np.ndarray), 'input tensor should be numpy.ndarray'
tensor = tensor.astype(np.float32)
tensor = (tensor*255).astype(np.uint8)
image = make_image(tensor)
return Summary(value=[Summary.Value(tag=tag, image=image)])
Expand Down

0 comments on commit 92e2750

Please sign in to comment.