Skip to content

Commit

Permalink
Merge pull request #189 from data61/feature/#185
Browse files Browse the repository at this point in the history
Feature/#185
  • Loading branch information
dsteinberg committed Aug 3, 2018
2 parents 77d0582 + c09a77b commit b824fc6
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 57 deletions.
100 changes: 100 additions & 0 deletions aboleth/initialisers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Functions for initialising weights or distributions."""
import numpy as np
import tensorflow as tf

from aboleth.random import seedgen
from aboleth.util import pos, summary_histogram


_INIT_DICT = {"glorot": tf.glorot_uniform_initializer(seed=next(seedgen)),
"glorot_trunc": tf.glorot_normal_initializer(seed=next(seedgen))}


def _glorot_std(n_in, n_out):
"""
Compute the standard deviation for initialising weights.
See Glorot and Bengio, AISTATS2010.
"""
std = 1. / np.sqrt(3 * (n_in + n_out))
return std


def _autonorm_std(n_in, n_out):
"""
Compute the auto-normalizing NN initialisation.
To be used with SELU nonlinearities. See Klambaur et. al. 2017
(https://arxiv.org/pdf/1706.02515.pdf)
"""
std = 1. / np.sqrt(n_in + n_out)
return std


_PRIOR_DICT = {"glorot": _glorot_std,
"autonorm": _autonorm_std}


def initialise_weights(shape, init_fn):
"""
Draw random initial weights using the specified function or method.
Parameters
----------
shape : tuple, list
The shape of the weight matrix to initialise. Typically this is
3D ie of size (samples, input_size, output_size).
init_fn : str, callable
The function to use to initialise the weights. The default is
'glorot_trunc', the truncated normal glorot function. If supplied,
the callable takes a shape (input_dim, output_dim) as an argument
and returns the weight matrix.
"""
if isinstance(init_fn, str):
fn = _INIT_DICT[init_fn]
else:
fn = init_fn
W = fn(shape)
return W


def initialise_stds(shape, init_val, learn_prior, suffix):
"""
Initialise the prior standard devation and initial poststerior.
Parameters
----------
shape : tuple, list
The shape of the matrix to initialise.
init_val : str, float
If a string, must be one of "glorot" or "autonorm", which will use
these methods to initialise a value. Otherwise, will use the provided
float to initialise.
learn_prior : bool
Whether to learn the prior or not. If true, will make the prior
a variable.
suffix : str
A string used to name the variable so Tensorboard can track it.
Returns
-------
std : tf.Variable, np.array
The standard deviation value/variable
std0 :
The initial value of the standard deviation
"""
if isinstance(init_val, str):
fn = _PRIOR_DICT[init_val]
std0 = fn(shape[-2], shape[-1])
else:
std0 = init_val
std0 = np.array(std0).astype(np.float32)

if learn_prior:
std = tf.Variable(pos(std0), name="prior_std_{}".format(suffix))
summary_histogram(std)
else:
std = std0
return std, std0
110 changes: 67 additions & 43 deletions aboleth/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from aboleth.distributions import (norm_prior, norm_posterior, gaus_posterior,
kl_sum)
from aboleth.baselayers import Layer, MultiLayer
from aboleth.util import summary_histogram, pos
from aboleth.util import summary_histogram
from aboleth.initialisers import initialise_weights, initialise_stds


#
Expand Down Expand Up @@ -380,30 +381,34 @@ class Conv2DVariational(SampleLayer):
padding : str
One of 'SAME' or 'VALID'. Defaults to 'SAME'. The type of padding
algorithm to use.
prior_std : float
the value of the weight prior standard deviation (:math:`\sigma` above)
prior_std : str, float
the value of the weight prior standard deviation
(:math:`\sigma` above). The user can also provide a string to specify
an initialisation function. Defaults to 'glorot'. If a string,
must be one of 'glorot' or 'autonorm'.
learn_prior: bool, optional
Whether to learn the prior standard deviation.
use_bias : bool
If true, also learn a bias weight, e.g. a constant offset weight.
"""

def __init__(self, filters, kernel_size, strides=(1, 1), padding='SAME',
prior_std=1., learn_prior=False, use_bias=True):
prior_std='glorot', learn_prior=False, use_bias=True):
"""Create and instance of a variational Conv2D layer."""
self.pstd = _make_prior_std(prior_std, learn_prior, "conv2d")
self.qstd = prior_std
self.filters = filters
self.kernel_size = kernel_size
self.strides = [1] + list(strides) + [1]
self.padding = padding
self.use_bias = use_bias
self.prior_std0 = prior_std
self.learn_prior = learn_prior

def _build(self, X):
"""Build the graph of this layer."""
n_samples, (height, width, channels) = self._get_X_dims(X)
W_shp, b_shp = self._weight_shapes(channels)

self.pstd, self.qstd = initialise_stds(W_shp, self.prior_std0,
self.learn_prior, "conv2d")
# Layer weights
self.pW = _make_prior(self.pstd, W_shp)
self.qW = _make_posterior(self.qstd, W_shp, False, "conv")
Expand Down Expand Up @@ -499,8 +504,11 @@ class DenseVariational(SampleLayer3):
----------
output_dim : int
the dimension of the output of this layer
prior_std : float
the value of the weight prior standard deviation (:math:`\sigma` above)
prior_std : str, float
the value of the weight prior standard deviation
(:math:`\sigma` above). The user can also provide a string to specify
an initialisation function. Defaults to 'glorot'. If a string,
must be one of 'glorot' or 'autonorm'.
learn_prior : bool, optional
Whether to learn the prior
full : bool
Expand All @@ -516,16 +524,19 @@ def __init__(self, output_dim, prior_std=1., learn_prior=False, full=False,
use_bias=True):
"""Create and instance of a variational dense layer."""
self.output_dim = output_dim
self.pstd = _make_prior_std(prior_std, learn_prior, "dense")
self.qstd = prior_std
self.full = full
self.use_bias = use_bias
self.prior_std0 = prior_std
self.learn_prior = learn_prior

def _build(self, X):
"""Build the graph of this layer."""
n_samples, input_dim = self._get_X_dims(X)
W_shp, b_shp = self._weight_shapes(input_dim)

self.pstd, self.qstd = initialise_stds(W_shp, self.prior_std0,
self.learn_prior, "dense")

# Layer weights
self.pW = _make_prior(self.pstd, W_shp)
self.qW = _make_posterior(self.qstd, W_shp, self.full, "dense")
Expand Down Expand Up @@ -617,8 +628,11 @@ class EmbedVariational(DenseVariational):
the dimension of the output (embedding) of this layer
n_categories : int
the number of categories in the input variable
prior_std : float
the value of the weight prior standard deviation (:math:`\sigma` above)
prior_std : str, float
the value of the weight prior standard deviation
(:math:`\sigma` above). The user can also provide a string to specify
an initialisation function. Defaults to 'glorot'. If a string,
must be one of 'glorot' or 'autonorm'.
learn_prior : bool, optional
Whether to learn the prior
full : bool
Expand All @@ -633,17 +647,20 @@ def __init__(self, output_dim, n_categories, prior_std=1.,
"""Create and instance of a variational dense embedding layer."""
assert n_categories >= 2, "Need 2 or more categories for embedding!"
self.output_dim = output_dim
self.pstd = _make_prior_std(prior_std, learn_prior, "embed")
self.qstd = prior_std
self.n_categories = n_categories
self.full = full
self.prior_std0 = prior_std
self.learn_prior = learn_prior

def _build(self, X):
"""Build the graph of this layer."""
n_samples, input_dim = self._get_X_dims(X)
W_shape, _ = self._weight_shapes(self.n_categories)
n_batch = tf.shape(X)[1]

self.pstd, self.qstd = initialise_stds(W_shape, self.prior_std0,
self.learn_prior, "embed")

# Layer weights
self.pW = _make_prior(self.pstd, W_shape)
self.qW = _make_posterior(self.qstd, W_shape, self.full, "embed")
Expand Down Expand Up @@ -692,11 +709,16 @@ class Conv2DMAP(SampleLayer):
:math:`\frac{1}{2} \text{l2_reg} \times \|\mathbf{W}\|^2_2`
use_bias : bool
If true, also learn a bias weight, e.g. a constant offset weight.
init_fn : str, callable
The function to use to initialise the weights. The default is
'glorot_trunc', the truncated normal glorot function. If supplied,
the callable takes a shape (input_dim, output_dim) as an argument
and returns the weight matrix.
"""

def __init__(self, filters, kernel_size, strides=(1, 1), padding='SAME',
l1_reg=0., l2_reg=0., use_bias=True):
l1_reg=0., l2_reg=0., use_bias=True, init_fn='glorot_trunc'):
"""Create and instance of a variational Conv2D layer."""
self.filters = filters
self.kernel_size = kernel_size
Expand All @@ -705,17 +727,15 @@ def __init__(self, filters, kernel_size, strides=(1, 1), padding='SAME',
self.l1 = l1_reg
self.l2 = l2_reg
self.use_bias = use_bias
self.init_fn = init_fn

def _build(self, X):
"""Build the graph of this layer."""
n_samples, (height, width, channels) = self._get_X_dims(X)
W_shape, b_shape = self._weight_shapes(channels)

W = tf.Variable(tf.truncated_normal(
shape=W_shape,
seed=next(seedgen)),
name="W_map"
)
W_init = initialise_weights(W_shape, self.init_fn)
W = tf.Variable(W_init, name="W_map")
summary_histogram(W)

Net = tf.map_fn(
Expand All @@ -727,11 +747,8 @@ def _build(self, X):

# Optional Bias
if self.use_bias:
b = tf.Variable(tf.truncated_normal(
shape=b_shape,
seed=next(seedgen)),
name="b_map"
)
b_init = initialise_weights(b_shape, self.init_fn)
b = tf.Variable(b_init, name="b_map")
summary_histogram(b)

Net = tf.nn.bias_add(Net, b)
Expand Down Expand Up @@ -773,23 +790,31 @@ class DenseMAP(SampleLayer):
:math:`\frac{1}{2} \text{l2_reg} \times \|\mathbf{W}\|^2_2`
use_bias : bool
If true, also learn a bias weight, e.g. a constant offset weight.
init_fn : str, callable
The function to use to initialise the weights. The default is
'glorot', the uniform glorot function. If supplied,
the callable takes a shape (input_dim, output_dim) as an argument
and returns the weight matrix.
"""

def __init__(self, output_dim, l1_reg=0., l2_reg=0., use_bias=True):
def __init__(self, output_dim, l1_reg=0., l2_reg=0., use_bias=True,
init_fn='glorot'):
"""Create and instance of a dense layer with MAP regularizers."""
self.output_dim = output_dim
self.l1 = l1_reg
self.l2 = l2_reg
self.use_bias = use_bias
self.init_fn = init_fn

def _build(self, X):
"""Build the graph of this layer."""
n_samples, input_shape = self._get_X_dims(X)
Wdim = tuple(input_shape) + (self.output_dim,)

W = tf.Variable(tf.random_normal(shape=Wdim, seed=next(seedgen)),
name="W_map")
W_init = initialise_weights(Wdim, self.init_fn)
W = tf.Variable(W_init, name="W_map")
summary_histogram(W)

# We don't want to copy tf.Variable W so map over X
Expand All @@ -800,8 +825,8 @@ def _build(self, X):

# Optional Bias
if self.use_bias is True:
b = tf.Variable(tf.random_normal(shape=(1, self.output_dim),
seed=next(seedgen)), name="b_map")
b_init = initialise_weights((1, self.output_dim), self.init_fn)
b = tf.Variable(b_init, name="b_map")
summary_histogram(b)

Net += b
Expand Down Expand Up @@ -839,25 +864,33 @@ class EmbedMAP(SampleLayer3):
l2_reg : float
the value of the l2 weight regularizer,
:math:`\frac{1}{2} \text{l2_reg} \times \|\mathbf{W}\|^2_2`
init_fn : str, callable
The function to use to initialise the weights. The default is
'glorot', the uniform glorot function. If supplied,
the callable takes a shape (input_dim, output_dim) as an argument
and returns the weight matrix.
"""

def __init__(self, output_dim, n_categories, l1_reg=0., l2_reg=0.):
def __init__(self, output_dim, n_categories, l1_reg=0., l2_reg=0.,
init_fn='glorot'):
"""Create and instance of a MAP embedding layer."""
assert n_categories >= 2, "Need 2 or more categories for embedding!"
self.output_dim = output_dim
self.n_categories = n_categories
self.l1 = l1_reg
self.l2 = l2_reg
self.init_fn = init_fn

def _build(self, X):
"""Build the graph of this layer."""
n_samples, input_dim = self._get_X_dims(X)
Wdim = (self.n_categories, self.output_dim)
n_batch = tf.shape(X)[1]

W = tf.Variable(tf.random_normal(shape=Wdim, seed=next(seedgen)),
name="W_map")
W_init = initialise_weights(Wdim, self.init_fn)
W = tf.Variable(W_init, name="W_map")
summary_histogram(W)

# Index into the relevant weights rather than using sparse matmul
Expand Down Expand Up @@ -917,12 +950,3 @@ def _make_posterior(std, weight_shape, full, suffix=None):
assert _is_dim(post_W, weight_shape), \
"Posterior inconsistent dimension!"
return post_W


def _make_prior_std(std, learn_prior, suffix=None):
if learn_prior:
x = tf.Variable(pos(std), name="prior_std_{}".format(suffix))
summary_histogram(x)
else:
x = std
return x

0 comments on commit b824fc6

Please sign in to comment.