Skip to content
This repository has been archived by the owner on Nov 15, 2023. It is now read-only.

Commit

Permalink
Add back flstm cell and LSTM with peepholes option
Browse files Browse the repository at this point in the history
  • Loading branch information
okuchaiev committed Feb 7, 2018
1 parent 1248b3e commit 85095e7
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 2 deletions.
118 changes: 118 additions & 0 deletions flstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Module for constructing RNN Cells."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import math

from tensorflow.contrib.compiler import jit
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest

# pylint: disable=protected-access
_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
# pylint: enable=protected-access

class FLSTMCell(rnn_cell_impl.RNNCell):
"""Group LSTM cell (G-LSTM).
The implementation is based on:
https://arxiv.org/abs/1703.10722
O. Kuchaiev and B. Ginsburg
"Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
"""

def __init__(self, num_units, fact_size, initializer=None, num_proj=None,
forget_bias=1.0, activation=math_ops.tanh,
reuse=None):
"""Initialize the parameters of G-LSTM cell.
Args:
num_units: int, The number of units in the G-LSTM cell
initializer: (optional) The initializer to use for the weight and
projection matrices.
num_proj: (optional) int, The output dimensionality for the projection
matrices. If None, no projection is performed.
number_of_groups: (optional) int, number of groups to use.
If `number_of_groups` is 1, then it should be equivalent to LSTM cell
forget_bias: Biases of the forget gate are initialized by default to 1
in order to reduce the scale of forgetting at the beginning of
the training.
activation: Activation function of the inner states.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already
has the given variables, an error is raised.
Raises:
ValueError: If `num_units` or `num_proj` is not divisible by
`number_of_groups`.
"""
super(FLSTMCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._initializer = initializer
self._fact_size = fact_size
self._forget_bias = forget_bias
self._activation = activation
self._num_proj = num_proj

if num_proj:
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
self._output_size = num_proj
else:
self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
self._output_size = num_units
self._linear1 = None
self._linear2 = None
self._linear3 = None

@property
def state_size(self):
return self._state_size

@property
def output_size(self):
return self._output_size

def call(self, inputs, state):
"""
"""
(c_prev, m_prev) = state
self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer):
x = array_ops.concat([inputs, m_prev], axis=1)
with vs.variable_scope("first_gemm"):
if self._linear1 is None:
self._linear1 = _Linear(x, self._fact_size, False) # no bias for bottleneck
R_fact = self._linear1(x)
with vs.variable_scope("second_gemm"):
if self._linear2 is None:
self._linear2 = _Linear(R_fact, 4*self._num_units, True)
R = self._linear2(R_fact)
i, j, f, o = array_ops.split(R, 4, 1)

c = (math_ops.sigmoid(f + self._forget_bias) * c_prev +
math_ops.sigmoid(i) * math_ops.tanh(j))
m = math_ops.sigmoid(o) * self._activation(c)

if self._num_proj is not None:
with vs.variable_scope("projection"):
if self._linear3 is None:
self._linear3 = _Linear(m, self._num_proj, False)
m = self._linear3(m)

new_state = rnn_cell_impl.LSTMStateTuple(c, m)
return m, new_state
12 changes: 10 additions & 2 deletions language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from hparams import HParams
#from tensorflow.contrib.rnn import LSTMCell
from glstm import GLSTMCell

from flstm import FLSTMCell

class LM(object):
def __init__(self, hps, mode="train", ps_device="/gpu:0"):
Expand Down Expand Up @@ -98,10 +98,17 @@ def _forward(self, gpu, x, y):
cell = GLSTMCell(num_units=hps.state_size,
num_proj=hps.projected_size,
number_of_groups=hps.num_of_groups)
elif hps.fact_size is not None:
print("Using FLSTM")
cell = FLSTMCell(num_units=hps.state_size,
fact_size=hps.fact_size,
num_proj=hps.projected_size)
else:
print("Using LSTMP")
print("Using peepholes: %s" % hps.use_peepholes)
cell = tf.nn.rnn_cell.LSTMCell(num_units=hps.state_size,
num_proj=hps.projected_size)
num_proj=hps.projected_size,
use_peepholes=hps.use_peepholes)

state = tf.contrib.rnn.LSTMStateTuple(self.initial_states[i][0],
self.initial_states[i][1])
Expand Down Expand Up @@ -223,6 +230,7 @@ def get_default_hparams():
fact_size=None,
fnon_linearity="none",
num_of_groups=0,
use_peepholes=False,

save_model_every_min=30,
save_summary_every_min=16,
Expand Down

0 comments on commit 85095e7

Please sign in to comment.