Skip to content

Commit

Permalink
Update core layers
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Apr 1, 2016
1 parent a066cf8 commit bf4dab3
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 468 deletions.
5 changes: 2 additions & 3 deletions keras/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def get_config(self):
unitnorm = UnitNorm

from .utils.generic_utils import get_from_module


def get(identifier, kwargs=None):
return get_from_module(identifier, globals(), 'constraint', instantiate=True, kwargs=kwargs)
return get_from_module(identifier, globals(), 'constraint',
instantiate=True, kwargs=kwargs)
2 changes: 1 addition & 1 deletion keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ class Merge(Layer):
a list of layer instances. Must be more
than one layer/tensor.
mode: string or lambda/function. If string, must be one
of: 'sum', 'mul', 'concat', 'ave', 'join', 'cos', 'dot'.
of: 'sum', 'mul', 'concat', 'ave', 'cos', 'dot'.
If lambda/function, it should take as input a list of tensors
and return a single tensor.
concat_axis: integer, axis to use in mode `concat`.
Expand Down
15 changes: 6 additions & 9 deletions keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,17 +820,15 @@ def build(self, input_shape):
self.set_weights(self.initial_weights)
del self.initial_weights

def call(self, train=False):
X = self.get_input(train)
transform_weight = activations.sigmoid(K.dot(X, self.W_carry) + self.b_carry)
act = self.activation(K.dot(X, self.W) + self.b)
def call(self, x, mask=None):
transform_weight = activations.sigmoid(K.dot(x, self.W_carry) + self.b_carry)
act = self.activation(K.dot(x, self.W) + self.b)
act *= transform_weight
output = act + (1 - transform_weight) * X
output = act + (1 - transform_weight) * x
return output

def get_config(self):
config = {'name': self.__class__.__name__,
'init': self.init.__name__,
config = {'init': self.init.__name__,
'transform_bias': self.transform_bias,
'activation': self.activation.__name__,
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
Expand Down Expand Up @@ -968,8 +966,7 @@ def call(self, x, mask=None):
return y

def get_config(self):
config = {'name': self.__class__.__name__,
'output_dim': self.output_dim,
config = {'output_dim': self.output_dim,
'init': self.init.__name__,
'activation': self.activation.__name__,
'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
Expand Down
8 changes: 8 additions & 0 deletions keras/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ def get_from_module(identifier, module_params, module_name,
return res(**kwargs)
else:
return res
elif type(identifier) is dict:
name = identifier.pop('name')
res = module_params.get(name)
if res:
return res(**identifier)
else:
raise Exception('Invalid ' + str(module_name) + ': ' +
str(identifier))
return identifier


Expand Down
4 changes: 2 additions & 2 deletions keras/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_test_data(nb_train=1000, nb_test=500, input_shape=(10,),
return (X[:nb_train], y[:nb_train]), (X[nb_train:], y[nb_train:])


def test_layer(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
input_data=None, expected_output=None):
'''Test routine for a layer with a single input tensor
and single output tensor.
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_layer(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
x = Input(shape=input_shape[1:], dtype=input_dtype)
y = layer(x)
model = Model(input=x, output=y)
model.compile('rmsprop', 'mse')
model.compile('rmsprop', 'mse', mode='FAST_COMPILE')

expected_output_shape = layer.get_output_shape_for(input_shape)
actual_output = model.predict(input_data)
Expand Down
27 changes: 0 additions & 27 deletions tests/keras/engine/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,6 @@
from keras.models import model_from_json, model_from_yaml


def test_lambda_serialization():
from keras.layers import Lambda
from keras.utils.layer_utils import layer_from_config
ld = Lambda(lambda x: x + 1)
config = ld.get_config()
ld = Lambda.from_config(config)

def f(x):
return x + 1
ld = Lambda(f)
config = ld.get_config()
ld = layer_from_config({'class_name': 'Lambda', 'config': config})

ld = Lambda(lambda x: K.concatenate([K.square(x), x]),
output_shape=lambda s: tuple(list(s)[:-1] + [2 * s[-1]]))
config = ld.get_config()
ld = Lambda.from_config(config)

def f(x):
return K.concatenate([K.square(x), x])
def f_shape(s):
return tuple(list(s)[:-1] + [2 * s[-1]])
ld = Lambda(f, output_shape=f_shape)
config = ld.get_config()
ld = layer_from_config({'class_name': 'Lambda', 'config': config})


def test_learning_phase():
a = Input(shape=(32,), name='input_a')
b = Input(shape=(32,), name='input_b')
Expand Down
16 changes: 8 additions & 8 deletions tests/keras/layers/test_advanced_activations.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,51 @@
import pytest
from keras.utils.test_utils import test_layer
from keras.utils.test_utils import layer_test


def test_leaky_relu():
from keras.layers.advanced_activations import LeakyReLU
for alpha in [0., .5, -1.]:
test_layer(LeakyReLU, kwargs={'alpha': alpha},
layer_test(LeakyReLU, kwargs={'alpha': alpha},
input_shape=(2, 3, 4))


def test_prelu():
from keras.layers.advanced_activations import PReLU
test_layer(PReLU, kwargs={},
layer_test(PReLU, kwargs={},
input_shape=(2, 3, 4))


def test_elu():
from keras.layers.advanced_activations import ELU
for alpha in [0., .5, -1.]:
test_layer(ELU, kwargs={'alpha': alpha},
layer_test(ELU, kwargs={'alpha': alpha},
input_shape=(2, 3, 4))


def test_parametric_softplus():
from keras.layers.advanced_activations import ParametricSoftplus
for alpha in [0., .5, -1.]:
test_layer(ParametricSoftplus,
layer_test(ParametricSoftplus,
kwargs={'alpha_init': 1.,
'beta_init': -1},
input_shape=(2, 3, 4))


def test_thresholded_linear():
from keras.layers.advanced_activations import ThresholdedLinear
test_layer(ThresholdedLinear, kwargs={'theta': 0.5},
layer_test(ThresholdedLinear, kwargs={'theta': 0.5},
input_shape=(2, 3, 4))


def test_thresholded_relu():
from keras.layers.advanced_activations import ThresholdedReLU
test_layer(ThresholdedReLU, kwargs={'theta': 0.5},
layer_test(ThresholdedReLU, kwargs={'theta': 0.5},
input_shape=(2, 3, 4))


def test_srelu():
from keras.layers.advanced_activations import SReLU
test_layer(SReLU, kwargs={},
layer_test(SReLU, kwargs={},
input_shape=(2, 3, 4))


Expand Down
212 changes: 0 additions & 212 deletions tests/keras/layers/test_call.py

This file was deleted.

0 comments on commit bf4dab3

Please sign in to comment.