Skip to content

Commit

Permalink
Merge pull request #193 from data61/feature/#192
Browse files Browse the repository at this point in the history
Feature/#192
  • Loading branch information
dsteinberg committed Aug 22, 2018
2 parents 3a1cd7b + 1f94dc4 commit e973bde
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 121 deletions.
4 changes: 2 additions & 2 deletions aboleth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .version import __version__
from .losses import elbo, max_posterior
from .baselayers import stack
from .layers import (Activation, DropOut, MaxPool2D, Reshape, DenseVariational,
from .layers import (Activation, DropOut, MaxPool2D, Flatten, DenseVariational,
EmbedVariational, Conv2DVariational, DenseMAP, EmbedMAP,
Conv2DMAP, InputLayer, RandomFourier, RandomArcCosine)
from .hlayers import Concat, Sum, PerFeature
Expand All @@ -24,7 +24,7 @@
'Activation',
'DropOut',
'MaxPool2D',
'Reshape',
'Flatten',
'Conv2DVariational',
'DenseVariational',
'EmbedVariational',
Expand Down
66 changes: 16 additions & 50 deletions aboleth/initialisers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functions for initialising weights or distributions."""
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.init_ops import VarianceScaling

from aboleth.random import seedgen
from aboleth.util import pos, summary_histogram
Expand All @@ -23,55 +24,15 @@ def _autonorm_std(n_in, n_out):
To be used with SELU nonlinearities. See Klambaur et. al. 2017
(https://arxiv.org/pdf/1706.02515.pdf)
"""
std = 1. / np.sqrt(n_in + n_out)
std = 1. / np.sqrt(n_in)
return std


class _autonorm_initializer:
"""
Implements the auto-normalizing NN initialisation for regular (MAP) layers.
To be used with SELU nonlinearities. See Klambaur et. al. 2017
(https://arxiv.org/pdf/1706.02515.pdf)
Parameters
----------
seed : None, int
A seed for the random initialization.
dtype : tf.dtype
The numerical type for weight initialization.
"""

def __init__(self, seed=None, dtype=tf.float32):
"""Create an instance of the autonorm initializer."""
self.seed = seed
self.dtype = dtype

def __call__(self, shape):
"""
Call the autonorm initalizer.
Parameters
----------
shape : tuple, tf.TensorShape
The shape of the weight matrix to initialize.
Returns
-------
W : Tensor
The initial values of the weight matrix.
"""
std = 1. / np.sqrt(np.product(shape))
W = tf.random_normal(shape, mean=0., stddev=std, dtype=self.dtype,
seed=self.seed)
return W


_INIT_DICT = {"glorot": tf.glorot_uniform_initializer(seed=next(seedgen)),
"glorot_trunc": tf.glorot_normal_initializer(seed=next(seedgen)),
"autonorm": _autonorm_initializer(seed=next(seedgen))}

"autonorm": VarianceScaling(scale=1.0, mode="fan_in",
distribution="normal",
seed=next(seedgen))}

_PRIOR_DICT = {"glorot": _glorot_std,
"autonorm": _autonorm_std}
Expand All @@ -84,8 +45,9 @@ def initialise_weights(shape, init_fn):
Parameters
----------
shape : tuple, list
The shape of the weight matrix to initialise. Typically this is
3D ie of size (samples, input_size, output_size).
The shape of the weight matrix ``W``. This uses the same convention as
tnesorflow for weight shapes (see their initializers in
tensorflow.python.ops.init_ops).
init_fn : str, callable
The function to use to initialise the weights. The default is
'glorot_trunc', the truncated normal glorot function. If supplied,
Expand All @@ -101,14 +63,16 @@ def initialise_weights(shape, init_fn):
return W


def initialise_stds(shape, init_val, learn_prior, suffix):
def initialise_stds(n_in, n_out, init_val, learn_prior, suffix):
"""
Initialise the prior standard devation and initial poststerior.
Parameters
----------
shape : tuple, list
The shape of the matrix to initialise.
n_in : int
The total number of input units in the layer.
n_out : int
The total number of output units in the layer.
init_val : str, float
If a string, must be one of "glorot" or "autonorm", which will use
these methods to initialise a value. Otherwise, will use the provided
Expand All @@ -127,9 +91,11 @@ def initialise_stds(shape, init_val, learn_prior, suffix):
The initial value of the standard deviation
"""
# assert len(shape) == 2

if isinstance(init_val, str):
fn = _PRIOR_DICT[init_val]
std0 = fn(shape[-2], shape[-1])
std0 = fn(n_in, n_out)
else:
std0 = init_val
std0 = np.array(std0).astype(np.float32)
Expand Down
69 changes: 38 additions & 31 deletions aboleth/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Layer type base classes
#


class InputLayer(MultiLayer):
r"""Create an input layer.
Expand Down Expand Up @@ -46,8 +47,10 @@ def __init__(self, name, n_samples=1):
def _build(self, **kwargs):
"""Build the tiling input layer."""
X = kwargs[self.name]
# (n_samples, N, D)
Xs = tf.tile(tf.expand_dims(X, 0), [self.n_samples, 1, 1])
ndims = len(X.shape)
# tile like (n_samples, ...)
new_shape = [self.n_samples] + ([1] * ndims)
Xs = tf.tile(tf.expand_dims(X, 0), new_shape)
return Xs, 0.0


Expand Down Expand Up @@ -116,17 +119,11 @@ def __call__(self, X):
Net, KL = super(SampleLayer3, self).__call__(X)
return Net, KL

@staticmethod
def _get_X_dims(X):
"""Get the dimensions of the rank 3 input tensor, X."""
n_samples, (input_dim,) = SampleLayer._get_X_dims(X)
return n_samples, input_dim


#
# Activation and Transformation Layers
#


class Activation(Layer):
"""Activation function layer.
Expand Down Expand Up @@ -165,21 +162,25 @@ class DropOut(Layer):
This is so we can repeat the dropout pattern over observations, which
has the effect of dropping out weights consistently, thereby sampling
the "latent function" of the layer.
alpha : bool
Use alpha dropout (tf.contrib.nn.alpha_dropout) that maintains the self
normalising property of SNNs.
"""

def __init__(self, keep_prob, observation_axis=1):
def __init__(self, keep_prob, observation_axis=1, alpha=False):
"""Create an instance of a Dropout layer."""
self.keep_prob = keep_prob
self.obsax = observation_axis
self.dropout = tf.contrib.nn.alpha_dropout if alpha else tf.nn.dropout

def _build(self, X):
"""Build the graph of this layer."""
# Set noise shape to equivalent to different samples from posterior
# i.e. share the samples along the data-observations axis
noise_shape = tf.concat([tf.shape(X)[:self.obsax], [1],
tf.shape(X)[(self.obsax + 1):]], axis=0)
Net = tf.nn.dropout(X, self.keep_prob, noise_shape, seed=next(seedgen))
Net = self.dropout(X, self.keep_prob, noise_shape, seed=next(seedgen))
KL = 0.
return Net, KL

Expand Down Expand Up @@ -217,25 +218,21 @@ def _build(self, X):
return Net, KL


class Reshape(Layer):
"""Reshape layer.
class Flatten(Layer):
"""Flattening layer.
Reshape and output an tensor to a specified shape.
Reshape and output a tensor to be always rank 3 (keeps first dimension
which is samples, and second dimension which is observations).
Parameters
----------
target_shape : tuple of ints
Does not include the samples or batch axes.
I.e. if ``X.shape`` is ``(3, 100, 5, 5, 3)`` this flatten the last
dimensions to ``(3, 100, 75)``.
"""

def __init__(self, target_shape):
"""Initialize instance of a Reshape layer."""
self.target_shape = target_shape

def _build(self, X):
"""Build the graph of this layer."""
new_shape = (int(X.shape[0]), tf.shape(X)[1]) + self.target_shape
flat_dim = np.product(X.shape[2:])
new_shape = tf.concat([tf.shape(X)[0:2], [flat_dim]], 0)
Net = tf.reshape(X, new_shape)
KL = 0.
return Net, KL
Expand Down Expand Up @@ -273,7 +270,7 @@ def __init__(self, n_features, kernel):
def _build(self, X):
"""Build the graph of this layer."""
# Random weights
n_samples, input_dim = self._get_X_dims(X)
n_samples, (input_dim,) = self._get_X_dims(X)
dtype = X.dtype.as_numpy_dtype
P, KL = self.kernel.weights(input_dim, self.n_features, dtype)
Ps = tf.tile(tf.expand_dims(P, 0), [n_samples, 1, 1])
Expand Down Expand Up @@ -390,6 +387,7 @@ class Conv2DVariational(SampleLayer):
Whether to learn the prior standard deviation.
use_bias : bool
If true, also learn a bias weight, e.g. a constant offset weight.
"""

def __init__(self, filters, kernel_size, strides=(1, 1), padding='SAME',
Expand All @@ -407,7 +405,14 @@ def _build(self, X):
"""Build the graph of this layer."""
n_samples, (height, width, channels) = self._get_X_dims(X)
W_shp, b_shp = self._weight_shapes(channels)
self.pstd, self.qstd = initialise_stds(W_shp, self.prior_std0,

# get effective IO shapes, DAN's fault if this is wrong
receptive_field = np.product(W_shp[:-2])
n_inputs = receptive_field * channels
n_outputs = receptive_field * self.filters

self.pstd, self.qstd = initialise_stds(n_inputs, n_outputs,
self.prior_std0,
self.learn_prior, "conv2d")
# Layer weights
self.pW = _make_prior(self.pstd, W_shp)
Expand Down Expand Up @@ -531,10 +536,11 @@ def __init__(self, output_dim, prior_std=1., learn_prior=False, full=False,

def _build(self, X):
"""Build the graph of this layer."""
n_samples, input_dim = self._get_X_dims(X)
n_samples, (input_dim,) = self._get_X_dims(X)
W_shp, b_shp = self._weight_shapes(input_dim)

self.pstd, self.qstd = initialise_stds(W_shp, self.prior_std0,
self.pstd, self.qstd = initialise_stds(input_dim, self.output_dim,
self.prior_std0,
self.learn_prior, "dense")

# Layer weights
Expand Down Expand Up @@ -654,11 +660,12 @@ def __init__(self, output_dim, n_categories, prior_std=1.,

def _build(self, X):
"""Build the graph of this layer."""
n_samples, input_dim = self._get_X_dims(X)
n_samples, (input_dim,) = self._get_X_dims(X)
W_shape, _ = self._weight_shapes(self.n_categories)
n_batch = tf.shape(X)[1]

self.pstd, self.qstd = initialise_stds(W_shape, self.prior_std0,
self.pstd, self.qstd = initialise_stds(input_dim, self.output_dim,
self.prior_std0,
self.learn_prior, "embed")

# Layer weights
Expand Down Expand Up @@ -810,7 +817,7 @@ def __init__(self, output_dim, l1_reg=0., l2_reg=0., use_bias=True,
def _build(self, X):
"""Build the graph of this layer."""
n_samples, input_shape = self._get_X_dims(X)
Wdim = tuple(input_shape) + (self.output_dim,)
Wdim = input_shape + [self.output_dim]

W_init = initialise_weights(Wdim, self.init_fn)
W = tf.Variable(W_init, name="W_map")
Expand Down Expand Up @@ -884,7 +891,7 @@ def __init__(self, output_dim, n_categories, l1_reg=0., l2_reg=0.,

def _build(self, X):
"""Build the graph of this layer."""
n_samples, input_dim = self._get_X_dims(X)
n_samples, (input_dim,) = self._get_X_dims(X)
Wdim = (self.n_categories, self.output_dim)
n_batch = tf.shape(X)[1]

Expand Down
11 changes: 7 additions & 4 deletions demos/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@
n_samples_ = tf.placeholder_with_default(LSAMPLES, [])
net = ab.stack(
ab.InputLayer(name='X', n_samples=n_samples_),
ab.DropOut(0.95),
ab.DenseMAP(output_dim=64, l2_reg=REG, init_fn="autonorm"),
ab.DropOut(0.95, alpha=True),
ab.DenseMAP(output_dim=128, l2_reg=REG, init_fn="autonorm"),
ab.Activation(h=tf.nn.selu),
ab.DropOut(0.9),
ab.DropOut(0.9, alpha=True),
ab.DenseMAP(output_dim=64, l2_reg=REG, init_fn="autonorm"),
ab.Activation(h=tf.nn.selu),
ab.DropOut(0.9),
ab.DropOut(0.9, alpha=True),
ab.DenseMAP(output_dim=32, l2_reg=REG, init_fn="autonorm"),
ab.Activation(h=tf.nn.selu),
ab.DropOut(0.9, alpha=True),
ab.DenseMAP(output_dim=1, l2_reg=REG, init_fn="autonorm"),
)

Expand Down
8 changes: 3 additions & 5 deletions demos/mnist_softmax_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
# Network architecture
net = ab.stack(
ab.InputLayer(name='X', n_samples=l_samples), # LSAMPLES,BATCH_SIZE,28*28
ab.Reshape(target_shape=(28, 28, 1)), # LSAMPLES, BATCH_SIZE, 28, 28, 1

ab.Conv2DMAP(filters=32,
kernel_size=(5, 5),
l2_reg=reg), # LSAMPLES, BATCH_SIZE, 28, 28, 32
Expand All @@ -39,7 +37,7 @@
ab.MaxPool2D(pool_size=(2, 2),
strides=(2, 2)), # LSAMPLES, BATCH_SIZE, 7, 7, 64

ab.Reshape(target_shape=(7*7*64,)), # LSAMPLES, BATCH_SIZE, 7*7*64
ab.Flatten(), # LSAMPLES, BATCH_SIZE, 7*7*64

ab.DenseMAP(output_dim=1024,
l2_reg=reg), # LSAMPLES, BATCH_SIZE, 1024
Expand All @@ -55,9 +53,9 @@ def main():

# Dataset
mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets(
'./mnist_demo', reshape=True)
'./mnist_demo', reshape=False)

N, D = mnist_data.train.images.shape
N = mnist_data.train.images.shape[0]

X, Y = tf.data.Dataset.from_tensor_slices(
(np.asarray(mnist_data.train.images, dtype=np.float32),
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
'pytest-cov>=2.5.1',
'pytest-flake8>=0.8.1',
'flake8-docstrings>=1.1.0',
'scikit-learn>=0.18.1',
],
'demos': [
'bokeh>=0.12.4',
Expand Down

0 comments on commit e973bde

Please sign in to comment.