Skip to content

[feat] Allow loading T5Gemma2Encoder with AutoModel#43559

Merged
tomaarsen merged 5 commits intohuggingface:mainfrom
tomaarsen:feat/t5gemma2_encoder
Feb 3, 2026
Merged

[feat] Allow loading T5Gemma2Encoder with AutoModel#43559
tomaarsen merged 5 commits intohuggingface:mainfrom
tomaarsen:feat/t5gemma2_encoder

Conversation

@tomaarsen
Copy link
Member

@tomaarsen tomaarsen commented Jan 28, 2026

What does this PR do?

  • Allow the encoder of T5Gemma2 to be loaded standalone

Details

This is valuable for Sentence Transformers, which may want to load the encoder only (see huggingface/sentence-transformers#3604). Here, we grab and train the encoder only, resulting in e.g.: https://huggingface.co/tomaarsen/t5gemma2-270m-gooaq-cmnrl

Usage:

from transformers import T5Gemma2Encoder, AutoTokenizer
import torch

model_name = "tomaarsen/t5gemma2-270m-gooaq-cmnrl"
model = T5Gemma2Encoder.from_pretrained(model_name)
processor = AutoTokenizer.from_pretrained(model_name)

queries = [
    "Which planet is known as the Red Planet?",
]
documents = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.",
]

query_inputs = processor(text=queries, truncation=True, padding=True, return_tensors="pt")
document_inputs = processor(text=documents, truncation=True, padding=True, return_tensors="pt")

with torch.no_grad():
    query_embeddings = model(**query_inputs).last_hidden_state[:, 0, :]
    document_embeddings = model(**document_inputs).last_hidden_state[:, 0, :]

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

query_embeddings = mean_pooling(model(**query_inputs), query_inputs['attention_mask'])
document_embeddings = mean_pooling(model(**document_inputs), document_inputs['attention_mask'])

similarities = torch.nn.functional.cosine_similarity(
    query_embeddings.unsqueeze(1), document_embeddings.unsqueeze(0), dim=-1
)
print(similarities.tolist())
# [[0.37183186411857605, 0.8092442750930786, 0.6081508994102478, 0.7218592762947083]]
# As expected: The second document is most similar to the query.

I've not added the decoder as I only have weights for the encoder.

P.s., equivalent in Sentence Transformers:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("tomaarsen/t5gemma2-270m-gooaq-cmnrl")
queries = [
    "Which planet is known as the Red Planet?",
]
documents = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.",
]
query_embedding = model.encode_query(queries)
document_embeddings = model.encode_document(documents)

similarities = model.similarity(query_embedding, document_embeddings)
print(similarities)
# tensor([[0.3722, 0.8101, 0.6088, 0.7216]])

This also relies on this PR.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

cc @Cyrilvallez @zucchini-nlp

P.s. let me know if you'd like to see new tests or docs for this.

  • Tom Aarsen

This is valuable for Sentence Transformers, which may want to load the encoder only. I've added the decoder only to mirror the changes I need for the encoder.
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC we want to be able to load a complete T5Gemma or only its encoder module in ST therefore we can't do the same as in T5 with self._load_t5_module?

In any case, I think there is no strong objection to keep the module private, and we can make it available through the Auto-API. Let's also see if core maintainers agree

@tomaarsen tomaarsen marked this pull request as draft January 28, 2026 15:56
@tomaarsen
Copy link
Member Author

IIUC we want to be able to load a complete T5Gemma or only its encoder module in ST therefore we can't do the same as in T5 with self._load_t5_module?

Hmm, looks like I might be able to change some things for T5Gemma2 on ST's side. With T5 I can import T5EncoderModel as it's exported in the __all__ :

from transformers import T5EncoderModel

model = T5EncoderModel.from_pretrained("sentence-transformers/gtr-t5-base")
print(type(model))
# <class 'transformers.models.t5.modeling_t5.T5EncoderModel'>

T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
model = T5EncoderModel.from_pretrained("google-t5/t5-base")
print(type(model))
# <class 'transformers.models.t5.modeling_t5.T5EncoderModel'>

If I can import T5Gemma2Encoder from transformers, then perhaps I can load it directly like that without having to update AutoModel. I'll run some tests. It seems that T5Gemma also wasn't nicely supported in ST.

Will send more details when I know what'll work best.

  • Tom Aarsen

@tomaarsen
Copy link
Member Author

tomaarsen commented Jan 28, 2026

Update: I can get it working like so for the various T5 variants:

from transformers import T5EncoderModel, T5Gemma2Encoder, T5GemmaEncoderModel, AutoConfig

# T5:
T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
# Encoder only:
model = T5EncoderModel.from_pretrained("sentence-transformers/gtr-t5-base")
print(type(model))
# <class 'transformers.models.t5.modeling_t5.T5EncoderModel'>

# Encoder-decoder:
model = T5EncoderModel.from_pretrained("google-t5/t5-base")
print(type(model))
# <class 'transformers.models.t5.modeling_t5.T5EncoderModel'>

# T5Gemma:
config = AutoConfig.from_pretrained("google/t5gemma-s-s-prefixlm")
config.is_encoder_decoder = False
T5GemmaEncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
# Encoder only (still training)
# model = T5GemmaEncoderModel.from_pretrained("tomaarsen/t5gemma-s-gooaq-cmnrl")
model = T5GemmaEncoderModel.from_pretrained(r"C:\code\sentence-transformers\models\t5gemma-s-gooaq-cmnrl\checkpoint-27")
print(type(model))
# <class 'transformers.models.t5gemma.modeling_t5gemma.T5GemmaEncoderModel'>

# Encoder-decoder
model = T5GemmaEncoderModel.from_pretrained("google/t5gemma-s-s-prefixlm", config=config)
print(type(model))
# <class 'transformers.models.t5gemma.modeling_t5gemma.T5GemmaEncoderModel'>



T5Gemma2Encoder._keys_to_ignore_on_load_unexpected = ["decoder.*"]
T5Gemma2Encoder.base_model_prefix = "model.encoder"
model = T5Gemma2Encoder.from_pretrained("tomaarsen/t5gemma2-270m-gooaq-cmnrl")
print(type(model))
# <class 'transformers.models.t5gemma2.modeling_t5gemma2.T5Gemma2Encoder'>

T5Gemma2Encoder._keys_to_ignore_on_load_unexpected = ["decoder.*"]
T5Gemma2Encoder.base_model_prefix = "model.encoder"
model = T5Gemma2Encoder.from_pretrained("google/t5gemma-2-270m-270m")
print(type(model))
# <class 'transformers.models.t5gemma2.modeling_t5gemma2.T5Gemma2Encoder'>

This even works on main if I import T5Gemma2Encoder like

from transformers.models.t5gemma2.modeling_t5gemma2 import T5Gemma2Encoder

In short, I reverted the t5gemma2_encoder changes on this PR.


However, one issue does remain: the T5Gemma2Config its __setattr__ is responsible for tying some attributes between the text and vision configs in T5Gemma2EncoderConfig, but it would be much preferable if T5Gemma2EncoderConfig is responsible for this itself. Without this fix, e.g. the config.text_config._attn_implementation is None because the config (T5Gemma2EncoderConfig) is updated, but it's not correctly propagated.

from transformers.models.t5gemma2.modeling_t5gemma2 import T5Gemma2Encoder

encoder = T5Gemma2Encoder.from_pretrained("tomaarsen/t5gemma2-270m-gooaq-cmnrl")
print(f"{encoder.config._attn_implementation=}")
print(f"{encoder.config.text_config._attn_implementation=}")
print(f"{encoder.config.vision_config._attn_implementation=}")

Main:

Loading weights: 100%|████████████████████████████████| 676/676 [00:00<00:00, 4255.35it/s, Materializing param=vision_tower.vision_model.post_layernorm.weight]
encoder.config._attn_implementation='sdpa'
encoder.config.text_config._attn_implementation=None
encoder.config.vision_config._attn_implementation='sdpa'

This PR:

Loading weights: 100%|████████████████████████████████| 676/676 [00:00<00:00, 4324.80it/s, Materializing param=vision_tower.vision_model.post_layernorm.weight]
encoder.config._attn_implementation='sdpa'
encoder.config.text_config._attn_implementation='sdpa'
encoder.config.vision_config._attn_implementation='sdpa

(I do feel like there should be a more fundamental solution to this, multi-configs are pretty common and it seems important to propagate them correctly).

I think this is ready for review again - I'm getting awkward issues with modular_model_converter.py:

  File "/mnt/c/code/transformers/./utils/modular_model_converter.py", line 365, in leave_FunctionDef
    original_modeling_method_body = self.original_modeling_methods[func_name].body.body
                                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^
KeyError: '__setattr__'

Will have to look into it later.

  • Tom Aarsen

@tomaarsen tomaarsen marked this pull request as ready for review January 28, 2026 16:36
@zucchini-nlp
Copy link
Member

Attn implementation usually gets propagated in PreTrainedConfig but we do it only for 1-level of nested configs. It doesn't run recursively until all subconfigs are updated, so imo we need to fix that part of code
Also I am surprised to see __setattr__ overriden, in other nested models we don't do such a thing because each config is responsible for its own "model" and thus has its own set of fields. Lemme check why it appeared in the first place

@_attn_implementation.setter
def _attn_implementation(self, value: str | dict | None):
"""We set it recursively on the sub-configs as well"""
# Set if for current config
current_attn = getattr(self, "_attn_implementation", None)
attn_implementation = value if not isinstance(value, dict) else value.get("", current_attn)
self._attn_implementation_internal = attn_implementation
# Set it recursively on the subconfigs
for subconfig_key in self.sub_configs:
subconfig = getattr(self, subconfig_key, None)
if subconfig is not None:
current_subconfig_attn = getattr(subconfig, "_attn_implementation", None)
sub_implementation = (
value if not isinstance(value, dict) else value.get(subconfig_key, current_subconfig_attn)
)
subconfig._attn_implementation = sub_implementation

@tomaarsen
Copy link
Member Author

Any luck @zucchini-nlp?
I'll start preparing Sentence Transformers for this config-propagation issue to be resolved, so that I can import transformers.models.t5gemma.modeling_t5gemma2.T5Gemma2Encoder and use it in ST.

  • Tom Aarsen

@tomaarsen
Copy link
Member Author

tomaarsen commented Feb 2, 2026

#43633 has superseded part of this PR. I'll instead focus on allowing t5gemma2_encoder to work with AutoConfig and AutoModel.

  • Tom Aarsen

@tomaarsen tomaarsen force-pushed the feat/t5gemma2_encoder branch from 45bc000 to deecb86 Compare February 2, 2026 13:36
@tomaarsen
Copy link
Member Author

I've removed some commits that mirrorred #43633. Now, all this PR does is allow for:

from transformers import T5Gemma2Encoder, AutoModel, AutoConfig

model_name = "tomaarsen/t5gemma2-270m-gooaq-cmnrl"
config = AutoConfig.from_pretrained(model_name)
print(type(config))
# <class 'transformers.models.t5gemma2.configuration_t5gemma2.T5Gemma2EncoderConfig'>
model = AutoModel.from_pretrained(model_name)
print(type(model))
# <class 'transformers.models.t5gemma2.modeling_t5gemma2.T5Gemma2Encoder'>
model = T5Gemma2Encoder.from_pretrained(model_name)
print(type(model))
# <class 'transformers.models.t5gemma2.modeling_t5gemma2.T5Gemma2Encoder'>

Especially the first one is required in Sentence Transformers.

  • Tom Aarsen

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think t5gemma2 is a special case so it's fine. However, I'm a bit concerned whether this will become a recurring pattern and we should update in a general pattern for all encoder-decoder models instead, e.g. Bart, T5

If this is indeed a unique one, I'm fine with making an exception - just wanna hear your opinion on this and whether we should focus on generalizing instead

Comment on lines 1628 to 1629
"T5Gemma2Decoder",
"T5Gemma2Encoder",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We allow both encoder and decoder, is this intentional? Looking at the auto mappings it is only focussing on the encoder

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can also exclusively allow importing the Encoder, that's also fine, but I imagined it might be useful to allow importing the decoder perhaps? That's not really my field, for the most part, so I can't say for sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldnt we need to update the auto mappings for the decoder as well?

Copy link
Member Author

@tomaarsen tomaarsen Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to support loading a decoder with AutoModel/AutoConfig as well, but I'm not sure if that happens. I do know that it happens for encoders, so I added deecb86 to not update it for decoders. I'm fine to exclude or include them either way.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it encoder only then. We should not add more than we need to

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, done in e3e1f0f

@tomaarsen
Copy link
Member Author

tomaarsen commented Feb 2, 2026

I think t5gemma2 is a special case so it's fine. However, I'm a bit concerned whether this will become a recurring pattern and we should update in a general pattern for all encoder-decoder models instead, e.g. Bart, T5

If this is indeed a unique one, I'm fine with making an exception - just wanna hear your opinion on this and whether we should focus on generalizing instead

Looking into this now. My understanding was that T5Gemma2 was the only architecture at this time that has different configs that have different model_type's, but apparently that's not the case: T5Gemma also has a subconfig with model_type = "t5_gemma_module". I think T5Gemma would then have the same issue as the issue that this PR is fixing for T5Gemma2.

That does make things a bit more awkward, perhaps there's more encoder-decoder architectures whose encoders can't be separately loaded with Auto... due to this. I think I'll need to do more research re. T5Gemma and its model_type = "t5_gemma_module".

  • Tom Aarsen

@tomaarsen
Copy link
Member Author

tomaarsen commented Feb 3, 2026

Okay, I've figured out why T5Gemma works fine without the changes in this PR:
With T5Gemma I'm loading a T5GemmaEncoderModel, a class that accepts a T5GemmaConfig, and then initializes a T5GemmaEncoder(config.encoder). For Sentence Transformers, I can AutoConfig.from_pretrained the config, recognize that it's T5GemmaConfig, and choose to load with T5GemmaEncoderModel instead of AutoModel. This gives me the encoder nicely.

With T5Gemma2 I'm loading the T5Gemma2Encoder directly, a class that accepts the T5Gemma2EncoderConfig. There is no T5Gemma2EncoderModel class that nicely wraps the T5Gemma2Encoder and accepts a T5Gemma2Config. For Sentence Transformers, I cannot use AutoConfig.from_pretrained to get the config, because it isn't registered.

In short, I think the T5Gemma2 is a bit of an edge case, and it should be fine to register the encoder (and decoder?) as well.

  • Tom Aarsen

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, should we update the title too? It's only about the config now

Also, let's wait for #43633 first? Kinda dependent on that one 👀 but feel free to merge if not

Comment on lines 1628 to 1629
"T5Gemma2Decoder",
"T5Gemma2Encoder",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it encoder only then. We should not add more than we need to

@github-actions
Copy link
Contributor

github-actions bot commented Feb 3, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto, t5gemma2

@tomaarsen
Copy link
Member Author

LGTM, should we update the title too? It's only about the config now

PR title is still correct as it stands, this PR currently allows for loading the AutoConfig, AutoProcessor, AutoModel for t5gemma2_encoder:

>>> from transformers import AutoModel
>>> model = AutoModel.from_pretrained("tomaarsen/t5gemma2-270m-gooaq-cmnrl")
Loading weights: 100%|████████████████████████████████| 676/676 [00:00<00:00, 2783.63it/s, Materializing param=vision_tower.vision_model.post_layernorm.weight]
>>> type(model)
<class 'transformers.models.t5gemma2.modeling_t5gemma2.T5Gemma2Encoder'>

Also, let's wait for #43633 first? Kinda dependent on that one 👀 but feel free to merge if not

They're related, but I think it doesn't matter which is merged first.
Seems fine to wait, though. P.s. I can't merge PRs with failing tests (anymore?)

  • Tom Aarsen

@vasqu
Copy link
Contributor

vasqu commented Feb 3, 2026

No worries, rerunning CI (it's a flaky test). And you should not be able to merge without green CI 😬

@tomaarsen
Copy link
Member Author

Makes sense! Especially with the weekly releases. I'll wait for #43633 as it grew quite a bit larger and I'll have to test to see whether it still works with my T5Gemma2 integration from huggingface/sentence-transformers#3644

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 8099619 into huggingface:main Feb 3, 2026
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants