From 8d7541431dde1724b231be20d9170ee4407398ab Mon Sep 17 00:00:00 2001 From: Jonni Kanerva Date: Wed, 29 Apr 2020 14:38:55 -0700 Subject: [PATCH] Micro-cleanup: use reshape and transpose *methods* for cleaner code. Also: - "seqlen" --> "seq_len" - slight rewrite of EncoderDecoderMask PiperOrigin-RevId: 309099470 --- trax/layers/attention.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/trax/layers/attention.py b/trax/layers/attention.py index 616e9f360..51da63dcb 100644 --- a/trax/layers/attention.py +++ b/trax/layers/attention.py @@ -109,16 +109,16 @@ def forward_with_state(self, x, weights, state, rng): def _split_into_heads(x): """Reshapes tensors to prepare for multi-headed computation.""" - # (b_size, seqlen, d_feature) --> (b_size, n_heads, seqlen, d_head) - x = jnp.reshape(x, (batch_size, -1, n_heads, d_head)) - x = jnp.transpose(x, (0, 2, 1, 3)) + # (b_size, seq_len, d_feature) --> (b_size, n_heads, seq_len, d_head) + x = x.reshape((batch_size, -1, n_heads, d_head)) + x = x.transpose((0, 2, 1, 3)) return x def _merge_heads(x): """Undoes splitting, post multi-headed computation.""" - # (b_size, n_heads, seqlen, d_head) --> (b_size, seqlen, d_feature) - x = jnp.transpose(x, (0, 2, 1, 3)) - x = jnp.reshape(x, (batch_size, -1, n_heads * d_head)) + # (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) + x = x.transpose((0, 2, 1, 3)) + x = x.reshape((batch_size, -1, n_heads * d_head)) return x res = _merge_heads( @@ -191,9 +191,9 @@ def f(x): seq_len = x.shape[1] # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) - x = jnp.reshape(x, (batch_size, seq_len, n_heads, d_head)) - x = jnp.transpose(x, (0, 2, 1, 3)) - x = jnp.reshape(x, (-1, seq_len, d_head)) + x = x.reshape((batch_size, seq_len, n_heads, d_head)) + x = x.transpose((0, 2, 1, 3)) + x = x.reshape((-1, seq_len, d_head)) return x return Fn('SplitIntoHeads', f) @@ -203,9 +203,9 @@ def f(x): seq_len = x.shape[1] # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) - x = jnp.reshape(x, (-1, n_heads, seq_len, d_head)) - x = jnp.transpose(x, (0, 2, 1, 3)) - x = jnp.reshape(x, (-1, seq_len, n_heads * d_head)) + x = x.reshape((-1, n_heads, seq_len, d_head)) + x = x.transpose((0, 2, 1, 3)) + x = x.reshape((-1, seq_len, n_heads * d_head)) return x return Fn('MergeHeads', f) @@ -293,11 +293,12 @@ def f(x): def EncoderDecoderMask(): """Makes encoder-decoder mask from decoder input and a padding mask.""" - def f(decoder_input, padding_mask): - padding_mask = jnp.reshape( - padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1])) + def f(decoder_input, mask): + batch_size = mask.shape[0] + d_feature = mask.shape[-1] + mask = mask.reshape((batch_size, 1, 1, d_feature)) # Final mask shape is [batch, 1 for heads, decoder-len, encoder-len]. - return padding_mask + jnp.zeros((1, 1, decoder_input.shape[1], 1)) + return mask + jnp.zeros((1, 1, decoder_input.shape[1], 1)) return Fn('EncoderDecoderMask', f)