Skip to content

Commit

Permalink
Merge a339231 into 44a5066
Browse files Browse the repository at this point in the history
  • Loading branch information
trax-robot committed Jun 30, 2020
2 parents 44a5066 + a339231 commit be9e59d
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 61 deletions.
76 changes: 40 additions & 36 deletions trax/layers/research/efficient_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,44 +258,48 @@ class EfficientAttentionBase(base.Layer):
def __init__(self, n_heads, n_in=1, n_parallel_heads=None,
incremental=False, predict_mem_len=None, predict_drop_len=None,
use_python_loop=False, use_reference_code=False):
"""Construct an EfficientAttentionBase instance.
"""Constructs an EfficientAttentionBase instance.
Args:
n_heads: int: Number of attention heads
n_in: int: Number of inputs to the layer (default 1)
n_parallel_heads: int: Number of attention heads to compute in parallel.
if n_parallel_heads is None (default): The entire layer is computed with
maximum parallelism. This mode is the fastest, but also uses the most
memory. Start with this mode, but switch to one of the others if
memory runs out.
if n_parallel_heads is 1: Attention is computed one head at a time, and
one example at a time. This mode uses the least memory but is not as
fast as batched attention. Use this mode when working with very long
sequences, such that any amount of parallelism won't fit in memory.
if n_parallel_heads is a multiple of n_heads: Attention is computed for
sub-batches of (n_parallel_heads // n_heads) examples at a time.
if 1 < n_parallel_heads < n_heads: Attention is computed for several
heads at a time, but only within a single example. It must be the case
that n_heads is a multiple of n_parallel_heads. Use this mode for long
sequences, to strike a balance between parallelism and memory usage.
incremental: bool: Enables fast inference for self-attention types. Note
that this flag should *not* be set when doing encoder-decoder attention,
but only when doing self-attention.
predict_mem_len: int: Number of input positions to remember in a cache
when doing fast inference. Whenever the cache fills up, some input
elements will be forgotten.
predict_drop_len: int: Number of input elements to drop once the fast
inference input cache fills up.
use_python_loop: bool: Set to True to use a Python loop when iterating
over sub-batches of examples/heads (as opposed to a JAX/XLA loop). This
option will increase compilation time and jitted code size, potentially
drastically. Using it is not recommended except for testing/debugging.
In particular, note that enabling this option on TPU can decrease the
maximum model size that will fit in memory.
use_reference_code: bool: Set to True to fall back to the reference
implementation of batched attention. This option will increase
compilation time and jitted code size, potentially drastically. Using it
is not recommended except for testing/debugging.
n_heads: Number of attention heads.
n_in: Number of inputs to the layer (default 1).
n_parallel_heads: Number of attention heads to compute in parallel.
- If `n_parallel_heads` is None (default), the entire layer is
computed with maximum parallelism. This mode is the fastest, but
also uses the most memory. Start with this mode, but switch to one
of the others if memory runs out.
- If `n_parallel_heads` is 1, attention is computed one head at a
time, and one example at a time. This mode uses the least memory
but is not as fast as batched attention. Use this mode when working
with very long sequences, such that any amount of parallelism won't
fit in memory.
- If `n_parallel_heads` is a multiple of `n_heads`, attention is
computed for sub-batches of (`n_parallel_heads // n_heads`)
examples at a time.
- If `1 < n_parallel_heads < n_heads`, attention is computed for
several heads at a time, but only within a single example. It must
be the case that `n_heads` is a multiple of `n_parallel_heads`. Use
this mode for long sequences, to strike a balance between
parallelism and memory usage.
incremental: If `True`, enable fast inference for self-attention types.
Note that this flag should *not* be set when doing encoder-decoder
attention, but only when doing self-attention.
predict_mem_len: Number of input positions to remember in a cache
when doing fast inference. Whenever the cache fills up, some input
elements will be forgotten.
predict_drop_len: Number of input elements to drop once the fast
inference input cache fills up.
use_python_loop: Set to True to use a Python loop when iterating over
sub-batches of examples/heads (as opposed to a JAX/XLA loop).
This option will increase compilation time and jitted code size,
potentially drastically. Using it is not recommended except for
testing/debugging. In particular, note that enabling this option on
TPU can decrease the maximum model size that will fit in memory.
use_reference_code: Set to True to fall back to the reference
implementation of batched attention. This option will increase
compilation time and jitted code size, potentially drastically. Using
it is not recommended except for testing/debugging.
"""
super().__init__(n_in=n_in, n_out=1)
self.n_heads = n_heads
Expand Down
14 changes: 8 additions & 6 deletions trax/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
Sample invocation:
TRAIN_BATCH_SIZE=32
python trax/rl_trainer.py \
--config_file=trax/rl/configs/ppo_acrobot.gin \
--train_batch_size=${TRAIN_BATCH_SIZE} \
--output_dir=${HOME}/ppo_acrobot \
--alsologtostderr
.. code-block:: bash
TRAIN_BATCH_SIZE=32
python trax/rl_trainer.py \
--config_file=trax/rl/configs/ppo_acrobot.gin \
--train_batch_size=${TRAIN_BATCH_SIZE} \
--output_dir=${HOME}/ppo_acrobot \
--alsologtostderr
"""

import faulthandler
Expand Down
18 changes: 10 additions & 8 deletions trax/supervised/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,18 +412,20 @@ def add_loss_weights(generator, id_to_mask=None):
"""Add weights to inputs without weights and masks by id if requested.
The generator stream is augmented in the following way:
* if the stream consists of pairs (inputs, targets), a loss mask is added
that is creates as a tensor of ones of the same shape as targets
* if id_to_mask is not None, and the stream (after the previous point) has
triples (inputs, targets, weights), the weights are multipled by a 0/1 mask
that is 0 iff targets is equal to id_to_mask (1 otherwise).
- If the stream consists of pairs `(inputs, targets)`, a loss mask is added
that is creates as a tensor of ones of the same shape as targets.
- If `id_to_mask` is not `None`, and the stream (after the previous point)
has triples `(inputs, targets, weights)`, the weights are multipled by a
0/1 mask that is 0 iff targets is equal to `id_to_mask` (1 otherwise).
Args:
generator: a python stream of tuples
id_to_mask: int or None, id to pad in targets if not None
generator: Stream of tuples.
id_to_mask: If not None, int-valued id that represents padding, as opposed
to true target id's.
Yields:
examples from the augmented stream
Examples from the augmented stream.
"""
for example in generator:
if len(example) > 3 or len(example) < 2:
Expand Down
13 changes: 9 additions & 4 deletions trax/supervised/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
# limitations under the License.

# Lint as: python3
"""Learning rate (LR) schedules as functions of time (step number).
r"""Learning rate (LR) schedules.
In Trax a learning rate schedule is a function: step -> learning_rate.
In Trax a learning rate schedule is a function:
:math:`\text{step} \mapsto \text{learning_rate}`.
This module provides helpers for constructing such functions. For example,
constant(0.001)
returns a function that takes each step --> 0.001.
.. code-block:: python
constant(0.001)
returns a function that always returns `0.001`.
"""

import math
Expand Down
14 changes: 7 additions & 7 deletions trax/supervised/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,15 +357,15 @@ def evaluation_round(self, inputs_stream, weights, state, rng):
"""Evaluate.
Args:
inputs_stream: iterable of inputs to evaluate on.
weights: weights for each f in eval_fns.
state: state for each f in eval_fns.
rng: random number generator.
inputs_stream: Iterable of inputs to evaluate on.
weights: Weights for each f in eval_fns.
state: State for each f in eval_fns.
rng: Single-use random number generator (JAX PRNG key).
Returns:
metrics: dict from metric name to metric value averaged over the number of
inputs.
state: end state for `predict_fn`.
Tuple of `(metrics, state)`. `metrics` is a dict from metric name to
metric value averaged over the number of inputs, and `state` is the end
state returned by this trainer's `predict_fn`.
"""
metrics = collections.defaultdict(float)
count = 0
Expand Down

0 comments on commit be9e59d

Please sign in to comment.