Skip to content

Commit

Permalink
LSTM with Batch Normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
udibr committed Apr 5, 2016
1 parent b587aee commit c8d2015
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 19 deletions.
8 changes: 8 additions & 0 deletions keras/layers/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def trainable_weights(self):
weights += l.get_params()[0]
return weights

@property
def non_trainable_weights(self):
weights = []
for l in self.layers:
if l.trainable:
weights += l.non_trainable_weights
return weights

@property
def regularizers(self):
regularizers = []
Expand Down
6 changes: 3 additions & 3 deletions keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def get_weights(self):

def set_weights(self, weights):
for i in range(len(self.layers)):
nb_param = len(self.layers[i].trainable_weights)
nb_param = len(self.layers[i].get_weights())
self.layers[i].set_weights(weights[:nb_param])
weights = weights[nb_param:]

Expand Down Expand Up @@ -1338,7 +1338,7 @@ def get_weights(self):
return weights

def set_weights(self, weights):
nb_param = len(self.encoder.trainable_weights)
nb_param = len(self.encoder.get_weights())
self.encoder.set_weights(weights[:nb_param])
self.decoder.set_weights(weights[nb_param:])

Expand Down Expand Up @@ -1642,7 +1642,7 @@ def get_weights(self):

def set_weights(self, weights):
for i in range(len(self.layers)):
nb_param = len(self.layers[i].trainable_weights) + len(self.non_trainable_weights)
nb_param = len(self.layers[i].get_weights())
self.layers[i].set_weights(weights[:nb_param])
weights = weights[nb_param:]

Expand Down
156 changes: 140 additions & 16 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,12 @@ def get_output(self, train=False):
X = self.get_input(train)
mask = self.get_input_mask(train)

# updates can be added inside self.step which is run by K.rnn
# and at the end of this method if stateful
# however, get_output can be called multiple time so we need to reset
# this before starting to accumulate the update expressions.
self.updates = []

assert K.ndim(X) == 3
if K._BACKEND == 'tensorflow':
if not self.input_shape[1]:
Expand All @@ -193,9 +199,8 @@ def get_output(self, train=False):
mask=mask,
constants=constants)
if self.stateful:
self.updates = []
for i in range(len(states)):
self.updates.append((self.states[i], states[i]))
self.updates += [(self.states[i], states[i]) for i in
range(len(states))]

if self.return_sequences:
return outputs
Expand Down Expand Up @@ -510,7 +515,7 @@ def get_config(self):


class LSTM(Recurrent):
'''Long-Short Term Memory unit - Hochreiter 1997.
"""Long-Short Term Memory unit - Hochreiter 1997.
For a step-by-step description of the algorithm, see
[this tutorial](http://deeplearning.net/tutorial/lstm.html).
Expand All @@ -536,19 +541,27 @@ class LSTM(Recurrent):
applied to the bias.
dropout_W: float between 0 and 1. Fraction of the input units to drop for input gates.
dropout_U: float between 0 and 1. Fraction of the input units to drop for recurrent connections.
batch_norm: bool (default False)
Perform batch normalization as described in
[Recurrent Batch Normalization](http://arxiv.org/abs/1603.09025)
we the simplification of using the same BN parameters on all steps.
gamma_init: float (default 0.1)
initalization value for all gammas used in batch normalization.
# References
- [Long short-term memory](http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf) (original 1997 paper)
- [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
- [Supervised sequence labelling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
- [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
'''
- [Recurrent Batch Normalization](http://arxiv.org/abs/1603.09025)
"""
def __init__(self, output_dim,
init='glorot_uniform', inner_init='orthogonal',
forget_bias_init='one', activation='tanh',
inner_activation='hard_sigmoid',
W_regularizer=None, U_regularizer=None, b_regularizer=None,
dropout_W=0., dropout_U=0., **kwargs):
dropout_W=0., dropout_U=0.,
batch_norm=False, gamma_init=0.1, **kwargs):
self.output_dim = output_dim
self.init = initializations.get(init)
self.inner_init = initializations.get(inner_init)
Expand All @@ -559,6 +572,14 @@ def __init__(self, output_dim,
self.U_regularizer = regularizers.get(U_regularizer)
self.b_regularizer = regularizers.get(b_regularizer)
self.dropout_W, self.dropout_U = dropout_W, dropout_U
self.batch_norm = batch_norm
if batch_norm:
def gamma_init_func(shape, name=None, c=gamma_init):
return K.variable(np.ones(shape) * c, name=name)
self.gamma_init = gamma_init_func
self.beta_init = initializations.get('zero')
self.momentum = 0.9
self.epsilon = 1e-6
super(LSTM, self).__init__(**kwargs)

def build(self):
Expand Down Expand Up @@ -626,6 +647,47 @@ def build(self):
self.set_weights(self.initial_weights)
del self.initial_weights

if self.batch_norm:
self.non_trainable_weights = []
shape = (self.output_dim,)
self.gammas = {}
self.betas = {}
self.running_mean = {}
self.running_std = {}
# BN is applied in 3 inputs/outputs (fields) of the cell
for fld in ['recurrent', 'input', 'output']:
gammas = {}
betas = {}
running_mean = {}
running_std = {}
# each of the fields affects 4 locations inside the cell
# (except output)
# each location has its own BN
for slc in ['i', 'f', 'c', 'o']:
running_mean[slc] = K.zeros(shape,
name='{}_running_mean_{}_{}'.format(
self.name,fld,slc))
running_std[slc] = K.ones(shape,
name='{}_running_std_{}_{}'.format(
self.name,fld,slc))
gammas[slc] = self.gamma_init(shape,
name='{}_gamma_{}_{}'.format(
self.name, fld, slc))
if fld == 'output':
betas[slc] = self.beta_init(shape,
name='{}_beta_{}_{}'.format(
self.name, fld, slc))
break # output has just one slice

self.gammas[fld] = gammas
self.betas[fld] = betas
self.running_mean[fld] = running_mean
self.running_std[fld] = running_std

self.trainable_weights += gammas.values() + betas.values()
self.non_trainable_weights += (running_mean.values() +
running_std.values())

def reset_states(self):
assert self.stateful, 'Layer must be stateful.'
input_shape = self.input_shape
Expand All @@ -650,16 +712,59 @@ def preprocess_input(self, x, train=False):
input_dim = input_shape[2]
timesteps = input_shape[1]

x_i = time_distributed_dense(x, self.W_i, self.b_i, dropout,
# bias is added inside step (after doing BN)
x_i = time_distributed_dense(x, self.W_i, None, dropout,
input_dim, self.output_dim, timesteps)
x_f = time_distributed_dense(x, self.W_f, self.b_f, dropout,
x_f = time_distributed_dense(x, self.W_f, None, dropout,
input_dim, self.output_dim, timesteps)
x_c = time_distributed_dense(x, self.W_c, self.b_c, dropout,
x_c = time_distributed_dense(x, self.W_c, None, dropout,
input_dim, self.output_dim, timesteps)
x_o = time_distributed_dense(x, self.W_o, self.b_o, dropout,
x_o = time_distributed_dense(x, self.W_o, None, dropout,
input_dim, self.output_dim, timesteps)
return K.concatenate([x_i, x_f, x_c, x_o], axis=2)

def bn(self, X, fld, slc='i'):
if not self.batch_norm:
return X
gamma = self.gammas[fld][slc]
# recurrent and input fields dont have beta
beta = self.betas[fld].get(slc)
axis = -1 # axis along which to normalize (we have mode=0)

input_shape = (self.input_shape[0], self.output_dim)
reduction_axes = list(range(len(input_shape)))
del reduction_axes[axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[axis] = input_shape[axis]
if self.train:
m = K.mean(X, axis=reduction_axes)
brodcast_m = K.reshape(m, broadcast_shape)
std = K.mean(K.square(X - brodcast_m) + self.epsilon,
axis=reduction_axes)
std = K.sqrt(std)
brodcast_std = K.reshape(std, broadcast_shape)
mean_update = (self.momentum * self.running_mean[fld][slc] +
(1 - self.momentum) * m)
std_update = (self.momentum * self.running_std[fld][slc] +
(1 - self.momentum) * std)
if not hasattr(self, 'updates'):
self.updates = []
self.updates += [(self.running_mean[fld][slc], mean_update),
(self.running_std[fld][slc], std_update)]

X_normed = (X - brodcast_m) / (brodcast_std + self.epsilon)
else:
brodcast_m = K.reshape(self.running_mean[fld][slc], broadcast_shape)
brodcast_std = K.reshape(self.running_std[fld][slc],
broadcast_shape)
X_normed = ((X - brodcast_m) /
(brodcast_std + self.epsilon))

out = K.reshape(gamma, broadcast_shape) * X_normed
if beta is not None:
out += K.reshape(beta, broadcast_shape)
return out

def step(self, x, states):
h_tm1 = states[0]
c_tm1 = states[1]
Expand All @@ -673,12 +778,24 @@ def step(self, x, states):
x_c = x[:, 2 * self.output_dim: 3 * self.output_dim]
x_o = x[:, 3 * self.output_dim:]

i = self.inner_activation(x_i + K.dot(h_tm1 * B_U[0], self.U_i))
f = self.inner_activation(x_f + K.dot(h_tm1 * B_U[1], self.U_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1 * B_U[2], self.U_c))
o = self.inner_activation(x_o + K.dot(h_tm1 * B_U[3], self.U_o))

h = o * self.activation(c)
i = self.inner_activation(
self.bn(x_i,'input','i') +
self.bn(K.dot(h_tm1 * B_U[0], self.U_i), 'recurrent','i') +
self.b_i)
f = self.inner_activation(
self.bn(x_f,'input','f') +
self.bn(K.dot(h_tm1 * B_U[1], self.U_f),'recurrent','f') +
self.b_f)
c = f * c_tm1 + i * self.activation(
self.bn(x_c,'input','c') +
self.bn(K.dot(h_tm1 * B_U[2], self.U_c),'recurrent','c') +
self.b_c)
o = self.inner_activation(
self.bn(x_o,'input','o') +
self.bn(K.dot(h_tm1 * B_U[3], self.U_o),'recurrent','o') +
self.b_o)

h = o * self.activation(self.bn(c, 'output'))
return h, [h, c]

def get_constants(self, x, train=False):
Expand All @@ -701,5 +818,12 @@ def get_config(self):
"b_regularizer": self.b_regularizer.get_config() if self.b_regularizer else None,
"dropout_W": self.dropout_W,
"dropout_U": self.dropout_U}
if self.batch_norm:
config["momentum"] = self.momentum
base_config = super(LSTM, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def get_output(self, train=False):
if self.batch_norm:
self.train = train
return super(LSTM,self).get_output(train=train)

0 comments on commit c8d2015

Please sign in to comment.