Skip to content

Commit

Permalink
Merge 8d75414 into 098f5cf
Browse files Browse the repository at this point in the history
  • Loading branch information
trax-robot committed Apr 29, 2020
2 parents 098f5cf + 8d75414 commit 7095f0c
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions trax/layers/attention.py
Expand Up @@ -109,16 +109,16 @@ def forward_with_state(self, x, weights, state, rng):

def _split_into_heads(x):
"""Reshapes tensors to prepare for multi-headed computation."""
# (b_size, seqlen, d_feature) --> (b_size, n_heads, seqlen, d_head)
x = jnp.reshape(x, (batch_size, -1, n_heads, d_head))
x = jnp.transpose(x, (0, 2, 1, 3))
# (b_size, seq_len, d_feature) --> (b_size, n_heads, seq_len, d_head)
x = x.reshape((batch_size, -1, n_heads, d_head))
x = x.transpose((0, 2, 1, 3))
return x

def _merge_heads(x):
"""Undoes splitting, post multi-headed computation."""
# (b_size, n_heads, seqlen, d_head) --> (b_size, seqlen, d_feature)
x = jnp.transpose(x, (0, 2, 1, 3))
x = jnp.reshape(x, (batch_size, -1, n_heads * d_head))
# (b_size, n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
x = x.transpose((0, 2, 1, 3))
x = x.reshape((batch_size, -1, n_heads * d_head))
return x

res = _merge_heads(
Expand Down Expand Up @@ -191,9 +191,9 @@ def f(x):
seq_len = x.shape[1]

# (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
x = jnp.reshape(x, (batch_size, seq_len, n_heads, d_head))
x = jnp.transpose(x, (0, 2, 1, 3))
x = jnp.reshape(x, (-1, seq_len, d_head))
x = x.reshape((batch_size, seq_len, n_heads, d_head))
x = x.transpose((0, 2, 1, 3))
x = x.reshape((-1, seq_len, d_head))
return x
return Fn('SplitIntoHeads', f)

Expand All @@ -203,9 +203,9 @@ def f(x):
seq_len = x.shape[1]

# (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
x = jnp.reshape(x, (-1, n_heads, seq_len, d_head))
x = jnp.transpose(x, (0, 2, 1, 3))
x = jnp.reshape(x, (-1, seq_len, n_heads * d_head))
x = x.reshape((-1, n_heads, seq_len, d_head))
x = x.transpose((0, 2, 1, 3))
x = x.reshape((-1, seq_len, n_heads * d_head))
return x
return Fn('MergeHeads', f)

Expand Down Expand Up @@ -293,11 +293,12 @@ def f(x):

def EncoderDecoderMask():
"""Makes encoder-decoder mask from decoder input and a padding mask."""
def f(decoder_input, padding_mask):
padding_mask = jnp.reshape(
padding_mask, (padding_mask.shape[0], 1, 1, padding_mask.shape[-1]))
def f(decoder_input, mask):
batch_size = mask.shape[0]
d_feature = mask.shape[-1]
mask = mask.reshape((batch_size, 1, 1, d_feature))
# Final mask shape is [batch, 1 for heads, decoder-len, encoder-len].
return padding_mask + jnp.zeros((1, 1, decoder_input.shape[1], 1))
return mask + jnp.zeros((1, 1, decoder_input.shape[1], 1))
return Fn('EncoderDecoderMask', f)


Expand Down

0 comments on commit 7095f0c

Please sign in to comment.