Skip to content

Conversation

gante
Copy link
Member

@gante gante commented Sep 16, 2025

What does this PR do?

🚨 BC-breaking: paligemma processor now returns token_type_ids by default. This is required to disambiguate forward passes, due to the bidirectional attention mask in the prompt. Advanced generation methods may run forward passes with prompt + generated tokens, so they will fail without token_type_ids.


This PR is originally aimed at fixing two flaky tests:

  • imageGPT + test_prompt_lookup_decoding_matches_greedy_search -> skip the test, imageGPT has dodgy layer initialization. This is better documented in the skip;
  • ⚠️ paligemma2 + test_prompt_lookup_decoding_matches_greedy_search -> upstreams attention mask creation from gemma3 to paligemma, since their masking strategy is the same. This also improves standardization, as we got rid of some legacy code 💛 Fixing this actually required a cascade of changes (changes in gemma for paligemma -> gemma-dependent models also needed updates)

✅ slow paligemma tests passing
✅ slow paligemma2 tests passing (but there are no integration tests ⚠️ )
✅ no regressions on slow gemma tests (i.e. some failures, same as in main)

@gante gante marked this pull request as ready for review September 16, 2025 17:36
@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Paligemma2 does not seem to be compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removing one skip at a time 🫡

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, thanks a lot! left some minor comments

@parameterized.expand([("random",), ("same",)])
@pytest.mark.generate
@unittest.skip("Paligemma2 does not seem to be compatible with assisted decoding")
def test_assisted_decoding_matches_greedy_search(self, assistant_type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work!

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, happy to see all Gemmas updated with the new API. Left a few questions 👇🏻

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Sep 17, 2025

Before we merge, could you also check if the correct dtype is used or mask creation with the general API? The reproducer is linked to this PR (#40912) and was fixed just yesterday

I believe that we text config's dtype, just for safety :)

@gante
Copy link
Member Author

gante commented Sep 17, 2025

run-slow: gemma, gemma3, gemma3n, paligemma, paligemma2

@gante
Copy link
Member Author

gante commented Sep 17, 2025

most PR comments addressed :)

The only open one is this one (@zucchini-nlp )

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/gemma', 'models/gemma3', 'models/gemma3n', 'models/paligemma', 'models/paligemma2']
quantizations: [] ...

@gante
Copy link
Member Author

gante commented Sep 17, 2025

The dtype fix in an upstream issue has been preserved:

from transformers import ColPaliForRetrieval, ColPaliProcessor
import torch
import numpy as np
from PIL import Image

device = "cuda"
model = ColPaliForRetrieval.from_pretrained(
    "vidore/colpali-v1.3-hf",
    dtype=torch.float16, # can also be bfloat16
).to(device)

processor = ColPaliProcessor.from_pretrained("vidore/colpali-v1.3-hf")

image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)]
image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs]

with torch.no_grad():
    image_inputs = processor(images=image_inputs)
    image_inputs = image_inputs.to(model.device, model.dtype)
    image_outputs = model(**image_inputs)
    image_embeddings_torch = image_outputs.embeddings
    print(image_embeddings_torch.dtype)
    # torch.float16

@zucchini-nlp
Copy link
Member

Thanks, replied under the comment. Agreed with your suggestion :)

@gante
Copy link
Member Author

gante commented Sep 17, 2025

run-slow: gemma, gemma3, gemma3n, helium, paligemma, paligemma2

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/gemma', 'models/gemma3', 'models/gemma3n', 'models/helium', 'models/paligemma', 'models/paligemma2']
quantizations: [] ...

@gante
Copy link
Member Author

gante commented Sep 17, 2025

run-slow: gemma, gemma3, gemma3n, helium, paligemma, paligemma2

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/gemma', 'models/gemma3', 'models/gemma3n', 'models/helium', 'models/paligemma', 'models/paligemma2']
quantizations: [] ...

@gante gante force-pushed the flaky_assisted_gen_tests branch from fe16966 to 0ee61ee Compare September 17, 2025 18:18
@gante gante added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Sep 18, 2025
("rocm", (9, 5)): "detect shoe\n<loc0051><loc0309><loc0708><loc0644> shoe",
(None, None): "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe",
("cuda", 8): "detect shoe\n<loc0045><loc0309><loc0708><loc0646> shoe",
("cuda", 8): "detect shoe\n<loc0051><loc0309><loc0708><loc0646> shoe",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test was failing on main as well, with the same output

@gante
Copy link
Member Author

gante commented Sep 22, 2025

@molbap @zucchini-nlp requesting a full re-review :) the final steps required a few chains of changes (change base model -> modular kicks in -> more models need changes).

  • I've rewritten the PR header, make sure to check it
  • I've added comments in the GH diff, in an effort to make review easier 🤗

@gante gante removed the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Sep 22, 2025
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huuge work, thanks for getting to the root if the problem. I can feel the pain of refactoring these model's attentions

Overall I agree with the changes. In unrelated models I think we can and should delete attributes in modular, to not keep unused config attributes

Comment on lines 391 to 411
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = not getattr(config, "use_bidirectional_attention", False)

self.q_proj = nn.Linear(
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.v_proj = nn.Linear(
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can indicate only self.is_causal, other attr will be copied by modular. Same for the decoder-layer

# NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
# (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
# means). Determining prefill in that case requires checking data values, which is not compile-compatible.
may_have_image_input = past_key_values is None or not past_key_values.is_initialized or pixel_values is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should pass token_type_ids in all cases when it is present, since the logic isn't perfect. The only issue of passing in all cases I can think of is that token_type_ids do not grow together with attention mask

So we can do if token_type_ids.shape == attention_mask.shape: do the optional bidirectional mask

self.mlp = HeliumMLP(config)
self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should avoid using attention types with Helium. The changes in modular are done for layers using Llama as base model so we shouldn't need them, no?

Comment on lines 167 to 169
# The logic bellow was originally written for gemma3, where `token_type_ids` is reversed. Let's reverse it to
# then use exactly the same logic.
token_type_ids = 1 - token_type_ids
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that we need to break BC to get this working, maybe we can do as follows to keep the model as in main branch

if is_training: token_type_ids = 1 - token_type_ids
else: token_type_ids = torch.ones_like(attention_mask)

scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
scaling factor when applying tanh softcapping on the attention scores.
use_bidirectional_attention (`bool`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we del use_bidirectional_attention in modular file config?

scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0):
scaling factor when applying tanh softcapping on the attention scores.
use_bidirectional_attention (`bool`, *optional*):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, I'd prefer to explicitly del from modular

Comment on lines +83 to +85
use_bidirectional_attention (`bool`, *optional*):
If True, the model will attend to all text tokens instead of using a causal mask.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, this makes so much sense. I wonder how gemma3 worked prev, afair we didn't have a flag for defining bidirectional attention at release time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually took it from gemma3 🤗 Most of the changes here are gemma3-inspired

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like it was added recently. Prev it used is_causal = True 🙈

@gante
Copy link
Member Author

gante commented Sep 23, 2025

@zucchini-nlp comments addressed and CI green 💚

Summary of the changes:

  1. Modular updates -- minimal rewrites where possible, no extra config parameters. Thank you for the suggestions, I was unaware it was possible to simply add/delete attributes 🫶
  2. Helium actually needs no changes. The issue was in the tester, which got updated 😢
  3. Added a BC path for paligemma, as you suggested. I've also added an exception when token_type_ids is missing at train time (present in the original code, but missing in gemma3), and a warning when it may be missing at inference time (to nudge users away from the bug-prone situation);
  4. Last mile of CI issues: propagate processor-related changes to paligemma-related models

Comment on lines 808 to 812
elif may_have_image_input:
logger.warning_once(
"There may be an image in the input to Gemma3 but `token_type_ids` is not provided. We recommend "
"passing `token_type_ids` to the model to prevent bad attention masking."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this message will false log when the model is used with text-only input. Can we keep it without warnings?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to delete it then.

(I'm not super happy about it -- in general, if a model needs some input to correctly do some operation, and we can't safely detect whether we need that operation, then the input should be required. Otherwise, it's prone to silent bugs, which are the worst kind of bugs 😢 )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, we need to keep the token types growing in correct way so that it is used always, without us checking for prefill etc.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, approving so it can be merged when CI is green :)

@gante gante enabled auto-merge (squash) September 23, 2025 16:11
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: bark, chameleon, colpali, colqwen2, gemma, gemma2, gemma3, gemma3n, helium, idefics2, paligemma, t5gemma, vaultgemma

@gante gante merged commit 869735d into huggingface:main Sep 23, 2025
25 checks passed
@gante gante deleted the flaky_assisted_gen_tests branch September 23, 2025 16:55
ErfanBaghaei pushed a commit to ErfanBaghaei/transformers that referenced this pull request Sep 25, 2025
…tion-related fixes) (huggingface#40917)

* tmp

* fix modular inheritance

* nit

* paligemma 1 doesn't have swa

* use same pattern as in models with hybrid layers

* PR comments

* helium also needs layer_typed (bc it relies on gemma)

* paligemma/gemma3: same mask creation fn in fwd and generate

* propagate changes to helium (gemma-based)

* tmp commit

* slow paligemma tests passing, let's see what breaks

* fix test_left_padding_compatibility

* tmp commit

* tmp commit

* rebase error

* docs

* reduce diff

* like this?

* t5gemma

* better comment

* shorter diff

* exception

* ffs type

* optional

* shorter modular_gemma.py

* helium model actually needs no changes -- the tester is the issue

* t5gemma modular config

* a few more modular; paligemma BC

* fix processor issues?

* rm config exception

* lift warning in gemma
vijayabhaskar-ev pushed a commit to vijayabhaskar-ev/transformers that referenced this pull request Oct 2, 2025
…tion-related fixes) (huggingface#40917)

* tmp

* fix modular inheritance

* nit

* paligemma 1 doesn't have swa

* use same pattern as in models with hybrid layers

* PR comments

* helium also needs layer_typed (bc it relies on gemma)

* paligemma/gemma3: same mask creation fn in fwd and generate

* propagate changes to helium (gemma-based)

* tmp commit

* slow paligemma tests passing, let's see what breaks

* fix test_left_padding_compatibility

* tmp commit

* tmp commit

* rebase error

* docs

* reduce diff

* like this?

* t5gemma

* better comment

* shorter diff

* exception

* ffs type

* optional

* shorter modular_gemma.py

* helium model actually needs no changes -- the tester is the issue

* t5gemma modular config

* a few more modular; paligemma BC

* fix processor issues?

* rm config exception

* lift warning in gemma
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! A pity no core maintainers were pinged as adding a flag that is not part of the original model release is rarely something we are gonna agree on 😓

self.final_logit_softcapping = final_logit_softcapping
self.attn_logit_softcapping = attn_logit_softcapping
self.layer_types = layer_types
self.use_bidirectional_attention = use_bidirectional_attention
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should not have been added, it goes agains our philosophy 😢

# BC: `use_bidirectional_attention` was originally unset in PaliGemma1 (backbone = Gemma1) AND PaliGemma2
# (backbone = Gemma2). Both PaliGemmas want to default to True.
if self.text_config.use_bidirectional_attention is None:
self.text_config.use_bidirectional_attention = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's what we want TBH

yuchenxie4645 pushed a commit to yuchenxie4645/transformers that referenced this pull request Oct 4, 2025
…tion-related fixes) (huggingface#40917)

* tmp

* fix modular inheritance

* nit

* paligemma 1 doesn't have swa

* use same pattern as in models with hybrid layers

* PR comments

* helium also needs layer_typed (bc it relies on gemma)

* paligemma/gemma3: same mask creation fn in fwd and generate

* propagate changes to helium (gemma-based)

* tmp commit

* slow paligemma tests passing, let's see what breaks

* fix test_left_padding_compatibility

* tmp commit

* tmp commit

* rebase error

* docs

* reduce diff

* like this?

* t5gemma

* better comment

* shorter diff

* exception

* ffs type

* optional

* shorter modular_gemma.py

* helium model actually needs no changes -- the tester is the issue

* t5gemma modular config

* a few more modular; paligemma BC

* fix processor issues?

* rm config exception

* lift warning in gemma
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants