From ff62eb251b04b8301e71aee970bdb157f2649fa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Chollet?= Date: Wed, 14 Dec 2016 13:41:24 -0800 Subject: [PATCH] Refactor regularizers and add add_weight method. (#4703) * Refactor regularizers, introduce layer.add_weight * Fix BN add_update syntax * Fix eigenvalue regularizer * Style fixes. --- keras/backend/theano_backend.py | 2 +- keras/engine/topology.py | 180 ++++++++++++++++--- keras/engine/training.py | 7 +- keras/layers/convolutional.py | 155 ++++++----------- keras/layers/core.py | 164 ++++++------------ keras/layers/embeddings.py | 21 +-- keras/layers/local.py | 66 +++---- keras/layers/normalization.py | 36 ++-- keras/layers/recurrent.py | 249 ++++++++++++++------------- keras/layers/wrappers.py | 13 +- keras/models.py | 7 + keras/regularizers.py | 146 ++++++---------- tests/keras/layers/test_recurrent.py | 6 + tests/keras/layers/test_wrappers.py | 9 + tests/keras/test_regularizers.py | 2 + 15 files changed, 524 insertions(+), 539 deletions(-) diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 71ca0c45943..67b2a85cd27 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -57,7 +57,7 @@ def to_dense(tensor): def variable(value, dtype=_FLOATX, name=None): - '''Instantiate a tensor variable. + '''Instantiates a variable. ''' if hasattr(value, 'tocoo'): _assert_sparse_module() diff --git a/keras/engine/topology.py b/keras/engine/topology.py index ee7e059a4a4..d728e4a1f9d 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -13,6 +13,7 @@ from six.moves import zip from .. import backend as K +from .. import initializations from ..utils.io_utils import ask_to_proceed_with_overwrite from ..utils.generic_utils import func_dump, func_load @@ -28,6 +29,11 @@ def to_list(x): return [x] +def object_list_uid(object_list): + object_list = to_list(object_list) + return ', '.join([str(abs(id(x))) for x in object_list]) + + class InputSpec(object): '''This specifies the ndim, dtype and shape of every input to a layer. Every layer should expose (if appropriate) an `input_spec` attribute: @@ -239,7 +245,6 @@ class Layer(object): non_trainable_weights: List of variables. weights: The concatenation of the lists trainable_weights and non_trainable_weights (in this order). - regularizers: List of regularizers. constraints: Dict mapping weights to constraints. # Methods @@ -294,8 +299,8 @@ def __init__(self, **kwargs): self.trainable_weights = [] if not hasattr(self, 'non_trainable_weights'): self.non_trainable_weights = [] - if not hasattr(self, 'regularizers'): - self.regularizers = [] + if not hasattr(self, 'losses'): + self.losses = [] if not hasattr(self, 'constraints'): self.constraints = {} # dict {tensor: constraint instance} self.built = False @@ -354,6 +359,19 @@ def non_trainable_weights(self): def non_trainable_weights(self, weights): self._non_trainable_weights = weights + @property + def regularizers(self): + warnings.warn('The `regularizers` property of layers/models is deprecated. ' + 'Regularization losses are now managed via the `losses` ' + 'layer/model property.') + return [] + + @regularizers.setter + def regularizers(self, _): + warnings.warn('The `regularizers` property of layers/models is deprecated. ' + 'Regularization losses are now managed via the `losses` ' + 'layer/model property.') + def create_input_layer(self, batch_input_shape, input_dtype=None, name=None): if not name: @@ -373,6 +391,32 @@ def create_input_layer(self, batch_input_shape, # to the input layer we just created. self(x) + def add_weight(self, shape, initializer, name=None, + trainable=True, + regularizer=None, + constraint=None): + '''Adds a weight variable to the layer. + + # Arguments: + shape: The shape tuple of the weight. + initializer: An Initializer instance (callable). + trainable: A boolean, whether the weight should + be trained via backprop or not (assuming + that the layer itself is also trainable). + regularizer: An optional Regularizer instance. + ''' + initializer = initializations.get(initializer) + weight = initializer(shape, name=name) + if regularizer is not None: + self.add_loss(regularizer(weight)) + if constraint is not None: + self.constraints[weight] = constraint + if trainable: + self.trainable_weights.append(weight) + else: + self.non_trainable_weights.append(weight) + return weight + def assert_input_compatibility(self, input): '''This checks that the tensor(s) `input` verify the input assumptions of the layer @@ -519,15 +563,21 @@ def __call__(self, x, mask=None): self.add_inbound_node(inbound_layers, node_indices, tensor_indices) # Outputs were already computed when calling self.add_inbound_node. outputs = self.inbound_nodes[-1].output_tensors - # If single output tensor: return it, - # else return a list (at least 2 elements). - if len(outputs) == 1: - return outputs[0] - else: - return outputs else: # This case appears if the input was not a Keras tensor. - return self.call(x, mask) + outputs = to_list(self.call(x, mask)) + + # Apply activity regularizer if any: + if hasattr(self, 'activity_regularizer') and self.activity_regularizer is not None: + regularization_losses = [self.activity_regularizer(x) for x in outputs] + self.add_loss(regularization_losses, input_tensors) + + # If single output tensor: return it, + # else return a list (at least 2 elements). + if len(outputs) == 1: + return outputs[0] + else: + return outputs def add_inbound_node(self, inbound_layers, node_indices=None, tensor_indices=None): @@ -806,20 +856,58 @@ def output_shape(self): 'ill-defined for the layer. ' + 'Use `get_output_shape_at(node_index)` instead.') - def add_updates(self, updates, inputs): + def add_loss(self, losses, inputs=None): + if losses is None: + return + # Update self.losses + losses = to_list(losses) + if not hasattr(self, 'losses'): + self.losses = [] + try: + self.losses += losses + except AttributeError: + # In case self.losses isn't settable + # (i.e. it's a getter method). + # In that case the `losses` property is + # auto-computed and shouldn't be set. + pass + # Update self._per_input_updates + if not hasattr(self, '_per_input_losses'): + self._per_input_losses = {} + if inputs is not None: + inputs_hash = object_list_uid(inputs) + else: + # Updates indexed by None are unconditional + # rather than input-dependent + inputs_hash = None + if inputs_hash not in self._per_input_losses: + self._per_input_losses[inputs_hash] = [] + self._per_input_losses[inputs_hash] += losses + + def add_update(self, updates, inputs=None): + if updates is None: + return # Update self.updates + updates = to_list(updates) if not hasattr(self, 'updates'): self.updates = [] try: self.updates += updates except AttributeError: + # In case self.updates isn't settable + # (i.e. it's a getter method). + # In that case the `updates` property is + # auto-computed and shouldn't be set. pass # Update self._per_input_updates if not hasattr(self, '_per_input_updates'): self._per_input_updates = {} - inputs = to_list(inputs) - updates = to_list(updates) - inputs_hash = ', '.join([str(abs(id(x))) for x in inputs]) + if inputs is not None: + inputs_hash = object_list_uid(inputs) + else: + # Updates indexed by None are unconditional + # rather than input-dependent + inputs_hash = None if inputs_hash not in self._per_input_updates: self._per_input_updates[inputs_hash] = [] self._per_input_updates[inputs_hash] += updates @@ -827,12 +915,19 @@ def add_updates(self, updates, inputs): def get_updates_for(self, inputs): if not hasattr(self, '_per_input_updates'): return [] - inputs = to_list(inputs) - inputs_hash = ', '.join([str(abs(id(x))) for x in inputs]) + inputs_hash = object_list_uid(inputs) if inputs_hash in self._per_input_updates: return self._per_input_updates[inputs_hash] return [] + def get_losses_for(self, inputs): + if not hasattr(self, '_per_input_losses'): + return [] + inputs_hash = object_list_uid(inputs) + if inputs_hash in self._per_input_losses: + return self._per_input_losses[inputs_hash] + return [] + @property def weights(self): return self.trainable_weights + self.non_trainable_weights @@ -950,7 +1045,6 @@ def __init__(self, input_shape=None, batch_input_shape=None, self.trainable_weights = [] self.non_trainable_weights = [] - self.regularizers = [] self.constraints = {} self.sparse = sparse @@ -1151,7 +1245,6 @@ def __init__(self, layers=None, mode='sum', concat_axis=-1, self.inbound_nodes = [] self.outbound_nodes = [] self.constraints = {} - self.regularizers = [] self.trainable_weights = [] self.non_trainable_weights = [] self.supports_masking = True @@ -1587,7 +1680,6 @@ class Container(Layer): supports_masking (boolean) trainable_weights (list of variables) non_trainable_weights (list of variables) - regularizers (list of regularizers) constraints (list of tuples (weight, constraint)) # Methods @@ -1901,7 +1993,6 @@ def build_map_of_graph(tensor, seen_nodes=set(), depth=0, self.supports_masking = False # The following are implemented as property functions: # self.constraints - # self.regularizers # self.trainable_weights # self.non_trainable_weights # self.input_spec @@ -1946,14 +2037,38 @@ def updates(self): if len(layer.inbound_nodes) == 1: updates += layer.updates else: + # Collect updates that are dependent on inputs + # that are part of the model. for node_index, node in enumerate(layer.inbound_nodes): node_key = layer.name + '_ib-' + str(node_index) if node_key in self.container_nodes: # The model owns this layer node. inputs = node.input_tensors updates += layer.get_updates_for(inputs) + # Collect unconditional updates. + updates += layer.get_updates_for(None) return updates + @property + def losses(self): + losses = [] + for layer in self.layers: + if hasattr(layer, 'losses'): + if len(layer.inbound_nodes) == 1: + losses += layer.losses + else: + # Collect losses that are dependent on inputs + # that are part of the model. + for node_index, node in enumerate(layer.inbound_nodes): + node_key = layer.name + '_ib-' + str(node_index) + if node_key in self.container_nodes: + # The model owns this layer node. + inputs = node.input_tensors + losses += layer.get_losses_for(inputs) + # Collect unconditional losses. + losses += layer.get_losses_for(None) + return losses + @property def stateful(self): return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers]) @@ -1990,10 +2105,13 @@ def constraints(self): @property def regularizers(self): - regs = [] - for layer in self.layers: - regs += layer.regularizers - return regs + warnings.warn('The `regularizers` attribute of layers/models ' + 'is deprecated. ' + 'Regularization losses are now managed via the `losses` ' + 'layer/model property.\n' + 'The `regularizers` attribute will be removed ' + 'after 06/2017.') + return [] @property def trainable_weights(self): @@ -2061,8 +2179,7 @@ def uses_learning_phase(self): '''True if any layer in the graph uses it. ''' layers_learning_phase = any([layer.uses_learning_phase for layer in self.layers]) - regs_learning_phase = any([reg.uses_learning_phase for reg in self.regularizers]) - return layers_learning_phase or regs_learning_phase + return layers_learning_phase def call(self, input, mask=None): '''`call` just reapplies all ops in the graph to the new inputs @@ -2239,9 +2356,16 @@ def run_internal_graph(self, inputs, masks=None): output_tensors = to_list(layer.call(computed_tensors, computed_masks)) output_masks = to_list(layer.compute_mask(computed_tensors, computed_masks)) - # update model updates + # Update model updates and losses: layer_inputs = [x[0] for x in computed_data] - self.add_updates(layer.get_updates_for(layer_inputs), inputs) + # Keep track of updates that depend on the inputs (e.g. BN updates). + self.add_update(layer.get_updates_for(layer_inputs), inputs) + # Keep track of unconditional updates (e.g. a counter). + self.add_update(layer.get_updates_for(None), None) + # Keep track of losses that depend on the inputs (e.g. activity regularizers). + self.add_loss(layer.get_losses_for(layer_inputs), inputs) + # Keep track of unconditional losses (e.g. weight regularizers). + self.add_loss(layer.get_losses_for(None), None) # Update _keras_shape. if all([hasattr(x, '_keras_shape') for x in computed_tensors]): diff --git a/keras/engine/training.py b/keras/engine/training.py index 862bfed76da..a2821f9fd5d 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -611,9 +611,10 @@ def compile(self, optimizer, loss, metrics=[], loss_weights=None, else: total_loss += loss_weight * output_loss - # add regularization penalties to the loss - for r in self.regularizers: - total_loss = r(total_loss) + # add regularization penalties + # and other layer-specific losses + for loss_tensor in self.losses: + total_loss += loss_tensor # list of same size as output_names. # contains tuples (metrics for output, names of metrics) diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index 487cdc83120..e4f347bc8cc 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -113,31 +113,20 @@ def __init__(self, nb_filter, filter_length, def build(self, input_shape): input_dim = input_shape[2] self.W_shape = (self.filter_length, 1, input_dim, self.nb_filter) - self.W = self.init(self.W_shape, name='{}_W'.format(self.name)) + + self.W = self.add_weight(self.W_shape, + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) if self.bias: - self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name)) - self.trainable_weights = [self.W, self.b] + self.b = self.add_weight((self.nb_filter,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) else: - self.trainable_weights = [self.W] - self.regularizers = [] - - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - - if self.bias and self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - if self.bias and self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) @@ -406,32 +395,20 @@ def build(self, input_shape): stack_size = input_shape[3] self.W_shape = (self.nb_row, self.nb_col, stack_size, self.nb_filter) else: - raise ValueError('Invalid dim_ordering:', self.dim_ordering) - self.W = self.init(self.W_shape, name='{}_W'.format(self.name)) + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + self.W = self.add_weight(self.W_shape, + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) if self.bias: - self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name)) - self.trainable_weights = [self.W, self.b] + self.b = self.add_weight((self.nb_filter,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) else: - self.trainable_weights = [self.W] - self.regularizers = [] - - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - - if self.bias and self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - if self.bias and self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) @@ -957,40 +934,26 @@ def build(self, input_shape): depthwise_shape = (self.nb_row, self.nb_col, stack_size, self.depth_multiplier) pointwise_shape = (1, 1, self.depth_multiplier * stack_size, self.nb_filter) else: - raise ValueError('Invalid dim_ordering:', self.dim_ordering) - self.depthwise_kernel = self.init(depthwise_shape, - name='{}_depthwise_kernel'.format(self.name)) - self.pointwise_kernel = self.init(pointwise_shape, - name='{}_pointwise_kernel'.format(self.name)) + raise Exception('Invalid dim_ordering: ' + self.dim_ordering) + + self.depthwise_kernel = self.add_weight(depthwise_shape, + initializer=self.init, + regularizer=self.depthwise_regularizer, + constraint=self.depthwise_constraint, + name='{}_depthwise_kernel'.format(self.name)) + self.pointwise_kernel = self.add_weight(pointwise_shape, + initializer=self.init, + regularizer=self.pointwise_regularizer, + constraint=self.pointwise_constraint, + name='{}_pointwise_kernel'.format(self.name)) if self.bias: - self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name)) - self.trainable_weights = [self.depthwise_kernel, - self.pointwise_kernel, - self.b] + self.b = self.add_weight((self.nb_filter,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) else: - self.trainable_weights = [self.depthwise_kernel, - self.pointwise_kernel] - self.regularizers = [] - if self.depthwise_regularizer: - self.depthwise_regularizer.set_param(self.depthwise_kernel) - self.regularizers.append(self.depthwise_regularizer) - if self.pointwise_regularizer: - self.pointwise_regularizer.set_param(self.pointwise_kernel) - self.regularizers.append(self.pointwise_regularizer) - if self.bias and self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.depthwise_constraint: - self.constraints[self.depthwise_kernel] = self.depthwise_constraint - if self.pointwise_constraint: - self.constraints[self.pointwise_kernel] = self.pointwise_constraint - if self.bias and self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) @@ -1165,31 +1128,19 @@ def build(self, input_shape): else: raise ValueError('Invalid dim_ordering:', self.dim_ordering) - self.W = self.init(self.W_shape, name='{}_W'.format(self.name)) + self.W = self.add_weight(self.W_shape, + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) if self.bias: - self.b = K.zeros((self.nb_filter,), name='{}_b'.format(self.name)) - self.trainable_weights = [self.W, self.b] + self.b = self.add_weight((self.nb_filter,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) else: - self.trainable_weights = [self.W] - - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - - if self.bias and self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - if self.bias and self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) diff --git a/keras/layers/core.py b/keras/layers/core.py index 5b287f9bb19..345c7bdd010 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -125,8 +125,8 @@ def _get_noise_shape(self, x): input_shape = K.shape(x) noise_shape = (input_shape[0], 1, input_shape[2]) return noise_shape - - + + class SpatialDropout2D(Dropout): '''This version performs the same function as Dropout, however it drops entire 2D feature maps instead of individual elements. If adjacent pixels @@ -728,33 +728,19 @@ def build(self, input_shape): self.input_spec = [InputSpec(dtype=K.floatx(), shape=(None, input_dim))] - self.W = self.init((input_dim, self.output_dim), - name='{}_W'.format(self.name)) + self.W = self.add_weight((input_dim, self.output_dim), + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) if self.bias: - self.b = K.zeros((self.output_dim,), - name='{}_b'.format(self.name)) - self.trainable_weights = [self.W, self.b] + self.b = self.add_weight((self.output_dim,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) else: - self.trainable_weights = [self.W] - - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - - if self.bias and self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - if self.bias and self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) @@ -808,9 +794,8 @@ def __init__(self, l1=0., l2=0., **kwargs): self.l2 = l2 super(ActivityRegularization, self).__init__(**kwargs) - activity_regularizer = ActivityRegularizer(l1=l1, l2=l2) - activity_regularizer.set_layer(self) - self.regularizers = [activity_regularizer] + self.activity_regularizer = regularizers.L1L2Regularizer(l1=l1, l2=l2) + self.regularizers = [self.activity_regularizer] def get_config(self): config = {'l1': self.l1, @@ -897,33 +882,19 @@ def build(self, input_shape): self.input_spec = [InputSpec(dtype=K.floatx(), shape=(None, input_dim))] - self.W = self.init((self.nb_feature, input_dim, self.output_dim), - name='{}_W'.format(self.name)) + self.W = self.add_weight((self.nb_feature, input_dim, self.output_dim), + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) if self.bias: - self.b = K.zeros((self.nb_feature, self.output_dim), - name='{}_b'.format(self.name)) - self.trainable_weights = [self.W, self.b] + self.b = self.add_weight((self.nb_feature, self.output_dim,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) else: - self.trainable_weights = [self.W] - - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - - if self.bias and self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - if self.bias and self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) @@ -1030,38 +1001,25 @@ def build(self, input_shape): self.input_spec = [InputSpec(dtype=K.floatx(), shape=(None, input_dim))] - self.W = self.init((input_dim, input_dim), - name='{}_W'.format(self.name)) - self.W_carry = self.init((input_dim, input_dim), - name='{}_W_carry'.format(self.name)) - + self.W = self.add_weight((input_dim, input_dim), + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) + self.W_carry = self.add_weight((input_dim, input_dim), + initializer=self.init, + name='{}_W_carry'.format(self.name)) if self.bias: - self.b = K.zeros((input_dim,), name='{}_b'.format(self.name)) - # initialize with a vector of values `transform_bias` - self.b_carry = K.variable(np.ones((input_dim,)) * self.transform_bias, - name='{}_b_carry'.format(self.name)) - self.trainable_weights = [self.W, self.b, self.W_carry, self.b_carry] + self.b = self.add_weight((input_dim,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) + self.b_carry = self.add_weight((input_dim,), + initializer='one', + name='{}_b_carry'.format(self.name)) else: - self.trainable_weights = [self.W, self.W_carry] - - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - - if self.bias and self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - if self.bias and self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b_carry = None if self.initial_weights is not None: self.set_weights(self.initial_weights) @@ -1178,31 +1136,19 @@ def build(self, input_shape): shape=(None,) + input_shape[1:])] input_dim = input_shape[2] - self.W = self.init((input_dim, self.output_dim), - name='{}_W'.format(self.name)) + self.W = self.add_weight((input_dim, self.output_dim), + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) if self.bias: - self.b = K.zeros((self.output_dim,), - name='{}_b'.format(self.name)) - self.trainable_weights = [self.W, self.b] - self.regularizers = [] - - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - - if self.bias and self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - if self.bias and self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b = self.add_weight((self.output_dim,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) + else: + self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) diff --git a/keras/layers/embeddings.py b/keras/layers/embeddings.py index 3679b8b4716..1d22c051879 100644 --- a/keras/layers/embeddings.py +++ b/keras/layers/embeddings.py @@ -91,22 +91,11 @@ def __init__(self, input_dim, output_dim, super(Embedding, self).__init__(**kwargs) def build(self, input_shape): - self.W = self.init((self.input_dim, self.output_dim), - name='{}_W'.format(self.name)) - self.trainable_weights = [self.W] - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) + self.W = self.add_weight((self.input_dim, self.output_dim), + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) if self.initial_weights is not None: self.set_weights(self.initial_weights) diff --git a/keras/layers/local.py b/keras/layers/local.py index 3cc90f12651..f6f51d9ea09 100644 --- a/keras/layers/local.py +++ b/keras/layers/local.py @@ -110,31 +110,21 @@ def __init__(self, nb_filter, filter_length, def build(self, input_shape): input_dim = input_shape[2] _, output_length, nb_filter = self.get_output_shape_for(input_shape) - self.W_shape = (output_length, self.filter_length * input_dim, nb_filter) - self.W = self.init(self.W_shape, name='{}_W'.format(self.name)) + + self.W = self.add_weight(self.W_shape, + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) if self.bias: - self.b = K.zeros((output_length, self.nb_filter), name='{}_b'.format(self.name)) - self.trainable_weights = [self.W, self.b] + self.b = self.add_weight((output_length, self.nb_filter), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) else: - self.trainable_weights = [self.W] - - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - if self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - if self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) @@ -306,30 +296,20 @@ def build(self, input_shape): self.output_row = output_row self.output_col = output_col self.W_shape = (output_row * output_col, self.nb_row * self.nb_col * input_filter, nb_filter) - self.W = self.init(self.W_shape, name='{}_W'.format(self.name)) + self.W = self.add_weight(self.W_shape, + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer, + constraint=self.W_constraint) if self.bias: - self.b = K.zeros((output_row, output_col, nb_filter), name='{}_b'.format(self.name)) - self.trainable_weights = [self.W, self.b] + self.b = self.add_weight((output_row, output_col, nb_filter), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer, + constraint=self.b_constraint) else: - self.trainable_weights = [self.W] - - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - if self.bias and self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - if self.activity_regularizer: - self.activity_regularizer.set_layer(self) - self.regularizers.append(self.activity_regularizer) - - self.constraints = {} - if self.W_constraint: - self.constraints[self.W] = self.W_constraint - if self.bias and self.b_constraint: - self.constraints[self.b] = self.b_constraint + self.b = None if self.initial_weights is not None: self.set_weights(self.initial_weights) diff --git a/keras/layers/normalization.py b/keras/layers/normalization.py index c72e6d56cc4..69b208eeccf 100644 --- a/keras/layers/normalization.py +++ b/keras/layers/normalization.py @@ -82,24 +82,20 @@ def build(self, input_shape): self.input_spec = [InputSpec(shape=input_shape)] shape = (input_shape[self.axis],) - self.gamma = self.gamma_init(shape, name='{}_gamma'.format(self.name)) - self.beta = self.beta_init(shape, name='{}_beta'.format(self.name)) - self.trainable_weights = [self.gamma, self.beta] - - self.regularizers = [] - if self.gamma_regularizer: - self.gamma_regularizer.set_param(self.gamma) - self.regularizers.append(self.gamma_regularizer) - - if self.beta_regularizer: - self.beta_regularizer.set_param(self.beta) - self.regularizers.append(self.beta_regularizer) - - self.running_mean = K.zeros(shape, - name='{}_running_mean'.format(self.name)) - self.running_std = K.ones(shape, - name='{}_running_std'.format(self.name)) - self.non_trainable_weights = [self.running_mean, self.running_std] + self.gamma = self.add_weight(shape, + initializer=self.gamma_init, + regularizer=self.gamma_regularizer, + name='{}_gamma'.format(self.name)) + self.beta = self.add_weight(shape, + initializer=self.beta_init, + regularizer=self.beta_regularizer, + name='{}_beta'.format(self.name)) + self.running_mean = self.add_weight(shape, initializer='zero', + name='{}_running_mean'.format(self.name), + trainable=False) + self.running_std = self.add_weight(shape, initializer='one', + name='{}_running_std'.format(self.name), + trainable=False) if self.initial_weights is not None: self.set_weights(self.initial_weights) @@ -121,8 +117,8 @@ def call(self, x, mask=None): epsilon=self.epsilon) if self.mode == 0: - self.add_updates([K.moving_average_update(self.running_mean, mean, self.momentum), - K.moving_average_update(self.running_std, std, self.momentum)], x) + self.add_update([K.moving_average_update(self.running_mean, mean, self.momentum), + K.moving_average_update(self.running_std, std, self.momentum)], x) if sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_running = K.batch_normalization( diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index 2f4bf12723d..beb1c733b42 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -229,7 +229,7 @@ def call(self, x, mask=None): updates = [] for i in range(len(states)): updates.append((self.states[i], states[i])) - self.add_updates(updates, x) + self.add_update(updates, x) if self.return_sequences: return outputs @@ -288,7 +288,8 @@ def __init__(self, output_dim, self.W_regularizer = regularizers.get(W_regularizer) self.U_regularizer = regularizers.get(U_regularizer) self.b_regularizer = regularizers.get(b_regularizer) - self.dropout_W, self.dropout_U = dropout_W, dropout_U + self.dropout_W = dropout_W + self.dropout_U = dropout_U if self.dropout_W or self.dropout_U: self.uses_learning_phase = True @@ -304,24 +305,18 @@ def build(self, input_shape): input_dim = input_shape[2] self.input_dim = input_dim - self.W = self.init((input_dim, self.output_dim), - name='{}_W'.format(self.name)) - self.U = self.inner_init((self.output_dim, self.output_dim), - name='{}_U'.format(self.name)) - self.b = K.zeros((self.output_dim,), name='{}_b'.format(self.name)) - - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - if self.U_regularizer: - self.U_regularizer.set_param(self.U) - self.regularizers.append(self.U_regularizer) - if self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - - self.trainable_weights = [self.W, self.U, self.b] + self.W = self.add_weight((input_dim, self.output_dim), + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer) + self.U = self.add_weight((self.output_dim, self.output_dim), + initializer=self.inner_init, + name='{}_U'.format(self.name), + regularizer=self.U_regularizer) + self.b = self.add_weight((self.output_dim,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer) if self.initial_weights is not None: self.set_weights(self.initial_weights) @@ -446,7 +441,8 @@ def __init__(self, output_dim, self.W_regularizer = regularizers.get(W_regularizer) self.U_regularizer = regularizers.get(U_regularizer) self.b_regularizer = regularizers.get(b_regularizer) - self.dropout_W, self.dropout_U = dropout_W, dropout_U + self.dropout_W = dropout_W + self.dropout_U = dropout_U if self.dropout_W or self.dropout_U: self.uses_learning_phase = True @@ -463,57 +459,59 @@ def build(self, input_shape): self.states = [None] if self.consume_less == 'gpu': - - self.W = self.init((self.input_dim, 3 * self.output_dim), - name='{}_W'.format(self.name)) - self.U = self.inner_init((self.output_dim, 3 * self.output_dim), - name='{}_U'.format(self.name)) - - self.b = K.variable(np.hstack((np.zeros(self.output_dim), - np.zeros(self.output_dim), - np.zeros(self.output_dim))), - name='{}_b'.format(self.name)) - - self.trainable_weights = [self.W, self.U, self.b] + self.W = self.add_weight((self.input_dim, 3 * self.output_dim), + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer) + self.U = self.add_weight((self.output_dim, 3 * self.output_dim), + initializer=self.inner_init, + name='{}_U'.format(self.name), + regularizer=self.U_regularizer) + self.b = self.add_weight((self.output_dim * 3,), + initializer='zero', + name='{}_b'.format(self.name), + regularizer=self.b_regularizer) else: - - self.W_z = self.init((self.input_dim, self.output_dim), - name='{}_W_z'.format(self.name)) - self.U_z = self.inner_init((self.output_dim, self.output_dim), - name='{}_U_z'.format(self.name)) - self.b_z = K.zeros((self.output_dim,), name='{}_b_z'.format(self.name)) - - self.W_r = self.init((self.input_dim, self.output_dim), - name='{}_W_r'.format(self.name)) - self.U_r = self.inner_init((self.output_dim, self.output_dim), - name='{}_U_r'.format(self.name)) - self.b_r = K.zeros((self.output_dim,), name='{}_b_r'.format(self.name)) - - self.W_h = self.init((self.input_dim, self.output_dim), - name='{}_W_h'.format(self.name)) - self.U_h = self.inner_init((self.output_dim, self.output_dim), - name='{}_U_h'.format(self.name)) - self.b_h = K.zeros((self.output_dim,), name='{}_b_h'.format(self.name)) - - self.trainable_weights = [self.W_z, self.U_z, self.b_z, - self.W_r, self.U_r, self.b_r, - self.W_h, self.U_h, self.b_h] - + self.W_z = self.add_weight((self.input_dim, self.output_dim), + initializer=self.init, + name='{}_W_z'.format(self.name), + regularizer=self.W_regularizer) + self.U_z = self.add_weight((self.output_dim, self.output_dim), + initializer=self.init, + name='{}_U_z'.format(self.name), + regularizer=self.W_regularizer) + self.b_z = self.add_weight((self.output_dim,), + initializer='zero', + name='{}_b_z'.format(self.name), + regularizer=self.b_regularizer) + self.W_r = self.add_weight((self.input_dim, self.output_dim), + initializer=self.init, + name='{}_W_r'.format(self.name), + regularizer=self.W_regularizer) + self.U_r = self.add_weight((self.output_dim, self.output_dim), + initializer=self.init, + name='{}_U_r'.format(self.name), + regularizer=self.W_regularizer) + self.b_r = self.add_weight((self.output_dim,), + initializer='zero', + name='{}_b_r'.format(self.name), + regularizer=self.b_regularizer) + self.W_h = self.add_weight((self.input_dim, self.output_dim), + initializer=self.init, + name='{}_W_h'.format(self.name), + regularizer=self.W_regularizer) + self.U_h = self.add_weight((self.output_dim, self.output_dim), + initializer=self.init, + name='{}_U_h'.format(self.name), + regularizer=self.W_regularizer) + self.b_h = self.add_weight((self.output_dim,), + initializer='zero', + name='{}_b_h'.format(self.name), + regularizer=self.b_regularizer) self.W = K.concatenate([self.W_z, self.W_r, self.W_h]) self.U = K.concatenate([self.U_z, self.U_r, self.U_h]) self.b = K.concatenate([self.b_z, self.b_r, self.b_h]) - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - if self.U_regularizer: - self.U_regularizer.set_param(self.U) - self.regularizers.append(self.U_regularizer) - if self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights @@ -671,7 +669,8 @@ def __init__(self, output_dim, self.W_regularizer = regularizers.get(W_regularizer) self.U_regularizer = regularizers.get(U_regularizer) self.b_regularizer = regularizers.get(b_regularizer) - self.dropout_W, self.dropout_U = dropout_W, dropout_U + self.dropout_W = dropout_W + self.dropout_U = dropout_U if self.dropout_W or self.dropout_U: self.uses_learning_phase = True @@ -688,63 +687,83 @@ def build(self, input_shape): self.states = [None, None] if self.consume_less == 'gpu': - self.W = self.init((self.input_dim, 4 * self.output_dim), - name='{}_W'.format(self.name)) - self.U = self.inner_init((self.output_dim, 4 * self.output_dim), - name='{}_U'.format(self.name)) - - self.b = K.variable(np.hstack((np.zeros(self.output_dim), - K.get_value(self.forget_bias_init((self.output_dim,))), - np.zeros(self.output_dim), - np.zeros(self.output_dim))), - name='{}_b'.format(self.name)) - self.trainable_weights = [self.W, self.U, self.b] + self.W = self.add_weight((self.input_dim, 4 * self.output_dim), + initializer=self.init, + name='{}_W'.format(self.name), + regularizer=self.W_regularizer) + self.U = self.add_weight((self.output_dim, 4 * self.output_dim), + initializer=self.inner_init, + name='{}_U'.format(self.name), + regularizer=self.U_regularizer) + + def b_reg(shape, name=None): + return K.variable(np.hstack((np.zeros(self.output_dim), + K.get_value(self.forget_bias_init((self.output_dim,))), + np.zeros(self.output_dim), + np.zeros(self.output_dim))), + name='{}_b'.format(self.name)) + self.b = self.add_weight((self.output_dim * 4,), + initializer=b_reg, + name='{}_b'.format(self.name), + regularizer=self.b_regularizer) else: - self.W_i = self.init((self.input_dim, self.output_dim), - name='{}_W_i'.format(self.name)) - self.U_i = self.inner_init((self.output_dim, self.output_dim), - name='{}_U_i'.format(self.name)) - self.b_i = K.zeros((self.output_dim,), name='{}_b_i'.format(self.name)) - - self.W_f = self.init((self.input_dim, self.output_dim), - name='{}_W_f'.format(self.name)) - self.U_f = self.inner_init((self.output_dim, self.output_dim), - name='{}_U_f'.format(self.name)) - self.b_f = self.forget_bias_init((self.output_dim,), - name='{}_b_f'.format(self.name)) - - self.W_c = self.init((self.input_dim, self.output_dim), - name='{}_W_c'.format(self.name)) - self.U_c = self.inner_init((self.output_dim, self.output_dim), - name='{}_U_c'.format(self.name)) - self.b_c = K.zeros((self.output_dim,), name='{}_b_c'.format(self.name)) - - self.W_o = self.init((self.input_dim, self.output_dim), - name='{}_W_o'.format(self.name)) - self.U_o = self.inner_init((self.output_dim, self.output_dim), - name='{}_U_o'.format(self.name)) - self.b_o = K.zeros((self.output_dim,), name='{}_b_o'.format(self.name)) + self.W_i = self.add_weight((self.input_dim, self.output_dim), + initializer=self.init, + name='{}_W_i'.format(self.name), + regularizer=self.W_regularizer) + self.U_i = self.add_weight((self.output_dim, self.output_dim), + initializer=self.init, + name='{}_U_i'.format(self.name), + regularizer=self.W_regularizer) + self.b_i = self.add_weight((self.output_dim,), + initializer='zero', + name='{}_b_i'.format(self.name), + regularizer=self.b_regularizer) + self.W_f = self.add_weight((self.input_dim, self.output_dim), + initializer=self.init, + name='{}_W_f'.format(self.name), + regularizer=self.W_regularizer) + self.U_f = self.add_weight((self.output_dim, self.output_dim), + initializer=self.init, + name='{}_U_f'.format(self.name), + regularizer=self.W_regularizer) + self.b_f = self.add_weight((self.output_dim,), + initializer=self.forget_bias_init, + name='{}_b_f'.format(self.name), + regularizer=self.b_regularizer) + self.W_c = self.add_weight((self.input_dim, self.output_dim), + initializer=self.init, + name='{}_W_c'.format(self.name), + regularizer=self.W_regularizer) + self.U_c = self.add_weight((self.output_dim, self.output_dim), + initializer=self.init, + name='{}_U_c'.format(self.name), + regularizer=self.W_regularizer) + self.b_c = self.add_weight((self.output_dim,), + initializer='zero', + name='{}_b_c'.format(self.name), + regularizer=self.b_regularizer) + self.W_o = self.add_weight((self.input_dim, self.output_dim), + initializer=self.init, + name='{}_W_o'.format(self.name), + regularizer=self.W_regularizer) + self.U_o = self.add_weight((self.output_dim, self.output_dim), + initializer=self.init, + name='{}_U_o'.format(self.name), + regularizer=self.W_regularizer) + self.b_o = self.add_weight((self.output_dim,), + initializer='zero', + name='{}_b_o'.format(self.name), + regularizer=self.b_regularizer) self.trainable_weights = [self.W_i, self.U_i, self.b_i, self.W_c, self.U_c, self.b_c, self.W_f, self.U_f, self.b_f, self.W_o, self.U_o, self.b_o] - self.W = K.concatenate([self.W_i, self.W_f, self.W_c, self.W_o]) self.U = K.concatenate([self.U_i, self.U_f, self.U_c, self.U_o]) self.b = K.concatenate([self.b_i, self.b_f, self.b_c, self.b_o]) - self.regularizers = [] - if self.W_regularizer: - self.W_regularizer.set_param(self.W) - self.regularizers.append(self.W_regularizer) - if self.U_regularizer: - self.U_regularizer.set_param(self.U) - self.regularizers.append(self.U_regularizer) - if self.b_regularizer: - self.b_regularizer.set_param(self.b) - self.regularizers.append(self.b_regularizer) - if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights diff --git a/keras/layers/wrappers.py b/keras/layers/wrappers.py index ac48cc05224..32e28d118f0 100644 --- a/keras/layers/wrappers.py +++ b/keras/layers/wrappers.py @@ -17,7 +17,7 @@ def build(self, input_shape=None): self.trainable_weights = getattr(self.layer, 'trainable_weights', []) self.non_trainable_weights = getattr(self.layer, 'non_trainable_weights', []) self.updates = getattr(self.layer, 'updates', []) - self.regularizers = getattr(self.layer, 'regularizers', []) + self.losses = getattr(self.layer, 'losses', []) self.constraints = getattr(self.layer, 'constraints', {}) # properly attribute the current layer to @@ -130,6 +130,11 @@ def step(x, states): # (nb_samples, timesteps, ...) output_shape = self.get_output_shape_for(input_shape) y = K.reshape(y, (-1, input_length) + output_shape[2:]) + + # Apply activity regularizer if any: + if hasattr(self.layer, 'activity_regularizer') and self.layer.activity_regularizer is not None: + regularization_loss = self.layer.activity_regularizer(y) + self.add_loss(regularization_loss, X) return y @@ -246,9 +251,9 @@ def updates(self): return [] @property - def regularizers(self): - if hasattr(self.forward_layer, 'regularizers'): - return self.forward_layer.regularizers + self.backward_layer.regularizers + def losses(self): + if hasattr(self.forward_layer, 'losses'): + return self.forward_layer.losses + self.backward_layer.losses return [] @property diff --git a/keras/models.py b/keras/models.py index c747afc4615..3373cedf2de 100644 --- a/keras/models.py +++ b/keras/models.py @@ -497,6 +497,13 @@ def state_updates(self): def get_updates_for(self, inputs): return self.model.get_updates_for(inputs) + @property + def losses(self): + return self.model.losses + + def get_losses_for(self, inputs): + return self.model.get_losses_for(inputs) + @property def regularizers(self): # support for legacy behavior diff --git a/keras/regularizers.py b/keras/regularizers.py index c6464dfd7b7..4829a11408d 100644 --- a/keras/regularizers.py +++ b/keras/regularizers.py @@ -1,22 +1,27 @@ from __future__ import absolute_import from . import backend as K from .utils.generic_utils import get_from_module +import warnings class Regularizer(object): - def set_param(self, p): - self.p = p - - def set_layer(self, layer): - self.layer = layer - - def __call__(self, loss): - return loss + def __call__(self, x): + return 0 def get_config(self): return {'name': self.__class__.__name__} + def set_param(self, _): + warnings.warn('The `set_param` method on regularizers is deprecated. ' + 'It no longer does anything, ' + 'and it will be removed after 06/2017.') + + def set_layer(self, _): + warnings.warn('The `set_layer` method on regularizers is deprecated. ' + 'It no longer does anything, ' + 'and it will be removed after 06/2017.') + class EigenvalueRegularizer(Regularizer): '''This takes a constant that controls @@ -28,71 +33,43 @@ class EigenvalueRegularizer(Regularizer): ''' def __init__(self, k): self.k = k - self.uses_learning_phase = True - - def set_param(self, p): - if hasattr(self, 'p'): - raise Exception('Regularizers cannot be reused. ' - 'Instantiate one regularizer per layer.') - self.p = p - - def __call__(self, loss): - power = 9 # number of iterations of the power method - W = self.p - if K.ndim(W) > 2: - raise Exception('Eigenvalue Decay regularizer ' - 'is only available for dense ' - 'and embedding layers.') - WW = K.dot(K.transpose(W), W) - dim1, dim2 = K.eval(K.shape(WW)) # number of neurons in the layer - - # power method for approximating the dominant eigenvector: - o = K.ones([dim1, 1]) # initial values for the dominant eigenvector - main_eigenvect = K.dot(WW, o) - for n in range(power - 1): - main_eigenvect = K.dot(WW, main_eigenvect) - WWd = K.dot(WW, main_eigenvect) + def __call__(self, x): + if K.ndim(x) > 2: + raise Exception('EigenvalueRegularizer ' + 'is only available for tensors of rank 2.') + covariance = K.dot(K.transpose(x), x) + dim1, dim2 = K.eval(K.shape(covariance)) + + # Power method for approximating the dominant eigenvector: + power = 9 # Number of iterations of the power method. + o = K.ones([dim1, 1]) # Initial values for the dominant eigenvector. + main_eigenvect = K.dot(covariance, o) + for n in range(power - 1): + main_eigenvect = K.dot(covariance, main_eigenvect) + covariance_d = K.dot(covariance, main_eigenvect) - # the corresponding dominant eigenvalue: - main_eigenval = (K.dot(K.transpose(WWd), main_eigenvect) / + # The corresponding dominant eigenvalue: + main_eigenval = (K.dot(K.transpose(covariance_d), main_eigenvect) / K.dot(K.transpose(main_eigenvect), main_eigenvect)) - # multiplied by the given regularization gain - regularized_loss = loss + (main_eigenval ** 0.5) * self.k + # Multiply by the given regularization gain. + regularization = (main_eigenval ** 0.5) * self.k + return K.sum(regularization) - return K.in_train_phase(regularized_loss[0, 0], loss) - -class WeightRegularizer(Regularizer): +class L1L2Regularizer(Regularizer): def __init__(self, l1=0., l2=0.): self.l1 = K.cast_to_floatx(l1) self.l2 = K.cast_to_floatx(l2) - self.uses_learning_phase = True - self.p = None - - def set_param(self, p): - if self.p is not None: - raise Exception('Regularizers cannot be reused. ' - 'Instantiate one regularizer per layer.') - self.p = p - - def __call__(self, loss): - if self.p is None: - raise Exception('Need to call `set_param` on ' - 'WeightRegularizer instance ' - 'before calling the instance. ' - 'Check that you are not passing ' - 'a WeightRegularizer instead of an ' - 'ActivityRegularizer ' - '(i.e. activity_regularizer="l2" instead ' - 'of activity_regularizer="activity_l2".') - regularized_loss = loss + + def __call__(self, x): + regularization = 0 if self.l1: - regularized_loss += K.sum(self.l1 * K.abs(self.p)) + regularization += K.sum(self.l1 * K.abs(x)) if self.l2: - regularized_loss += K.sum(self.l2 * K.square(self.p)) - return K.in_train_phase(regularized_loss, loss) + regularization += K.sum(self.l2 * K.square(x)) + return regularization def get_config(self): return {'name': self.__class__.__name__, @@ -100,61 +77,34 @@ def get_config(self): 'l2': float(self.l2)} -class ActivityRegularizer(Regularizer): +# Aliases. - def __init__(self, l1=0., l2=0.): - self.l1 = K.cast_to_floatx(l1) - self.l2 = K.cast_to_floatx(l2) - self.uses_learning_phase = True - self.layer = None - - def set_layer(self, layer): - if self.layer is not None: - raise Exception('Regularizers cannot be reused') - self.layer = layer - - def __call__(self, loss): - if self.layer is None: - raise Exception('Need to call `set_layer` on ' - 'ActivityRegularizer instance ' - 'before calling the instance.') - regularized_loss = loss - for i in range(len(self.layer.inbound_nodes)): - output = self.layer.get_output_at(i) - if self.l1: - regularized_loss += K.sum(self.l1 * K.abs(output)) - if self.l2: - regularized_loss += K.sum(self.l2 * K.square(output)) - return K.in_train_phase(regularized_loss, loss) - - def get_config(self): - return {'name': self.__class__.__name__, - 'l1': float(self.l1), - 'l2': float(self.l2)} +WeightRegularizer = L1L2Regularizer +ActivityRegularizer = L1L2Regularizer def l1(l=0.01): - return WeightRegularizer(l1=l) + return L1L2Regularizer(l1=l) def l2(l=0.01): - return WeightRegularizer(l2=l) + return L1L2Regularizer(l2=l) def l1l2(l1=0.01, l2=0.01): - return WeightRegularizer(l1=l1, l2=l2) + return L1L2Regularizer(l1=l1, l2=l2) def activity_l1(l=0.01): - return ActivityRegularizer(l1=l) + return L1L2Regularizer(l1=l) def activity_l2(l=0.01): - return ActivityRegularizer(l2=l) + return L1L2Regularizer(l2=l) def activity_l1l2(l1=0.01, l2=0.01): - return ActivityRegularizer(l1=l1, l2=l2) + return L1L2Regularizer(l1=l1, l2=l2) def get(identifier, kwargs=None): diff --git a/tests/keras/layers/test_recurrent.py b/tests/keras/layers/test_recurrent.py index f25eced4799..74eead78ed8 100644 --- a/tests/keras/layers/test_recurrent.py +++ b/tests/keras/layers/test_recurrent.py @@ -132,6 +132,12 @@ def test_regularizer(layer_class): layer.build(shape) output = layer(K.variable(np.ones(shape))) K.eval(output) + if layer_class == recurrent.SimpleRNN: + assert len(layer.losses) == 3 + if layer_class == recurrent.GRU: + assert len(layer.losses) == 9 + if layer_class == recurrent.LSTM: + assert len(layer.losses) == 12 @keras_test diff --git a/tests/keras/layers/test_wrappers.py b/tests/keras/layers/test_wrappers.py index 27063e6608f..8c523e114bd 100644 --- a/tests/keras/layers/test_wrappers.py +++ b/tests/keras/layers/test_wrappers.py @@ -76,6 +76,15 @@ def test_TimeDistributed(): outer_model.fit(np.random.random((10, 3, 2)), np.random.random((10, 3, 3)), nb_epoch=1, batch_size=10) +@keras_test +def test_regularizers(): + model = Sequential() + model.add(wrappers.TimeDistributed(core.Dense(2, W_regularizer='l1'), input_shape=(3, 4))) + model.add(core.Activation('relu')) + model.compile(optimizer='rmsprop', loss='mse') + assert len(model.losses) == 1 + + @keras_test def test_Bidirectional(): rnn = recurrent.SimpleRNN diff --git a/tests/keras/test_regularizers.py b/tests/keras/test_regularizers.py index b7d9a1abf7f..996236d7a47 100644 --- a/tests/keras/test_regularizers.py +++ b/tests/keras/test_regularizers.py @@ -67,6 +67,7 @@ def test_W_reg(): regularizers.l1l2()]: model = create_model(weight_reg=reg) model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + assert len(model.losses) == 1 model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, verbose=0) model.evaluate(X_test[test_ids, :], Y_test[test_ids, :], verbose=0) @@ -77,6 +78,7 @@ def test_A_reg(): for reg in [regularizers.activity_l1(), regularizers.activity_l2()]: model = create_model(activity_reg=reg) model.compile(loss='categorical_crossentropy', optimizer='rmsprop') + assert len(model.losses) == 1 model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, verbose=0) model.evaluate(X_test[test_ids, :], Y_test[test_ids, :], verbose=0)