Skip to content

Conversation

chenmoneygithub
Copy link
Contributor

Resolve #749

@chenmoneygithub
Copy link
Contributor Author

chenmoneygithub commented Feb 23, 2023

No unit tests added because the user-side behavior is unchanged.

Also I will mark this as ready after merging the cache pr.

@chenmoneygithub chenmoneygithub marked this pull request as draft February 23, 2023 05:59
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One meta comment on approach.

Another thing we should keep in mind is that this is I think precisely where left padding will become more efficient.

If you left pad your prompt originally, you start generating for all prompts right away, so it's more likely they will all "die" earlier. If you right pad, you might not even start generating on some sequences in the batch or a while.

Not something we need to solve on this PR, but something to think about!

@chenmoneygithub
Copy link
Contributor Author

@mattdangerw Yea, actually I have tried left padding earlier. My findings - Left padding does not work well with GPT2 and other models using absolute positional embedding. In my experiment, the generated text becomes chaotic when left padding is applied.

@mattdangerw
Copy link
Member

@mattdangerw Yea, actually I have tried left padding earlier. My findings - Left padding does not work well with GPT2 and other models using absolute positional embedding. In my experiment, the generated text becomes chaotic when left padding is applied.

I actually think this would be bugs in the attention mask and position embedding setup (both of which are complex in the left pad setup!). But if you do everything correctly, the computation is exactly the same as I understand (e.g. greedy search output will one to one identical). I can try to put together a colab with huggingface showing this.

@mattdangerw
Copy link
Member

One thing that might be worth noting in the left pad setup, is your really need to switch to a gather op for the position embedding, because your indices for the position embedding start varying per sample.

But overall, I am totally down to look at that as a follow up. Just wanted to point out a place where we are starting to leave performance on the table.

@chenmoneygithub
Copy link
Contributor Author

@mattdangerw Yea, I also suspected my code was buggy, and it was based off a much earlier version, so definitely worth a second trial.

@chenmoneygithub chenmoneygithub marked this pull request as ready for review February 28, 2023 05:51
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would need tests too before we land.

output_dim=self.feature_size,
),
keras.layers.Dense(self.vocab_size),
keras.layers.Softmax(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because we by default set from_logits=True (was False earlier for mysterious reason), so I am aligning the unit test with it.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! approving!

Note that github actions appears to be down, so make sure to test out any changes locally! Left some comments for some weird XLA compilation errors I was seeing.

# The index of the last non-padding token in prompt. Since all sequences
# are aligned to the right side, the index is the same for all.
current_index = max_length - num_steps
original_padding_mask = tf.cast(tf.identity(mask), dtype=tf.int32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we compute this twice (above as well), should we just pass it through?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fairly cheap to compute. I was doing this way because sample already has confusing arg list, would like to keep it shorter (though not much...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Do an actual stop when sampler sees an end token
2 participants