Skip to content

Commit

Permalink
Merge pull request #191 from data61/feature/#190
Browse files Browse the repository at this point in the history
implements autonorm initialization for MAP layers, closes #190
  • Loading branch information
dsteinberg committed Aug 14, 2018
2 parents b824fc6 + d59895f commit 3a1cd7b
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 13 deletions.
50 changes: 46 additions & 4 deletions aboleth/initialisers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
from aboleth.util import pos, summary_histogram


_INIT_DICT = {"glorot": tf.glorot_uniform_initializer(seed=next(seedgen)),
"glorot_trunc": tf.glorot_normal_initializer(seed=next(seedgen))}


def _glorot_std(n_in, n_out):
"""
Compute the standard deviation for initialising weights.
Expand All @@ -31,6 +27,52 @@ def _autonorm_std(n_in, n_out):
return std


class _autonorm_initializer:
"""
Implements the auto-normalizing NN initialisation for regular (MAP) layers.
To be used with SELU nonlinearities. See Klambaur et. al. 2017
(https://arxiv.org/pdf/1706.02515.pdf)
Parameters
----------
seed : None, int
A seed for the random initialization.
dtype : tf.dtype
The numerical type for weight initialization.
"""

def __init__(self, seed=None, dtype=tf.float32):
"""Create an instance of the autonorm initializer."""
self.seed = seed
self.dtype = dtype

def __call__(self, shape):
"""
Call the autonorm initalizer.
Parameters
----------
shape : tuple, tf.TensorShape
The shape of the weight matrix to initialize.
Returns
-------
W : Tensor
The initial values of the weight matrix.
"""
std = 1. / np.sqrt(np.product(shape))
W = tf.random_normal(shape, mean=0., stddev=std, dtype=self.dtype,
seed=self.seed)
return W


_INIT_DICT = {"glorot": tf.glorot_uniform_initializer(seed=next(seedgen)),
"glorot_trunc": tf.glorot_normal_initializer(seed=next(seedgen)),
"autonorm": _autonorm_initializer(seed=next(seedgen))}


_PRIOR_DICT = {"glorot": _glorot_std,
"autonorm": _autonorm_std}

Expand Down
1 change: 0 additions & 1 deletion aboleth/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,6 @@ class DenseMAP(SampleLayer):
the callable takes a shape (input_dim, output_dim) as an argument
and returns the weight matrix.
"""

def __init__(self, output_dim, l1_reg=0., l2_reg=0., use_bias=True,
Expand Down
14 changes: 7 additions & 7 deletions demos/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
net = ab.stack(
ab.InputLayer(name='X', n_samples=n_samples_),
ab.DropOut(0.95),
ab.DenseMAP(output_dim=64, l2_reg=REG),
ab.Activation(h=tf.nn.relu),
ab.DropOut(0.5),
ab.DenseMAP(output_dim=64, l2_reg=REG),
ab.Activation(h=tf.nn.relu),
ab.DropOut(0.5),
ab.DenseMAP(output_dim=1, l2_reg=REG),
ab.DenseMAP(output_dim=64, l2_reg=REG, init_fn="autonorm"),
ab.Activation(h=tf.nn.selu),
ab.DropOut(0.9),
ab.DenseMAP(output_dim=64, l2_reg=REG, init_fn="autonorm"),
ab.Activation(h=tf.nn.selu),
ab.DropOut(0.9),
ab.DenseMAP(output_dim=1, l2_reg=REG, init_fn="autonorm"),
)


Expand Down
18 changes: 17 additions & 1 deletion tests/test_initialisers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ def test_autonorm_std():
assert np.allclose(result, 1. / np.sqrt(31))


def test_autonorm_initializer():
init_fn = ab.initialisers._autonorm_initializer()
shape = (1000, 20, 3)
std = 1. / np.sqrt(np.product(shape))
W_init = init_fn(shape)

tc = tf.test.TestCase()
with tc.test_session():
W = W_init.eval()

assert np.allclose(0., np.mean(W), atol=1e-4)
assert np.allclose(std, np.std(W), atol=1e-4)


def test_initialise_weights(mocker):
mocker.patch.dict("aboleth.initialisers._INIT_DICT",
{"foo": lambda x: "bar"})
Expand Down Expand Up @@ -44,6 +58,8 @@ def test_initialise_stds(mocker):
std, std0 = ab.initialisers.initialise_stds(shape, init_val, learn_prior,
suffix)
assert std.name == 'prior_std_bar:0'
with tf.Session():

tc = tf.test.TestCase()
with tc.test_session():
assert std.initial_value.eval() == 10.0

0 comments on commit 3a1cd7b

Please sign in to comment.