Skip to content

Do an actual stop when sampler sees an end token #749

@chenmoneygithub

Description

@chenmoneygithub

Now we are always generating until the max_length, but we can actually stop when all sequences in the batch have seen the end token.

To do it, we can maintain a mask for each sequence, and before calling token_probability_fn, use tf.boolean_mask to keep the unfinished sequences only, then we can use tf.tensor_scatter_nd_update to put back the token to right indices to to the token updates.

this should significantly improve the decoding performance when max_length is long.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions