Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT [Generation]: Introduce a centralized API to switch between cache implementations #29030

Closed
wants to merge 11 commits into from

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Feb 15, 2024

What does this PR do?

I would like to introduce a new API before the release to centralize switching between cache implementations !

Right now to load SInkCache one needs to do:

from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache

tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")

cache = SinkCache(window_length=508, num_sink_tokens=4)

inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=300, past_key_values=cache)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

For static cache:

from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache

tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")

model.generation_config.cache_implementation = "static"

inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=300)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

With this PR:

from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache

tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")

- cache = SinkCache(window_length=508, num_sink_tokens=4)
+ model.set_cache_implementation("sink", sink_window_length=508, num_sink_tokens=4)

inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
- gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=300, past_key_value=cache)
+ gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=300)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache

tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")

- model.generation_config.cache_implementation = "static"
+ model.set_cache_implementation("static")

inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=300, past_key_values=cache)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)

What do you think @gante @tomaarsen @ArthurZucker @amyeroberts ?

If you are happy with the design and idea I can move forward with adding tests and docs !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean at a glance! I have some small nits regarding phrasing & the method name.
Thanks for tackling this 🤗

@@ -351,6 +359,43 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
)

def switch_cache_implementation(self, cache_implementation: Union[CacheImplementation, str], **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, we are obviously always switching from one cache implementation to another, but for the users it's likely more intuitive to simply set a new cache implementation, so set_cache_implementation might be a better method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree! thanks for ponting that out !

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
younesbelkada and others added 3 commits February 15, 2024 14:55
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Copy link
Member

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more small nits regarding the phrasing now that we're using set_...

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much in favor of this! 👍

@@ -351,6 +359,43 @@ def prepare_inputs_for_generation(self, *args, **kwargs):
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `.generate()`."
)

def set_cache_implementation(self, cache_implementation: Union[CacheImplementation, str], **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have this function as set_cache, with past_key_values as an optional parameter. But this is a tougher discussion: I've opened the discussion on our internal slack here 🤗

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for adding

Agreed set_ is preferable to switch_.

Only request is for tests!

src/transformers/generation/utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
Comment on lines 394 to 395
self.generation_config.sink_window_length = kwargs.get("window_length")
self.generation_config.num_sink_tokens = kwargs.get("num_sink_tokens")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own understanding, if the currently set cache implementation is sink, with e.g. window_length=window_length=508, num_sink_tokens=4, should I be allowed to call set_cache_implementation with updated parameters i.e.

config.set_cache_implementation(CacheImplementation.SINK, window_length=1_016, num_sink_tokens=8)

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I think so !
Per my understanding the way the API is designed currently for sink cache is to pass a SinkCache() object to generate() through past_key_values. If past_key_values is already passed in generate(), the way I designed things in the PR, we only use the cache from set_cache_implementation in case past_key_values is not passed to generate.

TLDR: If one calls model.set_cache_implementation(CacheImplementation.SINK), they shouldn't call model.generate(xxx, past_key_values=SinkCache()) --> maybe we should raise a warning saying that they already called model.set_cache_impl(), wdyt? | Commit is here: 2810ffa

younesbelkada and others added 7 commits February 19, 2024 04:31
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Down for that, let's not add enums where we don't need, let's make it easy to use custom CacheClasses and init them with a generation config no?

Comment on lines +97 to +100
class CacheImplementation(str, Enum):
DYNAMIC = "dynamic"
STATIC = "static"
SINK = "sink"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this prevents anyone from adding / using a custom implementation why not just use strings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general it is a good practice to use enums to avoid silent behaviours / errors.
E.g. if one passes

model.generation_config.cache_implementation = "Static"

The code will silently work out of the box as it will use dynamic cache, and can potentitally lead to silent errors.

When using enums,

model.set_cache_implementation("Static")

Will also work out of the box but this time will correctly use the static cache implementation and not dynamic cache as opposed to the snippet above.

We can also use a mapping with hardcoded strings but I found it clearer to have enums

cc @amyeroberts @gante @tomaarsen wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good practice does not mean it should always be used.

Here it's:

  • useless: we need a mapping with keys from "string_cls":cls anyways.
  • cumbersome: anywhere you compare the generation_config.cache_cls (a string) you need to use the enum. Why?
  • not restrictive, and let me be clear here I do NOT want users to pass anything and expect it to work. If the class is not in the Mapping error out and we are done with it. We should raise and error

TLDR; why would we use then when we don't need it and it only adds additional calls to CacheImplementation.DYNAMIC vs "dynamic" is something I don't understand.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker I don't understand your objections here. For one, best practices does mean it should be done wherever possible.

I don't believe using enums is useless or cumbersome (aside from these not being well defined)? As @younesbelkada highlights, they provide better guarantees of selection of valid types. The explicit enum means the user can pass a string, but the code uses stricter checks that string matching, which are error prone.

not restrictive, and let me be clear here I do NOT want users to pass anything and expect it to work. If the class is not in the Mapping error out and we are done with it. We should raise and error

Enums are restrictive? In fact, they're far more restrictive than doing string checks and have stronger guarantee's than checking a mapping which is mutable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that is exactly what we want: the mapping should be mutable, because we want it to be easily adapted for a custom code on the hub.
Down for stricter checks, but in this specific case IMO it is cumbersome and useless as we don't have the notion of "safety" and we basically use mapping for model_type, config_type etc etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway it's not that important

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, OK, I understand better now. I'm not sure how we can both let users modify a dict and have guarantees about not working if users can pass anything. If I've understood correctly, is the check you're wanting on the membership within the dictionary rather than e.g. how it's handled within the setting logic? In this case, I agree dictionaries are probably the simplest solution.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly 😉

Comment on lines +1245 to +1250
if getattr(generation_config, "cache_implementation", "") == CacheImplementation.SINK:
if "past_key_values" not in kwargs:
kwargs["past_key_values"] = SinkCache(
window_length=generation_config.sink_window_length,
num_sink_tokens=generation_config.num_sink_tokens,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not pass the class and do the same init scheme with generation config passed and you take these from the generation config to allow custom classes?

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants