Skip to content

Commit

Permalink
Merge pull request #194 from data61/feature/#140
Browse files Browse the repository at this point in the history
Feature/#140
  • Loading branch information
lmccalman committed Aug 24, 2018
2 parents e973bde + a7101dd commit f0418d9
Show file tree
Hide file tree
Showing 16 changed files with 123 additions and 107 deletions.
4 changes: 2 additions & 2 deletions aboleth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .kernels import RBF, Matern, RBFVariational
from .distributions import (norm_prior, norm_posterior, gaus_posterior)
from .prediction import sample_mean, sample_percentiles, sample_model
from .util import (batch, pos, batch_prediction)
from .util import (batch, pos_variable, batch_prediction)
from .random import set_hyperseed

__all__ = (
Expand All @@ -40,7 +40,7 @@
'sample_percentiles',
'sample_model',
'batch',
'pos',
'pos_variable',
'batch_prediction',
'set_hyperseed',
'InputLayer',
Expand Down
49 changes: 25 additions & 24 deletions aboleth/distributions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Helper functions for model parameter distributions."""
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from tensorflow.contrib.distributions import MultivariateNormalTriL
from aboleth.util import pos_variable, summary_histogram

from aboleth.util import pos, summary_histogram
from aboleth.random import seedgen

JIT = 1e-15 # cholesky jitter


#
Expand Down Expand Up @@ -54,22 +55,23 @@ def norm_posterior(dim, std0, suffix=None):
Note
----
This will make tf.Variables on the randomly initialised mean and standard
deviation of the posterior. The initialisation of the mean is from a Normal
with zero mean, and ``std0`` standard deviation, and the initialisation of
the standard deviation is from a gamma distribution with an alpha of
``std0`` and a beta of 1.
This will make tf.Variables on the mean standard deviation of the
posterior. The initialisation of the mean is zero and the initialisation of
the standard deviation is simply ``std0`` for each element.
"""
# we have different values for each dimension on the first axis
mu_0 = np.zeros(dim, dtype=np.float32)
assert (np.ndim(std0) == 0) or (np.shape(std0) == dim)
mu_0 = tf.zeros(dim)
mu = tf.Variable(mu_0, name=_add_suffix("W_mu_q", suffix))
std = tf.Variable(std0, name=_add_suffix("W_std_q", suffix))

if np.ndim(std0) == 0:
std0 = tf.ones(dim) * std0

std = pos_variable(std0, name=_add_suffix("W_std_q", suffix))
summary_histogram(mu)
summary_histogram(std)

Q = tf.distributions.Normal(loc=mu, scale=pos(std))
Q = tf.distributions.Normal(loc=mu, scale=std)
return Q


Expand Down Expand Up @@ -98,31 +100,29 @@ def gaus_posterior(dim, std0, suffix=None):
Note
----
This will make tf.Variables on the randomly initialised mean and covariance
of the posterior. The initialisation of the mean is from a Normal with zero
mean, and ``std0`` standard deviation, and the initialisation of the (lower
triangular of the) covariance is from a gamma distribution with an alpha of
``std0`` and a beta of 1.
This will make tf.Variables on the mean and covariance of the posterior.
The initialisation of the mean is zero, and the initialisation of the
(lower triangular of the) covariance is from diagonal matrices with
diagonal elements taking the value of `std0`.
"""
o, i = dim

# Optimize only values in lower triangular
u, v = np.tril_indices(i)
indices = (u * i + v)[:, np.newaxis]
l0 = np.tile(np.eye(i), [o, 1, 1])[:, u, v].T
l0 = l0 * tf.random_gamma(alpha=std0, shape=l0.shape, seed=next(seedgen))
l0 = (np.tile(np.eye(i) * std0, [o, 1, 1])[:, u, v].T).astype(np.float32)
lflat = tf.Variable(l0, name=_add_suffix("W_cov_q", suffix))
Lt = tf.transpose(tf.scatter_nd(indices, lflat, shape=(i * i, o)))
L = tf.reshape(Lt, (o, i, i))

mu_0 = tf.random_normal((o, i), stddev=std0, seed=next(seedgen))
mu_0 = tf.zeros((o, i))
mu = tf.Variable(mu_0, name=_add_suffix("W_mu_q", suffix))

summary_histogram(mu)
summary_histogram(lflat)

Q = MultivariateNormalTriL(mu, L)
Q = tfp.distributions.MultivariateNormalTriL(mu, L)
return Q


Expand Down Expand Up @@ -153,13 +153,14 @@ def kl_sum(q, p):
return kl


@tf.distributions.RegisterKL(MultivariateNormalTriL, tf.distributions.Normal)
@tf.distributions.RegisterKL(tfp.distributions.MultivariateNormalTriL,
tf.distributions.Normal)
def _kl_gaussian_normal(q, p, name=None):
"""Gaussian-Normal Kullback Leibler divergence calculation.
Parameters
----------
q : tf.contrib.distributions.MultivariateNormalTriL
q : tfp.distributions.MultivariateNormalTriL
the approximating 'q' distribution(s).
p : tf.distributions.Normal
the prior 'p' distribution(s), ``p.scale`` should be a *scalar* value!
Expand Down Expand Up @@ -196,7 +197,7 @@ def _kl_gaussian_normal(q, p, name=None):

def _chollogdet(L):
"""Log det of a cholesky, where L is (..., D, D)."""
ldiag = pos(tf.matrix_diag_part(L)) # keep > 0, and no vanashing gradient
ldiag = tf.maximum(tf.abs(tf.matrix_diag_part(L)), JIT) # keep > 0
logdet = 2. * tf.reduce_sum(tf.log(ldiag))
return logdet

Expand Down
15 changes: 5 additions & 10 deletions aboleth/impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from aboleth.baselayers import MultiLayer
from aboleth.random import seedgen
from aboleth.util import pos, summary_histogram
from aboleth.util import pos_variable, summary_histogram


class MaskInputLayer(MultiLayer):
Expand Down Expand Up @@ -252,18 +252,13 @@ def _initialise_variables(self, X):
tf.random_normal(shape=(datadim,), seed=next(seedgen)),
name="impute_means"
)
impute_var = tf.Variable(
tf.random_gamma(alpha=1., shape=(datadim,), seed=next(seedgen)),
name="impute_vars"
)
std0 = tf.random_gamma(alpha=1., shape=(datadim,), seed=next(seedgen))
impute_std = pos_variable(std0, name="impute_std")

summary_histogram(impute_means)
summary_histogram(impute_var)
summary_histogram(impute_std)

self.normal = tf.distributions.Normal(
impute_means,
tf.sqrt(pos(impute_var))
)
self.normal = tf.distributions.Normal(impute_means, impute_std)

def _impute_columns(self, X_2D_zero):
"""Return random draws from an iid Normal for imputation."""
Expand Down
15 changes: 7 additions & 8 deletions aboleth/initialisers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Functions for initialising weights or distributions."""
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.init_ops import VarianceScaling

from aboleth.random import seedgen
from aboleth.util import pos, summary_histogram
from aboleth.util import pos_variable, summary_histogram


def _glorot_std(n_in, n_out):
Expand All @@ -30,9 +29,11 @@ def _autonorm_std(n_in, n_out):

_INIT_DICT = {"glorot": tf.glorot_uniform_initializer(seed=next(seedgen)),
"glorot_trunc": tf.glorot_normal_initializer(seed=next(seedgen)),
"autonorm": VarianceScaling(scale=1.0, mode="fan_in",
distribution="normal",
seed=next(seedgen))}
"autonorm": tf.variance_scaling_initializer(
scale=1.0,
mode="fan_in",
distribution="untruncated_normal",
seed=next(seedgen))}

_PRIOR_DICT = {"glorot": _glorot_std,
"autonorm": _autonorm_std}
Expand Down Expand Up @@ -91,8 +92,6 @@ def initialise_stds(n_in, n_out, init_val, learn_prior, suffix):
The initial value of the standard deviation
"""
# assert len(shape) == 2

if isinstance(init_val, str):
fn = _PRIOR_DICT[init_val]
std0 = fn(n_in, n_out)
Expand All @@ -101,7 +100,7 @@ def initialise_stds(n_in, n_out, init_val, learn_prior, suffix):
std0 = np.array(std0).astype(np.float32)

if learn_prior:
std = tf.Variable(pos(std0), name="prior_std_{}".format(suffix))
std = pos_variable(std0, name="prior_std_{}".format(suffix))
summary_histogram(std)
else:
std = std0
Expand Down
4 changes: 2 additions & 2 deletions aboleth/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from aboleth.random import seedgen
from aboleth.distributions import norm_posterior, kl_sum
from aboleth.util import pos, summary_histogram
from aboleth.util import pos_variable, summary_histogram


#
Expand Down Expand Up @@ -270,7 +270,7 @@ def _init_lenscale(given_lenscale, learn_lenscale, input_dim):
np.float32)

if learn_lenscale:
lenscale = tf.Variable(pos(given_lenscale), name="kernel_lenscale")
lenscale = pos_variable(given_lenscale, name="kernel_lenscale")
summary_histogram(lenscale)
else:
lenscale = given_lenscale
Expand Down
8 changes: 4 additions & 4 deletions aboleth/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def elbo(log_likelihood, KL, N):
.. code-block:: python
noise = tf.Variable(1.0)
likelihood = tf.distributions.Normal(loc=NN, scale=ab.pos(noise))
noise = ab.pos_variable(1.0)
likelihood = tf.distributions.Normal(loc=NN, scale=noise)
log_likelihood = likelihood.log_prob(Y)
where ``NN`` is our neural network, and ``Y`` are our targets.
Expand Down Expand Up @@ -92,8 +92,8 @@ def max_posterior(log_likelihood, regulariser):
.. code-block:: python
noise = tf.Variable(1.0)
likelihood = tf.distributions.Normal(loc=NN, scale=ab.pos(noise))
noise = ab.pos_variable(1.0)
likelihood = tf.distributions.Normal(loc=NN, scale=noise)
log_likelihood = likelihood.log_prob(Y)
where ``NN`` is our neural network, and ``Y`` are our targets.
Expand Down
72 changes: 42 additions & 30 deletions aboleth/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,28 @@
from aboleth.random import endless_permutations


def pos(X, minval=1e-15):
r"""Constrain a ``tf.Variable`` to be positive only.
At the moment this is implemented as:
:math:`\max(|\mathbf{X}|, \text{minval})`
This is fast and does not result in vanishing gradients, but will lead to
non-smooth gradients and more local minima. In practice we haven't noticed
this being a problem.
def pos_variable(initial_value, name=None, **kwargs):
"""Make a tf.Variable that will remain positive.
Parameters
----------
X : Tensor
any Tensor in which all elements will be made positive.
minval : float
the minimum "positive" value the resulting tensor will have.
initial_value : float, np.array, tf.Tensor
the initial value of the Variable.
name : string
the name to give the returned tensor.
kwargs : dict
optional arguments to give the created ``tf.Variable``.
Returns
-------
X : Tensor
a tensor the same shape as the input ``X`` but positively constrained.
Examples
--------
>>> X = tf.constant(np.array([1.0, -1.0, 0.0]))
>>> Xp = pos(X)
>>> with tf.Session():
... xp = Xp.eval()
>>> all(xp == np.array([1., 1., 1.e-15]))
True
var : tf.Tensor
a tf.Variable within a Tensor that will remain positive through
training.
"""
# Other alternatives could be:
# Xp = tf.exp(X) # Medium speed, but gradients tend to explode
# Xp = tf.nn.softplus(X) # Slow but well behaved!
Xp = tf.maximum(tf.abs(X), minval) # Faster, but more local optima
return Xp
var0 = tf.Variable(_inverse_softplus(initial_value), **kwargs)
var = tf.nn.softplus(var0, name=name)
return var


def batch(feed_dict, batch_size, n_iter=10000, N_=None):
Expand Down Expand Up @@ -141,3 +125,31 @@ def summary_histogram(values):
def __data_len(feed_dict):
N = feed_dict[list(feed_dict.keys())[0]].shape[0]
return N


def _inverse_softplus(x):
r"""Inverse softplus function for initialising values.
This is useful for when we want to constrain a value to be positive using a
softplus function, but we wish to specify an exact value for
initialisation.
Examples
--------
Say we wish a variable to be positive, and have an initial value of 1.,
>>> var = tf.nn.softplus(tf.Variable(1.0))
>>> with tf.Session() as sess:
... sess.run(tf.global_variables_initializer())
... print(var.eval())
1.3132616
If we use this function,
>>> var = tf.nn.softplus(tf.Variable(_inverse_softplus(1.0)))
>>> with tf.Session() as sess:
... sess.run(tf.global_variables_initializer())
... print(var.eval())
1.0
"""
x_prime = tf.log(tf.exp(x) - 1.)
return x_prime
6 changes: 3 additions & 3 deletions demos/imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

RSEED = 666
ab.set_hyperseed(RSEED)
CONFIG = tf.ConfigProto(device_count={'GPU': 1}) # Use GPU ?
CONFIG = tf.ConfigProto(device_count={'GPU': 0}) # Use GPU ?

FRAC_TEST = 0.1 # Fraction of data to use for hold-out testing
FRAC_MISSING = 0.2 # Fraction of data that is missing
Expand All @@ -34,7 +34,7 @@
# Optimization
NEPOCHS = 5 # Number of times to see the data in training
BSIZE = 100 # Mini batch size
LSAMPLES = 5 # Number of samples for training
LSAMPLES = 3 # Number of samples for training
PSAMPLES = 50 # Number of predictions samples


Expand Down Expand Up @@ -96,7 +96,7 @@ def main():

net = (
ab.Concat(cat_layers, con_layers) >>
ab.Activation(tf.nn.elu) >>
ab.Activation(tf.nn.selu) >>
ab.DenseVariational(output_dim=NCLASSES)
)

Expand Down
5 changes: 3 additions & 2 deletions demos/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
config = tf.ConfigProto(device_count={'GPU': 0}) # Use GPU? 0 is no

# Model initialisation
noise = tf.Variable(1.) # Likelihood st. dev. initialisation, and learning
NOISE = 1. # Likelihood st. dev. initialisation, and learning

# Random Fourier Features
kern = ab.RBF(learn_lenscale=True) # keep the length scale positive
Expand Down Expand Up @@ -91,7 +91,8 @@ def main():
# This is where we build the actual GP model
with tf.name_scope("Deepnet"):
phi, kl = net(X=X_)
ll = tf.distributions.Normal(loc=phi, scale=ab.pos(noise)).log_prob(Y_)
noise = ab.pos_variable(NOISE)
ll = tf.distributions.Normal(loc=phi, scale=noise).log_prob(Y_)
loss = ab.elbo(ll, kl, N)

# Set up the training graph
Expand Down

0 comments on commit f0418d9

Please sign in to comment.