Skip to content

Commit

Permalink
Combine ARRNN and LRRNN models with a rate kwarg.
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Johnson committed Nov 13, 2015
1 parent 41281c4 commit de5d1d5
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 126 deletions.
5 changes: 2 additions & 3 deletions docs/api/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,11 @@ Recurrent
:toctree: generated/

RNN
MRNN
LRRNN
ARRNN
RRNN
MUT1
GRU
LSTM
MRNN
Clockwork
Bidirectional

Expand Down
3 changes: 1 addition & 2 deletions docs/api/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ Reference
theanets.layers.feedforward.Classifier
theanets.layers.feedforward.Feedforward
theanets.layers.feedforward.Tied
theanets.layers.recurrent.ARRNN
theanets.layers.recurrent.Bidirectional
theanets.layers.recurrent.Clockwork
theanets.layers.recurrent.GRU
theanets.layers.recurrent.LRRNN
theanets.layers.recurrent.LSTM
theanets.layers.recurrent.MRNN
theanets.layers.recurrent.MUT1
theanets.layers.recurrent.RNN
theanets.layers.recurrent.RRNN
theanets.losses.CrossEntropy
theanets.losses.GaussianLogLikelihood
theanets.losses.Hinge
Expand Down
2 changes: 1 addition & 1 deletion examples/recurrent-sinusoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
# predicted output.
for i, layer in enumerate((
dict(form='rnn', activation='relu', diagonal=0.5),
dict(form='lrrnn', activation='relu', diagonal=0.5),
dict(form='rrnn', activation='relu', rate='vector', diagonal=0.5),
dict(form='gru', activation='relu'),
dict(form='lstm', activation='tanh'),
dict(form='clockwork', activation='linear', periods=(1, 4, 16, 64)))):
Expand Down
28 changes: 23 additions & 5 deletions test/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def test_transform(self):

class TestARRNN(BaseRecurrent):
def _build(self):
return theanets.layers.ARRNN(
return theanets.layers.RRNN(
inputs=self.NUM_INPUTS, size=self.NUM_HIDDEN, name='l')

def test_create(self):
Expand All @@ -310,8 +310,9 @@ def test_transform(self):

class TestLRRNN(BaseRecurrent):
def _build(self):
return theanets.layers.LRRNN(
inputs=self.NUM_INPUTS, size=self.NUM_HIDDEN, name='l')
return theanets.layers.RRNN(
inputs=self.NUM_INPUTS, size=self.NUM_HIDDEN, name='l',
rate='vector')

def test_create(self):
self.assert_param_names(['b', 'hh', 'r', 'xh'])
Expand All @@ -324,6 +325,23 @@ def test_transform(self):
assert not upd


class TestRRNN(BaseRecurrent):
def _build(self):
return theanets.layers.RRNN(
inputs=self.NUM_INPUTS, size=self.NUM_HIDDEN, name='l',
rate='uniform')

def test_create(self):
self.assert_param_names(['b', 'hh', 'xh'])
self.assert_count(
(1 + self.NUM_INPUTS + self.NUM_HIDDEN) * self.NUM_HIDDEN)

def test_transform(self):
out, upd = self.l.transform(dict(out=self.x))
assert len(out) == 4
assert not upd


class TestMRNN(BaseRecurrent):
def _build(self):
return theanets.layers.MRNN(
Expand Down Expand Up @@ -413,7 +431,7 @@ def test_spec(self):
class TestBidirectional(BaseRecurrent):
def _build(self):
return theanets.layers.Bidirectional(
inputs=self.NUM_INPUTS, size=self.NUM_HIDDEN, worker='arrnn', name='l')
inputs=self.NUM_INPUTS, size=self.NUM_HIDDEN, worker='rrnn', name='l')

def test_create(self):
self.assert_param_names(
Expand All @@ -428,4 +446,4 @@ def test_transform(self):
assert not upd

def test_spec(self):
self.assert_spec(size=self.NUM_HIDDEN, form='bidirectional', worker='arrnn')
self.assert_spec(size=self.NUM_HIDDEN, form='bidirectional', worker='rrnn')
179 changes: 64 additions & 115 deletions theanets/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
logging = climate.get_logger(__name__)

__all__ = [
'ARRNN',
'Bidirectional',
'Clockwork',
'GRU',
'LRRNN',
'LSTM',
'MRNN',
'MUT1',
'RNN',
'RRNN',
]


Expand Down Expand Up @@ -250,8 +249,20 @@ def fn(x_t, h_tm1):
return dict(pre=pre, out=out), updates


class LRRNN(Recurrent):
r'''An RNN with learned rate for each unit.
class RRNN(Recurrent):
r'''An RNN with an update rate for each unit.
Parameters
----------
rate : str, optional
This parameter controls how rates are represented in the layer. If this
is ``None``, the default, then rates are computed as a function of the
input at each time step. If this parameter is ``'vector'``, then rates
are represented as a single vector of learnable rates. If this parameter
is ``'uniform'`` then rates are chosen randomly at uniform from the open
interval (0, 1). If this parameter is ``'log'`` then rates are chosen
randomly from a log-uniform distribution such that few rates are near 0
and many rates are near 1.
Notes
-----
Expand All @@ -266,25 +277,30 @@ class LRRNN(Recurrent):
where :math:`\odot` indicates elementwise multiplication.
Rates might be defined in a number of ways, spanning a continuum between
vanilla RNNs (i.e., all rate parameters are fixed at 1), fixed but
non-uniform rates for each hidden unit [Ben12]_, parametric rates that are
dependent only on the input (i.e., the :class:`ARRNN`), all the way to
parametric rates that are computed as a function of the inputs and the
hidden state at each time step (i.e., something more like the :class:`gated
recurrent unit <GRU>`).
This class represents rates as a single learnable vector of parameters. This
representation uses the fewest number of parameters for learnable rates, but
the simplicity of the model comes at the cost of effectively fixing the rate
for each unit as a constant value across time.
vanilla RNNs (i.e., all rate parameters are effectively fixed at 1), fixed
but non-uniform rates for each hidden unit [Ben12]_, parametric rates that
are dependent only on the input, all the way to parametric rates that are
computed as a function of the inputs and the hidden state at each time step
(i.e., something more like the :class:`gated recurrent unit <GRU>`).
This class represents rates in different ways depending on the value of the
``rate`` parameter at inititialization.
*Parameters*
- ``b`` --- vector of bias values for each hidden unit
- ``r`` --- vector of rates for each hidden unit
- ``xh`` --- matrix connecting inputs to hidden units
- ``hh`` --- matrix connecting hiddens to hiddens
If ``rate`` is initialized to the string ``'vector'``, we define:
- ``r`` --- vector of rates for each hidden unit
If ``rate`` is initialized to ``None``, we define:
- ``r`` --- vector of rate bias values for each hidden unit
- ``xr`` --- matrix connecting inputs to rate values for each hidden unit
*Outputs*
- ``out`` --- the post-activation state of the layer
Expand All @@ -300,12 +316,27 @@ class LRRNN(Recurrent):
http://arxiv.org/abs/1212.0901
'''

def __init__(self, rate='matrix', **kwargs):
self.rate = rate.lower().strip()
super(RRNN, self).__init__(**kwargs)
self._rates = None
if self.rate == 'uniform':
z = np.random.uniform(0, 1, size=self.size).astype(util.FLOAT)
self._rates = theano.shared(z, name=self._fmt('rate'))
if self.rate == 'log':
z = np.random.uniform(-6, 0, size=self.size).astype(util.FLOAT)
self._rates = theano.shared(np.exp(z), name=self._fmt('rate'))

def setup(self):
'''Set up the parameters and initial values for this layer.'''
self.add_weights('xh', self.input_size, self.size)
self.add_weights('hh', self.size, self.size)
self.add_bias('b', self.size)
self.add_bias('r', self.size, mean=2, std=1)

if self.rate == 'vector' or self.rate == 'matrix':
self.add_bias('r', self.size, mean=2, std=1)
if self.rate == 'matrix':
self.add_weights('xr', self.input_size, self.size)

def transform(self, inputs):
'''Transform the inputs for this layer into an output for the layer.
Expand All @@ -332,113 +363,31 @@ def transform(self, inputs):
# scan wants: (time, batch, input)
x = self._only_input(inputs).dimshuffle(1, 0, 2)
h = TT.dot(x, self.find('xh')) + self.find('b')
r = TT.nnet.sigmoid(self.find('r'))
r = self._rates

def fn(x_t, h_tm1):
def fn_dynamic(x_t, r_t, h_tm1):
pre = x_t + TT.dot(h_tm1, self.find('hh'))
h_t = self.activate(pre)
return [pre, h_t, (1 - r) * h_tm1 + r * h_t]

# output is: (time, batch, output)
# we want: (batch, time, output)
(p, h, o), updates = self._scan(fn, [h], [None, None, x])
pre = p.dimshuffle(1, 0, 2)
hid = h.dimshuffle(1, 0, 2)
out = o.dimshuffle(1, 0, 2)

return dict(pre=pre, hid=hid, rate=r, out=out), updates


class ARRNN(Recurrent):
r'''An RNN with adaptive rate per unit.
Notes
-----
In a normal RNN, a hidden unit is updated completely at each time step,
:math:`h_t = f(x_t, h_{t-1})`. With an explicit update rate, the state of a
hidden unit is computed as a mixture of the new and old values,
.. math::
h_t = (1 - z_t) \odot h_{t-1} + z_t \odot f(x_t, h_{t-1})
where :math:`\odot` indicates elementwise multiplication.
Rates might be defined in a number of ways, spanning a continuum between
vanilla RNNs (i.e., all rate parameters are fixed at 1) all the way to
parametric rates that are computed as a function of the inputs and the
hidden state at each time step (i.e., something more like the :class:`GRU`).
In the ARRNN model, the rate values are represented as a computed at each
time step as a logistic sigmoid applied to an affine transform of the input:
.. math::
z_t = \frac{1}{1 + \exp(-x_t W_{xr} - b_r)}.
This representation of the rates uses more parameters than the
:class:`LRRNN` but is able to adapt rates to the input at each time step.
However, in this model, rates are not able to adapt to the state of the
hidden units at each time step.
*Parameters*
- ``b`` --- vector of bias values for each hidden unit
- ``r`` --- vector of rate biases for each hidden unit
- ``xh`` --- matrix connecting inputs to hidden units
- ``xr`` --- matrix connecting inputs to rate "gates"
- ``hh`` --- matrix connecting hiddens to hiddens
*Outputs*
- ``out`` --- the post-activation state of the layer
- ``pre`` --- the pre-activation state of the layer
- ``hid`` --- the pre-rate-mixing hidden state
- ``rate`` --- the rate values
'''

def setup(self):
'''Set up the parameters and initial values for this layer.'''
self.add_weights('xh', self.input_size, self.size)
self.add_weights('xr', self.input_size, self.size)
self.add_weights('hh', self.size, self.size)
self.add_bias('b', self.size)
self.add_bias('r', self.size)

def transform(self, inputs):
'''Transform the inputs for this layer into an output for the layer.
Parameters
----------
inputs : dict of theano expressions
Symbolic inputs to this layer, given as a dictionary mapping string
names to Theano expressions. See :func:`base.Layer.connect`.
Returns
-------
outputs : theano expression
A map from string output names to Theano expressions for the outputs
from this layer. This layer type generates a "pre" output that gives
the unit activity before applying the layer's activation function,
a "hid" output that gives the rate-independent, post-activation
hidden state, a "rate" output that gives the rate value for each
hidden unit, and an "out" output that gives the hidden output.
updates : list of update pairs
A sequence of updates to apply inside a theano function.
'''
# input is: (batch, time, input)
# scan wants: (time, batch, input)
x = self._only_input(inputs).dimshuffle(1, 0, 2)
r = TT.nnet.sigmoid(TT.dot(x, self.find('xr')) + self.find('r'))
h = TT.dot(x, self.find('xh')) + self.find('b')
return [pre, h_t, (1 - r_t) * h_tm1 + r_t * h_t]

def fn(x_t, r_t, h_tm1):
def fn_static(x_t, h_tm1):
pre = x_t + TT.dot(h_tm1, self.find('hh'))
h_t = self.activate(pre)
return [pre, h_t, (1 - r_t) * h_tm1 + r_t * h_t]
return [pre, h_t, (1 - r) * h_tm1 + r * h_t]

fn = fn_static
seqs = [h]

if self.rate == 'matrix':
fn = fn_dynamic
r = TT.nnet.sigmoid(TT.dot(x, self.find('xr')) + self.find('r'))
seqs.append(r)
elif self.rate == 'vector':
r = TT.nnet.sigmoid(self.find('r'))

# output is: (time, batch, output)
# we want: (batch, time, output)
(p, h, o), updates = self._scan(fn, [h, r], [None, None, x])
(p, h, o), updates = self._scan(fn, seqs, [None, None, x])
pre = p.dimshuffle(1, 0, 2)
hid = h.dimshuffle(1, 0, 2)
out = o.dimshuffle(1, 0, 2)
Expand Down

0 comments on commit de5d1d5

Please sign in to comment.