Skip to content

Commit

Permalink
Merge pull request #1055 from dmitriy-serdyuk/lstm-activations
Browse files Browse the repository at this point in the history
Uninfy LSTM arguments
  • Loading branch information
rizar committed Apr 7, 2016
2 parents 063dcdb + 231cf0f commit e0b2be4
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions blocks/bricks/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,19 +364,28 @@ class LSTM(BaseRecurrent, Initializable):
activation : :class:`.Brick`, optional
The activation function. The default and by far the most popular
is :class:`.Tanh`.
gate_activation : :class:`.Brick` or None
The brick to apply as activation for gates (input/output/forget).
If ``None`` a :class:`.Logistic` brick is used.
Notes
-----
See :class:`.Initializable` for initialization parameters.
"""
@lazy(allocation=['dim'])
def __init__(self, dim, activation=None, **kwargs):
def __init__(self, dim, activation=None, gate_activation=None, **kwargs):
self.dim = dim

if not activation:
activation = Tanh()
children = [activation] + kwargs.get('children', [])
if not gate_activation:
gate_activation = Logistic()
self.activation = activation
self.gate_activation = gate_activation

children = ([self.activation, self.gate_activation] +
kwargs.get('children', []))
super(LSTM, self).__init__(children=children, **kwargs)

def get_dim(self, name):
Expand Down Expand Up @@ -457,18 +466,17 @@ def apply(self, inputs, states, cells, mask=None):
def slice_last(x, no):
return x[:, no*self.dim: (no+1)*self.dim]

nonlinearity = self.children[0].apply

activation = tensor.dot(states, self.W_state) + inputs
in_gate = tensor.nnet.sigmoid(slice_last(activation, 0) +
cells * self.W_cell_to_in)
forget_gate = tensor.nnet.sigmoid(slice_last(activation, 1) +
cells * self.W_cell_to_forget)
next_cells = (forget_gate * cells +
in_gate * nonlinearity(slice_last(activation, 2)))
out_gate = tensor.nnet.sigmoid(slice_last(activation, 3) +
next_cells * self.W_cell_to_out)
next_states = out_gate * nonlinearity(next_cells)
in_gate = self.gate_activation.apply(
slice_last(activation, 0) + cells * self.W_cell_to_in)
forget_gate = self.gate_activation.apply(
slice_last(activation, 1) + cells * self.W_cell_to_forget)
next_cells = (
forget_gate * cells +
in_gate * self.activation.apply(slice_last(activation, 2)))
out_gate = self.gate_activation.apply(
slice_last(activation, 3) + next_cells * self.W_cell_to_out)
next_states = out_gate * self.activation.apply(next_cells)

if mask:
next_states = (mask[:, None] * next_states +
Expand Down

0 comments on commit e0b2be4

Please sign in to comment.