Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[FEATURE] T5NMTInference model and incremental decoding (#1498)
Browse files Browse the repository at this point in the history
* implement T5 incremental decoding

* add T5NMTInference model and test cases; debug incremental decoding

* add docs for relative position computation; use t5-small for nmt test
  • Loading branch information
Yongyi (Ethan) Wu committed Jan 21, 2021
1 parent 2cfa894 commit dfdfb6d
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 103 deletions.
48 changes: 48 additions & 0 deletions src/gluonnlp/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,54 @@ def __repr__(self):
dtype=self._dtype)


def gen_rel_position(data, past_data=None, dtype=np.int32, layout='NT'):
"""Create a matrix of relative position for RelAttentionScoreCell.
The relative position is defined as the index difference: `mem_i` - `query_j`.
Note, though, that the implementation here makes sense in self-attention's setting,
but not in cross-attention's. Hence, both `mem_i` and `query_j` are time indices from
`data` (or, in incremental decoding's case, the concatenated sequence from the current
stepwise `data` and the previous steps `past_data`).
Parameters
----------
data
The data. Under incremental decoding, seq_length = 1.
- layout = 'NT'
Shape (batch_size, seq_length, C)
- layout = 'TN'
Shape (seq_length, batch_size, C)
past_data
This is only used under incremental decoding. Stacked data from previous steps.
dtype
Data type of the mask
layout
Layout of the data + past_data
Returns
-------
relative_position :
Shape (seq_length, seq_length)
"""
time_axis = 1 if layout == 'NT' else 0
if past_data is None:
position = npx.arange_like(data, axis=time_axis)
else:
# for incremental decoding only, where past data is of the shape:
# NT(NTK): (B, L_seq, num_heads, n_kv) -> (B, L_seq, inner_dim)
# TN(TNK): (L_seq, B, num_heads, n_kv) -> (L_seq, B, inner_dim)
past_data = npx.reshape(past_data, (-2, -2, -5))
position = npx.arange_like(
np.concatenate([past_data, data], axis=time_axis),
axis=time_axis
)
query_position = np.expand_dims(position, axis=-1)
mem_position = np.expand_dims(position, axis=0)
relative_position = mem_position - query_position
return relative_position.astype(np.int32) # shape (L_seq, L_seq)


class RelAttentionScoreCell(HybridBlock):
r"""Get the score based on the query and relative position index. This is used for implementing
relative attention.
Expand Down

0 comments on commit dfdfb6d

Please sign in to comment.