-
Notifications
You must be signed in to change notification settings - Fork 25.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT [Generation
]: Introduce a centralized API to switch between cache implementations
#29030
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very clean at a glance! I have some small nits regarding phrasing & the method name.
Thanks for tackling this 🤗
src/transformers/generation/utils.py
Outdated
@@ -351,6 +359,43 @@ def prepare_inputs_for_generation(self, *args, **kwargs): | |||
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." | |||
) | |||
|
|||
def switch_cache_implementation(self, cache_implementation: Union[CacheImplementation, str], **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice, we are obviously always switching from one cache implementation to another, but for the users it's likely more intuitive to simply set
a new cache implementation, so set_cache_implementation
might be a better method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree! thanks for ponting that out !
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more small nits regarding the phrasing now that we're using set_...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much in favor of this! 👍
@@ -351,6 +359,43 @@ def prepare_inputs_for_generation(self, *args, **kwargs): | |||
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`." | |||
) | |||
|
|||
def set_cache_implementation(self, cache_implementation: Union[CacheImplementation, str], **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather have this function as set_cache
, with past_key_values
as an optional parameter. But this is a tougher discussion: I've opened the discussion on our internal slack here 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for adding
Agreed set_
is preferable to switch_
.
Only request is for tests!
src/transformers/generation/utils.py
Outdated
self.generation_config.sink_window_length = kwargs.get("window_length") | ||
self.generation_config.num_sink_tokens = kwargs.get("num_sink_tokens") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my own understanding, if the currently set cache implementation is sink, with e.g. window_length=window_length=508, num_sink_tokens=4
, should I be allowed to call set_cache_implementation
with updated parameters i.e.
config.set_cache_implementation(CacheImplementation.SINK, window_length=1_016, num_sink_tokens=8)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I think so !
Per my understanding the way the API is designed currently for sink cache is to pass a SinkCache()
object to generate()
through past_key_values
. If past_key_values
is already passed in generate()
, the way I designed things in the PR, we only use the cache from set_cache_implementation
in case past_key_values
is not passed to generate
.
TLDR: If one calls model.set_cache_implementation(CacheImplementation.SINK)
, they shouldn't call model.generate(xxx, past_key_values=SinkCache())
--> maybe we should raise a warning saying that they already called model.set_cache_impl()
, wdyt? | Commit is here: 2810ffa
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Down for that, let's not add enums where we don't need, let's make it easy to use custom CacheClasses and init them with a generation config no?
class CacheImplementation(str, Enum): | ||
DYNAMIC = "dynamic" | ||
STATIC = "static" | ||
SINK = "sink" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this prevents anyone from adding / using a custom implementation why not just use strings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in general it is a good practice to use enums to avoid silent behaviours / errors.
E.g. if one passes
model.generation_config.cache_implementation = "Static"
The code will silently work out of the box as it will use dynamic cache, and can potentitally lead to silent errors.
When using enums,
model.set_cache_implementation("Static")
Will also work out of the box but this time will correctly use the static cache implementation and not dynamic cache as opposed to the snippet above.
We can also use a mapping with hardcoded strings but I found it clearer to have enums
cc @amyeroberts @gante @tomaarsen wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good practice does not mean it should always be used.
Here it's:
- useless: we need a mapping with keys from
"string_cls":cls
anyways. - cumbersome: anywhere you compare the generation_config.cache_cls (a string) you need to use the enum. Why?
- not restrictive, and let me be clear here I do NOT want users to pass anything and expect it to work. If the class is not in the Mapping error out and we are done with it. We should raise and error
TLDR; why would we use then when we don't need it and it only adds additional calls to CacheImplementation.DYNAMIC
vs "dynamic"
is something I don't understand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker I don't understand your objections here. For one, best practices does mean it should be done wherever possible.
I don't believe using enums is useless or cumbersome (aside from these not being well defined)? As @younesbelkada highlights, they provide better guarantees of selection of valid types. The explicit enum means the user can pass a string, but the code uses stricter checks that string matching, which are error prone.
not restrictive, and let me be clear here I do NOT want users to pass anything and expect it to work. If the class is not in the Mapping error out and we are done with it. We should raise and error
Enums are restrictive? In fact, they're far more restrictive than doing string checks and have stronger guarantee's than checking a mapping which is mutable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But that is exactly what we want: the mapping should be mutable, because we want it to be easily adapted for a custom code on the hub.
Down for stricter checks, but in this specific case IMO it is cumbersome and useless as we don't have the notion of "safety" and we basically use mapping for model_type, config_type etc etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Anyway it's not that important
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, OK, I understand better now. I'm not sure how we can both let users modify a dict and have guarantees about not working if users can pass anything. If I've understood correctly, is the check you're wanting on the membership within the dictionary rather than e.g. how it's handled within the setting logic? In this case, I agree dictionaries are probably the simplest solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly 😉
if getattr(generation_config, "cache_implementation", "") == CacheImplementation.SINK: | ||
if "past_key_values" not in kwargs: | ||
kwargs["past_key_values"] = SinkCache( | ||
window_length=generation_config.sink_window_length, | ||
num_sink_tokens=generation_config.num_sink_tokens, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not pass the class and do the same init scheme with generation config passed and you take these from the generation config to allow custom classes?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
I would like to introduce a new API before the release to centralize switching between cache implementations !
Right now to load SInkCache one needs to do:
For static cache:
With this PR:
What do you think @gante @tomaarsen @ArthurZucker @amyeroberts ?
If you are happy with the design and idea I can move forward with adding tests and docs !