Skip to content

Conversation

chenmoneygithub
Copy link
Contributor

@chenmoneygithub chenmoneygithub commented Mar 21, 2023

This turned out to require lots of design, so I am taking over the work by myself.

To play with the API, use the code below:

!pip install -q -U git+https://github.com/chenmoneygithub/keras-nlp.git@contrastive

import keras_nlp

gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.compile(run_eagerly=False, jit_compile=True, sampler="contrastive")

print(
    gpt2_lm.generate(
        ["that's weird", "that's even weirder"],
        max_length=10,
    )
)

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool this is working! Talk offline, I'm just submitting the comments from talking live for now.

Let's try to make it so that the model code does not need to special case a sampling strategy, and the sampler code does not need to assume anything about the structure of the passed state.

@mattdangerw
Copy link
Member

Also, let's benchmark this! I think there are two performance questions of interest.

  • First, we should make sure our implementation of contrastive search is competitive with what else is out there.
  • Second, we should make sure that if we do always return hidden state from the model, that we will not slow down our current "top_p" or "top_k" search etc. I'm pretty sure this will be fine, as the unused hidden_state that we return should "compile out" when we trace.

@mattdangerw
Copy link
Member

@chenmoneygithub
Copy link
Contributor Author

@mattdangerw From benchmark - we are faster OvO. The time cost I got from your colab:

  • our contrastive search: 55.54s for 25 times.
  • HuggingFace: 62.98s for 25 times.

I cannot share github gist while I am in China, but here is how I set HF to use contrastive:

generation_kwargs = {"penalty_alpha": 0.2, "max_new_tokens": 256, "top_k": 5}

I am marking this as ready for review.

@chenmoneygithub chenmoneygithub marked this pull request as ready for review March 26, 2023 13:58
@mattdangerw mattdangerw self-assigned this Mar 29, 2023
@chenmoneygithub
Copy link
Contributor Author

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still want to step through the main logic a bit more, but submitting these comments, as I am almost out of time for the day!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comment as I dig in here!

@mattdangerw
Copy link
Member

OK! Finally looked through this in more detail! Great work, this is a quite readable implementation of a complex sampler.

The one thing I found a big confusing was the indexing here. I actually think we may want to mix up the way we count our indices across all samplers, so loop index == cache index == index of the token being fed in. Then for this contrastive search we would not need to feed in an "out of bounds" looking index on the final pass (currently the final index is 30 if max_length is 30, which just seems weird to me).

I think an update like this could make our overall code more readable. Here's an example commit -> mattdangerw@3b079b7

No need to do this on this PR, we can do this as a follow up, but what do you think?

@chenmoneygithub
Copy link
Contributor Author

/gcbrun

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just a last few comments!

Any ideas on how we can validate this is actually working correctly?

@chenmoneygithub
Copy link
Contributor Author

/gcbrun

@chenmoneygithub
Copy link
Contributor Author

/gcbrun

@chenmoneygithub chenmoneygithub merged commit 12549ce into keras-team:master Apr 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants