Skip to content

Commit

Permalink
pip8
Browse files Browse the repository at this point in the history
  • Loading branch information
udibr committed Apr 4, 2016
1 parent e8ce471 commit 12d69a2
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,9 @@ def get_config(self):
class LSTMBN(LSTM):
'''Long-Short Term Memory unit with Batch Normalization
This implementation uses the same BN parameters for all steps.
The paper recomended using different parameters for each training step.
# Arguments
LSTM arguments + batch normalization parameters:
epsilon: small float > 0. Fuzz parameter.
Expand All @@ -731,6 +734,7 @@ def __init__(self, output_dim,
**kwargs):
if isinstance(gamma_init, float):
c = gamma_init

def gamma_init(shape, name=None, c=c):
return K.variable(np.ones(shape)*c, name=name)
self.gamma_init = initializations.get(gamma_init)
Expand Down Expand Up @@ -810,13 +814,14 @@ def bn(self, X, fld, slc='i'):
axis=reduction_axes)
std = K.sqrt(std)
brodcast_std = K.reshape(std, broadcast_shape)
mean_update = self.momentum * self.running_mean[fld][slc] + (
1 - self.momentum) * m
std_update = self.momentum * self.running_std[fld][slc] + (
1 - self.momentum) * std
mean_update = (self.momentum * self.running_mean[fld][slc] +
(1 - self.momentum) * m)
std_update = (self.momentum * self.running_std[fld][slc] +
(1 - self.momentum) * std)
self.extend_attr('updates',
[(self.running_mean[fld][slc], mean_update),
(self.running_std[fld][slc], std_update)])
[(self.running_mean[fld][slc], mean_update),
(self.running_std[fld][slc], std_update)
])

X_normed = (X - brodcast_m) / (brodcast_std + self.epsilon)
else:
Expand Down

0 comments on commit 12d69a2

Please sign in to comment.