Skip to content

Commit

Permalink
add Scaled Exponential Linear Unit activation (#6924)
Browse files Browse the repository at this point in the history
* add Scaled Exponential Linear Unit activation

* selu: hardcode alpha and scale variable

* add AlphaDropout (from SELU), K.floor backend function, and tests

* move AlphaDropout from core layers to noise layers

* fix pep8 and tensorflow backend failure

* undo add (delete) K.floor on backends

* undo add (delete): selu in check_single_tensor_operation

* [skip ci] edit docstring

remove `alpha` and `scale` from docstring

* add(initializers): selu_normal

* fix: use 'q' instead of 'rate'

* [skip ci] update comment

* feat(AlphaDropout): add 'noise_shape' param back

* add SpatialAlphaDropout1D layer

* [skip ci] update comment

* [skip ci] remove unnecessary check

* update equation

* remove spatialalphadropout test

* fix(AlphaDropout): add get_config method

* [skip ci] s/selu_normal/lecun_normal

* [skip ci] fix docstring to LeCun normal init
  • Loading branch information
luthfianto authored and fchollet committed Jun 14, 2017
1 parent 4a6f06f commit 21cf507
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 1 deletion.
14 changes: 14 additions & 0 deletions keras/activations.py
Expand Up @@ -34,6 +34,20 @@ def elu(x, alpha=1.0):
return K.elu(x, alpha)


def selu(x):
"""Scaled Exponential Linear Unit. (Klambauer et al., 2017)
# Arguments
x: A tensor or variable to compute the activation function for.
# References
- [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale * K.elu(x, alpha)


def softplus(x):
return K.softplus(x)

Expand Down
23 changes: 23 additions & 0 deletions keras/initializers.py
Expand Up @@ -371,6 +371,29 @@ def he_normal(seed=None):
seed=seed)


def lecun_normal(seed=None):
"""LeCun normal initializer.
It draws samples from a truncated normal distribution centered on 0
with `stddev = sqrt(1 / fan_in)`
where `fan_in` is the number of input units in the weight tensor.
# Arguments
seed: A Python integer. Used to seed the random generator.
# Returns
An initializer.
# References
- [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
- [Efficient Backprop](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
"""
return VarianceScaling(scale=1.,
mode='fan_in',
distribution='normal',
seed=seed)


def he_uniform(seed=None):
"""He uniform variance scaling initializer.
Expand Down
68 changes: 68 additions & 0 deletions keras/layers/noise.py
Expand Up @@ -5,6 +5,7 @@
from .. import backend as K
import numpy as np
from ..legacy import interfaces
from ..engine import InputSpec


class GaussianNoise(Layer):
Expand Down Expand Up @@ -90,3 +91,70 @@ def get_config(self):
config = {'rate': self.rate}
base_config = super(GaussianDropout, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class AlphaDropout(Layer):
"""Applies Alpha Dropout to the input.
Alpha Dropout is a `Dropout` that keeps mean and variance of inputs
to their original values, in order to ensure the self-normalizing property
even after this dropout.
Alpha Dropout fits well to Scaled Exponential Linear Units
by randomly setting activations to the negative saturation value.
# Arguments
rate: float, drop probability (as with `Dropout`).
The multiplicative noise will have
standard deviation `sqrt(rate / (1 - rate))`.
seed: A Python integer to use as random seed.
# Input shape
Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Same shape as input.
# References
- [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
"""
def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
super(AlphaDropout, self).__init__(**kwargs)
self.rate = rate
self.noise_shape = noise_shape
self.seed = seed
self.supports_masking = True

def _get_noise_shape(self, inputs):
return self.noise_shape if self.noise_shape else K.shape(inputs)

def call(self, inputs, training=None):
if 0. < self.rate < 1.:
noise_shape = self._get_noise_shape(inputs)

def dropped_inputs(inputs=inputs, rate=self.rate, seed=self.seed):
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
alpha_p = -alpha * scale

kept_idx = K.greater_equal(K.random_uniform(noise_shape, seed=seed), rate)
kept_idx = K.cast(kept_idx, K.floatx())

# Get affine transformation params
a = ((1 - rate) * (1 + rate * alpha_p ** 2)) ** -0.5
b = -a * alpha_p * rate

# Apply mask
x = inputs * kept_idx + alpha_p * (1 - kept_idx)

# Do affine transformation
return a * x + b

return K.in_train_phase(dropped_inputs, inputs, training=training)
return inputs

def get_config(self):
config = {'rate': self.rate}
base_config = super(AlphaDropout, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
25 changes: 24 additions & 1 deletion tests/keras/activations_test.py
Expand Up @@ -15,7 +15,7 @@ def get_standard_values():
def test_serialization():
all_activations = ['softmax', 'relu', 'elu', 'tanh',
'sigmoid', 'hard_sigmoid', 'linear',
'softplus', 'softsign']
'softplus', 'softsign', 'selu']
for name in all_activations:
fn = activations.get(name)
ref_fn = getattr(activations, name)
Expand Down Expand Up @@ -146,6 +146,29 @@ def test_elu():
assert_allclose(result, true_result)


def test_selu():
x = K.placeholder(ndim=2)
f = K.function([x], [activations.selu(x)])
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946

positive_values = get_standard_values()
result = f([positive_values])[0]
assert_allclose(result, positive_values * scale, rtol=1e-05)

negative_values = np.array([[-1, -2]], dtype=K.floatx())

# cntk can't rebind the input shape, so create the model again to test different batch size
if (K.backend() == 'cntk'):
x2 = K.placeholder(ndim=2)
f = K.function([x2], [activations.selu(x2)])

result = f([negative_values])[0]
true_result = (np.exp(negative_values) - 1) * scale * alpha

assert_allclose(result, true_result)


def test_tanh():
test_values = get_standard_values()

Expand Down
8 changes: 8 additions & 0 deletions tests/keras/initializers_test.py
Expand Up @@ -90,6 +90,14 @@ def test_he_normal(tensor_shape):
target_mean=0., target_std=None, target_max=2 * scale)


@pytest.mark.parametrize('tensor_shape', [FC_SHAPE, CONV_SHAPE], ids=['FC', 'CONV'])
def test_lecun_normal(tensor_shape):
fan_in, _ = initializers._compute_fans(tensor_shape)
scale = np.sqrt(1. / fan_in)
_runner(initializers.lecun_normal(), tensor_shape,
target_mean=0., target_std=scale)


@pytest.mark.parametrize('tensor_shape', [FC_SHAPE, CONV_SHAPE], ids=['FC', 'CONV'])
def test_orthogonal(tensor_shape):
_runner(initializers.orthogonal(), tensor_shape,
Expand Down
9 changes: 9 additions & 0 deletions tests/keras/layers/noise_test.py
Expand Up @@ -23,5 +23,14 @@ def test_GaussianDropout():
input_shape=(3, 2, 3))


@keras_test
@pytest.mark.skipif((K.backend() == 'cntk'),
reason="cntk does not support it yet")
def test_AlphaDropout():
layer_test(noise.AlphaDropout,
kwargs={'rate': 0.1},
input_shape=(3, 2, 3))


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 21cf507

Please sign in to comment.