Skip to content

Commit

Permalink
Merge pull request #51 from google-research/seq2seq
Browse files Browse the repository at this point in the history
Adds utility functions and unroll_reinject.
  • Loading branch information
ramasesh committed Jan 14, 2021
2 parents f5b3fb6 + 8313f4e commit 5e51864
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 11 deletions.
52 changes: 41 additions & 11 deletions renn/rnn/unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Recurrent neural network (RNN) cells."""
"""Unroll functions for recurrent neural network (RNN) cells."""

import jax
import jax.numpy as jnp

__all__ = ['unroll_rnn']
from renn.utils import identity

__all__ = ['unroll_rnn', 'unroll_reinject']

def identity(x):
"""Identity function f(x) = x."""
return x


def unroll_rnn(initial_states, input_sequences, rnn_update, readout=identity):
def unroll_rnn(initial_states,
input_sequences,
apply_rnn,
apply_readout=identity):
"""Unrolls an RNN on a batch of input sequences.
Given a batch of initial RNN states, and a batch of input sequences, this
function unrolls application of the RNN along the sequence. The RNN state
is updated using the `rnn_update` function, and the `readout` is used to
is updated using the `apply_rnn` function, and the `readout` is used to
convert the RNN state to outputs (defaults to the identity function).
B: batch size.
Expand All @@ -39,7 +39,7 @@ def unroll_rnn(initial_states, input_sequences, rnn_update, readout=identity):
Args:
initial_states: batch of initial states, with shape (B, N).
input_sequences: batch of inputs, with shape (B, T, N).
rnn_update: updates the RNN hidden state, given (inputs, current_states).
apply_rnn: updates the RNN hidden state, given (inputs, current_states).
readout: applies the readout, given current states. If this is the identity
function, then no readout is applied (returns the hidden states).
Expand All @@ -48,11 +48,41 @@ def unroll_rnn(initial_states, input_sequences, rnn_update, readout=identity):
"""

def _step(state, inputs):
next_state = rnn_update(inputs, state)
outputs = readout(next_state)
next_state = apply_rnn(inputs, state)
outputs = apply_readout(next_state)
return next_state, outputs

input_sequences = jnp.swapaxes(input_sequences, 0, 1)
_, outputs = jax.lax.scan(_step, initial_states, input_sequences)

return jnp.swapaxes(outputs, 0, 1)


def unroll_reinject(initial_states, initial_token, sequence_length,
apply_embedding, apply_rnn, apply_readout):
"""Unrolls an RNN, reinjecting the output back into the RNN."""

def _step(state, _):

# Unpack loop state.
tokens, rnn_state = state

# Apply embedding, RNN, and readout.
rnn_inputs = apply_embedding(tokens)
rnn_state = apply_rnn(rnn_inputs, rnn_state)
logits = apply_readout(rnn_state)

# Pack new loop state
next_state = (jnp.argmax(logits, axis=-1), rnn_state)

return next_state, logits

# Format scan arguments.
batch_size = initial_states.shape[0]
batch_inputs = initial_token * jnp.ones(batch_size).astype(jnp.int32)
dummy_inputs = jnp.zeros((sequence_length, 1))

# Unroll loop via scan.
_, outputs = jax.lax.scan(_step, (batch_inputs, initial_states), dummy_inputs)

return jnp.swapaxes(outputs, 0, 1)
35 changes: 35 additions & 0 deletions renn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,41 @@ def wrapper(x):
return wrapper


def build_mask(max_length: int):
"""Builds a function that generates a binary mask.
For example, `f = build_mask(5)` returns a function that generates masks of
total length 5. Calling this function with an array of integers, e.g.
f(jnp.array([2, 3])), will return a binary (mask) array:
[[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.]]
This is useful for generating binary arrays for masking out elements of a
padded batch, such as computing a loss over a batch of padded sequences.
Args:
max_length: int, The total length of the mask/sequence.
Returns:
mask_fun: function, Takes an array of indices (lengths) and returns
a binary mask where
"""

def mask_fun(index: jnp.array) -> jnp.array:
"""Builds a binary mask."""
return jnp.where(
jnp.arange(max_length) < index, jnp.ones(max_length),
jnp.zeros(max_length))

return jax.vmap(mask_fun)


def select(sequences, indices):
"""Selects a particular timestep from a batch of sequences."""
last_index = jnp.array(indices)[:, jnp.newaxis, jnp.newaxis]
return jnp.squeeze(jnp.take_along_axis(sequences, last_index, axis=1))


def optimize(loss_fun, x0, optimizer, steps, stop_tol=-np.inf):
"""Run an optimizer on a given loss function.
Expand Down
20 changes: 20 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,23 @@ def test_batch_mean():
data = jnp.array([0, 1, 2])
result = square_fun(data)
assert result == jnp.mean(data**2)

def test_build_mask():
"""Tests the build_mask function."""

max_length = 5
test_function = utils.build_mask(max_length)
result = test_function(jnp.array([0,1,2,3,4,5,6,7,8,9,10]))
ideal_result = jnp.array([[0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]])
assert np.allclose(result, ideal_result)

0 comments on commit 5e51864

Please sign in to comment.