-
Notifications
You must be signed in to change notification settings - Fork 30.6k
🚨 [generate] update paligemma mask updates (and other assisted generation-related fixes) #40917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@parameterized.expand([("random",), ("same",)]) | ||
@pytest.mark.generate | ||
@unittest.skip("Paligemma2 does not seem to be compatible with assisted decoding") | ||
def test_assisted_decoding_matches_greedy_search(self, assistant_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing one skip at a time 🫡
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me, thanks a lot! left some minor comments
@parameterized.expand([("random",), ("same",)]) | ||
@pytest.mark.generate | ||
@unittest.skip("Paligemma2 does not seem to be compatible with assisted decoding") | ||
def test_assisted_decoding_matches_greedy_search(self, assistant_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, happy to see all Gemmas updated with the new API. Left a few questions 👇🏻
Before we merge, could you also check if the correct dtype is used or mask creation with the general API? The reproducer is linked to this PR (#40912) and was fixed just yesterday I believe that we text config's dtype, just for safety :) |
run-slow: gemma, gemma3, gemma3n, paligemma, paligemma2 |
most PR comments addressed :) The only open one is this one (@zucchini-nlp ) |
This comment contains run-slow, running the specified jobs: models: ['models/gemma', 'models/gemma3', 'models/gemma3n', 'models/paligemma', 'models/paligemma2'] |
The dtype fix in an upstream issue has been preserved: from transformers import ColPaliForRetrieval, ColPaliProcessor
import torch
import numpy as np
from PIL import Image
device = "cuda"
model = ColPaliForRetrieval.from_pretrained(
"vidore/colpali-v1.3-hf",
dtype=torch.float16, # can also be bfloat16
).to(device)
processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.3-hf")
image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]
with torch.no_grad():
image_inputs = processor(images=image_inputs)
image_inputs = image_inputs.to(model.device, model.dtype)
image_outputs = model(**image_inputs)
image_embeddings_torch = image_outputs.embeddings
print(image_embeddings_torch.dtype)
# torch.float16 |
Thanks, replied under the comment. Agreed with your suggestion :) |
run-slow: gemma, gemma3, gemma3n, helium, paligemma, paligemma2 |
This comment contains run-slow, running the specified jobs: models: ['models/gemma', 'models/gemma3', 'models/gemma3n', 'models/helium', 'models/paligemma', 'models/paligemma2'] |
run-slow: gemma, gemma3, gemma3n, helium, paligemma, paligemma2 |
This comment contains run-slow, running the specified jobs: models: ['models/gemma', 'models/gemma3', 'models/gemma3n', 'models/helium', 'models/paligemma', 'models/paligemma2'] |
fe16966
to
0ee61ee
Compare
("rocm", (9, 5)): "detect shoe\n<loc0051><loc0309><loc0708><loc0644> shoe", | ||
(None, None): "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe", | ||
("cuda", 8): "detect shoe\n<loc0045><loc0309><loc0708><loc0646> shoe", | ||
("cuda", 8): "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test was failing on main
as well, with the same output
@molbap @zucchini-nlp requesting a full re-review :) the final steps required a few chains of changes (change base model -> modular kicks in -> more models need changes).
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huuge work, thanks for getting to the root if the problem. I can feel the pain of refactoring these model's attentions
Overall I agree with the changes. In unrelated models I think we can and should delete attributes in modular, to not keep unused config attributes
super().__init__() | ||
self.config = config | ||
self.layer_idx = layer_idx | ||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) | ||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads | ||
self.scaling = self.head_dim**-0.5 | ||
self.attention_dropout = config.attention_dropout | ||
self.is_causal = not getattr(config, "use_bidirectional_attention", False) | ||
|
||
self.q_proj = nn.Linear( | ||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias | ||
) | ||
self.k_proj = nn.Linear( | ||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias | ||
) | ||
self.v_proj = nn.Linear( | ||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias | ||
) | ||
self.o_proj = nn.Linear( | ||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can indicate only self.is_causal
, other attr will be copied by modular. Same for the decoder-layer
# NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized | ||
# (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other | ||
# means). Determining prefill in that case requires checking data values, which is not compile-compatible. | ||
may_have_image_input = past_key_values is None or not past_key_values.is_initialized or pixel_values is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should pass token_type_ids
in all cases when it is present, since the logic isn't perfect. The only issue of passing in all cases I can think of is that token_type_ids
do not grow together with attention mask
So we can do if token_type_ids.shape == attention_mask.shape: do the optional bidirectional mask
self.mlp = HeliumMLP(config) | ||
self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
self.attention_type = config.layer_types[layer_idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we should avoid using attention types with Helium. The changes in modular are done for layers using Llama as base model so we shouldn't need them, no?
# The logic bellow was originally written for gemma3, where `token_type_ids` is reversed. Let's reverse it to | ||
# then use exactly the same logic. | ||
token_type_ids = 1 - token_type_ids |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw that we need to break BC to get this working, maybe we can do as follows to keep the model as in main branch
if is_training: token_type_ids = 1 - token_type_ids
else: token_type_ids = torch.ones_like(attention_mask)
scaling factor when applying tanh softcapping on the logits. | ||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): | ||
scaling factor when applying tanh softcapping on the attention scores. | ||
use_bidirectional_attention (`bool`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if we del use_bidirectional_attention
in modular file config?
scaling factor when applying tanh softcapping on the logits. | ||
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): | ||
scaling factor when applying tanh softcapping on the attention scores. | ||
use_bidirectional_attention (`bool`, *optional*): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, I'd prefer to explicitly del
from modular
use_bidirectional_attention (`bool`, *optional*): | ||
If True, the model will attend to all text tokens instead of using a causal mask. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, this makes so much sense. I wonder how gemma3 worked prev, afair we didn't have a flag for defining bidirectional attention at release time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually took it from gemma3 🤗 Most of the changes here are gemma3-inspired
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like it was added recently. Prev it used is_causal = True
🙈
@zucchini-nlp comments addressed and CI green 💚 Summary of the changes:
|
elif may_have_image_input: | ||
logger.warning_once( | ||
"There may be an image in the input to Gemma3 but `token_type_ids` is not provided. We recommend " | ||
"passing `token_type_ids` to the model to prevent bad attention masking." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this message will false log when the model is used with text-only input. Can we keep it without warnings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm going to delete it then.
(I'm not super happy about it -- in general, if a model needs some input to correctly do some operation, and we can't safely detect whether we need that operation, then the input should be required. Otherwise, it's prone to silent bugs, which are the worst kind of bugs 😢 )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, we need to keep the token types growing in correct way so that it is used always, without us checking for prefill etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, approving so it can be merged when CI is green :)
[For maintainers] Suggested jobs to run (before merge) run-slow: bark, chameleon, colpali, colqwen2, gemma, gemma2, gemma3, gemma3n, helium, idefics2, paligemma, t5gemma, vaultgemma |
…tion-related fixes) (huggingface#40917) * tmp * fix modular inheritance * nit * paligemma 1 doesn't have swa * use same pattern as in models with hybrid layers * PR comments * helium also needs layer_typed (bc it relies on gemma) * paligemma/gemma3: same mask creation fn in fwd and generate * propagate changes to helium (gemma-based) * tmp commit * slow paligemma tests passing, let's see what breaks * fix test_left_padding_compatibility * tmp commit * tmp commit * rebase error * docs * reduce diff * like this? * t5gemma * better comment * shorter diff * exception * ffs type * optional * shorter modular_gemma.py * helium model actually needs no changes -- the tester is the issue * t5gemma modular config * a few more modular; paligemma BC * fix processor issues? * rm config exception * lift warning in gemma
token_type_ids
is required as a model input when training
huggingface/trl#4142
…tion-related fixes) (huggingface#40917) * tmp * fix modular inheritance * nit * paligemma 1 doesn't have swa * use same pattern as in models with hybrid layers * PR comments * helium also needs layer_typed (bc it relies on gemma) * paligemma/gemma3: same mask creation fn in fwd and generate * propagate changes to helium (gemma-based) * tmp commit * slow paligemma tests passing, let's see what breaks * fix test_left_padding_compatibility * tmp commit * tmp commit * rebase error * docs * reduce diff * like this? * t5gemma * better comment * shorter diff * exception * ffs type * optional * shorter modular_gemma.py * helium model actually needs no changes -- the tester is the issue * t5gemma modular config * a few more modular; paligemma BC * fix processor issues? * rm config exception * lift warning in gemma
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! A pity no core maintainers were pinged as adding a flag that is not part of the original model release is rarely something we are gonna agree on 😓
self.final_logit_softcapping = final_logit_softcapping | ||
self.attn_logit_softcapping = attn_logit_softcapping | ||
self.layer_types = layer_types | ||
self.use_bidirectional_attention = use_bidirectional_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should not have been added, it goes agains our philosophy 😢
# BC: `use_bidirectional_attention` was originally unset in PaliGemma1 (backbone = Gemma1) AND PaliGemma2 | ||
# (backbone = Gemma2). Both PaliGemmas want to default to True. | ||
if self.text_config.use_bidirectional_attention is None: | ||
self.text_config.use_bidirectional_attention = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's what we want TBH
…tion-related fixes) (huggingface#40917) * tmp * fix modular inheritance * nit * paligemma 1 doesn't have swa * use same pattern as in models with hybrid layers * PR comments * helium also needs layer_typed (bc it relies on gemma) * paligemma/gemma3: same mask creation fn in fwd and generate * propagate changes to helium (gemma-based) * tmp commit * slow paligemma tests passing, let's see what breaks * fix test_left_padding_compatibility * tmp commit * tmp commit * rebase error * docs * reduce diff * like this? * t5gemma * better comment * shorter diff * exception * ffs type * optional * shorter modular_gemma.py * helium model actually needs no changes -- the tester is the issue * t5gemma modular config * a few more modular; paligemma BC * fix processor issues? * rm config exception * lift warning in gemma
What does this PR do?
🚨 BC-breaking:
paligemma
processor now returnstoken_type_ids
by default. This is required to disambiguate forward passes, due to the bidirectional attention mask in the prompt. Advanced generation methods may run forward passes with prompt + generated tokens, so they will fail withouttoken_type_ids
.This PR is originally aimed at fixing two flaky tests:
imageGPT
+ test_prompt_lookup_decoding_matches_greedy_search -> skip the test, imageGPT has dodgy layer initialization. This is better documented in the skip;paligemma2
+ test_prompt_lookup_decoding_matches_greedy_search -> upstreams attention mask creation from gemma3 to paligemma, since their masking strategy is the same. This also improves standardization, as we got rid of some legacy code 💛 Fixing this actually required a cascade of changes (changes ingemma
forpaligemma
->gemma
-dependent models also needed updates)✅ slow paligemma tests passing⚠️ )
✅ slow paligemma2 tests passing (but there are no integration tests
✅ no regressions on slow gemma tests (i.e. some failures, same as in
main
)