Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/transformers/models/llama4/configuration_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...utils import logging
Expand Down Expand Up @@ -56,7 +57,6 @@ class Llama4VisionConfig(PretrainedConfig):
The size (resolution) of each patch.
norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
vision_feature_layer (``, *optional*, defaults to -1): TODO
vision_feature_select_strategy (`int`, *optional*, defaults to `"default"`): TODO
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
Expand Down Expand Up @@ -93,7 +93,6 @@ def __init__(
image_size: int = 448,
patch_size: int = 14,
norm_eps: float = 1e-5,
vision_feature_layer=-1,
vision_feature_select_strategy="default",
initializer_range: float = 0.02,
pixel_shuffle_ratio=0.5,
Expand Down Expand Up @@ -122,9 +121,23 @@ def __init__(
self.multi_modal_projector_bias = multi_modal_projector_bias
self.projector_dropout = projector_dropout
self.attention_dropout = attention_dropout
self.vision_feature_layer = vision_feature_layer
self.vision_feature_select_strategy = vision_feature_select_strategy
self.rope_theta = rope_theta

self._vision_feature_layer = kwargs.get("vision_feature_layer", -1)

@property
def vision_feature_layer(self):
warnings.warn(
"The `vision_feature_layer` attribute is deprecated and will be removed in v4.58.0.",
FutureWarning,
)
return self._vision_feature_layer

@vision_feature_layer.setter
def vision_feature_layer(self, value):
self._vision_feature_layer = value

super().__init__(**kwargs)


Expand Down
14 changes: 1 addition & 13 deletions src/transformers/models/llama4/modeling_llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,6 @@ def get_decoder(self):
def get_image_features(
self,
pixel_values: torch.FloatTensor,
vision_feature_layer: Union[int, list[int]],
vision_feature_select_strategy: str,
**kwargs,
):
Expand All @@ -1183,10 +1182,6 @@ def get_image_features(
Args:
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
The tensors corresponding to the input images.
vision_feature_layer (`Union[int, list[int]]`):
The index of the layer to select the vision feature. If multiple indices are provided,
the vision feature of the corresponding indices will be concatenated to form the
vision features.
vision_feature_select_strategy (`str`):
The feature selection strategy used to select the vision feature from the vision backbone.
Can be one of `"default"` or `"full"`
Expand Down Expand Up @@ -1224,6 +1219,7 @@ def get_placeholder_mask(
return special_image_mask

@auto_docstring
@deprecate_kwarg("vision_feature_layer", version="4.58")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand All @@ -1241,7 +1237,6 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
image_sizes: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Llama4CausalLMOutputWithPast]:
r"""
Expand Down Expand Up @@ -1277,11 +1272,6 @@ def forward(
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_feature_layer = (
vision_feature_layer
if vision_feature_layer is not None
else self.config.vision_config.vision_feature_layer
)
vision_feature_select_strategy = (
vision_feature_select_strategy
if vision_feature_select_strategy is not None
Expand All @@ -1302,9 +1292,7 @@ def forward(
if pixel_values is not None:
image_features = self.get_image_features(
pixel_values=pixel_values,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
image_sizes=image_sizes,
)

vision_flat = image_features.view(-1, image_features.size(-1))
Expand Down