Skip to content

Commit

Permalink
Refactor regularizers and add add_weight method. (#4703)
Browse files Browse the repository at this point in the history
* Refactor regularizers, introduce layer.add_weight

* Fix BN add_update syntax

* Fix eigenvalue regularizer

* Style fixes.
  • Loading branch information
fchollet committed Dec 14, 2016
1 parent 2b33675 commit ff62eb2
Show file tree
Hide file tree
Showing 15 changed files with 524 additions and 539 deletions.
2 changes: 1 addition & 1 deletion keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def to_dense(tensor):


def variable(value, dtype=_FLOATX, name=None):
'''Instantiate a tensor variable.
'''Instantiates a variable.
'''
if hasattr(value, 'tocoo'):
_assert_sparse_module()
Expand Down
180 changes: 152 additions & 28 deletions keras/engine/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from six.moves import zip

from .. import backend as K
from .. import initializations
from ..utils.io_utils import ask_to_proceed_with_overwrite
from ..utils.generic_utils import func_dump, func_load

Expand All @@ -28,6 +29,11 @@ def to_list(x):
return [x]


def object_list_uid(object_list):
object_list = to_list(object_list)
return ', '.join([str(abs(id(x))) for x in object_list])


class InputSpec(object):
'''This specifies the ndim, dtype and shape of every input to a layer.
Every layer should expose (if appropriate) an `input_spec` attribute:
Expand Down Expand Up @@ -239,7 +245,6 @@ class Layer(object):
non_trainable_weights: List of variables.
weights: The concatenation of the lists trainable_weights and
non_trainable_weights (in this order).
regularizers: List of regularizers.
constraints: Dict mapping weights to constraints.
# Methods
Expand Down Expand Up @@ -294,8 +299,8 @@ def __init__(self, **kwargs):
self.trainable_weights = []
if not hasattr(self, 'non_trainable_weights'):
self.non_trainable_weights = []
if not hasattr(self, 'regularizers'):
self.regularizers = []
if not hasattr(self, 'losses'):
self.losses = []
if not hasattr(self, 'constraints'):
self.constraints = {} # dict {tensor: constraint instance}
self.built = False
Expand Down Expand Up @@ -354,6 +359,19 @@ def non_trainable_weights(self):
def non_trainable_weights(self, weights):
self._non_trainable_weights = weights

@property
def regularizers(self):
warnings.warn('The `regularizers` property of layers/models is deprecated. '
'Regularization losses are now managed via the `losses` '
'layer/model property.')
return []

@regularizers.setter
def regularizers(self, _):
warnings.warn('The `regularizers` property of layers/models is deprecated. '
'Regularization losses are now managed via the `losses` '
'layer/model property.')

def create_input_layer(self, batch_input_shape,
input_dtype=None, name=None):
if not name:
Expand All @@ -373,6 +391,32 @@ def create_input_layer(self, batch_input_shape,
# to the input layer we just created.
self(x)

def add_weight(self, shape, initializer, name=None,
trainable=True,
regularizer=None,
constraint=None):
'''Adds a weight variable to the layer.
# Arguments:
shape: The shape tuple of the weight.
initializer: An Initializer instance (callable).
trainable: A boolean, whether the weight should
be trained via backprop or not (assuming
that the layer itself is also trainable).
regularizer: An optional Regularizer instance.
'''
initializer = initializations.get(initializer)
weight = initializer(shape, name=name)
if regularizer is not None:
self.add_loss(regularizer(weight))
if constraint is not None:
self.constraints[weight] = constraint
if trainable:
self.trainable_weights.append(weight)
else:
self.non_trainable_weights.append(weight)
return weight

def assert_input_compatibility(self, input):
'''This checks that the tensor(s) `input`
verify the input assumptions of the layer
Expand Down Expand Up @@ -519,15 +563,21 @@ def __call__(self, x, mask=None):
self.add_inbound_node(inbound_layers, node_indices, tensor_indices)
# Outputs were already computed when calling self.add_inbound_node.
outputs = self.inbound_nodes[-1].output_tensors
# If single output tensor: return it,
# else return a list (at least 2 elements).
if len(outputs) == 1:
return outputs[0]
else:
return outputs
else:
# This case appears if the input was not a Keras tensor.
return self.call(x, mask)
outputs = to_list(self.call(x, mask))

# Apply activity regularizer if any:
if hasattr(self, 'activity_regularizer') and self.activity_regularizer is not None:
regularization_losses = [self.activity_regularizer(x) for x in outputs]
self.add_loss(regularization_losses, input_tensors)

# If single output tensor: return it,
# else return a list (at least 2 elements).
if len(outputs) == 1:
return outputs[0]
else:
return outputs

def add_inbound_node(self, inbound_layers,
node_indices=None, tensor_indices=None):
Expand Down Expand Up @@ -806,33 +856,78 @@ def output_shape(self):
'ill-defined for the layer. ' +
'Use `get_output_shape_at(node_index)` instead.')

def add_updates(self, updates, inputs):
def add_loss(self, losses, inputs=None):
if losses is None:
return
# Update self.losses
losses = to_list(losses)
if not hasattr(self, 'losses'):
self.losses = []
try:
self.losses += losses
except AttributeError:
# In case self.losses isn't settable
# (i.e. it's a getter method).
# In that case the `losses` property is
# auto-computed and shouldn't be set.
pass
# Update self._per_input_updates
if not hasattr(self, '_per_input_losses'):
self._per_input_losses = {}
if inputs is not None:
inputs_hash = object_list_uid(inputs)
else:
# Updates indexed by None are unconditional
# rather than input-dependent
inputs_hash = None
if inputs_hash not in self._per_input_losses:
self._per_input_losses[inputs_hash] = []
self._per_input_losses[inputs_hash] += losses

def add_update(self, updates, inputs=None):
if updates is None:
return
# Update self.updates
updates = to_list(updates)
if not hasattr(self, 'updates'):
self.updates = []
try:
self.updates += updates
except AttributeError:
# In case self.updates isn't settable
# (i.e. it's a getter method).
# In that case the `updates` property is
# auto-computed and shouldn't be set.
pass
# Update self._per_input_updates
if not hasattr(self, '_per_input_updates'):
self._per_input_updates = {}
inputs = to_list(inputs)
updates = to_list(updates)
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
if inputs is not None:
inputs_hash = object_list_uid(inputs)
else:
# Updates indexed by None are unconditional
# rather than input-dependent
inputs_hash = None
if inputs_hash not in self._per_input_updates:
self._per_input_updates[inputs_hash] = []
self._per_input_updates[inputs_hash] += updates

def get_updates_for(self, inputs):
if not hasattr(self, '_per_input_updates'):
return []
inputs = to_list(inputs)
inputs_hash = ', '.join([str(abs(id(x))) for x in inputs])
inputs_hash = object_list_uid(inputs)
if inputs_hash in self._per_input_updates:
return self._per_input_updates[inputs_hash]
return []

def get_losses_for(self, inputs):
if not hasattr(self, '_per_input_losses'):
return []
inputs_hash = object_list_uid(inputs)
if inputs_hash in self._per_input_losses:
return self._per_input_losses[inputs_hash]
return []

@property
def weights(self):
return self.trainable_weights + self.non_trainable_weights
Expand Down Expand Up @@ -950,7 +1045,6 @@ def __init__(self, input_shape=None, batch_input_shape=None,

self.trainable_weights = []
self.non_trainable_weights = []
self.regularizers = []
self.constraints = {}

self.sparse = sparse
Expand Down Expand Up @@ -1151,7 +1245,6 @@ def __init__(self, layers=None, mode='sum', concat_axis=-1,
self.inbound_nodes = []
self.outbound_nodes = []
self.constraints = {}
self.regularizers = []
self.trainable_weights = []
self.non_trainable_weights = []
self.supports_masking = True
Expand Down Expand Up @@ -1587,7 +1680,6 @@ class Container(Layer):
supports_masking (boolean)
trainable_weights (list of variables)
non_trainable_weights (list of variables)
regularizers (list of regularizers)
constraints (list of tuples (weight, constraint))
# Methods
Expand Down Expand Up @@ -1901,7 +1993,6 @@ def build_map_of_graph(tensor, seen_nodes=set(), depth=0,
self.supports_masking = False
# The following are implemented as property functions:
# self.constraints
# self.regularizers
# self.trainable_weights
# self.non_trainable_weights
# self.input_spec
Expand Down Expand Up @@ -1946,14 +2037,38 @@ def updates(self):
if len(layer.inbound_nodes) == 1:
updates += layer.updates
else:
# Collect updates that are dependent on inputs
# that are part of the model.
for node_index, node in enumerate(layer.inbound_nodes):
node_key = layer.name + '_ib-' + str(node_index)
if node_key in self.container_nodes:
# The model owns this layer node.
inputs = node.input_tensors
updates += layer.get_updates_for(inputs)
# Collect unconditional updates.
updates += layer.get_updates_for(None)
return updates

@property
def losses(self):
losses = []
for layer in self.layers:
if hasattr(layer, 'losses'):
if len(layer.inbound_nodes) == 1:
losses += layer.losses
else:
# Collect losses that are dependent on inputs
# that are part of the model.
for node_index, node in enumerate(layer.inbound_nodes):
node_key = layer.name + '_ib-' + str(node_index)
if node_key in self.container_nodes:
# The model owns this layer node.
inputs = node.input_tensors
losses += layer.get_losses_for(inputs)
# Collect unconditional losses.
losses += layer.get_losses_for(None)
return losses

@property
def stateful(self):
return any([(hasattr(layer, 'stateful') and layer.stateful) for layer in self.layers])
Expand Down Expand Up @@ -1990,10 +2105,13 @@ def constraints(self):

@property
def regularizers(self):
regs = []
for layer in self.layers:
regs += layer.regularizers
return regs
warnings.warn('The `regularizers` attribute of layers/models '
'is deprecated. '
'Regularization losses are now managed via the `losses` '
'layer/model property.\n'
'The `regularizers` attribute will be removed '
'after 06/2017.')
return []

@property
def trainable_weights(self):
Expand Down Expand Up @@ -2061,8 +2179,7 @@ def uses_learning_phase(self):
'''True if any layer in the graph uses it.
'''
layers_learning_phase = any([layer.uses_learning_phase for layer in self.layers])
regs_learning_phase = any([reg.uses_learning_phase for reg in self.regularizers])
return layers_learning_phase or regs_learning_phase
return layers_learning_phase

def call(self, input, mask=None):
'''`call` just reapplies all ops in the graph to the new inputs
Expand Down Expand Up @@ -2239,9 +2356,16 @@ def run_internal_graph(self, inputs, masks=None):
output_tensors = to_list(layer.call(computed_tensors, computed_masks))
output_masks = to_list(layer.compute_mask(computed_tensors, computed_masks))

# update model updates
# Update model updates and losses:
layer_inputs = [x[0] for x in computed_data]
self.add_updates(layer.get_updates_for(layer_inputs), inputs)
# Keep track of updates that depend on the inputs (e.g. BN updates).
self.add_update(layer.get_updates_for(layer_inputs), inputs)
# Keep track of unconditional updates (e.g. a counter).
self.add_update(layer.get_updates_for(None), None)
# Keep track of losses that depend on the inputs (e.g. activity regularizers).
self.add_loss(layer.get_losses_for(layer_inputs), inputs)
# Keep track of unconditional losses (e.g. weight regularizers).
self.add_loss(layer.get_losses_for(None), None)

# Update _keras_shape.
if all([hasattr(x, '_keras_shape') for x in computed_tensors]):
Expand Down
7 changes: 4 additions & 3 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,10 @@ def compile(self, optimizer, loss, metrics=[], loss_weights=None,
else:
total_loss += loss_weight * output_loss

# add regularization penalties to the loss
for r in self.regularizers:
total_loss = r(total_loss)
# add regularization penalties
# and other layer-specific losses
for loss_tensor in self.losses:
total_loss += loss_tensor

# list of same size as output_names.
# contains tuples (metrics for output, names of metrics)
Expand Down
Loading

0 comments on commit ff62eb2

Please sign in to comment.