Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite the example of VAE using Chainer distributions #5356

Merged
merged 19 commits into from Nov 21, 2018
Merged
Changes from 5 commits
Commits
File filter...
Filter file types
Jump to…
Jump to file or symbol
Failed to load files and symbols.
+137 −226
Diff settings

Always

Just for now

@@ -1,67 +1,103 @@
import six
import numpy as np

import chainer
from chainer import cuda
import chainer.distributions as D
import chainer.functions as F
from chainer.functions.loss.vae import gaussian_kl_divergence
import chainer.links as L
from chainer import reporter


class VAE(chainer.Chain):
"""Variational AutoEncoder"""
class AvgELBOLoss(chainer.Chain):
This conversation was marked as resolved by toslunar

This comment has been minimized.

Copy link
@toslunar

toslunar Nov 12, 2018

Member

Could you add a docstring to explain the arguments? The docstring of get_loss_func in the previous example could be used here.

This comment has been minimized.

Copy link
@ganow

ganow Nov 12, 2018

Author Contributor

I've added it: 3caaae1

def __init__(self, encoder, decoder, prior, beta=1.0, k=1):
super(AvgELBOLoss, self).__init__()
self.beta = beta
self.k = k

with self.init_scope():
self.encoder = encoder
self.decoder = decoder
self.prior = prior

self.loss = None
self.rec = None
self.penalty = None
This conversation was marked as resolved by toslunar

This comment has been minimized.

Copy link
@toslunar

toslunar Nov 12, 2018

Member

It seems these loss values don't need to be saved to attributes of the chain. How about reporter.report({'loss':, loss}, self) etc?

This comment has been minimized.

Copy link
@ganow

ganow Nov 12, 2018

Author Contributor

fixed them: d20a9b9


def __call__(self, x):
q_z = self.encoder(x)
z = q_z.sample(self.k)
p_x = self.decoder(z)
p_z = self.prior()

self.loss = None
self.rec = None
self.penalty = None
self.rec = F.mean(F.sum(p_x.log_prob(
F.broadcast_to(x[None, :], (self.k,) + x.shape)), axis=-1))
self.penalty = F.mean(F.sum(chainer.kl_divergence(q_z, p_z), axis=-1))
self.loss = - (self.rec - self.beta * self.penalty)
reporter.report({'loss': self.loss}, self)
reporter.report({'rec': self.rec}, self)
reporter.report({'penalty': self.penalty}, self)
return self.loss


class Encoder(chainer.Chain):

def __init__(self, n_in, n_latent, n_h):
super(Encoder, self).__init__()
with self.init_scope():
self.linear = L.Linear(n_in, n_h)
self.mu = L.Linear(n_h, n_latent)
self.ln_sigma = L.Linear(n_h, n_latent)

def forward(self, x):
h = F.tanh(self.linear(x))
mu = self.mu(h)
ln_sigma = self.ln_sigma(h) # log(sigma)
return D.Normal(loc=mu, log_scale=ln_sigma)


class Decoder(chainer.Chain):

def __init__(self, n_in, n_latent, n_h):
super(VAE, self).__init__()
super(Decoder, self).__init__()
with self.init_scope():
# encoder
self.le1 = L.Linear(n_in, n_h)
self.le2_mu = L.Linear(n_h, n_latent)
self.le2_ln_var = L.Linear(n_h, n_latent)
# decoder
self.ld1 = L.Linear(n_latent, n_h)
self.ld2 = L.Linear(n_h, n_in)

def forward(self, x, sigmoid=True):
"""AutoEncoder"""
return self.decode(self.encode(x)[0], sigmoid)

def encode(self, x):
h1 = F.tanh(self.le1(x))
mu = self.le2_mu(h1)
ln_var = self.le2_ln_var(h1) # log(sigma**2)
return mu, ln_var

def decode(self, z, sigmoid=True):
h1 = F.tanh(self.ld1(z))
h2 = self.ld2(h1)
if sigmoid:
return F.sigmoid(h2)
else:
return h2

def get_loss_func(self, beta=1.0, k=1):
"""Get loss function of VAE.
The loss value is equal to ELBO (Evidence Lower Bound)
multiplied by -1.
Args:
beta (int): Usually this is 1.0. Can be changed to control the
second term of ELBO bound, which works as regularization.
k (int): Number of Monte Carlo samples used in encoded vector.
"""
def lf(x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction loss
rec_loss = 0
for l in six.moves.range(k):
z = F.gaussian(mu, ln_var)
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) \
/ (k * batchsize)
self.rec_loss = rec_loss
self.loss = self.rec_loss + \
beta * gaussian_kl_divergence(mu, ln_var) / batchsize
chainer.report(
{'rec_loss': rec_loss, 'loss': self.loss}, observer=self)
return self.loss
return lf
self.linear = L.Linear(n_latent, n_h)
self.output = L.Linear(n_h, n_in)

def forward(self, z, inference=False):
n_batch_axes = 1 if inference else 2
h = F.tanh(self.linear(z, n_batch_axes=n_batch_axes))
h = self.output(h, n_batch_axes=n_batch_axes)
return D.Bernoulli(logit=h)


class Prior(chainer.Chain):

def __init__(self, n_latent, dtype=np.float32, device=-1):
This conversation was marked as resolved by toslunar

This comment has been minimized.

Copy link
@toslunar

toslunar Nov 12, 2018

Member

Please do not specify device here. Why don't you make Prior inherit chainer.Link and add loc and scale as chainer.Parameters?

This comment has been minimized.

Copy link
@ganow

ganow Nov 12, 2018

Author Contributor

I fixed it: 39f8620

This comment has been minimized.

Copy link
@ganow

ganow Nov 12, 2018

Author Contributor

I have just realized that the prior distribution of the original VAE paper is not trainable. My current implementation (which is based on your advice) makes loc and scale as chainer.Parameters and it makes these variables trainable. I couldn't find how to allocate non-trainable variable without using device option. Are there any ways to do that?

This comment has been minimized.

Copy link
@toslunar

toslunar Nov 12, 2018

Member

Could you try register_persistent?

This comment has been minimized.

Copy link
@ganow

ganow Nov 12, 2018

Author Contributor

parameters for prior distribution now seem to be non-trainable, thank you. 8294d6f

super(Prior, self).__init__()

loc = np.zeros(n_latent, dtype=dtype)
scale = np.ones(n_latent, dtype=dtype)
if device != -1:
loc = cuda.to_gpu(loc, device=device)
scale = cuda.to_gpu(scale, device=device)

self.loc = chainer.Variable(loc)
self.scale = chainer.Variable(scale)

def forward(self):
return D.Normal(self.loc, scale=self.scale)


def make_encoder(n_in, n_latent, n_h):
return Encoder(n_in, n_latent, n_h)


def make_decoder(n_in, n_latent, n_h):
return Decoder(n_in, n_latent, n_h)


def make_prior(n_latent, dtype=np.float32, device=-1):
return Prior(n_latent, dtype=dtype, device=device)
@@ -1,13 +1,14 @@
#!/usr/bin/env python
"""Chainer example: train a VAE on MNIST
"""Chainer example: train a VAE on Binarized MNIST
"""
import argparse
import os

import numpy as np

import chainer
from chainer import training
from chainer.training import extensions
import numpy as np

import net

@@ -20,60 +21,79 @@ def main():
help='Resume the optimization from snapshot')
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
parser.add_argument('--out', '-o', default='results',
help='Directory to output the result')
parser.add_argument('--epoch', '-e', default=100, type=int,
help='number of epochs to learn')
parser.add_argument('--dimz', '-z', default=20, type=int,
parser.add_argument('--dim-z', '-z', default=20, type=int,
help='dimention of encoded vector')
parser.add_argument('--batchsize', '-b', type=int, default=100,
parser.add_argument('--dim-h', default=500, type=int,
help='dimention of hidden layer')
parser.add_argument('--beta', default=1.0, type=float,
help='Regularization coefficient for '
'the second term of ELBO bound')
parser.add_argument('--k', '-k', default=1, type=int,
help='Number of Monte Carlo samples used in '
'encoded vector')
parser.add_argument('--batch-size', '-b', type=int, default=100,
help='learning minibatch size')
parser.add_argument('--test', action='store_true',
help='Use tiny datasets for quick tests')
args = parser.parse_args()

print('GPU: {}'.format(args.gpu))
print('# dim z: {}'.format(args.dimz))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# dim z: {}'.format(args.dim_z))
print('# Minibatch-size: {}'.format(args.batch_size))
print('# epoch: {}'.format(args.epoch))
print('')

# Prepare VAE model, defined in net.py
model = net.VAE(784, args.dimz, 500)
encoder = net.make_encoder(784, args.dim_z, args.dim_h)
decoder = net.make_decoder(784, args.dim_z, args.dim_h)
prior = net.make_prior(args.dim_z, device=args.gpu)
avg_elbo_loss = net.AvgELBOLoss(encoder, decoder, prior,
beta=args.beta, k=args.k)
This conversation was marked as resolved by ganow

This comment has been minimized.

Copy link
@toslunar

toslunar Nov 13, 2018

Member
Suggested change
beta=args.beta, k=args.k)
beta=args.beta, k=args.k)
if args.gpu >= 0:
avg_elbo_loss.to_gpu(args.gpu)

# Setup an optimizer
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
optimizer.setup(avg_elbo_loss)

# Initialize
if args.initmodel:
chainer.serializers.load_npz(args.initmodel, model)
chainer.serializers.load_npz(args.initmodel, avg_elbo_loss)

# Load the MNIST dataset
train, test = chainer.datasets.get_mnist(withlabel=False)

# Binarize dataset
train[train >= 0.5] = 1.0
train[train < 0.5] = 0.0
This conversation was marked as resolved by toslunar

This comment has been minimized.

Copy link
@toslunar

toslunar Nov 12, 2018

Member

How about train = (train >= 0.5).astype(np.float32)?

This comment has been minimized.

Copy link
@ganow

ganow Nov 12, 2018

Author Contributor

Thank you. It seems smarter than the previous way. I've changed it: 099a98c

test[test >= 0.5] = 1.0
test[test < 0.5] = 0.0

This comment has been minimized.

Copy link
@YoshikawaMasashi

YoshikawaMasashi Sep 17, 2018

Member

I'll implement not strict option for D.Bernoulli.

This comment has been minimized.

Copy link
@YoshikawaMasashi

This comment has been minimized.

Copy link
@YoshikawaMasashi

YoshikawaMasashi Sep 26, 2018

Member

this PR is merged. so you can use D.Bernoulli without binarization.


if args.test:
train, _ = chainer.datasets.split_dataset(train, 100)
test, _ = chainer.datasets.split_dataset(test, 100)

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
train_iter = chainer.iterators.SerialIterator(train, args.batch_size)
test_iter = chainer.iterators.SerialIterator(test, args.batch_size,
repeat=False, shuffle=False)

# Set up an updater. StandardUpdater can explicitly specify a loss function
# used in the training with 'loss_func' option
updater = training.updaters.StandardUpdater(
train_iter, optimizer,
device=args.gpu, loss_func=model.get_loss_func())
device=args.gpu, loss_func=avg_elbo_loss)

trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu,
eval_func=model.get_loss_func(k=10)))
trainer.extend(extensions.Evaluator(
test_iter, avg_elbo_loss, device=args.gpu))
trainer.extend(extensions.dump_graph('main/loss'))
trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/rec_loss', 'validation/main/rec_loss', 'elapsed_time']))
'main/rec', 'main/penalty', 'elapsed_time']))
This conversation was marked as resolved by toslunar

This comment has been minimized.

Copy link
@toslunar

toslunar Nov 12, 2018

Member

rec might be too short to guess its meaning.
I'd write reconstr and kl_div. What do you think?

This comment has been minimized.

Copy link
@ganow

ganow Nov 12, 2018

Author Contributor

Thank you for your recommendation. I think kl_div is still vague because this name does not specify KL divergence for what. Since I thought that the variable name of KL penalty should imply that it is the penalty term of the loss function, I have changed rec and penalty to reconstr and kl_penalty: add8bfe

trainer.extend(extensions.ProgressBar())

if args.resume:
@@ -90,25 +110,26 @@ def save_images(x, filename):
ai.imshow(xi.reshape(28, 28))
fig.savefig(filename)

model.to_cpu()
encoder.to_cpu()
decoder.to_cpu()
prior.to_cpu()
This conversation was marked as resolved by ganow

This comment has been minimized.

Copy link
@toslunar

toslunar Nov 13, 2018

Member

avg_elbo_loss.to_cpu() works. If we call to_cpu for each models, AvgELBOLoss doesn't need to be a Chain.

This comment has been minimized.

Copy link
@ganow

ganow Nov 13, 2018

Author Contributor
train_ind = [1, 3, 5, 10, 2, 0, 13, 15, 17]
x = chainer.Variable(np.asarray(train[train_ind]))
with chainer.using_config('train', False), chainer.no_backprop_mode():
x1 = model(x)
x1 = decoder(encoder(x).mean, inference=True).mean
save_images(x.data, os.path.join(args.out, 'train'))
save_images(x1.data, os.path.join(args.out, 'train_reconstructed'))

test_ind = [3, 2, 1, 18, 4, 8, 11, 17, 61]
x = chainer.Variable(np.asarray(test[test_ind]))
with chainer.using_config('train', False), chainer.no_backprop_mode():
x1 = model(x)
x1 = decoder(encoder(x).mean, inference=True).mean
save_images(x.data, os.path.join(args.out, 'test'))
save_images(x1.data, os.path.join(args.out, 'test_reconstructed'))

# draw images from randomly sampled z
z = chainer.Variable(
np.random.normal(0, 1, (9, args.dimz)).astype(np.float32))
x = model.decode(z)
z = prior().sample(9)
x = decoder(z, inference=True).mean
save_images(x.data, os.path.join(args.out, 'sampled'))


Oops, something went wrong.
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.