Skip to content
93 changes: 72 additions & 21 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,26 +2205,78 @@ def disable_input_require_grads(self):
"""
self._require_grads_hook.remove()

def get_encoder(self, modality: Optional[str] = None):
"""
Best-effort lookup of the *encoder* module. If provided with `modality` argument,
it looks for a modality-specific encoder in multimodal models (e.g. "image_encoder")
By default the function returns model's text encoder if any, and otherwise returns `self`.

Possible `modality` values are "image", "video" and "audio".
"""
# NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely
if modality in ["image", "video"]:
possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"]
elif modality == "audio":
possible_module_names = ["audio_tower", "audio_encoder", "speech_encoder"]
elif modality is None:
possible_module_names = ["text_encoder", "encoder"]
else:
raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}')

for name in possible_module_names:
if hasattr(self, name):
return getattr(self, name)

if self.base_model is not self and hasattr(self.base_model, "get_encoder"):
return self.base_model.get_encoder(modality=modality)

# If this is a base transformer model (no encoder/model attributes), return self
return self

def set_encoder(self, encoder, modality: Optional[str] = None):
"""
Symmetric setter. Mirrors the lookup logic used in `get_encoder`.
"""

# NOTE: new models need to use existing names for layers if possible, so this list doesn't grow infinitely
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To note, this should be enforced in make fixup in code consistency part to save ourselves the hassle

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, isn't it going to be a huge limitation for contributors if we force it and auto-renam with fix-copies? Imo we need to communicate it when reviewing and explain why it's important. It's only a few ppl reviewing VLMs currently, so it might be easier

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the make fixup updated message (or rather code-quality check on the CI, same) would be informative enough, saying "decoder layer names should be part of this list: ..." rather than auto-renaming. Could be a ruff warning if we think it's too restrictive as an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, lemme see where I can fit this in a non-disruptive way. Not sure if users actually read the warnings, we should be more strict in review process in any case imo 😆

if modality in ["image", "video"]:
possible_module_names = ["vision_tower", "visual", "vision_model", "vision_encoder", "image_tower"]
if modality == "audio":
possible_module_names = ["audio_tower", "audio_encoder"]
elif modality is None:
possible_module_names = ["text_encoder", "encoder"]
else:
raise ValueError(f'Unnrecognized modality, has to be "image", "video" or "audio" but found {modality}')

for name in possible_module_names:
if hasattr(self, name):
setattr(self, name, encoder)
return

if self.base_model is not self:
if hasattr(self.base_model, "set_encoder"):
self.base_model.set_encoder(encoder, modality=modality)
else:
self.model = encoder

def get_decoder(self):
"""
Best-effort lookup of the *decoder* module.

Order of attempts (covers ~85 % of current usages):

1. `self.decoder`
2. `self.model` (many wrappers store the decoder here)
3. `self.model.get_decoder()` (nested wrappers)
1. `self.decoder/self.language_model/self.text_model`
2. `self.base_model` (many wrappers store the decoder here)
3. `self.base_model.get_decoder()` (nested wrappers)
4. fallback: raise for the few exotic models that need a bespoke rule
"""
if hasattr(self, "decoder"):
return self.decoder
possible_module_names = ["language_model", "text_model", "decoder", "text_decoder"]
for name in possible_module_names:
if hasattr(self, name):
return getattr(self, name)

if hasattr(self, "model"):
inner = self.model
# See: https://github.com/huggingface/transformers/issues/40815
if hasattr(inner, "get_decoder") and type(inner) is not type(self):
return inner.get_decoder()
return inner
if self.base_model is not self and hasattr(self.base_model, "get_decoder"):
return self.base_model.get_decoder()

# If this is a base transformer model (no decoder/model attributes), return self
# This handles cases like MistralModel which is itself the decoder
Expand All @@ -2235,19 +2287,18 @@ def set_decoder(self, decoder):
Symmetric setter. Mirrors the lookup logic used in `get_decoder`.
"""

if hasattr(self, "decoder"):
self.decoder = decoder
return
possible_module_names = ["language_model", "text_model", "decoder"]
for name in possible_module_names:
if hasattr(self, name):
print(name)
setattr(self, name, decoder)
return

if hasattr(self, "model"):
inner = self.model
if hasattr(inner, "set_decoder"):
inner.set_decoder(decoder)
if self.base_model is not self:
if hasattr(self.base_model, "set_decoder"):
self.base_model.set_decoder(decoder)
else:
self.model = decoder
return

return

@torch.no_grad()
def _init_weights(self, module):
Expand Down
25 changes: 0 additions & 25 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,12 +910,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def set_decoder(self, decoder):
self.language_model = decoder

def get_decoder(self):
return self.language_model

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand Down Expand Up @@ -1075,12 +1069,6 @@ def set_input_embeddings(self, value):
def get_output_embeddings(self) -> nn.Module:
return self.lm_head

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

def get_decoder(self):
return self.model.get_decoder()

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand All @@ -1093,19 +1081,6 @@ def get_image_features(
vision_feature_layer=vision_feature_layer,
)

# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model

@property
def vision_tower(self):
return self.model.vision_tower

@property
def multi_modal_projector(self):
return self.model.multi_modal_projector

@can_return_tuple
@auto_docstring
def forward(
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/autoformer/modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,9 +1342,6 @@ def create_network_inputs(
)
return reshaped_lagged_sequence, features, loc, scale, static_feat

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1588,12 +1585,6 @@ def __init__(self, config: AutoformerConfig):
def output_params(self, decoder_output):
return self.parameter_projection(decoder_output[:, -self.config.prediction_length :, :])

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

@torch.jit.ignore
def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution:
sliced_params = params
Expand Down
25 changes: 0 additions & 25 deletions src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def set_decoder(self, decoder):
self.language_model = decoder

def get_decoder(self):
return self.language_model

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand Down Expand Up @@ -357,12 +351,6 @@ def set_input_embeddings(self, value):
def get_output_embeddings(self) -> nn.Module:
return self.lm_head

def set_decoder(self, decoder):
self.model.set_decoder(decoder)

def get_decoder(self):
return self.model.get_decoder()

def get_image_features(
self,
pixel_values: torch.FloatTensor,
Expand All @@ -377,19 +365,6 @@ def get_image_features(
**kwargs,
)

# Make modules available through conditional class for BC
@property
def language_model(self):
return self.model.language_model

@property
def vision_tower(self):
return self.model.vision_tower

@property
def multi_modal_projector(self):
return self.model.multi_modal_projector

@can_return_tuple
@auto_docstring
def forward(
Expand Down
15 changes: 0 additions & 15 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,9 +905,6 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1037,12 +1034,6 @@ def __init__(self, config: BartConfig):
# Initialize weights and apply final processing
self.post_init()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
Expand Down Expand Up @@ -1498,12 +1489,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def set_decoder(self, decoder):
self.model.decoder = decoder

def get_decoder(self):
return self.model.decoder

@auto_docstring
def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2083,9 +2083,6 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -2205,12 +2202,6 @@ def __init__(self, config: BigBirdPegasusConfig):
# Initialize weights and apply final processing
self.post_init()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
Expand Down Expand Up @@ -2609,12 +2600,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def set_decoder(self, decoder):
self.model.decoder = decoder

def get_decoder(self):
return self.model.decoder

@auto_docstring
def forward(
self,
Expand Down
15 changes: 0 additions & 15 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,9 +869,6 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -1009,12 +1006,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
Expand Down Expand Up @@ -1189,12 +1180,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def set_decoder(self, decoder):
self.model.decoder = decoder

def get_decoder(self):
return self.model.decoder

@auto_docstring
def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -842,9 +842,6 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def get_encoder(self):
return self.encoder

@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -969,12 +966,6 @@ def __init__(self, config: BlenderbotSmallConfig):
# Initialize weights and apply final processing
self.post_init()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
Expand Down Expand Up @@ -1149,12 +1140,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value

def set_decoder(self, decoder):
self.model.decoder = decoder

def get_decoder(self):
return self.model.decoder

@auto_docstring
def forward(
self,
Expand Down
Loading