You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Now we are always generating until the max_length, but we can actually stop when all sequences in the batch have seen the end token.
To do it, we can maintain a mask for each sequence, and before calling token_probability_fn, use tf.boolean_mask to keep the unfinished sequences only, then we can use tf.tensor_scatter_nd_update to put back the token to right indices to to the token updates.
this should significantly improve the decoding performance when max_length is long.