diff --git a/trax/layers/attention.py b/trax/layers/attention.py index 820117bc8..8831fb1d0 100644 --- a/trax/layers/attention.py +++ b/trax/layers/attention.py @@ -195,23 +195,34 @@ def forward_with_state(self, x, weights, state, rng): del weights n_heads, dropout, mode = self._n_heads, self._dropout, self._mode q, k, v, mask = x + batch_size = q.shape[0] d_feature = q.shape[-1] + assert d_feature % n_heads == 0 d_head = d_feature // n_heads - nbatch = jnp.shape(q)[0] - # nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head - def SplitHeads(x): - return jnp.transpose( - jnp.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3)) - # nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature - def JoinHeads(x): # pylint: disable=invalid-name - return jnp.reshape( - jnp.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head)) - # Split heads, dot-product attention, rejoin heads. - res = JoinHeads( - DotProductAttention( - SplitHeads(q), SplitHeads(k), SplitHeads(v), mask, - dropout=dropout, mode=mode, rng=rng)) + + def _split_into_heads(x): + """Reshapes tensors to prepare for multi-headed computation.""" + # (b_size, seqlen, d_feature) --> (b_size, n_heads, seqlen, d_head) + x = jnp.reshape(x, (batch_size, -1, n_heads, d_head)) + x = jnp.transpose(x, (0, 2, 1, 3)) + return x + + def _merge_heads(x): + """Undoes splitting, post multi-headed computation.""" + # (b_size, n_heads, seqlen, d_head) --> (b_size, seqlen, d_feature) + x = jnp.transpose(x, (0, 2, 1, 3)) + x = jnp.reshape(x, (batch_size, -1, n_heads * d_head)) + return x + + res = _merge_heads( + DotProductAttention(_split_into_heads(q), + _split_into_heads(k), + _split_into_heads(v), + mask, + dropout=dropout, + mode=mode, + rng=rng)) return (res, mask), state # Keep the mask. @@ -360,31 +371,38 @@ def CausalAttention(d_feature, n_heads=1, dropout=0.0, mode='train'): assert d_feature % n_heads == 0 d_head = d_feature // n_heads - def compute_attention_heads(x): - batch_size = x.shape[0] - seqlen = x.shape[1] - # n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head - x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head)) - # n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head - x = jnp.transpose(x, (0, 2, 1, 3)) - # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head - return jnp.reshape(x, (-1, seqlen, d_head)) + def _split_into_heads(): + """Layer that reshapes tensors to prepare for multi-headed computation.""" + def f(x): + batch_size = x.shape[0] + seq_len = x.shape[1] - ComputeAttentionHeads = Fn('ComputeAttentionHeads', compute_attention_heads) + # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) + x = jnp.reshape(x, (batch_size, seq_len, n_heads, d_head)) + x = jnp.transpose(x, (0, 2, 1, 3)) + x = jnp.reshape(x, (-1, seq_len, d_head)) + return x + return Fn('SplitIntoHeads', f) - def compute_attention_output(x): - seqlen = x.shape[1] - x = jnp.reshape(x, (-1, n_heads, seqlen, d_head)) - x = jnp.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head - return jnp.reshape(x, (-1, seqlen, n_heads * d_head)) + def _merge_heads(): + """Layer that undoes splitting, post multi-headed computation.""" + def f(x): + seq_len = x.shape[1] + + # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) + x = jnp.reshape(x, (-1, n_heads, seq_len, d_head)) + x = jnp.transpose(x, (0, 2, 1, 3)) + x = jnp.reshape(x, (-1, seq_len, n_heads * d_head)) + return x + return Fn('MergeHeads', f) return cb.Serial( cb.Branch( - [core.Dense(d_feature), ComputeAttentionHeads], - [core.Dense(d_feature), ComputeAttentionHeads], - [core.Dense(d_feature), ComputeAttentionHeads], + [core.Dense(d_feature), _split_into_heads()], + [core.Dense(d_feature), _split_into_heads()], + [core.Dense(d_feature), _split_into_heads()], ), DotProductCausalAttention(dropout=dropout, mode=mode), - Fn('ComputeAttentionOutput', compute_attention_output), - core.Dense(d_feature) + _merge_heads(), + core.Dense(d_feature), )