Skip to content

Commit

Permalink
Fix behaviour of unfrozen BatchNormalization layer (#47)
Browse files Browse the repository at this point in the history
Previously, if `BatchNormalization` was been initialized with `BatchNormalization(freeze=False)`, its behaviour was not equivalent to the standard `BatchNormalization` layer, as one would expect. Instead, it was always forced to be in training mode, providing wrong validation results.

This PR does not change the behaviour for `freeze=True`, but makes the layer equivalent to the standard `BatchNormalization` layer from Keras for `freeze=False`.
  • Loading branch information
Callidior authored and 0x00b1 committed Nov 28, 2018
1 parent 55828cf commit 7e2e67b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions keras_resnet/layers/_batch_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ def __init__(self, freeze, *args, **kwargs):
self.trainable = not self.freeze

def call(self, *args, **kwargs):
# return super.call, but set training
return super(BatchNormalization, self).call(training=(not self.freeze), *args, **kwargs)
# Force test mode if frozen, otherwise use default behaviour (i.e., training=None).
if self.freeze:
kwargs['training'] = False
return super(BatchNormalization, self).call(*args, **kwargs)

def get_config(self):
config = super(BatchNormalization, self).get_config()
Expand Down

0 comments on commit 7e2e67b

Please sign in to comment.