Skip to content
This repository was archived by the owner on Oct 31, 2025. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 52 additions & 34 deletions trax/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,23 +195,34 @@ def forward_with_state(self, x, weights, state, rng):
del weights
n_heads, dropout, mode = self._n_heads, self._dropout, self._mode
q, k, v, mask = x
batch_size = q.shape[0]
d_feature = q.shape[-1]

assert d_feature % n_heads == 0
d_head = d_feature // n_heads
nbatch = jnp.shape(q)[0]
# nbatch, seqlen, d_feature --> nbatch, n_heads, seqlen, d_head
def SplitHeads(x):
return jnp.transpose(
jnp.reshape(x, (nbatch, -1, n_heads, d_head)), (0, 2, 1, 3))
# nbatch, n_heads, seqlen, d_head --> nbatch, seqlen, d_feature
def JoinHeads(x): # pylint: disable=invalid-name
return jnp.reshape(
jnp.transpose(x, (0, 2, 1, 3)), (nbatch, -1, n_heads * d_head))
# Split heads, dot-product attention, rejoin heads.
res = JoinHeads(
DotProductAttention(
SplitHeads(q), SplitHeads(k), SplitHeads(v), mask,
dropout=dropout, mode=mode, rng=rng))

def _split_into_heads(x):
"""Reshapes tensors to prepare for multi-headed computation."""
# (b_size, seqlen, d_feature) --> (b_size, n_heads, seqlen, d_head)
x = jnp.reshape(x, (batch_size, -1, n_heads, d_head))
x = jnp.transpose(x, (0, 2, 1, 3))
return x

def _merge_heads(x):
"""Undoes splitting, post multi-headed computation."""
# (b_size, n_heads, seqlen, d_head) --> (b_size, seqlen, d_feature)
x = jnp.transpose(x, (0, 2, 1, 3))
x = jnp.reshape(x, (batch_size, -1, n_heads * d_head))
return x

res = _merge_heads(
DotProductAttention(_split_into_heads(q),
_split_into_heads(k),
_split_into_heads(v),
mask,
dropout=dropout,
mode=mode,
rng=rng))
return (res, mask), state # Keep the mask.


Expand Down Expand Up @@ -360,31 +371,38 @@ def CausalAttention(d_feature, n_heads=1, dropout=0.0, mode='train'):
assert d_feature % n_heads == 0
d_head = d_feature // n_heads

def compute_attention_heads(x):
batch_size = x.shape[0]
seqlen = x.shape[1]
# n_batch, seqlen, n_heads*d_head -> n_batch, seqlen, n_heads, d_head
x = jnp.reshape(x, (batch_size, seqlen, n_heads, d_head))
# n_batch, seqlen, n_heads, d_head -> n_batch, n_heads, seqlen, d_head
x = jnp.transpose(x, (0, 2, 1, 3))
# n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
return jnp.reshape(x, (-1, seqlen, d_head))
def _split_into_heads():
"""Layer that reshapes tensors to prepare for multi-headed computation."""
def f(x):
batch_size = x.shape[0]
seq_len = x.shape[1]

ComputeAttentionHeads = Fn('ComputeAttentionHeads', compute_attention_heads)
# (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head)
x = jnp.reshape(x, (batch_size, seq_len, n_heads, d_head))
x = jnp.transpose(x, (0, 2, 1, 3))
x = jnp.reshape(x, (-1, seq_len, d_head))
return x
return Fn('SplitIntoHeads', f)

def compute_attention_output(x):
seqlen = x.shape[1]
x = jnp.reshape(x, (-1, n_heads, seqlen, d_head))
x = jnp.transpose(x, (0, 2, 1, 3)) # -> n_batch, seqlen, n_heads, d_head
return jnp.reshape(x, (-1, seqlen, n_heads * d_head))
def _merge_heads():
"""Layer that undoes splitting, post multi-headed computation."""
def f(x):
seq_len = x.shape[1]

# (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature)
x = jnp.reshape(x, (-1, n_heads, seq_len, d_head))
x = jnp.transpose(x, (0, 2, 1, 3))
x = jnp.reshape(x, (-1, seq_len, n_heads * d_head))
return x
return Fn('MergeHeads', f)

return cb.Serial(
cb.Branch(
[core.Dense(d_feature), ComputeAttentionHeads],
[core.Dense(d_feature), ComputeAttentionHeads],
[core.Dense(d_feature), ComputeAttentionHeads],
[core.Dense(d_feature), _split_into_heads()],
[core.Dense(d_feature), _split_into_heads()],
[core.Dense(d_feature), _split_into_heads()],
),
DotProductCausalAttention(dropout=dropout, mode=mode),
Fn('ComputeAttentionOutput', compute_attention_output),
core.Dense(d_feature)
_merge_heads(),
core.Dense(d_feature),
)