-
Notifications
You must be signed in to change notification settings - Fork 301
Add contrastive sampler #896
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add contrastive sampler #896
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool this is working! Talk offline, I'm just submitting the comments from talking live for now.
Let's try to make it so that the model code does not need to special case a sampling strategy, and the sampler code does not need to assume anything about the structure of the passed state
.
Also, let's benchmark this! I think there are two performance questions of interest.
|
Here's a colab that might be usefulf for benchmarking https://colab.research.google.com/gist/mattdangerw/b31acb4b33ac13e6403fb043c9162657/rough-generation-benchmarking.ipynb |
429f0b6
to
78a6dd0
Compare
@mattdangerw From benchmark - we are faster OvO. The time cost I got from your colab:
I cannot share github gist while I am in China, but here is how I set HF to use contrastive:
I am marking this as ready for review. |
9796e51
to
2f37199
Compare
/gcbrun |
c0eaf4e
to
47ad7a2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still want to step through the main logic a bit more, but submitting these comments, as I am almost out of time for the day!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comment as I dig in here!
OK! Finally looked through this in more detail! Great work, this is a quite readable implementation of a complex sampler. The one thing I found a big confusing was the indexing here. I actually think we may want to mix up the way we count our indices across all samplers, so loop index == cache index == index of the token being fed in. Then for this contrastive search we would not need to feed in an "out of bounds" looking index on the final pass (currently the final index is 30 if max_length is 30, which just seems weird to me). I think an update like this could make our overall code more readable. Here's an example commit -> mattdangerw@3b079b7 No need to do this on this PR, we can do this as a follow up, but what do you think? |
6b10de7
to
2bda2d7
Compare
2bda2d7
to
e5877b3
Compare
/gcbrun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just a last few comments!
Any ideas on how we can validate this is actually working correctly?
/gcbrun |
/gcbrun |
This turned out to require lots of design, so I am taking over the work by myself.
To play with the API, use the code below: