From dbe7ec25c6b70f4bbf6b04b880ac11aa92eba4a5 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Tue, 23 Sep 2025 16:09:27 +0200 Subject: [PATCH 01/31] Added an initial conversion script --- .../fastvlm/convert_fastvlm_weights_to_hf.py | 199 ++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 src/transformers/models/fastvlm/convert_fastvlm_weights_to_hf.py diff --git a/src/transformers/models/fastvlm/convert_fastvlm_weights_to_hf.py b/src/transformers/models/fastvlm/convert_fastvlm_weights_to_hf.py new file mode 100644 index 000000000000..8af5d5aa6ac2 --- /dev/null +++ b/src/transformers/models/fastvlm/convert_fastvlm_weights_to_hf.py @@ -0,0 +1,199 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import glob + +import torch +from huggingface_hub import snapshot_download +from safetensors import safe_open +import re + +from transformers import ( + AddedToken, + AutoConfig, + AutoTokenizer, + LlavaConfig, + LlavaForConditionalGeneration, + LlavaProcessor, + CLIPImageProcessor, +) + +KEYS_TO_MODIFY_MAPPING = { + "model.vision_tower.vision_tower.model": "model.vision_tower.timm_model", + "patch_embed": "stem", + "layers": "language_model.layers", + "embed_tokens": "language_model.embed_tokens", + "layer_scale_1": "layer_scale_1.gamma", + "layer_scale_2": "layer_scale_2.gamma", + "mm_projector.0": "multi_modal_projector.linear_1", + "mm_projector.2": "multi_modal_projector.linear_2", + "conv_exp": "final_conv", + "se.reduce": "se.fc1", + "se.expand": "se.fc2", + "convffn": "mlp", + "lkb_reparam": "reparam_conv", +} + +def map_to_stage(number): + number = int(number) + if number == 0: + return 0 + if number in {1, 2}: + return 1 + if number in {3, 4}: + return 2 + if number in {5, 6, 7}: + return 3 + if number in {8, 9, 10}: + return 4 + +def load_original_state_dict(model_id): + directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"]) + + original_state_dict = {} + for path in glob.glob(f"{directory_path}/*"): + if path.endswith(".safetensors"): + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + original_state_dict[key] = f.get_tensor(key) + + if "model.vision_tower.vision_tower.model.head.proj" in original_state_dict: + del original_state_dict["model.vision_tower.vision_tower.model.head.proj"] + return original_state_dict + +def convert_state_dict_to_hf(state_dict): + new_state_dict = {} + + single_pattern = r"network\.(\d{1,2})" + double_pattern = r"network\.(\d{1,2})\.(\d{1,2})" + pos_embedding_pattern = r"stages\.(\d{1,2})\.reparam_conv" + + for key, value in state_dict.items(): + if key.endswith("layer_scale"): + key = key.replace("layer_scale", "layer_scale.gamma") + if key.startswith("model.norm"): + key = key.replace("model.norm", "model.language_model.norm") + if "token_mixer" not in key: + key = key.replace(".proj.", ".downsample.proj.") + + matches = re.findall(double_pattern, key) + if len(matches) == 1: + match = matches[0] + key = key.replace(f"network.{match[0]}.{match[1]}", f"stages.{map_to_stage(match[0])}.blocks.{match[1]}") + + matches = re.findall(single_pattern, key) + if len(matches) == 1: + match = matches[0] + key = key.replace(f"network.{match[0]}", f"stages.{map_to_stage(match[0])}") + + matches = re.findall(pos_embedding_pattern, key) + if len(matches) == 1: + match = matches[0] + key = key.replace(f"stages.{match[0]}", f"stages.{match[0]}.pos_emb") + + for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in key: + key = key.replace(key_to_modify, new_key) + + new_state_dict[key] = value + return new_state_dict + + +def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id): + torch.set_default_dtype(torch.bfloat16) + + text_config = AutoConfig.from_pretrained(text_model_id) + vision_config = AutoConfig.from_pretrained(vision_model_id) + vision_config.model_args = {"inference_mode": True} + vision_config.hidden_size = vision_config.num_features + + config = LlavaConfig( + text_config=text_config, + vision_config=vision_config, + ) + config.vision_feature_select_strategy = "full" + config.vision_feature_layer = -1 + config.image_token_id = 151646 + + tokenizer = AutoTokenizer.from_pretrained(text_model_id) + tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) + image_processor = CLIPImageProcessor(crop_size={"height": 1024, + "width": 1024}, + image_mean=[0.0, 0.0, 0.0], + image_std=[1.0, 1.0, 1.0], + size={"shortest_edge": 1024}) + + processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) + processor.patch_size = 64 # effective patch size (2^6) + + model = LlavaForConditionalGeneration(config) + + state_dict = load_original_state_dict(old_state_dict_id) + state_dict = convert_state_dict_to_hf(state_dict) + model.load_state_dict(state_dict, strict=True, assign=True) + + pre_expansion_embeddings = model.language_model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) + + # We add an image token so we resize the model and pad to 64 for performance reasons + pad_shape = 64 + vocab_size = config.text_config.vocab_size + model.resize_token_embeddings(config.text_config.vocab_size + 1, pad_shape) + model.language_model.embed_tokens.weight.data[vocab_size:] = torch.stack( + tuple(dist.sample() for _ in range(model.language_model.embed_tokens.weight.data[vocab_size:].shape[0])), + dim=0, + ) + model.lm_head.weight.data[vocab_size:] = torch.stack( + tuple(dist.sample() for _ in range(model.lm_head.weight.data[vocab_size:].shape[0])), + dim=0, + ) + + model.push_to_hub(output_hub_path) + processor.push_to_hub(output_hub_path) + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--text_model_id", + default="Qwen/Qwen2-0.5B", + help="Hub location of the text model", + ) + parser.add_argument( + "--vision_model_id", + default="timm/fastvit_mci3.apple_mclip2_dfndr2b", + help="Hub location of the vision model", + ) + parser.add_argument( + "--output_hub_path", + default="KamilaMila/FastVLM-0.5B", + help="Location on the hub of the converted model", + ) + parser.add_argument( + "--old_state_dict_id", + default="apple/FastVLM-0.5B", + help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", + ) + args = parser.parse_args() + convert_fastvlm_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id) + + +if __name__ == "__main__": + main() From 977d05ee453cbe4a4dc6d9379c0c99ab6677d43d Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Tue, 23 Sep 2025 21:43:16 +0200 Subject: [PATCH 02/31] Added a modular where FastVLM is different from LlaVA --- src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/fast_vlm/__init__.py | 27 + .../models/fast_vlm/configuration_fast_vlm.py | 125 +++++ .../convert_fastvlm_weights_to_hf.py | 8 +- .../models/fast_vlm/modeling_fast_vlm.py | 494 ++++++++++++++++++ .../models/fast_vlm/modular_fast_vlm.py | 72 +++ 7 files changed, 725 insertions(+), 4 deletions(-) create mode 100644 src/transformers/models/fast_vlm/__init__.py create mode 100644 src/transformers/models/fast_vlm/configuration_fast_vlm.py rename src/transformers/models/{fastvlm => fast_vlm}/convert_fastvlm_weights_to_hf.py (98%) create mode 100644 src/transformers/models/fast_vlm/modeling_fast_vlm.py create mode 100644 src/transformers/models/fast_vlm/modular_fast_vlm.py diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 5630063f92ec..d16581cdb99b 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -125,6 +125,7 @@ from .falcon_h1 import * from .falcon_mamba import * from .fastspeech2_conformer import * + from .fast_vlm import * from .flaubert import * from .flava import * from .flex_olmo import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7e2e84a445ef..54e6f2279be5 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -150,6 +150,7 @@ ("falcon_mamba", "FalconMambaConfig"), ("fastspeech2_conformer", "FastSpeech2ConformerConfig"), ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"), + ("fast_vlm", "FastVlmConfig"), ("flaubert", "FlaubertConfig"), ("flava", "FlavaConfig"), ("flex_olmo", "FlexOlmoConfig"), @@ -593,6 +594,7 @@ ("falcon_mamba", "FalconMamba"), ("fastspeech2_conformer", "FastSpeech2Conformer"), ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), + ("fast_vlm", "FastVlm"), ("flan-t5", "FLAN-T5"), ("flan-ul2", "FLAN-UL2"), ("flaubert", "FlauBERT"), diff --git a/src/transformers/models/fast_vlm/__init__.py b/src/transformers/models/fast_vlm/__init__.py new file mode 100644 index 000000000000..949f087650dd --- /dev/null +++ b/src/transformers/models/fast_vlm/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_fast_vlm import * + from .modeling_fast_vlm import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py new file mode 100644 index 000000000000..38f65368cb8f --- /dev/null +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -0,0 +1,125 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/fast_vlm/modular_fast_vlm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_fast_vlm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 + +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING, AutoConfig + + +class FastVlmConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FastVlmForConditionalGeneration`]. It is used to instantiate an + FastVlm model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the FastVlm-9B. + + e.g. [fast_vlm-hf/fast_vlm-9b](https://huggingface.co/fast_vlm-hf/fast_vlm-9b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -2): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_seq_length (`int`, *optional*, defaults to 576): + Sequence length of one image embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. + + Example: + + ```python + >>> from transformers import FastVlmForConditionalGeneration, FastVlmConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a FastVlm fast_vlm-1.5-7b style configuration + >>> configuration = FastVlmConfig(vision_config, text_config) + + >>> # Initializing a model from the fast_vlm-1.5-7b style configuration + >>> model = FastVlmForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "fast_vlm" + attribute_map = { + "image_token_id": "image_token_index", + } + sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig} + + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-2, + image_seq_length=576, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy not in ["default", "full"]: + raise ValueError( + "vision_feature_select_strategy should be one of 'default', 'full'." + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "clip_vision_model") + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["clip_vision_model"]( + intermediate_size=4096, + hidden_size=1024, + patch_size=14, + image_size=336, + num_hidden_layers=24, + num_attention_heads=16, + vocab_size=32000, + projection_dim=768, + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "llama") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["llama"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + super().__init__(**kwargs) + + +__all__ = ["FastVlmConfig"] diff --git a/src/transformers/models/fastvlm/convert_fastvlm_weights_to_hf.py b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py similarity index 98% rename from src/transformers/models/fastvlm/convert_fastvlm_weights_to_hf.py rename to src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py index 8af5d5aa6ac2..a43a20768662 100644 --- a/src/transformers/models/fastvlm/convert_fastvlm_weights_to_hf.py +++ b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py @@ -23,8 +23,8 @@ AddedToken, AutoConfig, AutoTokenizer, - LlavaConfig, - LlavaForConditionalGeneration, + FastVlmConfig, + FastVlmForConditionalGeneration, LlavaProcessor, CLIPImageProcessor, ) @@ -118,7 +118,7 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s vision_config.model_args = {"inference_mode": True} vision_config.hidden_size = vision_config.num_features - config = LlavaConfig( + config = FastVlmConfig( text_config=text_config, vision_config=vision_config, ) @@ -137,7 +137,7 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) processor.patch_size = 64 # effective patch size (2^6) - model = LlavaForConditionalGeneration(config) + model = FastVlmForConditionalGeneration(config) state_dict = load_original_state_dict(old_state_dict_id) state_dict = convert_state_dict_to_hf(state_dict) diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py new file mode 100644 index 000000000000..bedf89903327 --- /dev/null +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -0,0 +1,494 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/fast_vlm/modular_fast_vlm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_fast_vlm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ..auto import AutoModel +from .configuration_fast_vlm import FastVlmConfig + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for FastVlm outputs, with hidden states and attentions. + """ +) +class FastVlmModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + image_hidden_states: Optional[torch.FloatTensor] = None + + +class FastVlmMultiModalProjector(nn.Module): + def __init__(self, config: FastVlmConfig): + super().__init__() + # We have hidden_size * the number of vision feature layers + num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) + self.linear_1 = nn.Linear( + config.vision_config.hidden_size * num_feature_layers, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@auto_docstring +class FastVlmPreTrainedModel(PreTrainedModel): + config: FastVlmConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_flex_attn = True + _supports_attention_backend = True + + +@auto_docstring( + custom_intro=""" + The FastVlm model which consists of a vision backbone and a language model, without a language modeling head. + """ +) +class FastVlmModel(FastVlmPreTrainedModel): + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + + def __init__(self, config: FastVlmConfig): + super().__init__(config) + self.vision_tower = AutoModel.from_config(config.vision_config) + + self.multi_modal_projector = FastVlmMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder): + self.language_model = decoder + + def get_decoder(self): + return self.language_model + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if vision_feature_select_strategy != "full": + raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}") + + if vision_feature_layer != -1: + raise ValueError(f"Unexpected vision feature layer: {vision_feature_layer}") + + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + + # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava + selected_image_feature = image_outputs.last_hidden_state + selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) + + image_features = self.multi_modal_projector(selected_image_feature) + + if "image_sizes" in kwargs: + split_sizes = [ + (height // self.vision_tower.patch_size) * (width // self.vision_tower.patch_size) + for height, width in kwargs["image_sizes"] + ] + image_features = torch.split(image_features.squeeze(0), split_sizes) + else: + image_features = list(image_features) + return image_features + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + image_sizes: torch.Tensor = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[tuple, FastVlmModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + image_features = self.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + image_sizes=image_sizes, + ) + image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return FastVlmModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for FastVlm causal language model (or autoregressive) outputs. + """ +) +class FastVlmCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`. + image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state. + """ + + loss: Optional[torch.FloatTensor] = None + logits: Optional[torch.FloatTensor] = None + past_key_values: Optional[list[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +@auto_docstring( + custom_intro=""" + The FAST_VLM model which consists of a vision backbone and a language model. + """ +) +class FastVlmForConditionalGeneration(FastVlmPreTrainedModel, GenerationMixin): + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: FastVlmConfig): + super().__init__(config) + self.model = FastVlmModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_decoder(self, decoder): + self.model.set_decoder(decoder) + + def get_decoder(self): + return self.model.get_decoder() + + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + **kwargs, + ): + return self.model.get_image_features( + pixel_values=pixel_values, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + **kwargs, + ) + + # Make modules available through conditional class for BC + @property + def language_model(self): + return self.model.language_model + + @property + def vision_tower(self): + return self.model.vision_tower + + @property + def multi_modal_projector(self): + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> Union[tuple, FastVlmCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration + + >>> model = FastVlmForConditionalGeneration.from_pretrained("fast_vlm-hf/fast_vlm-1.5-7b-hf") + >>> processor = AutoProcessor.from_pretrained("fast_vlm-hf/fast_vlm-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + image_sizes=image_sizes, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return FastVlmCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + attention_mask=None, + cache_position=None, + logits_to_keep=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + logits_to_keep=logits_to_keep, + **kwargs, + ) + + if cache_position[0] == 0: + # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore + # Otherwise we need pixel values to be passed to model + model_inputs["pixel_values"] = pixel_values + + return model_inputs + + +__all__ = ["FastVlmForConditionalGeneration", "FastVlmModel", "FastVlmPreTrainedModel"] diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py new file mode 100644 index 000000000000..e6a4ec4ae0b2 --- /dev/null +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -0,0 +1,72 @@ +from ..llava.configuration_llava import LlavaConfig +from ..llava.modeling_llava import LlavaModel, LlavaForConditionalGeneration +import torch +from typing import Optional, Union + +class FastVlmConfig(LlavaConfig): + model_type = "fast_vlm" + +class FastVlmModel(LlavaModel): + def get_image_features( + self, + pixel_values: torch.FloatTensor, + vision_feature_layer: Optional[Union[int, list[int]]] = None, + vision_feature_select_strategy: Optional[str] = None, + **kwargs, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): + The tensors corresponding to the input images. + vision_feature_layer (`Union[int, list[int]]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + vision_feature_layer = ( + vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer + ) + vision_feature_select_strategy = ( + vision_feature_select_strategy + if vision_feature_select_strategy is not None + else self.config.vision_feature_select_strategy + ) + + if vision_feature_select_strategy != "full": + raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}") + + if vision_feature_layer != -1: + raise ValueError(f"Unexpected vision feature layer: {vision_feature_layer}") + + + kwargs = {k: v for k, v in kwargs.items() if v is not None} + # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + + # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava + selected_image_feature = image_outputs.last_hidden_state + selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) + + image_features = self.multi_modal_projector(selected_image_feature) + + if "image_sizes" in kwargs: + split_sizes = [ + (height // self.vision_tower.patch_size) * (width // self.vision_tower.patch_size) + for height, width in kwargs["image_sizes"] + ] + image_features = torch.split(image_features.squeeze(0), split_sizes) + else: + image_features = list(image_features) + return image_features + +class FastVlmForConditionalGeneration(LlavaForConditionalGeneration): + pass + +__all__ = ["FastVlmForConditionalGeneration", "FastVlmModel", "FastVlmPreTrainedModel", "FastVlmConfig"] \ No newline at end of file From 4e3679f1a09c1a1ac6025691490ca044e47a11cb Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 24 Sep 2025 20:21:29 +0200 Subject: [PATCH 03/31] Improved the conversion script --- .../fast_vlm/convert_fastvlm_weights_to_hf.py | 47 +++++++++++++++++-- .../models/fast_vlm/modeling_fast_vlm.py | 1 + .../models/fast_vlm/modular_fast_vlm.py | 1 + 3 files changed, 45 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py index a43a20768662..c7659ea9f9f4 100644 --- a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py +++ b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse import glob +import os import torch from huggingface_hub import snapshot_download @@ -29,6 +30,11 @@ CLIPImageProcessor, ) +from PIL import Image +import requests + +os.environ["TIMM_FUSED_ATTN"] = "0" # needed because the original implementation uses regular atteniton (to avoid logits diverging) + KEYS_TO_MODIFY_MAPPING = { "model.vision_tower.vision_tower.model": "model.vision_tower.timm_model", "patch_embed": "stem", @@ -124,7 +130,8 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s ) config.vision_feature_select_strategy = "full" config.vision_feature_layer = -1 - config.image_token_id = 151646 + config.image_token_index = 151646 + config.image_seq_length = 256 tokenizer = AutoTokenizer.from_pretrained(text_model_id) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) @@ -162,9 +169,42 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s dim=0, ) + conversation = [ + { + "role": "user", + "content": "\nWhat are these?" + } + ] + prompt = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + prompt = prompt.replace("assistant<", "assistant.<") # to make it aligned with the prompt from the old Apple repo + + image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to("cuda") + inputs = {k: (v.to(torch.bfloat16) if v.dtype == torch.float32 else v) for k, v in inputs.items()} + + model = model.cuda() + model.eval() + with torch.no_grad(): + logits = model(**inputs).logits + + expected_shape = torch.Size([1, 280, 152000]) + # in order to get the same logits as in the Apple repo, we need to replace the original LayerNorm2D with Timm2D layer norm or vice versa + # otherwise numerical errors accumulate + if output_hub_path == "KamilaMila/FastVLM-0.5B": + expected_slice = torch.tensor([ 4.1250, 9.6875, 11.1875], device="cuda") + elif output_hub_path == "KamilaMila/FastVLM-1.5B": + expected_slice = torch.tensor([ 3.3750, 11.5000, 11.8125], device="cuda") + elif output_hub_path == "KamilaMila/FastVLM-7B": + expected_slice = torch.tensor([4.0312, 10.0000, 7.9062], device="cuda") + + logits_slice = logits[0, -1, :3] + assert torch.allclose(expected_slice, logits_slice, atol=1e-8) + assert logits.shape == expected_shape + model.push_to_hub(output_hub_path) processor.push_to_hub(output_hub_path) - + print("Successfully pushed to hub!") def main(): parser = argparse.ArgumentParser( @@ -189,11 +229,10 @@ def main(): parser.add_argument( "--old_state_dict_id", default="apple/FastVLM-0.5B", - help="Location on the hub of the raw state dict of the original model. The filename needs to be `model_state_dict.bin`", + help="Location on the hub of the raw state dict of the original model.", ) args = parser.parse_args() convert_fastvlm_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id) - if __name__ == "__main__": main() diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index bedf89903327..60cebf6ac7a6 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -141,6 +141,7 @@ def get_image_features( else self.config.vision_feature_select_strategy ) + # only those values make sense in FastVLM if vision_feature_select_strategy != "full": raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}") diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index e6a4ec4ae0b2..569f268a1231 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -39,6 +39,7 @@ def get_image_features( else self.config.vision_feature_select_strategy ) + # only those values make sense in FastVLM if vision_feature_select_strategy != "full": raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}") From dd2da9a0a09e912c54126d4b80a3b0e9f2f67156 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Mon, 6 Oct 2025 23:17:15 +0200 Subject: [PATCH 04/31] Adjusted the conversion script --- .../fast_vlm/convert_fastvlm_weights_to_hf.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py index c7659ea9f9f4..13bbd4baaac1 100644 --- a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py +++ b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py @@ -33,7 +33,7 @@ from PIL import Image import requests -os.environ["TIMM_FUSED_ATTN"] = "0" # needed because the original implementation uses regular atteniton (to avoid logits diverging) +os.environ["TIMM_FUSED_ATTN"] = "0" # to avoid logits diverging, needed because the original implementation uses regular (not fused) atteniton KEYS_TO_MODIFY_MAPPING = { "model.vision_tower.vision_tower.model": "model.vision_tower.timm_model", @@ -176,7 +176,7 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s } ] prompt = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) - prompt = prompt.replace("assistant<", "assistant.<") # to make it aligned with the prompt from the old Apple repo + prompt = prompt.replace("assistant<", "assistant.<") # to make it aligned with the prompt from the old Apple remote code image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) @@ -188,15 +188,15 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s with torch.no_grad(): logits = model(**inputs).logits - expected_shape = torch.Size([1, 280, 152000]) - # in order to get the same logits as in the Apple repo, we need to replace the original LayerNorm2D with Timm2D layer norm or vice versa + expected_shape = torch.Size([1, 280, 152128]) + # in order to get the same logits as in the Apple repo, we need to manually replace the original (Apple) LayerNorm2D with Timm's LayerNorm2D or vice versa # otherwise numerical errors accumulate if output_hub_path == "KamilaMila/FastVLM-0.5B": expected_slice = torch.tensor([ 4.1250, 9.6875, 11.1875], device="cuda") elif output_hub_path == "KamilaMila/FastVLM-1.5B": expected_slice = torch.tensor([ 3.3750, 11.5000, 11.8125], device="cuda") elif output_hub_path == "KamilaMila/FastVLM-7B": - expected_slice = torch.tensor([4.0312, 10.0000, 7.9062], device="cuda") + expected_slice = torch.tensor([3.8125, 9.0625, 7.9062], device="cuda") logits_slice = logits[0, -1, :3] assert torch.allclose(expected_slice, logits_slice, atol=1e-8) @@ -213,7 +213,7 @@ def main(): parser.add_argument( "--text_model_id", - default="Qwen/Qwen2-0.5B", + default="Qwen/Qwen2-7B", help="Hub location of the text model", ) parser.add_argument( @@ -223,12 +223,12 @@ def main(): ) parser.add_argument( "--output_hub_path", - default="KamilaMila/FastVLM-0.5B", + default="KamilaMila/FastVLM-7B", help="Location on the hub of the converted model", ) parser.add_argument( "--old_state_dict_id", - default="apple/FastVLM-0.5B", + default="apple/FastVLM-7B", help="Location on the hub of the raw state dict of the original model.", ) args = parser.parse_args() From 9715630a3694dee2a886bf399e1e4990c71c0393 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Tue, 7 Oct 2025 18:41:15 +0200 Subject: [PATCH 05/31] Removed redundant labels from FastViT & improved the template --- docs/source/en/_toctree.yml | 2 ++ .../fast_vlm/convert_fastvlm_weights_to_hf.py | 25 +++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index f2fe366a69fa..38fba41da9d1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -1050,6 +1050,8 @@ title: Emu3 - local: model_doc/evolla title: Evolla + - local: model_doc/fast_vlm + title: FastVLM - local: model_doc/flava title: FLAVA - local: model_doc/florence2 diff --git a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py index 13bbd4baaac1..139cc7bfdb2e 100644 --- a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py +++ b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py @@ -123,7 +123,8 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s vision_config = AutoConfig.from_pretrained(vision_model_id) vision_config.model_args = {"inference_mode": True} vision_config.hidden_size = vision_config.num_features - + vision_config.label2id = {} + vision_config.id2label = {} config = FastVlmConfig( text_config=text_config, vision_config=vision_config, @@ -133,7 +134,11 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s config.image_token_index = 151646 config.image_seq_length = 256 - tokenizer = AutoTokenizer.from_pretrained(text_model_id) + tokenizer = AutoTokenizer.from_pretrained( + text_model_id, + chat_template="{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n'}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all text next #}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{{'<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + ) + tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) image_processor = CLIPImageProcessor(crop_size={"height": 1024, "width": 1024}, @@ -172,11 +177,13 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s conversation = [ { "role": "user", - "content": "\nWhat are these?" + "content": [ + {"type": "text", "text": "What are these?"}, + {"type": "image"} + ] } ] prompt = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) - prompt = prompt.replace("assistant<", "assistant.<") # to make it aligned with the prompt from the old Apple remote code image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) @@ -188,14 +195,16 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s with torch.no_grad(): logits = model(**inputs).logits - expected_shape = torch.Size([1, 280, 152128]) # in order to get the same logits as in the Apple repo, we need to manually replace the original (Apple) LayerNorm2D with Timm's LayerNorm2D or vice versa # otherwise numerical errors accumulate if output_hub_path == "KamilaMila/FastVLM-0.5B": + expected_shape = torch.Size([1, 280, 152000]) expected_slice = torch.tensor([ 4.1250, 9.6875, 11.1875], device="cuda") elif output_hub_path == "KamilaMila/FastVLM-1.5B": + expected_shape = torch.Size([1, 280, 152000]) expected_slice = torch.tensor([ 3.3750, 11.5000, 11.8125], device="cuda") elif output_hub_path == "KamilaMila/FastVLM-7B": + expected_shape = torch.Size([1, 280, 152000]) expected_slice = torch.tensor([3.8125, 9.0625, 7.9062], device="cuda") logits_slice = logits[0, -1, :3] @@ -213,7 +222,7 @@ def main(): parser.add_argument( "--text_model_id", - default="Qwen/Qwen2-7B", + default="Qwen/Qwen2-1.5B", help="Hub location of the text model", ) parser.add_argument( @@ -223,12 +232,12 @@ def main(): ) parser.add_argument( "--output_hub_path", - default="KamilaMila/FastVLM-7B", + default="KamilaMila/FastVLM-1.5B", help="Location on the hub of the converted model", ) parser.add_argument( "--old_state_dict_id", - default="apple/FastVLM-7B", + default="apple/FastVLM-1.5B", help="Location on the hub of the raw state dict of the original model.", ) args = parser.parse_args() From a75c141ee97589b04e2dd272bf2de939df09d930 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 8 Oct 2025 00:57:25 +0200 Subject: [PATCH 06/31] Added docs and changed default config --- docs/source/en/model_doc/fast_vlm.md | 236 ++++++++++++++++++ .../models/fast_vlm/configuration_fast_vlm.py | 55 ++-- .../fast_vlm/convert_fastvlm_weights_to_hf.py | 4 +- .../models/fast_vlm/modeling_fast_vlm.py | 20 +- .../models/fast_vlm/modular_fast_vlm.py | 154 ++++++++++-- 5 files changed, 412 insertions(+), 57 deletions(-) create mode 100644 docs/source/en/model_doc/fast_vlm.md diff --git a/docs/source/en/model_doc/fast_vlm.md b/docs/source/en/model_doc/fast_vlm.md new file mode 100644 index 000000000000..8d1eccb37a0a --- /dev/null +++ b/docs/source/en/model_doc/fast_vlm.md @@ -0,0 +1,236 @@ + + +*This model was released on 2025-05-06 and added to Hugging Face Transformers on 2025-10-07.* + +# FastVLM + +
+PyTorch + +
+ +## Overview + +FastVLM is an open-source vision-language model featuring a novel hybrid vision encoder, FastViTHD. Leveraging reparameterizable convolutional layers, scaled input resolution, and a reduced number of visual tokens, FastVLM delivers high accuracy with exceptional efficiency. Its optimized architecture enables deployment even on edge devices, achieving ultra-low TTFT (time to first token) without sacrificing performance. + +The model was proposed in [FastVLM: Efficient Vision Encoding for Vision Language Models](https://huggingface.co/papers/2412.13303) by Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel and Hadi Pouransari. + +The abstract from the paper is the following: + +*Scaling the input image resolution is essential for enhancing the performance of Vision Language Models (VLMs), particularly in text-rich image understanding tasks. However, popular visual encoders such as ViTs become inefficient at high resolutions due to the large number of tokens and high encoding latency. At different operational resolutions, the vision encoder of a VLM can be optimized along two axes: reducing encoding latency and minimizing the number of visual tokens passed to the LLM, thereby lowering overall latency. Based on a comprehensive efficiency analysis of the interplay between image resolution, vision latency, token count, and LLM size, we introduce FastVLMβ€”a model that achieves an optimized trade-off between resolution, latency, and accuracy. FastVLM incorporates FastViTHD, a novel hybrid vision encoder designed to output fewer tokens and significantly reduce encoding time for high-resolution images. Unlike previous methods, FastVLM achieves the optimal balance between visual token count and image resolution solely by scaling the input image, eliminating the need for additional token pruning and simplifying the model design. In the LLaVA-1.5 setup, FastVLM achieves 3.2Γ— improvement in time-to-first-token (TTFT) while maintaining similar performance on VLM benchmarks compared to prior works. Compared to LLaVa-OneVision at the highest resolution (1152Γ—1152), FastVLM achieves better performance on key benchmarks like SeedBench, MMMU and DocVQA, using the same 0.5B LLM, but with 85Γ— faster TTFT and a vision encoder that is 3.4Γ— smaller.* + +This model was contributed by [Kamila](https://github.com/kamila-chay). +The original code can be found [here](https://github.com/apple/ml-fastvlm). + +## Usage tips + +- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating. + +- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results. + +### Formatting Prompts with Chat Templates + +Each **checkpoint** is trained with a specific prompt format, depending on the underlying large language model backbone. To ensure correct formatting, use the processor’s `apply_chat_template` method. + +**Important:** +- You must construct a conversation history β€” passing a plain string won't work. +- Each message should be a dictionary with `"role"` and `"content"` keys. +- The `"content"` should be a list of dictionaries for different modalities like `"text"` and `"image"`. + + +Here’s an example of how to structure your input. +We will use a conversation history of text and image. Each content field has to be a list of dicts, as follows: + + +```python +from transformers import AutoProcessor + +processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") + +conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What’s shown in this image?"}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "This image shows a red stop sign."},] + }, + { + + "role": "user", + "content": [ + {"type": "text", "text": "Describe the image in more details."}, + ], + }, +] + +text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + +# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images +print(text_prompt) +>>> "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nWhat’s shown in this image?<|im_end|>\n<|im_start|>assistant\n\nThis image shows a red stop sign.<|im_end|>\n<|im_start|>user\n\nDescribe the image in more details.<|im_end|>\n<|im_start|>assistant\n" +``` + +πŸš€ **Bonus:** If you're using `transformers>=4.49.0`, you can also get a vectorized output from `apply_chat_template`. See the **Usage Examples** below for more details on how to use it. + +## Usage examples + +### Single input inference + + +```python +import torch +from transformers import AutoProcessor, FastVlmForConditionalGeneration + +# Load the model in half-precision +model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", dtype=torch.bfloat16, device_map="auto") +processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") + +conversation = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, +] + +inputs = processor.apply_chat_template( + conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt" +).to(model.device, torch.bfloat16) + +# Generate +generate_ids = model.generate(**inputs, max_new_tokens=30) +processor.batch_decode(generate_ids, skip_special_tokens=True) +``` + + +### Batched inference + +FastVLM also supports batched inference. Here is how you can do it: + +```python +import torch +from transformers import AutoProcessor, FastVlmForConditionalGeneration + +# Load the model in half-precision +model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", dtype=torch.bfloat16, device_map="auto") +processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") + + +# Prepare a batch of two prompts +conversation_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, +] + +conversation_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, +] + +inputs = processor.apply_chat_template( + [conversation_1, conversation_2], + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding=True, + return_tensors="pt" +).to(model.device, torch.bfloat16) + + +# Generate +generate_ids = model.generate(**inputs, max_new_tokens=30) +processor.batch_decode(generate_ids, skip_special_tokens=True) +``` + + +## Note regarding reproducing original implementation + +In order to match the logits of the [original implementation](https://github.com/apple/ml-fastvlm), one needs to set the default timm attention implementation to the most basic version(not fused): + +``` +import os +# at the beginning of your script +os.environ["TIMM_FUSED_ATTN"] = "0" +``` + +In addition, the layer norm used by Apple doesn't use the standard LayerNorm class form Torch and therefore our logits diverge. To get exactly the same values, one needs to manually change timm/layers/norm.py: + +``` +class LayerNorm2d(nn.LayerNorm): + """ LayerNorm for channels of '2D' spatial NCHW tensors """ + _fast_norm: torch.jit.Final[bool] + + def __init__(): + ... # not important + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \ + + self.bias.unsqueeze(-1).unsqueeze(-1) + return x +``` +Please note, that this is only needed in oder to get the exact same numerical values on the output of the model. It's not necessary to make this change to use FastVLM. + + + +## Resources + +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with image-to-text transformers (here using the example of Llava). + + + +- A [Google Colab demo](https://colab.research.google.com/drive/1qsl6cd2c8gGtEW1xV5io7S8NHh-Cp1TV?usp=sharing) on how to run Llava on a free-tier Google colab instance leveraging 4-bit inference. +- A [similar notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LLaVa/Inference_with_LLaVa_for_multimodal_generation.ipynb) showcasing batched inference. 🌎 + +## FastVlmConfig + +[[autodoc]] FastVlmConfig + +## FastVlmModel + +[[autodoc]] FastVlmModel + +## FastVlmForConditionalGeneration + +[[autodoc]] FastVlmForConditionalGeneration + - forward \ No newline at end of file diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 38f65368cb8f..e7a8e8ee82b7 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -4,23 +4,22 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_fast_vlm.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 - from ...configuration_utils import PretrainedConfig from ..auto import CONFIG_MAPPING, AutoConfig class FastVlmConfig(PretrainedConfig): r""" - This is the configuration class to store the configuration of a [`FastVlmForConditionalGeneration`]. It is used to instantiate an - FastVlm model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the FastVlm-9B. + This is the configuration class to store the configuration of a [`FastVLMForConditionalGeneration`]. It is used to instantiate a + FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the FastVLM-7B. - e.g. [fast_vlm-hf/fast_vlm-9b](https://huggingface.co/fast_vlm-hf/fast_vlm-9b) + e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: + Args: TODO !!!!!!!!!!! vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): @@ -44,7 +43,7 @@ class FastVlmConfig(PretrainedConfig): Example: ```python - >>> from transformers import FastVlmForConditionalGeneration, FastVlmConfig, CLIPVisionConfig, LlamaConfig + >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig >>> # Initializing a CLIP-vision config >>> vision_config = CLIPVisionConfig() @@ -52,11 +51,11 @@ class FastVlmConfig(PretrainedConfig): >>> # Initializing a Llama config >>> text_config = LlamaConfig() - >>> # Initializing a FastVlm fast_vlm-1.5-7b style configuration - >>> configuration = FastVlmConfig(vision_config, text_config) + >>> # Initializing a Llava llava-1.5-7b style configuration + >>> configuration = LlavaConfig(vision_config, text_config) - >>> # Initializing a model from the fast_vlm-1.5-7b style configuration - >>> model = FastVlmForConditionalGeneration(configuration) + >>> # Initializing a model from the llava-1.5-7b style configuration + >>> model = LlavaForConditionalGeneration(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -72,11 +71,11 @@ def __init__( self, vision_config=None, text_config=None, - image_token_index=32000, + image_token_index=151646, projector_hidden_act="gelu", - vision_feature_select_strategy="default", - vision_feature_layer=-2, - image_seq_length=576, + vision_feature_select_strategy="full", + vision_feature_layer=-1, + image_seq_length=256, multimodal_projector_bias=True, **kwargs, ): @@ -84,9 +83,9 @@ def __init__( self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length - if vision_feature_select_strategy not in ["default", "full"]: + if vision_feature_select_strategy != "full": raise ValueError( - "vision_feature_select_strategy should be one of 'default', 'full'." + "Only vision_feature_select_strategy='full' supported in FastVLM!" f"Got: {vision_feature_select_strategy}" ) @@ -94,27 +93,25 @@ def __init__( self.vision_feature_layer = vision_feature_layer if isinstance(vision_config, dict): - vision_config["model_type"] = vision_config.get("model_type", "clip_vision_model") + vision_config["model_type"] = vision_config.get("model_type", "timm_wrapper") vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) elif vision_config is None: - vision_config = CONFIG_MAPPING["clip_vision_model"]( - intermediate_size=4096, - hidden_size=1024, - patch_size=14, - image_size=336, - num_hidden_layers=24, - num_attention_heads=16, - vocab_size=32000, - projection_dim=768, + vision_config = CONFIG_MAPPING["timm_wrapper"]( + architecture="fastvit_mci3", + do_pooling=True, + global_pool="avg", + hidden_size=3072, + initializer_range=0.02, + model_args={"inference_mode": True}, ) self.vision_config = vision_config if isinstance(text_config, dict): - text_config["model_type"] = text_config.get("model_type", "llama") + text_config["model_type"] = text_config.get("model_type", "qwen2") text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) elif text_config is None: - text_config = CONFIG_MAPPING["llama"]() + text_config = CONFIG_MAPPING["qwen2"]() self.text_config = text_config self.multimodal_projector_bias = multimodal_projector_bias diff --git a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py index 139cc7bfdb2e..3886fb0ebbc9 100644 --- a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py +++ b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py @@ -204,8 +204,8 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s expected_shape = torch.Size([1, 280, 152000]) expected_slice = torch.tensor([ 3.3750, 11.5000, 11.8125], device="cuda") elif output_hub_path == "KamilaMila/FastVLM-7B": - expected_shape = torch.Size([1, 280, 152000]) - expected_slice = torch.tensor([3.8125, 9.0625, 7.9062], device="cuda") + expected_shape = torch.Size([1, 280, 152128]) + expected_slice = torch.tensor([3.8281, 9.0625, 7.9062], device="cuda") logits_slice = logits[0, -1, :3] assert torch.allclose(expected_slice, logits_slice, atol=1e-8) diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index 60cebf6ac7a6..d37d9e7c847b 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -123,12 +123,10 @@ def get_image_features( pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. vision_feature_layer (`Union[int, list[int]]`, *optional*): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. + The index/indices of the layer to select the vision feature. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"` + Only "full" supported. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ @@ -141,16 +139,14 @@ def get_image_features( else self.config.vision_feature_select_strategy ) - # only those values make sense in FastVLM + # only this value makes sense in FastVLM if vision_feature_select_strategy != "full": - raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}") - - if vision_feature_layer != -1: - raise ValueError(f"Unexpected vision feature layer: {vision_feature_layer}") + raise ValueError( + f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported." + ) kwargs = {k: v for k, v in kwargs.items() if v is not None} - # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + image_outputs = self.vision_tower(pixel_values, **kwargs) # add more choice here! # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava selected_image_feature = image_outputs.last_hidden_state @@ -395,7 +391,7 @@ def forward( >>> import requests >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration - >>> model = FastVlmForConditionalGeneration.from_pretrained("fast_vlm-hf/fast_vlm-1.5-7b-hf") + >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/fast_vlm-1.5-7b-hf") #TODO change!!! >>> processor = AutoProcessor.from_pretrained("fast_vlm-hf/fast_vlm-1.5-7b-hf") >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 569f268a1231..7f61e93a6163 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -2,10 +2,114 @@ from ..llava.modeling_llava import LlavaModel, LlavaForConditionalGeneration import torch from typing import Optional, Union +from ...configuration_utils import PretrainedConfig +from ..auto import CONFIG_MAPPING class FastVlmConfig(LlavaConfig): + r""" + This is the configuration class to store the configuration of a [`FastVLMForConditionalGeneration`]. It is used to instantiate a + FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the FastVLM-7B. + + e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: TODO !!!!!!!!!!! + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + The config object or dictionary of the vision backbone. + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + The config object or dictionary of the text backbone. + image_token_index (`int`, *optional*, defaults to 32000): + The image token index to encode the image prompt. + projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): + The activation function used by the multimodal projector. + vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"`. + vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -2): + The index of the layer to select the vision feature. If multiple indices are provided, + the vision feature of the corresponding indices will be concatenated to form the + vision features. + image_seq_length (`int`, *optional*, defaults to 576): + Sequence length of one image embedding. + multimodal_projector_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the multimodal projector. + + Example: + + ```python + >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig + + >>> # Initializing a CLIP-vision config + >>> vision_config = CLIPVisionConfig() + + >>> # Initializing a Llama config + >>> text_config = LlamaConfig() + + >>> # Initializing a Llava llava-1.5-7b style configuration + >>> configuration = LlavaConfig(vision_config, text_config) + + >>> # Initializing a model from the llava-1.5-7b style configuration + >>> model = LlavaForConditionalGeneration(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" model_type = "fast_vlm" + def __init__( + self, + vision_config=None, + text_config=None, + image_token_index=151646, + projector_hidden_act="gelu", + vision_feature_select_strategy="full", + vision_feature_layer=-1, + image_seq_length=256, + multimodal_projector_bias=True, + **kwargs, + ): + self.image_token_index = image_token_index + self.projector_hidden_act = projector_hidden_act + self.image_seq_length = image_seq_length + + if vision_feature_select_strategy != "full": + raise ValueError( + "Only vision_feature_select_strategy='full' supported in FastVLM!" + f"Got: {vision_feature_select_strategy}" + ) + + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + + if isinstance(vision_config, dict): + vision_config["model_type"] = vision_config.get("model_type", "timm_wrapper") + vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config) + elif vision_config is None: + vision_config = CONFIG_MAPPING["timm_wrapper"]( + architecture="fastvit_mci3", + do_pooling=True, + global_pool="avg", + hidden_size=3072, + initializer_range=0.02, + model_args={"inference_mode": True} + ) + + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = text_config.get("model_type", "qwen2") + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["qwen2"]() + + self.text_config = text_config + self.multimodal_projector_bias = multimodal_projector_bias + + PretrainedConfig.__init__(**kwargs) + class FastVlmModel(LlavaModel): def get_image_features( self, @@ -21,12 +125,10 @@ def get_image_features( pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. vision_feature_layer (`Union[int, list[int]]`, *optional*): - The index of the layer to select the vision feature. If multiple indices are provided, - the vision feature of the corresponding indices will be concatenated to form the - vision features. + The index/indices of the layer to select the vision feature. vision_feature_select_strategy (`str`, *optional*): - The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"` + The feature selection strategy used to select the vision feature from the vision backbone. + Only "full" supported. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). """ @@ -39,17 +141,12 @@ def get_image_features( else self.config.vision_feature_select_strategy ) - # only those values make sense in FastVLM + # only this value makes sense in FastVLM if vision_feature_select_strategy != "full": - raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}") - - if vision_feature_layer != -1: - raise ValueError(f"Unexpected vision feature layer: {vision_feature_layer}") - + raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported.") kwargs = {k: v for k, v in kwargs.items() if v is not None} - # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states. - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + image_outputs = self.vision_tower(pixel_values, **kwargs) # add more choice here! # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava selected_image_feature = image_outputs.last_hidden_state @@ -68,6 +165,35 @@ def get_image_features( return image_features class FastVlmForConditionalGeneration(LlavaForConditionalGeneration): - pass + def forward(self, **super_kwargs): + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration + + >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/fast_vlm-1.5-7b-hf") #TODO change!!! + >>> processor = AutoProcessor.from_pretrained("fast_vlm-hf/fast_vlm-1.5-7b-hf") + + >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + ```""" + super().forward(**super_kwargs) + __all__ = ["FastVlmForConditionalGeneration", "FastVlmModel", "FastVlmPreTrainedModel", "FastVlmConfig"] \ No newline at end of file From 030ad2446c2e864f978822008fecb51b38062843 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 8 Oct 2025 13:44:45 +0200 Subject: [PATCH 07/31] Fix default config --- .../models/fast_vlm/configuration_fast_vlm.py | 8 +++++++- src/transformers/models/fast_vlm/modular_fast_vlm.py | 6 +++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index e7a8e8ee82b7..04c1fb38a54c 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -111,7 +111,13 @@ def __init__( text_config["model_type"] = text_config.get("model_type", "qwen2") text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) elif text_config is None: - text_config = CONFIG_MAPPING["qwen2"]() + text_config = CONFIG_MAPPING["qwen2"]( + hidden_size=3584, + vocab_size=152128, + intermediate_size=18944, + num_attention_heads=28, + num_key_value_heads=4, + ) self.text_config = text_config self.multimodal_projector_bias = multimodal_projector_bias diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 7f61e93a6163..e033c1e82449 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -103,7 +103,11 @@ def __init__( text_config["model_type"] = text_config.get("model_type", "qwen2") text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) elif text_config is None: - text_config = CONFIG_MAPPING["qwen2"]() + text_config = CONFIG_MAPPING["qwen2"](hidden_size=3584, + vocab_size=152128, + intermediate_size=18944, + num_attention_heads=28, + num_key_value_heads=4) self.text_config = text_config self.multimodal_projector_bias = multimodal_projector_bias From af251d27325f1f315ae34152d66b7ac03a1e586c Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 8 Oct 2025 13:51:53 +0200 Subject: [PATCH 08/31] Fix default config --- src/transformers/models/fast_vlm/configuration_fast_vlm.py | 1 + src/transformers/models/fast_vlm/modular_fast_vlm.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 04c1fb38a54c..b13f9b4f64d1 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -117,6 +117,7 @@ def __init__( intermediate_size=18944, num_attention_heads=28, num_key_value_heads=4, + num_hidden_layers=28, ) self.text_config = text_config diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index e033c1e82449..9c463ebc9529 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -107,7 +107,8 @@ def __init__( vocab_size=152128, intermediate_size=18944, num_attention_heads=28, - num_key_value_heads=4) + num_key_value_heads=4, + num_hidden_layers=28) self.text_config = text_config self.multimodal_projector_bias = multimodal_projector_bias From 17b9e8960533b165985546b18a861aee3293d38e Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 8 Oct 2025 16:37:58 +0200 Subject: [PATCH 09/31] Fixed layer feature handling and more docs --- .../models/fast_vlm/configuration_fast_vlm.py | 59 +++++--- .../models/fast_vlm/modeling_fast_vlm.py | 121 +++++++++------ .../models/fast_vlm/modular_fast_vlm.py | 143 ++++++++++++------ 3 files changed, 213 insertions(+), 110 deletions(-) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index b13f9b4f64d1..80d2e0ccf8e3 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -4,6 +4,22 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_fast_vlm.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + from ...configuration_utils import PretrainedConfig from ..auto import CONFIG_MAPPING, AutoConfig @@ -12,30 +28,30 @@ class FastVlmConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`FastVLMForConditionalGeneration`]. It is used to instantiate a FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the FastVLM-7B. + with the defaults will yield the same configurationa as the one of FastVLM-7B. e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: TODO !!!!!!!!!!! - vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `TimmWrapperConfig` for `fastvit_mci3`): The config object or dictionary of the vision backbone. - text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`): The config object or dictionary of the text backbone. - image_token_index (`int`, *optional*, defaults to 32000): + image_token_index (`int`, *optional*, defaults to 151646): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. - vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`): The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"`. - vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -2): + Can only be `"full"`. + vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the - vision features. - image_seq_length (`int`, *optional*, defaults to 576): + vision features. Must be negative. + image_seq_length (`int`, *optional*, defaults to 256): Sequence length of one image embedding. multimodal_projector_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the multimodal projector. @@ -43,19 +59,13 @@ class FastVlmConfig(PretrainedConfig): Example: ```python - >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig - - >>> # Initializing a CLIP-vision config - >>> vision_config = CLIPVisionConfig() + >>> from transformers import FastVlmForConditionalGeneration, FastVlmConfig - >>> # Initializing a Llama config - >>> text_config = LlamaConfig() + >>> # Initializing a FastVLM-7B style configuration + >>> configuration = FastVlmConfig() - >>> # Initializing a Llava llava-1.5-7b style configuration - >>> configuration = LlavaConfig(vision_config, text_config) - - >>> # Initializing a model from the llava-1.5-7b style configuration - >>> model = LlavaForConditionalGeneration(configuration) + >>> # Initializing a model from the FastVLM-7B style configuration + >>> model = FastVlmForConditionalGeneration(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -82,6 +92,8 @@ def __init__( self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length + if math.isqrt(image_seq_length).pow(2) != image_seq_length: + raise ValueError(f"Inavalid image_seq_length: {image_seq_length}. It needs to be a perfect square.") if vision_feature_select_strategy != "full": raise ValueError( @@ -89,6 +101,11 @@ def __init__( f"Got: {vision_feature_select_strategy}" ) + if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( + [layer >= 0 for layer in vision_feature_layer] + ): + raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") + self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index d37d9e7c847b..c314604e5dd8 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -4,11 +4,27 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_fast_vlm.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math from dataclasses import dataclass from typing import Optional, Union import torch from torch import nn +from torch.nn.functional import adaptive_avg_pool2d from ...activations import ACT2FN from ...cache_utils import Cache @@ -22,6 +38,34 @@ from .configuration_fast_vlm import FastVlmConfig +class FastVlmMultiModalProjector(nn.Module): + def __init__(self, config: FastVlmConfig): + super().__init__() + if isinstance(config.vision_feature_layer, int): + layers = [config.vision_feature_layer] + else: + layers = config.vision_feature_layer + # different layers have different hidden sizes that are concatenated + total_hidden_size = 0 + for layer in layers: + total_hidden_size += config.vision_feature_layer // (2).pow(-layer - 1) + self.linear_1 = nn.Linear( + total_hidden_size, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + @dataclass @auto_docstring( custom_intro=""" @@ -44,28 +88,6 @@ class FastVlmModelOutputWithPast(BaseModelOutputWithPast): image_hidden_states: Optional[torch.FloatTensor] = None -class FastVlmMultiModalProjector(nn.Module): - def __init__(self, config: FastVlmConfig): - super().__init__() - # We have hidden_size * the number of vision feature layers - num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer) - self.linear_1 = nn.Linear( - config.vision_config.hidden_size * num_feature_layers, - config.text_config.hidden_size, - bias=config.multimodal_projector_bias, - ) - self.act = ACT2FN[config.projector_hidden_act] - self.linear_2 = nn.Linear( - config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias - ) - - def forward(self, image_features): - hidden_states = self.linear_1(image_features) - hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) - return hidden_states - - @auto_docstring class FastVlmPreTrainedModel(PreTrainedModel): config: FastVlmConfig @@ -123,7 +145,7 @@ def get_image_features( pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. vision_feature_layer (`Union[int, list[int]]`, *optional*): - The index/indices of the layer to select the vision feature. + The index/indices of the layer to select the vision feature. Must be negative. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported. @@ -139,29 +161,43 @@ def get_image_features( else self.config.vision_feature_select_strategy ) - # only this value makes sense in FastVLM + # only this value makes sense in FastVLM (we can't have a CLS token in conv layers) if vision_feature_select_strategy != "full": raise ValueError( - f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported." + f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." ) + if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( + [layer >= 0 for layer in vision_feature_layer] + ): + raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") + kwargs = {k: v for k, v in kwargs.items() if v is not None} - image_outputs = self.vision_tower(pixel_values, **kwargs) # add more choice here! + # this is not memory-efficient at all + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava - selected_image_feature = image_outputs.last_hidden_state - selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) + desired_shape = math.isqrt(self.config.image_seq_length) + if isinstance(vision_feature_layer, int): + if vision_feature_layer == -1: + selected_image_feature = image_outputs.last_hidden_state + else: + selected_image_feature = image_outputs.hidden_states[vision_feature_layer + 1] + selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) + else: + hs_pool = [] + for layer_idx in vision_feature_layer: + if layer_idx == -1: + selected_image_feature = image_outputs.last_hidden_state + else: + selected_image_feature = image_outputs.hidden_states[layer_idx + 1] + selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) + hs_pool.append(selected_image_feature) + selected_image_feature = torch.cat(hs_pool, dim=-1) + selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) image_features = self.multi_modal_projector(selected_image_feature) - - if "image_sizes" in kwargs: - split_sizes = [ - (height // self.vision_tower.patch_size) * (width // self.vision_tower.patch_size) - for height, width in kwargs["image_sizes"] - ] - image_features = torch.split(image_features.squeeze(0), split_sizes) - else: - image_features = list(image_features) + image_features = list(image_features) return image_features def get_placeholder_mask( @@ -391,19 +427,18 @@ def forward( >>> import requests >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration - >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/fast_vlm-1.5-7b-hf") #TODO change!!! - >>> processor = AutoProcessor.from_pretrained("fast_vlm-hf/fast_vlm-1.5-7b-hf") + >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-7B") + >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-7B") - >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> prompt = "<|im_start|>user\n\nWhat's the content of the image?<|im_end|>\n<|im_start|>assistant\n" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + >>> generated_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 9c463ebc9529..9228446a36ea 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -1,38 +1,59 @@ -from ..llava.configuration_llava import LlavaConfig -from ..llava.modeling_llava import LlavaModel, LlavaForConditionalGeneration -import torch +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math from typing import Optional, Union + +import torch +from torch import nn +from torch.nn.functional import adaptive_avg_pool2d + +from ..llava.configuration_llava import LlavaConfig +from ..llava.modeling_llava import LlavaModel, LlavaForConditionalGeneration, LlavaMultiModalProjector from ...configuration_utils import PretrainedConfig from ..auto import CONFIG_MAPPING +from ...activations import ACT2FN + class FastVlmConfig(LlavaConfig): r""" This is the configuration class to store the configuration of a [`FastVLMForConditionalGeneration`]. It is used to instantiate a FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield a similar configuration to that of the FastVLM-7B. + with the defaults will yield the same configurationa as the one of FastVLM-7B. e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B) Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: TODO !!!!!!!!!!! - vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `CLIPVisionConfig`): + Args: + vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `TimmWrapperConfig` for `fastvit_mci3`): The config object or dictionary of the vision backbone. - text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`): + text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`): The config object or dictionary of the text backbone. - image_token_index (`int`, *optional*, defaults to 32000): + image_token_index (`int`, *optional*, defaults to 151646): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. - vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`): + vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`): The feature selection strategy used to select the vision feature from the vision backbone. - Can be one of `"default"` or `"full"`. - vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -2): + Can only be `"full"`. + vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the - vision features. - image_seq_length (`int`, *optional*, defaults to 576): + vision features. Must be negative. + image_seq_length (`int`, *optional*, defaults to 256): Sequence length of one image embedding. multimodal_projector_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the multimodal projector. @@ -40,19 +61,13 @@ class FastVlmConfig(LlavaConfig): Example: ```python - >>> from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, LlamaConfig - - >>> # Initializing a CLIP-vision config - >>> vision_config = CLIPVisionConfig() - - >>> # Initializing a Llama config - >>> text_config = LlamaConfig() + >>> from transformers import FastVlmForConditionalGeneration, FastVlmConfig - >>> # Initializing a Llava llava-1.5-7b style configuration - >>> configuration = LlavaConfig(vision_config, text_config) + >>> # Initializing a FastVLM-7B style configuration + >>> configuration = FastVlmConfig() - >>> # Initializing a model from the llava-1.5-7b style configuration - >>> model = LlavaForConditionalGeneration(configuration) + >>> # Initializing a model from the FastVLM-7B style configuration + >>> model = FastVlmForConditionalGeneration(configuration) >>> # Accessing the model configuration >>> configuration = model.config @@ -74,12 +89,17 @@ def __init__( self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length - + if math.isqrt(image_seq_length).pow(2) != image_seq_length: + raise ValueError(f"Inavalid image_seq_length: {image_seq_length}. It needs to be a perfect square.") + if vision_feature_select_strategy != "full": raise ValueError( "Only vision_feature_select_strategy='full' supported in FastVLM!" f"Got: {vision_feature_select_strategy}" ) + + if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any([layer >= 0 for layer in vision_feature_layer]): + raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer @@ -115,6 +135,26 @@ def __init__( PretrainedConfig.__init__(**kwargs) +class FastVlmMultiModalProjector(LlavaMultiModalProjector): + def __init__(self, config: FastVlmConfig): + nn.Module.__init__() + if isinstance(config.vision_feature_layer, int): + layers = [config.vision_feature_layer] + else: + layers = config.vision_feature_layer + # different layers have different hidden sizes that are concatenated + total_hidden_size = 0 + for layer in layers: + total_hidden_size += config.vision_feature_layer // (2).pow(-layer - 1) + self.linear_1 = nn.Linear( + total_hidden_size, + config.text_config.hidden_size, + bias=config.multimodal_projector_bias, + ) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear( + config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias + ) class FastVlmModel(LlavaModel): def get_image_features( self, @@ -130,7 +170,7 @@ def get_image_features( pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. vision_feature_layer (`Union[int, list[int]]`, *optional*): - The index/indices of the layer to select the vision feature. + The index/indices of the layer to select the vision feature. Must be negative. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported. @@ -146,27 +186,39 @@ def get_image_features( else self.config.vision_feature_select_strategy ) - # only this value makes sense in FastVLM + # only this value makes sense in FastVLM (we can't have a CLS token in conv layers) if vision_feature_select_strategy != "full": - raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported.") + raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM.") + + if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any([layer >= 0 for layer in vision_feature_layer]): + raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") kwargs = {k: v for k, v in kwargs.items() if v is not None} - image_outputs = self.vision_tower(pixel_values, **kwargs) # add more choice here! + # this is not memory-efficient at all + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava - selected_image_feature = image_outputs.last_hidden_state - selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) + desired_shape = math.isqrt(self.config.image_seq_length) + if isinstance(vision_feature_layer, int): + if vision_feature_layer == -1: + selected_image_feature = image_outputs.last_hidden_state + else: + selected_image_feature = image_outputs.hidden_states[vision_feature_layer + 1] + selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) + else: + hs_pool = [] + for layer_idx in vision_feature_layer: + if layer_idx == -1: + selected_image_feature = image_outputs.last_hidden_state + else: + selected_image_feature = image_outputs.hidden_states[layer_idx + 1] + selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) + hs_pool.append(selected_image_feature) + selected_image_feature = torch.cat(hs_pool, dim=-1) + selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) image_features = self.multi_modal_projector(selected_image_feature) - - if "image_sizes" in kwargs: - split_sizes = [ - (height // self.vision_tower.patch_size) * (width // self.vision_tower.patch_size) - for height, width in kwargs["image_sizes"] - ] - image_features = torch.split(image_features.squeeze(0), split_sizes) - else: - image_features = list(image_features) + image_features = list(image_features) return image_features class FastVlmForConditionalGeneration(LlavaForConditionalGeneration): @@ -184,19 +236,18 @@ def forward(self, **super_kwargs): >>> import requests >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration - >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/fast_vlm-1.5-7b-hf") #TODO change!!! - >>> processor = AutoProcessor.from_pretrained("fast_vlm-hf/fast_vlm-1.5-7b-hf") + >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-7B") + >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-7B") - >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" + >>> prompt = "<|im_start|>user\n\nWhat's the content of the image?<|im_end|>\n<|im_start|>assistant\n" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" + >>> generated_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] ```""" super().forward(**super_kwargs) From 51010f5b41387e20c83e9a1a82e9b3fbdae30e43 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 8 Oct 2025 17:17:00 +0200 Subject: [PATCH 10/31] Fixed documentation --- docs/source/en/model_doc/fast_vlm.md | 4 ++-- .../models/fast_vlm/modeling_fast_vlm.py | 15 +++++++++++++++ .../models/fast_vlm/modular_fast_vlm.py | 18 ++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/docs/source/en/model_doc/fast_vlm.md b/docs/source/en/model_doc/fast_vlm.md index 8d1eccb37a0a..7d72c728c3d1 100644 --- a/docs/source/en/model_doc/fast_vlm.md +++ b/docs/source/en/model_doc/fast_vlm.md @@ -215,7 +215,7 @@ Flash Attention 2 is an even faster, optimized version of the previous optimizat ## Resources -A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with image-to-text transformers (here using the example of Llava). +A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with image-to-text transformers (here using Llava as an example). @@ -233,4 +233,4 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h ## FastVlmForConditionalGeneration [[autodoc]] FastVlmForConditionalGeneration - - forward \ No newline at end of file + - forward diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index c314604e5dd8..353a36c7e727 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -244,6 +244,14 @@ def forward( image_sizes: torch.Tensor = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> Union[tuple, FastVlmModelOutputWithPast]: + r""" + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. + + vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the + corresponding indices will be concatenated to form the vision features. Must be negative. + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -420,6 +428,13 @@ def forward( config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. + + vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the + corresponding indices will be concatenated to form the vision features. Must be negative. + Example: ```python diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 9228446a36ea..99d2dc478ce2 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -221,6 +221,17 @@ def get_image_features( image_features = list(image_features) return image_features + def forward(self, **super_kwargs): + r""" + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. + + vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the + corresponding indices will be concatenated to form the vision features. Must be negative. + """ + super().forward(**super_kwargs) + class FastVlmForConditionalGeneration(LlavaForConditionalGeneration): def forward(self, **super_kwargs): r""" @@ -229,6 +240,13 @@ def forward(self, **super_kwargs): config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + vision_feature_select_strategy (`str`, *optional*): + The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. + + vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): + The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the + corresponding indices will be concatenated to form the vision features. Must be negative. + Example: ```python From 1e92007dd32db1cb493ce6585c4286229ff0b801 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 8 Oct 2025 17:39:03 +0200 Subject: [PATCH 11/31] Style fixed --- src/transformers/models/__init__.py | 2 +- .../models/auto/configuration_auto.py | 4 +- .../models/fast_vlm/configuration_fast_vlm.py | 2 +- .../fast_vlm/convert_fastvlm_weights_to_hf.py | 50 ++++++------- .../models/fast_vlm/modeling_fast_vlm.py | 32 ++++----- .../models/fast_vlm/modular_fast_vlm.py | 70 ++++++++++++------- 6 files changed, 91 insertions(+), 69 deletions(-) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index d16581cdb99b..2232bf025a03 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -124,8 +124,8 @@ from .falcon import * from .falcon_h1 import * from .falcon_mamba import * - from .fastspeech2_conformer import * from .fast_vlm import * + from .fastspeech2_conformer import * from .flaubert import * from .flava import * from .flex_olmo import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 54e6f2279be5..2c4889ec07d6 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -148,9 +148,9 @@ ("falcon", "FalconConfig"), ("falcon_h1", "FalconH1Config"), ("falcon_mamba", "FalconMambaConfig"), + ("fast_vlm", "FastVlmConfig"), ("fastspeech2_conformer", "FastSpeech2ConformerConfig"), ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"), - ("fast_vlm", "FastVlmConfig"), ("flaubert", "FlaubertConfig"), ("flava", "FlavaConfig"), ("flex_olmo", "FlexOlmoConfig"), @@ -592,9 +592,9 @@ ("falcon3", "Falcon3"), ("falcon_h1", "FalconH1"), ("falcon_mamba", "FalconMamba"), + ("fast_vlm", "FastVlm"), ("fastspeech2_conformer", "FastSpeech2Conformer"), ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), - ("fast_vlm", "FastVlm"), ("flan-t5", "FLAN-T5"), ("flan-ul2", "FLAN-UL2"), ("flaubert", "FlauBERT"), diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 80d2e0ccf8e3..06ee51e7e402 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -102,7 +102,7 @@ def __init__( ) if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( - [layer >= 0 for layer in vision_feature_layer] + layer >= 0 for layer in vision_feature_layer ): raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") diff --git a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py index 3886fb0ebbc9..84edc8e27101 100644 --- a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py +++ b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py @@ -14,26 +14,28 @@ import argparse import glob import os +import re +import requests import torch from huggingface_hub import snapshot_download +from PIL import Image from safetensors import safe_open -import re from transformers import ( AddedToken, AutoConfig, AutoTokenizer, + CLIPImageProcessor, FastVlmConfig, FastVlmForConditionalGeneration, LlavaProcessor, - CLIPImageProcessor, ) -from PIL import Image -import requests -os.environ["TIMM_FUSED_ATTN"] = "0" # to avoid logits diverging, needed because the original implementation uses regular (not fused) atteniton +os.environ["TIMM_FUSED_ATTN"] = ( + "0" # to avoid logits diverging, needed because the original implementation uses regular (not fused) atteniton +) KEYS_TO_MODIFY_MAPPING = { "model.vision_tower.vision_tower.model": "model.vision_tower.timm_model", @@ -51,6 +53,7 @@ "lkb_reparam": "reparam_conv", } + def map_to_stage(number): number = int(number) if number == 0: @@ -64,6 +67,7 @@ def map_to_stage(number): if number in {8, 9, 10}: return 4 + def load_original_state_dict(model_id): directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"]) @@ -78,6 +82,7 @@ def load_original_state_dict(model_id): del original_state_dict["model.vision_tower.vision_tower.model.head.proj"] return original_state_dict + def convert_state_dict_to_hf(state_dict): new_state_dict = {} @@ -136,18 +141,19 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s tokenizer = AutoTokenizer.from_pretrained( text_model_id, - chat_template="{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n'}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all text next #}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{{'<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + chat_template="{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n'}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all text next #}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{{'<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", ) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) - image_processor = CLIPImageProcessor(crop_size={"height": 1024, - "width": 1024}, - image_mean=[0.0, 0.0, 0.0], - image_std=[1.0, 1.0, 1.0], - size={"shortest_edge": 1024}) - + image_processor = CLIPImageProcessor( + crop_size={"height": 1024, "width": 1024}, + image_mean=[0.0, 0.0, 0.0], + image_std=[1.0, 1.0, 1.0], + size={"shortest_edge": 1024}, + ) + processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) - processor.patch_size = 64 # effective patch size (2^6) + processor.patch_size = 64 # effective patch size (2^6) model = FastVlmForConditionalGeneration(config) @@ -174,20 +180,12 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s dim=0, ) - conversation = [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What are these?"}, - {"type": "image"} - ] - } - ] + conversation = [{"role": "user", "content": [{"type": "text", "text": "What are these?"}, {"type": "image"}]}] prompt = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) image_file = "http://images.cocodataset.org/val2017/000000039769.jpg" raw_image = Image.open(requests.get(image_file, stream=True).raw) - inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to("cuda") + inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda") inputs = {k: (v.to(torch.bfloat16) if v.dtype == torch.float32 else v) for k, v in inputs.items()} model = model.cuda() @@ -199,10 +197,10 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s # otherwise numerical errors accumulate if output_hub_path == "KamilaMila/FastVLM-0.5B": expected_shape = torch.Size([1, 280, 152000]) - expected_slice = torch.tensor([ 4.1250, 9.6875, 11.1875], device="cuda") + expected_slice = torch.tensor([4.1250, 9.6875, 11.1875], device="cuda") elif output_hub_path == "KamilaMila/FastVLM-1.5B": expected_shape = torch.Size([1, 280, 152000]) - expected_slice = torch.tensor([ 3.3750, 11.5000, 11.8125], device="cuda") + expected_slice = torch.tensor([3.3750, 11.5000, 11.8125], device="cuda") elif output_hub_path == "KamilaMila/FastVLM-7B": expected_shape = torch.Size([1, 280, 152128]) expected_slice = torch.tensor([3.8281, 9.0625, 7.9062], device="cuda") @@ -215,6 +213,7 @@ def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_s processor.push_to_hub(output_hub_path) print("Successfully pushed to hub!") + def main(): parser = argparse.ArgumentParser( formatter_class=argparse.RawDescriptionHelpFormatter, @@ -243,5 +242,6 @@ def main(): args = parser.parse_args() convert_fastvlm_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id) + if __name__ == "__main__": main() diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index 353a36c7e727..08449af9978c 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -66,6 +66,21 @@ def forward(self, image_features): return hidden_states +@auto_docstring +class FastVlmPreTrainedModel(PreTrainedModel): + config: FastVlmConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + + _supports_flash_attn = True + _supports_sdpa = True + + _can_compile_fullgraph = True + _supports_flex_attn = True + _supports_attention_backend = True + + @dataclass @auto_docstring( custom_intro=""" @@ -88,21 +103,6 @@ class FastVlmModelOutputWithPast(BaseModelOutputWithPast): image_hidden_states: Optional[torch.FloatTensor] = None -@auto_docstring -class FastVlmPreTrainedModel(PreTrainedModel): - config: FastVlmConfig - base_model_prefix = "" - supports_gradient_checkpointing = True - _skip_keys_device_placement = "past_key_values" - - _supports_flash_attn = True - _supports_sdpa = True - - _can_compile_fullgraph = True - _supports_flex_attn = True - _supports_attention_backend = True - - @auto_docstring( custom_intro=""" The FastVlm model which consists of a vision backbone and a language model, without a language modeling head. @@ -168,7 +168,7 @@ def get_image_features( ) if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( - [layer >= 0 for layer in vision_feature_layer] + layer >= 0 for layer in vision_feature_layer ): raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 99d2dc478ce2..68551cf61561 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -19,11 +19,16 @@ from torch import nn from torch.nn.functional import adaptive_avg_pool2d -from ..llava.configuration_llava import LlavaConfig -from ..llava.modeling_llava import LlavaModel, LlavaForConditionalGeneration, LlavaMultiModalProjector +from ...activations import ACT2FN from ...configuration_utils import PretrainedConfig from ..auto import CONFIG_MAPPING -from ...activations import ACT2FN +from ..llava.configuration_llava import LlavaConfig +from ..llava.modeling_llava import ( + LlavaForConditionalGeneration, + LlavaModel, + LlavaMultiModalProjector, + LlavaPreTrainedModel, +) class FastVlmConfig(LlavaConfig): @@ -72,6 +77,7 @@ class FastVlmConfig(LlavaConfig): >>> # Accessing the model configuration >>> configuration = model.config ```""" + model_type = "fast_vlm" def __init__( @@ -91,14 +97,16 @@ def __init__( self.image_seq_length = image_seq_length if math.isqrt(image_seq_length).pow(2) != image_seq_length: raise ValueError(f"Inavalid image_seq_length: {image_seq_length}. It needs to be a perfect square.") - + if vision_feature_select_strategy != "full": raise ValueError( "Only vision_feature_select_strategy='full' supported in FastVLM!" f"Got: {vision_feature_select_strategy}" ) - - if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any([layer >= 0 for layer in vision_feature_layer]): + + if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( + layer >= 0 for layer in vision_feature_layer + ): raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") self.vision_feature_select_strategy = vision_feature_select_strategy @@ -114,7 +122,7 @@ def __init__( global_pool="avg", hidden_size=3072, initializer_range=0.02, - model_args={"inference_mode": True} + model_args={"inference_mode": True}, ) self.vision_config = vision_config @@ -123,18 +131,21 @@ def __init__( text_config["model_type"] = text_config.get("model_type", "qwen2") text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) elif text_config is None: - text_config = CONFIG_MAPPING["qwen2"](hidden_size=3584, - vocab_size=152128, - intermediate_size=18944, - num_attention_heads=28, - num_key_value_heads=4, - num_hidden_layers=28) + text_config = CONFIG_MAPPING["qwen2"]( + hidden_size=3584, + vocab_size=152128, + intermediate_size=18944, + num_attention_heads=28, + num_key_value_heads=4, + num_hidden_layers=28, + ) self.text_config = text_config self.multimodal_projector_bias = multimodal_projector_bias PretrainedConfig.__init__(**kwargs) + class FastVlmMultiModalProjector(LlavaMultiModalProjector): def __init__(self, config: FastVlmConfig): nn.Module.__init__() @@ -155,6 +166,12 @@ def __init__(self, config: FastVlmConfig): self.linear_2 = nn.Linear( config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias ) + + +class FastVlmPreTrainedModel(LlavaPreTrainedModel): + pass + + class FastVlmModel(LlavaModel): def get_image_features( self, @@ -172,7 +189,7 @@ def get_image_features( vision_feature_layer (`Union[int, list[int]]`, *optional*): The index/indices of the layer to select the vision feature. Must be negative. vision_feature_select_strategy (`str`, *optional*): - The feature selection strategy used to select the vision feature from the vision backbone. + The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported. Returns: image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). @@ -188,14 +205,18 @@ def get_image_features( # only this value makes sense in FastVLM (we can't have a CLS token in conv layers) if vision_feature_select_strategy != "full": - raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM.") - - if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any([layer >= 0 for layer in vision_feature_layer]): + raise ValueError( + f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." + ) + + if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( + layer >= 0 for layer in vision_feature_layer + ): raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") kwargs = {k: v for k, v in kwargs.items() if v is not None} # this is not memory-efficient at all - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava desired_shape = math.isqrt(self.config.image_seq_length) @@ -220,18 +241,19 @@ def get_image_features( image_features = self.multi_modal_projector(selected_image_feature) image_features = list(image_features) return image_features - + def forward(self, **super_kwargs): r""" vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): - The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the + The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. Must be negative. """ super().forward(**super_kwargs) - + + class FastVlmForConditionalGeneration(LlavaForConditionalGeneration): def forward(self, **super_kwargs): r""" @@ -244,7 +266,7 @@ def forward(self, **super_kwargs): The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): - The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the + The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. Must be negative. Example: @@ -268,6 +290,6 @@ def forward(self, **super_kwargs): >>> processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] ```""" super().forward(**super_kwargs) - -__all__ = ["FastVlmForConditionalGeneration", "FastVlmModel", "FastVlmPreTrainedModel", "FastVlmConfig"] \ No newline at end of file + +__all__ = ["FastVlmForConditionalGeneration", "FastVlmModel", "FastVlmPreTrainedModel", "FastVlmConfig"] From dc5e83e61ef2dc3fcdad7afbba3549157b3b239e Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 8 Oct 2025 20:56:20 +0200 Subject: [PATCH 12/31] Some small fixes --- docs/source/en/model_doc/fast_vlm.md | 10 +++---- .../models/fast_vlm/configuration_fast_vlm.py | 10 +++++-- .../models/fast_vlm/modeling_fast_vlm.py | 17 +++++++---- .../models/fast_vlm/modular_fast_vlm.py | 29 ++++++++++++++----- 4 files changed, 46 insertions(+), 20 deletions(-) diff --git a/docs/source/en/model_doc/fast_vlm.md b/docs/source/en/model_doc/fast_vlm.md index 7d72c728c3d1..3afd35d2511c 100644 --- a/docs/source/en/model_doc/fast_vlm.md +++ b/docs/source/en/model_doc/fast_vlm.md @@ -20,8 +20,8 @@ rendered properly in your Markdown viewer.
PyTorch - +FlashAttention +SDPA
## Overview @@ -189,7 +189,7 @@ import os os.environ["TIMM_FUSED_ATTN"] = "0" ``` -In addition, the layer norm used by Apple doesn't use the standard LayerNorm class form Torch and therefore our logits diverge. To get exactly the same values, one needs to manually change timm/layers/norm.py: +In addition, the layer norm used by Apple doesn't use the standard LayerNorm class form Torch and therefore our logits diverge. To get exactly the same values, one needs to manually change `timm/layers/norm.py`: ``` class LayerNorm2d(nn.LayerNorm): @@ -209,9 +209,9 @@ class LayerNorm2d(nn.LayerNorm): ``` Please note, that this is only needed in oder to get the exact same numerical values on the output of the model. It's not necessary to make this change to use FastVLM. - +Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [Flash Attention 2 section of performance docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one). ## Resources diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 06ee51e7e402..3c0ad6753cfd 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -19,6 +19,7 @@ # limitations under the License. import math +from collections.abc import Iterable from ...configuration_utils import PretrainedConfig from ..auto import CONFIG_MAPPING, AutoConfig @@ -92,7 +93,7 @@ def __init__( self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length - if math.isqrt(image_seq_length).pow(2) != image_seq_length: + if math.isqrt(image_seq_length) ** 2 != image_seq_length: raise ValueError(f"Inavalid image_seq_length: {image_seq_length}. It needs to be a perfect square.") if vision_feature_select_strategy != "full": @@ -101,8 +102,11 @@ def __init__( f"Got: {vision_feature_select_strategy}" ) - if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( - layer >= 0 for layer in vision_feature_layer + if any( + layer >= 0 + for layer in ( + vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] + ) ): raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index 08449af9978c..34c99606e831 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -19,6 +19,7 @@ # limitations under the License. import math +from collections.abc import Iterable from dataclasses import dataclass from typing import Optional, Union @@ -48,7 +49,7 @@ def __init__(self, config: FastVlmConfig): # different layers have different hidden sizes that are concatenated total_hidden_size = 0 for layer in layers: - total_hidden_size += config.vision_feature_layer // (2).pow(-layer - 1) + total_hidden_size += config.vision_config.hidden_size // (2 ** (-layer - 1)) self.linear_1 = nn.Linear( total_hidden_size, config.text_config.hidden_size, @@ -109,10 +110,13 @@ class FastVlmModelOutputWithPast(BaseModelOutputWithPast): """ ) class FastVlmModel(FastVlmPreTrainedModel): - _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + _checkpoint_conversion_mapping = {} def __init__(self, config: FastVlmConfig): super().__init__(config) + # Timm models don't support this way of setting attention mode so we set the vision config to eager while keeping the language part + # the same as the user requested + config.vision_config._attn_implementation = "eager" self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = FastVlmMultiModalProjector(config) @@ -167,8 +171,11 @@ def get_image_features( f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." ) - if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( - layer >= 0 for layer in vision_feature_layer + if any( + layer >= 0 + for layer in ( + vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] + ) ): raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") @@ -193,7 +200,7 @@ def get_image_features( selected_image_feature = image_outputs.hidden_states[layer_idx + 1] selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) hs_pool.append(selected_image_feature) - selected_image_feature = torch.cat(hs_pool, dim=-1) + selected_image_feature = torch.cat(hs_pool, dim=-3) selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) image_features = self.multi_modal_projector(selected_image_feature) diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 68551cf61561..04d545021cdf 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from collections.abc import Iterable from typing import Optional, Union import torch @@ -95,7 +96,7 @@ def __init__( self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length - if math.isqrt(image_seq_length).pow(2) != image_seq_length: + if math.isqrt(image_seq_length) ** 2 != image_seq_length: raise ValueError(f"Inavalid image_seq_length: {image_seq_length}. It needs to be a perfect square.") if vision_feature_select_strategy != "full": @@ -104,8 +105,11 @@ def __init__( f"Got: {vision_feature_select_strategy}" ) - if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( - layer >= 0 for layer in vision_feature_layer + if any( + layer >= 0 + for layer in ( + vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] + ) ): raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") @@ -156,7 +160,7 @@ def __init__(self, config: FastVlmConfig): # different layers have different hidden sizes that are concatenated total_hidden_size = 0 for layer in layers: - total_hidden_size += config.vision_feature_layer // (2).pow(-layer - 1) + total_hidden_size += config.vision_config.hidden_size // (2 ** (-layer - 1)) self.linear_1 = nn.Linear( total_hidden_size, config.text_config.hidden_size, @@ -173,6 +177,14 @@ class FastVlmPreTrainedModel(LlavaPreTrainedModel): class FastVlmModel(LlavaModel): + _checkpoint_conversion_mapping = {} + + def __init__(self, config: FastVlmConfig): + # Timm models don't support this way of setting attention mode so we set the vision config to eager while keeping the language part + # the same as the user requested + config.vision_config._attn_implementation = "eager" + super().__init__(config) + def get_image_features( self, pixel_values: torch.FloatTensor, @@ -209,8 +221,11 @@ def get_image_features( f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." ) - if (isinstance(vision_feature_layer, int) and vision_feature_layer >= 0) or any( - layer >= 0 for layer in vision_feature_layer + if any( + layer >= 0 + for layer in ( + vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] + ) ): raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") @@ -235,7 +250,7 @@ def get_image_features( selected_image_feature = image_outputs.hidden_states[layer_idx + 1] selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) hs_pool.append(selected_image_feature) - selected_image_feature = torch.cat(hs_pool, dim=-1) + selected_image_feature = torch.cat(hs_pool, dim=-3) selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) image_features = self.multi_modal_projector(selected_image_feature) From 64e24aeb937cd1c9f7671e7d0ba5a37ff3e20801 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 8 Oct 2025 22:30:51 +0200 Subject: [PATCH 13/31] Improved the example script to be more inclusive --- .../models/fast_vlm/modeling_fast_vlm.py | 16 +++++++--------- .../models/fast_vlm/modular_fast_vlm.py | 13 +++++++++---- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index 34c99606e831..de9d217364d5 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -50,6 +50,7 @@ def __init__(self, config: FastVlmConfig): total_hidden_size = 0 for layer in layers: total_hidden_size += config.vision_config.hidden_size // (2 ** (-layer - 1)) + self.linear_1 = nn.Linear( total_hidden_size, config.text_config.hidden_size, @@ -351,12 +352,7 @@ class FastVlmCausalLMOutputWithPast(ModelOutput): """ ) class FastVlmForConditionalGeneration(FastVlmPreTrainedModel, GenerationMixin): - _checkpoint_conversion_mapping = { - "^language_model.model": "model.language_model", - "^vision_tower": "model.vision_tower", - "^multi_modal_projector": "model.multi_modal_projector", - "^language_model.lm_head": "lm_head", - } + _checkpoint_conversion_mapping = {} _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: FastVlmConfig): @@ -449,14 +445,16 @@ def forward( >>> import requests >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration - >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-7B") - >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-7B") + >>> device = "cuda" if torch.cuda.is_available() else "cpu" + + >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B").to(device) + >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") >>> prompt = "<|im_start|>user\n\nWhat's the content of the image?<|im_end|>\n<|im_start|>assistant\n" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) >>> # Generate >>> generated_ids = model.generate(**inputs, max_new_tokens=15) diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 04d545021cdf..05dd499e7627 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -161,6 +161,7 @@ def __init__(self, config: FastVlmConfig): total_hidden_size = 0 for layer in layers: total_hidden_size += config.vision_config.hidden_size // (2 ** (-layer - 1)) + self.linear_1 = nn.Linear( total_hidden_size, config.text_config.hidden_size, @@ -270,6 +271,7 @@ def forward(self, **super_kwargs): class FastVlmForConditionalGeneration(LlavaForConditionalGeneration): + _checkpoint_conversion_mapping = {} def forward(self, **super_kwargs): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -290,19 +292,22 @@ def forward(self, **super_kwargs): >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration + >>> import torch + + >>> device = "cuda" if torch.cuda.is_available() else "cpu" - >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-7B") - >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-7B") + >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B").to(device) + >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") >>> prompt = "<|im_start|>user\n\nWhat's the content of the image?<|im_end|>\n<|im_start|>assistant\n" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) >>> # Generate >>> generated_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]) ```""" super().forward(**super_kwargs) From cf6336a584118d0fea5edeee4d8cbf659b9a92b7 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 8 Oct 2025 23:12:25 +0200 Subject: [PATCH 14/31] Fixes after the rebase --- .../models/fast_vlm/configuration_fast_vlm.py | 7 ++- .../models/fast_vlm/modeling_fast_vlm.py | 50 ++++--------------- .../models/fast_vlm/modular_fast_vlm.py | 8 +-- 3 files changed, 18 insertions(+), 47 deletions(-) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 3c0ad6753cfd..08e28cb77105 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -21,11 +21,11 @@ import math from collections.abc import Iterable -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PreTrainedConfig from ..auto import CONFIG_MAPPING, AutoConfig -class FastVlmConfig(PretrainedConfig): +class FastVlmConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`FastVLMForConditionalGeneration`]. It is used to instantiate a FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration @@ -90,6 +90,7 @@ def __init__( multimodal_projector_bias=True, **kwargs, ): + super().__init__(**kwargs) self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length @@ -144,7 +145,5 @@ def __init__( self.text_config = text_config self.multimodal_projector_bias = multimodal_projector_bias - super().__init__(**kwargs) - __all__ = ["FastVlmConfig"] diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index de9d217364d5..4298beb2bfeb 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -30,7 +30,6 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack @@ -92,8 +91,7 @@ class FastVlmPreTrainedModel(PreTrainedModel): class FastVlmModelOutputWithPast(BaseModelOutputWithPast): r""" past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -236,21 +234,17 @@ def get_placeholder_mask( @auto_docstring def forward( self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - image_sizes: torch.Tensor = None, - **kwargs: Unpack[FlashAttentionKwargs], + image_sizes: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple, FastVlmModelOutputWithPast]: r""" vision_feature_select_strategy (`str`, *optional*): @@ -260,11 +254,6 @@ def forward( The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. Must be negative. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) @@ -298,10 +287,6 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, cache_position=cache_position, **kwargs, ) @@ -328,8 +313,7 @@ class FastVlmCausalLMOutputWithPast(ModelOutput): logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. @@ -340,7 +324,7 @@ class FastVlmCausalLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None - past_key_values: Optional[list[torch.FloatTensor]] = None + past_key_values: Optional[Cache] = None hidden_states: Optional[tuple[torch.FloatTensor]] = None attentions: Optional[tuple[torch.FloatTensor]] = None image_hidden_states: Optional[torch.FloatTensor] = None @@ -407,8 +391,8 @@ def multi_modal_projector(self): @auto_docstring def forward( self, - input_ids: torch.LongTensor = None, - pixel_values: torch.FloatTensor = None, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, @@ -416,10 +400,6 @@ def forward( vision_feature_layer: Optional[Union[int, list[int]]] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, image_sizes: Optional[torch.Tensor] = None, @@ -444,6 +424,7 @@ def forward( >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration + >>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" @@ -458,13 +439,8 @@ def forward( >>> # Generate >>> generated_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]) ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) @@ -483,10 +459,6 @@ def forward( inputs_embeds=inputs_embeds, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, cache_position=cache_position, image_sizes=image_sizes, **kwargs, diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 05dd499e7627..bcdd52c1fb19 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -21,7 +21,7 @@ from torch.nn.functional import adaptive_avg_pool2d from ...activations import ACT2FN -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PreTrainedConfig from ..auto import CONFIG_MAPPING from ..llava.configuration_llava import LlavaConfig from ..llava.modeling_llava import ( @@ -93,6 +93,7 @@ def __init__( multimodal_projector_bias=True, **kwargs, ): + PreTrainedConfig.__init__(**kwargs) self.image_token_index = image_token_index self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length @@ -147,8 +148,6 @@ def __init__( self.text_config = text_config self.multimodal_projector_bias = multimodal_projector_bias - PretrainedConfig.__init__(**kwargs) - class FastVlmMultiModalProjector(LlavaMultiModalProjector): def __init__(self, config: FastVlmConfig): @@ -161,7 +160,7 @@ def __init__(self, config: FastVlmConfig): total_hidden_size = 0 for layer in layers: total_hidden_size += config.vision_config.hidden_size // (2 ** (-layer - 1)) - + self.linear_1 = nn.Linear( total_hidden_size, config.text_config.hidden_size, @@ -272,6 +271,7 @@ def forward(self, **super_kwargs): class FastVlmForConditionalGeneration(LlavaForConditionalGeneration): _checkpoint_conversion_mapping = {} + def forward(self, **super_kwargs): r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): From d428d6009aeec09d5b427830c07af2b89bf719b0 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Thu, 9 Oct 2025 01:03:36 +0200 Subject: [PATCH 15/31] Made the code and docs more readable and consistent --- .../models/fast_vlm/configuration_fast_vlm.py | 9 +++--- .../models/fast_vlm/modeling_fast_vlm.py | 16 +++++----- .../models/fast_vlm/modular_fast_vlm.py | 29 +++++++++++-------- 3 files changed, 29 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 08e28cb77105..1244a3f2e6af 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -29,7 +29,7 @@ class FastVlmConfig(PreTrainedConfig): r""" This is the configuration class to store the configuration of a [`FastVLMForConditionalGeneration`]. It is used to instantiate a FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield the same configurationa as the one of FastVLM-7B. + with the defaults will yield the same configuration as the one of FastVLM-7B. e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B) @@ -47,7 +47,7 @@ class FastVlmConfig(PreTrainedConfig): The activation function used by the multimodal projector. vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`): The feature selection strategy used to select the vision feature from the vision backbone. - Can only be `"full"`. + Only "full" supported. vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the @@ -99,8 +99,7 @@ def __init__( if vision_feature_select_strategy != "full": raise ValueError( - "Only vision_feature_select_strategy='full' supported in FastVLM!" - f"Got: {vision_feature_select_strategy}" + f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." ) if any( @@ -109,7 +108,7 @@ def __init__( vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] ) ): - raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") + raise ValueError(f"Only negative vision feature layer values are supported. Got {vision_feature_layer}") self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index 4298beb2bfeb..a558ab047075 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -176,7 +176,7 @@ def get_image_features( vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] ) ): - raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") + raise ValueError(f"Only negative vision feature layer values are supported. Got {vision_feature_layer}") kwargs = {k: v for k, v in kwargs.items() if v is not None} # this is not memory-efficient at all @@ -194,11 +194,11 @@ def get_image_features( hs_pool = [] for layer_idx in vision_feature_layer: if layer_idx == -1: - selected_image_feature = image_outputs.last_hidden_state + partial_image_feature = image_outputs.last_hidden_state else: - selected_image_feature = image_outputs.hidden_states[layer_idx + 1] - selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) - hs_pool.append(selected_image_feature) + partial_image_feature = image_outputs.hidden_states[layer_idx + 1] + partial_image_feature = adaptive_avg_pool2d(partial_image_feature, (desired_shape, desired_shape)) + hs_pool.append(partial_image_feature) selected_image_feature = torch.cat(hs_pool, dim=-3) selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) @@ -248,7 +248,7 @@ def forward( ) -> Union[tuple, FastVlmModelOutputWithPast]: r""" vision_feature_select_strategy (`str`, *optional*): - The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. + The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported. vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the @@ -332,7 +332,7 @@ class FastVlmCausalLMOutputWithPast(ModelOutput): @auto_docstring( custom_intro=""" - The FAST_VLM model which consists of a vision backbone and a language model. + The FastVlm model which consists of a vision backbone and a language model. """ ) class FastVlmForConditionalGeneration(FastVlmPreTrainedModel, GenerationMixin): @@ -412,7 +412,7 @@ def forward( (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. vision_feature_select_strategy (`str`, *optional*): - The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. + The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported. vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index bcdd52c1fb19..b50042267d5d 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -22,6 +22,7 @@ from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig +from ...utils import auto_docstring from ..auto import CONFIG_MAPPING from ..llava.configuration_llava import LlavaConfig from ..llava.modeling_llava import ( @@ -36,7 +37,7 @@ class FastVlmConfig(LlavaConfig): r""" This is the configuration class to store the configuration of a [`FastVLMForConditionalGeneration`]. It is used to instantiate a FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration - with the defaults will yield the same configurationa as the one of FastVLM-7B. + with the defaults will yield the same configuration as the one of FastVLM-7B. e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B) @@ -54,7 +55,7 @@ class FastVlmConfig(LlavaConfig): The activation function used by the multimodal projector. vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`): The feature selection strategy used to select the vision feature from the vision backbone. - Can only be `"full"`. + Only "full" supported. vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the @@ -102,8 +103,7 @@ def __init__( if vision_feature_select_strategy != "full": raise ValueError( - "Only vision_feature_select_strategy='full' supported in FastVLM!" - f"Got: {vision_feature_select_strategy}" + f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." ) if any( @@ -112,7 +112,7 @@ def __init__( vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] ) ): - raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") + raise ValueError(f"Only negative vision feature layer values are supported. Got {vision_feature_layer}") self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer @@ -227,7 +227,7 @@ def get_image_features( vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] ) ): - raise ValueError(f"Only negative layer values are supported. Got {vision_feature_layer}") + raise ValueError(f"Only negative vision feature layer values are supported. Got {vision_feature_layer}") kwargs = {k: v for k, v in kwargs.items() if v is not None} # this is not memory-efficient at all @@ -245,11 +245,11 @@ def get_image_features( hs_pool = [] for layer_idx in vision_feature_layer: if layer_idx == -1: - selected_image_feature = image_outputs.last_hidden_state + partial_image_feature = image_outputs.last_hidden_state else: - selected_image_feature = image_outputs.hidden_states[layer_idx + 1] - selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) - hs_pool.append(selected_image_feature) + partial_image_feature = image_outputs.hidden_states[layer_idx + 1] + partial_image_feature = adaptive_avg_pool2d(partial_image_feature, (desired_shape, desired_shape)) + hs_pool.append(partial_image_feature) selected_image_feature = torch.cat(hs_pool, dim=-3) selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) @@ -260,7 +260,7 @@ def get_image_features( def forward(self, **super_kwargs): r""" vision_feature_select_strategy (`str`, *optional*): - The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. + The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported. vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the @@ -269,6 +269,11 @@ def forward(self, **super_kwargs): super().forward(**super_kwargs) +@auto_docstring( + custom_intro=""" + The FastVlm model which consists of a vision backbone and a language model. + """ +) class FastVlmForConditionalGeneration(LlavaForConditionalGeneration): _checkpoint_conversion_mapping = {} @@ -280,7 +285,7 @@ def forward(self, **super_kwargs): (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. vision_feature_select_strategy (`str`, *optional*): - The feature selection strategy used to select the vision feature from the vision backbone. Can only be `"full"`. + The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported. vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the From adcea054b6ca69482d2a6b3c3d321d3c85847173 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Thu, 16 Oct 2025 14:40:53 +0200 Subject: [PATCH 16/31] Some fixes from the review --- docs/source/en/model_doc/fast_vlm.md | 39 +---------------- src/transformers/models/auto/modeling_auto.py | 2 + .../models/fast_vlm/configuration_fast_vlm.py | 14 ++----- .../fast_vlm/convert_fastvlm_weights_to_hf.py | 2 +- .../models/fast_vlm/modeling_fast_vlm.py | 31 ++++++-------- .../models/fast_vlm/modular_fast_vlm.py | 42 +++++++------------ 6 files changed, 35 insertions(+), 95 deletions(-) diff --git a/docs/source/en/model_doc/fast_vlm.md b/docs/source/en/model_doc/fast_vlm.md index 3afd35d2511c..48d3146c1d63 100644 --- a/docs/source/en/model_doc/fast_vlm.md +++ b/docs/source/en/model_doc/fast_vlm.md @@ -90,8 +90,6 @@ print(text_prompt) >>> "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nWhat’s shown in this image?<|im_end|>\n<|im_start|>assistant\n\nThis image shows a red stop sign.<|im_end|>\n<|im_start|>user\n\nDescribe the image in more details.<|im_end|>\n<|im_start|>assistant\n" ``` -πŸš€ **Bonus:** If you're using `transformers>=4.49.0`, you can also get a vectorized output from `apply_chat_template`. See the **Usage Examples** below for more details on how to use it. - ## Usage examples ### Single input inference @@ -181,47 +179,12 @@ processor.batch_decode(generate_ids, skip_special_tokens=True) ## Note regarding reproducing original implementation -In order to match the logits of the [original implementation](https://github.com/apple/ml-fastvlm), one needs to set the default timm attention implementation to the most basic version(not fused): - -``` -import os -# at the beginning of your script -os.environ["TIMM_FUSED_ATTN"] = "0" -``` - -In addition, the layer norm used by Apple doesn't use the standard LayerNorm class form Torch and therefore our logits diverge. To get exactly the same values, one needs to manually change `timm/layers/norm.py`: - -``` -class LayerNorm2d(nn.LayerNorm): - """ LayerNorm for channels of '2D' spatial NCHW tensors """ - _fast_norm: torch.jit.Final[bool] - - def __init__(): - ... # not important - - def forward(self, x: torch.Tensor) -> torch.Tensor: - u = x.mean(1, keepdim=True) - s = (x - u).pow(2).mean(1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.eps) - x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \ - + self.bias.unsqueeze(-1).unsqueeze(-1) - return x -``` -Please note, that this is only needed in oder to get the exact same numerical values on the output of the model. It's not necessary to make this change to use FastVLM. +In order to match the logits of the [original implementation](https://github.com/apple/ml-fastvlm), one needs to use float32. In half precision the logit difference is higher due to tiny differences in how some ops are implemented in timm. ### Using Flash Attention 2 Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [Flash Attention 2 section of performance docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one). -## Resources - -A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with image-to-text transformers (here using Llava as an example). - - - -- A [Google Colab demo](https://colab.research.google.com/drive/1qsl6cd2c8gGtEW1xV5io7S8NHh-Cp1TV?usp=sharing) on how to run Llava on a free-tier Google colab instance leveraging 4-bit inference. -- A [similar notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LLaVa/Inference_with_LLaVa_for_multimodal_generation.ipynb) showcasing batched inference. 🌎 - ## FastVlmConfig [[autodoc]] FastVlmConfig diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 197029464efd..e381d52a7e75 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -151,6 +151,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("falcon", "FalconModel"), ("falcon_h1", "FalconH1Model"), ("falcon_mamba", "FalconMambaModel"), + ("fast_vlm", "FastVLMModel"), ("fastspeech2_conformer", "FastSpeech2ConformerModel"), ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), ("flaubert", "FlaubertModel"), @@ -1024,6 +1025,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"), ("emu3", "Emu3ForConditionalGeneration"), ("evolla", "EvollaForProteinText2Text"), + ("fast_vlm", "FastVLMForConditionalGeneration"), ("florence2", "Florence2ForConditionalGeneration"), ("fuyu", "FuyuForCausalLM"), ("gemma3", "Gemma3ForConditionalGeneration"), diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 1244a3f2e6af..25b0088df47f 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -19,7 +19,6 @@ # limitations under the License. import math -from collections.abc import Iterable from ...configuration_utils import PreTrainedConfig from ..auto import CONFIG_MAPPING, AutoConfig @@ -82,7 +81,7 @@ def __init__( self, vision_config=None, text_config=None, - image_token_index=151646, + image_token_id=151646, projector_hidden_act="gelu", vision_feature_select_strategy="full", vision_feature_layer=-1, @@ -91,9 +90,10 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.image_token_index = image_token_index + self.image_token_id = image_token_id self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length + if math.isqrt(image_seq_length) ** 2 != image_seq_length: raise ValueError(f"Inavalid image_seq_length: {image_seq_length}. It needs to be a perfect square.") @@ -102,14 +102,6 @@ def __init__( f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." ) - if any( - layer >= 0 - for layer in ( - vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] - ) - ): - raise ValueError(f"Only negative vision feature layer values are supported. Got {vision_feature_layer}") - self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer diff --git a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py index 84edc8e27101..70fbedb2e2d6 100644 --- a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py +++ b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index a558ab047075..89a048f64f41 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -19,7 +19,6 @@ # limitations under the License. import math -from collections.abc import Iterable from dataclasses import dataclass from typing import Optional, Union @@ -113,9 +112,6 @@ class FastVlmModel(FastVlmPreTrainedModel): def __init__(self, config: FastVlmConfig): super().__init__(config) - # Timm models don't support this way of setting attention mode so we set the vision config to eager while keeping the language part - # the same as the user requested - config.vision_config._attn_implementation = "eager" self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = FastVlmMultiModalProjector(config) @@ -164,20 +160,6 @@ def get_image_features( else self.config.vision_feature_select_strategy ) - # only this value makes sense in FastVLM (we can't have a CLS token in conv layers) - if vision_feature_select_strategy != "full": - raise ValueError( - f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." - ) - - if any( - layer >= 0 - for layer in ( - vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] - ) - ): - raise ValueError(f"Only negative vision feature layer values are supported. Got {vision_feature_layer}") - kwargs = {k: v for k, v in kwargs.items() if v is not None} # this is not memory-efficient at all image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) @@ -431,7 +413,17 @@ def forward( >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B").to(device) >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") - >>> prompt = "<|im_start|>user\n\nWhat's the content of the image?<|im_end|>\n<|im_start|>assistant\n" + >>> conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What are these?"}, + {"type": "image"} + ] + } + ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) @@ -440,6 +432,7 @@ def forward( >>> # Generate >>> generated_ids = model.generate(**inputs, max_new_tokens=15) >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]) + system\n You are a helpful assistant.\n user\n What are these?\n assistant\n The image depicts a traditional Chinese street... ```""" vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index b50042267d5d..fd335106bff4 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -86,7 +86,7 @@ def __init__( self, vision_config=None, text_config=None, - image_token_index=151646, + image_token_id=151646, projector_hidden_act="gelu", vision_feature_select_strategy="full", vision_feature_layer=-1, @@ -95,9 +95,10 @@ def __init__( **kwargs, ): PreTrainedConfig.__init__(**kwargs) - self.image_token_index = image_token_index + self.image_token_id = image_token_id self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length + if math.isqrt(image_seq_length) ** 2 != image_seq_length: raise ValueError(f"Inavalid image_seq_length: {image_seq_length}. It needs to be a perfect square.") @@ -106,14 +107,6 @@ def __init__( f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." ) - if any( - layer >= 0 - for layer in ( - vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] - ) - ): - raise ValueError(f"Only negative vision feature layer values are supported. Got {vision_feature_layer}") - self.vision_feature_select_strategy = vision_feature_select_strategy self.vision_feature_layer = vision_feature_layer @@ -182,7 +175,7 @@ class FastVlmModel(LlavaModel): def __init__(self, config: FastVlmConfig): # Timm models don't support this way of setting attention mode so we set the vision config to eager while keeping the language part # the same as the user requested - config.vision_config._attn_implementation = "eager" + # config.vision_config._attn_implementation = "eager" super().__init__(config) def get_image_features( @@ -215,20 +208,6 @@ def get_image_features( else self.config.vision_feature_select_strategy ) - # only this value makes sense in FastVLM (we can't have a CLS token in conv layers) - if vision_feature_select_strategy != "full": - raise ValueError( - f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." - ) - - if any( - layer >= 0 - for layer in ( - vision_feature_layer if isinstance(vision_feature_layer, Iterable) else [vision_feature_layer] - ) - ): - raise ValueError(f"Only negative vision feature layer values are supported. Got {vision_feature_layer}") - kwargs = {k: v for k, v in kwargs.items() if v is not None} # this is not memory-efficient at all image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) @@ -304,7 +283,17 @@ def forward(self, **super_kwargs): >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B").to(device) >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") - >>> prompt = "<|im_start|>user\n\nWhat's the content of the image?<|im_end|>\n<|im_start|>assistant\n" + >>> conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What are these?"}, + {"type": "image"} + ] + } + ] + + >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) @@ -313,6 +302,7 @@ def forward(self, **super_kwargs): >>> # Generate >>> generated_ids = model.generate(**inputs, max_new_tokens=15) >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]) + system\n You are a helpful assistant.\n user\n What are these?\n assistant\n The image depicts a traditional Chinese street... ```""" super().forward(**super_kwargs) From d8664ec914f34953bccf1b8ccc1072562eda0d1c Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Thu, 16 Oct 2025 17:34:29 +0200 Subject: [PATCH 17/31] Reverted back to last layer only --- .../models/fast_vlm/configuration_fast_vlm.py | 14 ++--- .../models/fast_vlm/modeling_fast_vlm.py | 41 +++----------- .../models/fast_vlm/modular_fast_vlm.py | 54 +++++-------------- 3 files changed, 26 insertions(+), 83 deletions(-) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 25b0088df47f..d10efe490c95 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -18,8 +18,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - from ...configuration_utils import PreTrainedConfig from ..auto import CONFIG_MAPPING, AutoConfig @@ -50,7 +48,7 @@ class FastVlmConfig(PreTrainedConfig): vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the - vision features. Must be negative. + vision features. Only -1 supported. image_seq_length (`int`, *optional*, defaults to 256): Sequence length of one image embedding. multimodal_projector_bias (`bool`, *optional*, defaults to `True`): @@ -94,12 +92,14 @@ def __init__( self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length - if math.isqrt(image_seq_length) ** 2 != image_seq_length: - raise ValueError(f"Inavalid image_seq_length: {image_seq_length}. It needs to be a perfect square.") - if vision_feature_select_strategy != "full": raise ValueError( - f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." + f"Unexpected select feature strategy: {vision_feature_select_strategy}. Only 'full' is supported in FastVLM." + ) + + if vision_feature_layer != -1: + raise ValueError( + f"Unexpected vision feature layer: {vision_feature_select_strategy}. Only -1 is supported in FastVLM." ) self.vision_feature_select_strategy = vision_feature_select_strategy diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index 89a048f64f41..1068ab1e1614 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -18,13 +18,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from dataclasses import dataclass from typing import Optional, Union import torch from torch import nn -from torch.nn.functional import adaptive_avg_pool2d from ...activations import ACT2FN from ...cache_utils import Cache @@ -40,17 +38,8 @@ class FastVlmMultiModalProjector(nn.Module): def __init__(self, config: FastVlmConfig): super().__init__() - if isinstance(config.vision_feature_layer, int): - layers = [config.vision_feature_layer] - else: - layers = config.vision_feature_layer - # different layers have different hidden sizes that are concatenated - total_hidden_size = 0 - for layer in layers: - total_hidden_size += config.vision_config.hidden_size // (2 ** (-layer - 1)) - self.linear_1 = nn.Linear( - total_hidden_size, + config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias, ) @@ -144,7 +133,7 @@ def get_image_features( pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. vision_feature_layer (`Union[int, list[int]]`, *optional*): - The index/indices of the layer to select the vision feature. Must be negative. + The index/indices of the layer to select the vision feature. Only -1 supported. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported. @@ -161,28 +150,10 @@ def get_image_features( ) kwargs = {k: v for k, v in kwargs.items() if v is not None} - # this is not memory-efficient at all - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + image_outputs = self.vision_tower(pixel_values, **kwargs) # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava - desired_shape = math.isqrt(self.config.image_seq_length) - if isinstance(vision_feature_layer, int): - if vision_feature_layer == -1: - selected_image_feature = image_outputs.last_hidden_state - else: - selected_image_feature = image_outputs.hidden_states[vision_feature_layer + 1] - selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) - else: - hs_pool = [] - for layer_idx in vision_feature_layer: - if layer_idx == -1: - partial_image_feature = image_outputs.last_hidden_state - else: - partial_image_feature = image_outputs.hidden_states[layer_idx + 1] - partial_image_feature = adaptive_avg_pool2d(partial_image_feature, (desired_shape, desired_shape)) - hs_pool.append(partial_image_feature) - selected_image_feature = torch.cat(hs_pool, dim=-3) - + selected_image_feature = image_outputs.last_hidden_state selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) image_features = self.multi_modal_projector(selected_image_feature) image_features = list(image_features) @@ -234,7 +205,7 @@ def forward( vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the - corresponding indices will be concatenated to form the vision features. Must be negative. + corresponding indices will be concatenated to form the vision features. Only -1 supported. """ vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer @@ -398,7 +369,7 @@ def forward( vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the - corresponding indices will be concatenated to form the vision features. Must be negative. + corresponding indices will be concatenated to form the vision features. Only -1 supported. Example: diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index fd335106bff4..735f9946452b 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from collections.abc import Iterable from typing import Optional, Union import torch from torch import nn -from torch.nn.functional import adaptive_avg_pool2d from ...activations import ACT2FN from ...configuration_utils import PreTrainedConfig @@ -59,7 +56,7 @@ class FastVlmConfig(LlavaConfig): vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the - vision features. Must be negative. + vision features. Only -1 supported. image_seq_length (`int`, *optional*, defaults to 256): Sequence length of one image embedding. multimodal_projector_bias (`bool`, *optional*, defaults to `True`): @@ -99,12 +96,14 @@ def __init__( self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length - if math.isqrt(image_seq_length) ** 2 != image_seq_length: - raise ValueError(f"Inavalid image_seq_length: {image_seq_length}. It needs to be a perfect square.") - if vision_feature_select_strategy != "full": raise ValueError( - f"Unexpected select feature strategy: {vision_feature_select_strategy}, Only 'full' is supported in FastVLM." + f"Unexpected select feature strategy: {vision_feature_select_strategy}. Only 'full' is supported in FastVLM." + ) + + if vision_feature_layer != -1: + raise ValueError( + f"Unexpected vision feature layer: {vision_feature_select_strategy}. Only -1 is supported in FastVLM." ) self.vision_feature_select_strategy = vision_feature_select_strategy @@ -145,17 +144,8 @@ def __init__( class FastVlmMultiModalProjector(LlavaMultiModalProjector): def __init__(self, config: FastVlmConfig): nn.Module.__init__() - if isinstance(config.vision_feature_layer, int): - layers = [config.vision_feature_layer] - else: - layers = config.vision_feature_layer - # different layers have different hidden sizes that are concatenated - total_hidden_size = 0 - for layer in layers: - total_hidden_size += config.vision_config.hidden_size // (2 ** (-layer - 1)) - self.linear_1 = nn.Linear( - total_hidden_size, + config.vision_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias, ) @@ -192,7 +182,7 @@ def get_image_features( pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`): The tensors corresponding to the input images. vision_feature_layer (`Union[int, list[int]]`, *optional*): - The index/indices of the layer to select the vision feature. Must be negative. + The index/indices of the layer to select the vision feature. Only -1 supported. vision_feature_select_strategy (`str`, *optional*): The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported. @@ -209,28 +199,10 @@ def get_image_features( ) kwargs = {k: v for k, v in kwargs.items() if v is not None} - # this is not memory-efficient at all - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) + image_outputs = self.vision_tower(pixel_values, **kwargs) # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava - desired_shape = math.isqrt(self.config.image_seq_length) - if isinstance(vision_feature_layer, int): - if vision_feature_layer == -1: - selected_image_feature = image_outputs.last_hidden_state - else: - selected_image_feature = image_outputs.hidden_states[vision_feature_layer + 1] - selected_image_feature = adaptive_avg_pool2d(selected_image_feature, (desired_shape, desired_shape)) - else: - hs_pool = [] - for layer_idx in vision_feature_layer: - if layer_idx == -1: - partial_image_feature = image_outputs.last_hidden_state - else: - partial_image_feature = image_outputs.hidden_states[layer_idx + 1] - partial_image_feature = adaptive_avg_pool2d(partial_image_feature, (desired_shape, desired_shape)) - hs_pool.append(partial_image_feature) - selected_image_feature = torch.cat(hs_pool, dim=-3) - + selected_image_feature = image_outputs.last_hidden_state selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1) image_features = self.multi_modal_projector(selected_image_feature) image_features = list(image_features) @@ -243,7 +215,7 @@ def forward(self, **super_kwargs): vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the - corresponding indices will be concatenated to form the vision features. Must be negative. + corresponding indices will be concatenated to form the vision features. Only -1 supported. """ super().forward(**super_kwargs) @@ -268,7 +240,7 @@ def forward(self, **super_kwargs): vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the - corresponding indices will be concatenated to form the vision features. Must be negative. + corresponding indices will be concatenated to form the vision features. Only -1 supported. Example: From 065b79dfba47329627e48e26d4eb6921f2d3c13d Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Fri, 17 Oct 2025 11:54:11 +0200 Subject: [PATCH 18/31] Typos fixed --- src/transformers/models/auto/modeling_auto.py | 4 ++-- src/transformers/models/fast_vlm/configuration_fast_vlm.py | 2 +- src/transformers/models/fast_vlm/modular_fast_vlm.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e381d52a7e75..43d3546c3575 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -151,7 +151,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("falcon", "FalconModel"), ("falcon_h1", "FalconH1Model"), ("falcon_mamba", "FalconMambaModel"), - ("fast_vlm", "FastVLMModel"), + ("fast_vlm", "FastVlmModel"), ("fastspeech2_conformer", "FastSpeech2ConformerModel"), ("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"), ("flaubert", "FlaubertModel"), @@ -1025,7 +1025,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"), ("emu3", "Emu3ForConditionalGeneration"), ("evolla", "EvollaForProteinText2Text"), - ("fast_vlm", "FastVLMForConditionalGeneration"), + ("fast_vlm", "FastVlmForConditionalGeneration"), ("florence2", "Florence2ForConditionalGeneration"), ("fuyu", "FuyuForCausalLM"), ("gemma3", "Gemma3ForConditionalGeneration"), diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index d10efe490c95..fcee7b744089 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -24,7 +24,7 @@ class FastVlmConfig(PreTrainedConfig): r""" - This is the configuration class to store the configuration of a [`FastVLMForConditionalGeneration`]. It is used to instantiate a + This is the configuration class to store the configuration of a [`FastVlmForConditionalGeneration`]. It is used to instantiate a FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield the same configuration as the one of FastVLM-7B. diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 735f9946452b..ed27a686dd0a 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -32,7 +32,7 @@ class FastVlmConfig(LlavaConfig): r""" - This is the configuration class to store the configuration of a [`FastVLMForConditionalGeneration`]. It is used to instantiate a + This is the configuration class to store the configuration of a [`FastVlmForConditionalGeneration`]. It is used to instantiate a FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield the same configuration as the one of FastVLM-7B. From 3b9d90749ebc60dd7fa0bf1adb3c93b0f525f5c0 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Fri, 17 Oct 2025 16:17:38 +0200 Subject: [PATCH 19/31] added initial tests - some still failing --- tests/models/fast_vlm/__init__.py | 0 .../fast_vlm/test_configuration_fast_vlm.py | 14 + .../models/fast_vlm/test_modeling_fast_vlm.py | 327 ++++++++++++++++++ 3 files changed, 341 insertions(+) create mode 100644 tests/models/fast_vlm/__init__.py create mode 100644 tests/models/fast_vlm/test_configuration_fast_vlm.py create mode 100644 tests/models/fast_vlm/test_modeling_fast_vlm.py diff --git a/tests/models/fast_vlm/__init__.py b/tests/models/fast_vlm/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/fast_vlm/test_configuration_fast_vlm.py b/tests/models/fast_vlm/test_configuration_fast_vlm.py new file mode 100644 index 000000000000..134d050fb226 --- /dev/null +++ b/tests/models/fast_vlm/test_configuration_fast_vlm.py @@ -0,0 +1,14 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the FastVLM configuration.""" \ No newline at end of file diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py new file mode 100644 index 000000000000..7351dc89407e --- /dev/null +++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py @@ -0,0 +1,327 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the FastVLM model.""" + +import copy +import unittest + +import requests + +from transformers import ( + AutoProcessor, + BitsAndBytesConfig, + FastVlmConfig, + FastVlmForConditionalGeneration, + FastVlmModel, + is_torch_available, + is_vision_available, +) +from transformers.testing_utils import ( + cleanup, + require_bitsandbytes, + require_torch, + require_vision, + slow, + torch_device, +) + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor + + +if is_torch_available(): + import torch + + +if is_vision_available(): + from PIL import Image + + +class FastVlmVisionText2TextModelTester: + def __init__( + self, + parent, + ignore_index=-100, + image_token_id=0, + projector_hidden_act="gelu", + seq_length=7, + vision_feature_select_strategy="full", + vision_feature_layer=-1, + text_config={ + "model_type": "qwen2", + "is_training": True, + "vocab_size": 99, + "hidden_size": 32, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "intermediate_size": 37, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "initializer_range": 0.02, + "pad_token_id": 1, + }, + is_training=True, + vision_config={ + "image_size": 16, + "patch_size": 8, + "num_channels": 3, + "hidden_size": 32, + "initializer_range": 0.02, + "architecture": "fastvit_mci3", + "do_pooling": True, + "global_pool": "avg", + "model_args": { + "inference_mode": True, + "layers":(2, 2), + "embed_dims": (8, 16), + "mlp_ratios":(4, 4), + "se_downsamples": (False, False), + "downsamples": (False, True), + "pos_embs": (None, None), + "token_mixers":("repmixer", "repmixer"), + "lkc_use_act": True, + "stem_use_scale_branch": False} + }, + ): + self.parent = parent + self.ignore_index = ignore_index + self.image_token_id = image_token_id + self.projector_hidden_act = projector_hidden_act + self.vision_feature_select_strategy = vision_feature_select_strategy + self.vision_feature_layer = vision_feature_layer + self.text_config = text_config + self.vision_config = vision_config + self.pad_token_id = text_config["pad_token_id"] + + self.num_hidden_layers = text_config["num_hidden_layers"] + self.vocab_size = text_config["vocab_size"] + self.hidden_size = text_config["hidden_size"] + self.num_attention_heads = text_config["num_attention_heads"] + self.is_training = is_training + + self.batch_size = 3 + self.num_image_tokens = (self.vision_config["image_size"] // self.vision_config["patch_size"]) ** 2 + self.seq_length = seq_length + self.num_image_tokens + + def get_config(self): + return FastVlmConfig( + text_config=self.text_config, + vision_config=self.vision_config, + ignore_index=self.ignore_index, + image_token_id=self.image_token_id, + projector_hidden_act=self.projector_hidden_act, + vision_feature_select_strategy=self.vision_feature_select_strategy, + vision_feature_layer=self.vision_feature_layer, + image_seq_length=self.num_image_tokens, + ) + + def prepare_config_and_inputs(self): + pixel_values = floats_tensor( + [ + self.batch_size, + self.vision_config["num_channels"], + self.vision_config["image_size"], + self.vision_config["image_size"], + ] + ) + config = self.get_config() + + return config, pixel_values + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + config, pixel_values = config_and_inputs + input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1 + input_ids[input_ids == config.image_token_index] = self.pad_token_id + input_ids[:, : self.num_image_tokens] = config.image_token_index + attention_mask = input_ids.ne(1).to(torch_device) + + inputs_dict = { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": attention_mask, + } + return config, inputs_dict + + +@require_torch +class FastVlmForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): + """ + Model tester for `FastVlmForConditionalGeneration`. + """ + + all_model_classes = ( + ( + FastVlmModel, + FastVlmForConditionalGeneration, + ) + if is_torch_available() + else () + ) + pipeline_model_mapping = ( + {"image-to-text": FastVlmForConditionalGeneration, "image-text-to-text": FastVlmForConditionalGeneration} + if is_torch_available() + else {} + ) + + _is_composite = True + + def setUp(self): + self.model_tester = FastVlmVisionText2TextModelTester(self) + common_properties = ["image_token_index", "vision_feature_layer", "image_seq_length"] + self.config_tester = ConfigTester( + self, config_class=FastVlmConfig, has_text_modality=False, common_properties=common_properties + ) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_mismatching_num_image_tokens(self): + """ + Tests that an explicit error is thrown when the number of image tokens + doesn't match the number of image placeholders in the text. + We also test multi-image cases when one prompt has multiple image tokens. + """ + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_model_classes: + model = model_class(config).to(torch_device) + model.eval() + curr_input_dict = copy.deepcopy(input_dict) # in-place modifications further + _ = model(**curr_input_dict) # successful forward with no modifications + + # remove one image but leave all the image tokens in text + curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-2:, ...] + with self.assertRaises(ValueError): + _ = model(**curr_input_dict) + + # simulate the multi-image/single set of placeholders case by concatenating + input_ids = curr_input_dict["input_ids"][:1] + pixel_values = curr_input_dict["pixel_values"][:1] + pixel_values = torch.cat([pixel_values, pixel_values], dim=0) + + # two images and one set of image tokens raise an error + with self.assertRaises(ValueError): + _ = model(input_ids=input_ids, pixel_values=pixel_values) + + # two images and two sets of image tokens don't raise an error + input_ids = torch.cat([input_ids, input_ids], dim=0) + _ = model(input_ids=input_ids, pixel_values=pixel_values) + + # @unittest.skip( + # reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + # ) + # def test_training_gradient_checkpointing(self): + # pass + + # @unittest.skip( + # reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + # ) + # def test_training_gradient_checkpointing_use_reentrant(self): + # pass + + # @unittest.skip( + # reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + # ) + # def test_training_gradient_checkpointing_use_reentrant_false(self): + # pass + + # @unittest.skip( + # "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + # ) + # def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + # pass + + +@require_torch +class FastVlmForConditionalGenerationIntegrationTest(unittest.TestCase): + def setUp(self): + self.processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") + + def tearDown(self): + cleanup(torch_device, gc_collect=True) + + @slow + @require_bitsandbytes + @require_vision + def test_small_model_integration_test(self): + model = FastVlmForConditionalGeneration.from_pretrained( + "KamilaMila/FastVLM-0.5B", quantization_config=BitsAndBytesConfig(load_in_4bit=True) + ) + + prompt = "user\n\nWhat are the things I should be cautious about when I visit this place?\nassistant" + image_file = "https://llava-vl.github.io/static/images/view.jpg" + raw_image = Image.open(requests.get(image_file, stream=True).raw) + inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20) + expected_decoded_texts = """ +user\n\nWhat are the things I should be cautious about when I visit this place?\nassistant\nWhen visiting this place, you should be cautious about the following:\n\n1. Water safety: +""" # fmt: skip + + EXPECTED_DECODED_TEXT = expected_decoded_texts[1:-1] + + self.assertEqual( + self.processor.decode(output[0], skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + @require_vision + def test_small_model_integration_test_batch(self): + model = FastVlmForConditionalGeneration.from_pretrained( + "KamilaMila/FastVLM-0.5B", quantization_config=BitsAndBytesConfig(load_in_4bit=True) + ) + + prompts = [ + "user\n\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nassistant", + "user\n\nWhat is this?\nassistant", + ] + image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw) + image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + + inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True).to( + torch_device + ) + + output = model.generate(**inputs, max_new_tokens=20) + + EXPECTED_DECODED_TEXT = [ + "user\n\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nassistant\n\nWhen visiting this place, you should be cautious of the following:\n\n1. **Weather Conditions**:", + "user\n\nWhat is this?\nassistant\nThe image depicts two cats lying on a pink surface, which appears to be a couch or" + ] # fmt: skip + + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_bitsandbytes + def test_generation_no_images(self): + model_id = "KamilaMila/FastVLM-0.5B" + model = FastVlmForConditionalGeneration.from_pretrained( + model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True) + ) + processor = AutoProcessor.from_pretrained(model_id) + + # Prepare inputs with no images + inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device) + + # Make sure that `generate` works + _ = model.generate(**inputs, max_new_tokens=20) From 4ec2d234183ecc71e588a571ccd514a72a74ef6e Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Fri, 17 Oct 2025 16:22:17 +0200 Subject: [PATCH 20/31] Style and quality fixes --- .../models/fast_vlm/configuration_fast_vlm.py | 2 +- .../models/fast_vlm/modular_fast_vlm.py | 4 ++-- .../fast_vlm/test_configuration_fast_vlm.py | 2 +- tests/models/fast_vlm/test_modeling_fast_vlm.py | 17 +++++++++-------- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index fcee7b744089..1fc0c12621b1 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -38,7 +38,7 @@ class FastVlmConfig(PreTrainedConfig): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`): The config object or dictionary of the text backbone. - image_token_index (`int`, *optional*, defaults to 151646): + image_token_id (`int`, *optional*, defaults to 151646): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index ed27a686dd0a..ca6bfa63bebc 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -46,7 +46,7 @@ class FastVlmConfig(LlavaConfig): The config object or dictionary of the vision backbone. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`): The config object or dictionary of the text backbone. - image_token_index (`int`, *optional*, defaults to 151646): + image_token_id (`int`, *optional*, defaults to 151646): The image token index to encode the image prompt. projector_hidden_act (`str`, *optional*, defaults to `"gelu"`): The activation function used by the multimodal projector. @@ -100,7 +100,7 @@ def __init__( raise ValueError( f"Unexpected select feature strategy: {vision_feature_select_strategy}. Only 'full' is supported in FastVLM." ) - + if vision_feature_layer != -1: raise ValueError( f"Unexpected vision feature layer: {vision_feature_select_strategy}. Only -1 is supported in FastVLM." diff --git a/tests/models/fast_vlm/test_configuration_fast_vlm.py b/tests/models/fast_vlm/test_configuration_fast_vlm.py index 134d050fb226..c9aefa4e59c6 100644 --- a/tests/models/fast_vlm/test_configuration_fast_vlm.py +++ b/tests/models/fast_vlm/test_configuration_fast_vlm.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Testing suite for the FastVLM configuration.""" \ No newline at end of file +"""Testing suite for the FastVLM configuration.""" diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py index 7351dc89407e..b74d32824c5c 100644 --- a/tests/models/fast_vlm/test_modeling_fast_vlm.py +++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py @@ -87,16 +87,17 @@ def __init__( "global_pool": "avg", "model_args": { "inference_mode": True, - "layers":(2, 2), + "layers": (2, 2), "embed_dims": (8, 16), - "mlp_ratios":(4, 4), + "mlp_ratios": (4, 4), "se_downsamples": (False, False), "downsamples": (False, True), "pos_embs": (None, None), - "token_mixers":("repmixer", "repmixer"), + "token_mixers": ("repmixer", "repmixer"), "lkc_use_act": True, - "stem_use_scale_branch": False} + "stem_use_scale_branch": False, }, + }, ): self.parent = parent self.ignore_index = ignore_index @@ -157,7 +158,7 @@ def prepare_config_and_inputs_for_common(self): "attention_mask": attention_mask, } return config, inputs_dict - + @require_torch class FastVlmForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): @@ -271,7 +272,7 @@ def test_small_model_integration_test(self): output = model.generate(**inputs, max_new_tokens=20) expected_decoded_texts = """ user\n\nWhat are the things I should be cautious about when I visit this place?\nassistant\nWhen visiting this place, you should be cautious about the following:\n\n1. Water safety: -""" # fmt: skip +""" # fmt: skip EXPECTED_DECODED_TEXT = expected_decoded_texts[1:-1] @@ -302,9 +303,9 @@ def test_small_model_integration_test_batch(self): output = model.generate(**inputs, max_new_tokens=20) EXPECTED_DECODED_TEXT = [ - "user\n\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nassistant\n\nWhen visiting this place, you should be cautious of the following:\n\n1. **Weather Conditions**:", + "user\n\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nassistant\n\nWhen visiting this place, you should be cautious of the following:\n\n1. **Weather Conditions**:", "user\n\nWhat is this?\nassistant\nThe image depicts two cats lying on a pink surface, which appears to be a couch or" - ] # fmt: skip + ] # fmt: skip self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), From 6204dc262981def73a6d26d1e9d77c305b8146a2 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Fri, 17 Oct 2025 17:36:28 +0200 Subject: [PATCH 21/31] Updated modular according to the review --- .../models/fast_vlm/configuration_fast_vlm.py | 3 ++- .../models/fast_vlm/modular_fast_vlm.py | 6 ++---- .../models/fast_vlm/test_modeling_fast_vlm.py | 18 ------------------ 3 files changed, 4 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 1fc0c12621b1..23ef3ca253f4 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -87,7 +87,6 @@ def __init__( multimodal_projector_bias=True, **kwargs, ): - super().__init__(**kwargs) self.image_token_id = image_token_id self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length @@ -136,5 +135,7 @@ def __init__( self.text_config = text_config self.multimodal_projector_bias = multimodal_projector_bias + super().__init__(**kwargs) + __all__ = ["FastVlmConfig"] diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index ca6bfa63bebc..67dcbcad8a0b 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -91,7 +91,6 @@ def __init__( multimodal_projector_bias=True, **kwargs, ): - PreTrainedConfig.__init__(**kwargs) self.image_token_id = image_token_id self.projector_hidden_act = projector_hidden_act self.image_seq_length = image_seq_length @@ -140,6 +139,8 @@ def __init__( self.text_config = text_config self.multimodal_projector_bias = multimodal_projector_bias + PreTrainedConfig.__init__(**kwargs) + class FastVlmMultiModalProjector(LlavaMultiModalProjector): def __init__(self, config: FastVlmConfig): @@ -163,9 +164,6 @@ class FastVlmModel(LlavaModel): _checkpoint_conversion_mapping = {} def __init__(self, config: FastVlmConfig): - # Timm models don't support this way of setting attention mode so we set the vision config to eager while keeping the language part - # the same as the user requested - # config.vision_config._attn_implementation = "eager" super().__init__(config) def get_image_features( diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py index b74d32824c5c..5a1889fa6fa9 100644 --- a/tests/models/fast_vlm/test_modeling_fast_vlm.py +++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py @@ -223,24 +223,6 @@ def test_mismatching_num_image_tokens(self): input_ids = torch.cat([input_ids, input_ids], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values) - # @unittest.skip( - # reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - # ) - # def test_training_gradient_checkpointing(self): - # pass - - # @unittest.skip( - # reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - # ) - # def test_training_gradient_checkpointing_use_reentrant(self): - # pass - - # @unittest.skip( - # reason="This architecture seems to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" - # ) - # def test_training_gradient_checkpointing_use_reentrant_false(self): - # pass - # @unittest.skip( # "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" # ) From c15f4d7d668f6d17bb39780630e3f690299f27e3 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Mon, 20 Oct 2025 19:51:36 +0200 Subject: [PATCH 22/31] Tests passing and some suggested generic improvements --- src/transformers/integrations/accelerate.py | 4 ---- .../models/fast_vlm/configuration_fast_vlm.py | 2 +- .../models/fast_vlm/modular_fast_vlm.py | 2 +- tests/generation/test_utils.py | 15 +++++++++++++ .../fast_vlm/test_configuration_fast_vlm.py | 14 ------------- .../models/fast_vlm/test_modeling_fast_vlm.py | 10 ++++----- tests/test_modeling_common.py | 21 ++++++++++++++++++- 7 files changed, 41 insertions(+), 27 deletions(-) delete mode 100644 tests/models/fast_vlm/test_configuration_fast_vlm.py diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 79ef98d8a4dc..3e8b12976446 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -105,10 +105,6 @@ def init_on_device(device: "torch.device", include_buffers: bool = False): tst = nn.Linear(100, 100) # on `cuda` device ``` """ - if include_buffers: - with device: - yield - return old_register_parameter = nn.Module.register_parameter if include_buffers: diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 23ef3ca253f4..14967b083a62 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -98,7 +98,7 @@ def __init__( if vision_feature_layer != -1: raise ValueError( - f"Unexpected vision feature layer: {vision_feature_select_strategy}. Only -1 is supported in FastVLM." + f"Unexpected vision feature layer: {vision_feature_layer}. Only -1 is supported in FastVLM." ) self.vision_feature_select_strategy = vision_feature_select_strategy diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 67dcbcad8a0b..77f5086f684f 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -102,7 +102,7 @@ def __init__( if vision_feature_layer != -1: raise ValueError( - f"Unexpected vision feature layer: {vision_feature_select_strategy}. Only -1 is supported in FastVLM." + f"Unexpected vision feature layer: {vision_feature_layer}. Only -1 is supported in FastVLM." ) self.vision_feature_select_strategy = vision_feature_select_strategy diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 4120f0926f0f..a3fe02be6e32 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1894,6 +1894,13 @@ def test_flash_attention_2_continue_generate_with_position_ids(self): config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1 model = model_class(config) + if not all( + getattr(submodel, "_supports_flash_attn") + for submodel in model.modules() + if isinstance(submodel, PreTrainedModel) + ): + self.skipTest(f"At least some parts of {model_class.__name__} don't support flash attention") + if "position_ids" not in inspect.signature(model.forward).parameters: self.skipTest("Model does not support position_ids") @@ -1994,6 +2001,14 @@ def attention_mask_padding_matches_padding_free_with_position_ids( config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1 model = model_class(config) + if attn_implementation != "eager": + if not all( + getattr(submodel, support_flag[attn_implementation]) + for submodel in model.modules() + if isinstance(submodel, PreTrainedModel) + ): + self.skipTest(f"At least some parts of {model_class.__name__} don't support {attn_implementation}") + if "position_ids" not in inspect.signature(model.forward).parameters: self.skipTest("Model does not support position_ids") diff --git a/tests/models/fast_vlm/test_configuration_fast_vlm.py b/tests/models/fast_vlm/test_configuration_fast_vlm.py deleted file mode 100644 index c9aefa4e59c6..000000000000 --- a/tests/models/fast_vlm/test_configuration_fast_vlm.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Testing suite for the FastVLM configuration.""" diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py index 5a1889fa6fa9..e9800feb9e85 100644 --- a/tests/models/fast_vlm/test_modeling_fast_vlm.py +++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py @@ -184,7 +184,7 @@ class FastVlmForConditionalGenerationModelTest(ModelTesterMixin, GenerationTeste def setUp(self): self.model_tester = FastVlmVisionText2TextModelTester(self) - common_properties = ["image_token_index", "vision_feature_layer", "image_seq_length"] + common_properties = ["image_token_id", "image_seq_length"] self.config_tester = ConfigTester( self, config_class=FastVlmConfig, has_text_modality=False, common_properties=common_properties ) @@ -223,11 +223,9 @@ def test_mismatching_num_image_tokens(self): input_ids = torch.cat([input_ids, input_ids], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values) - # @unittest.skip( - # "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" - # ) - # def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): - # pass + @unittest.skip("Timm wrapper and backbone don't currently support HF initialization") + def test_can_init_all_missing_weights(self): + pass @require_torch diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 4ba9e1240e48..2546a53f04e7 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -41,6 +41,7 @@ set_seed, ) from transformers.integrations import HfDeepSpeedConfig +from transformers.integrations.accelerate import init_empty_weights from transformers.integrations.deepspeed import ( is_deepspeed_available, is_deepspeed_zero3_enabled, @@ -1251,6 +1252,8 @@ def test_attention_outputs(self): del inputs_dict["output_attentions"] config.output_attentions = True for k in config.sub_configs: + if self._is_composite and k == "vision_config": + continue if getattr(config, k) is not None: getattr(config, k).output_attentions = True @@ -1410,6 +1413,8 @@ def test_retain_grad_hidden_states_attentions(self): config.output_attentions = self.has_attentions for k in config.sub_configs: + if self._is_composite and k == "vision_config": # to be generalized + continue if getattr(config, k) is not None: getattr(config, k).output_attentions = self.has_attentions @@ -3191,6 +3196,11 @@ def test_flash_attn_2_fp32_ln(self): self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) + if not all( + submodel._supports_flex_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel) + ): + self.skipTest(reason="At least some parts of this model do not support flex attention") + with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) @@ -3285,6 +3295,15 @@ def flash_attn_from_config(self, attn_implementation: str, test_fwd_in_train: bo self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + with init_empty_weights(include_buffers=True): + model = model_class(config) # this model won't be used anywhere so we can initialize on meta + if not all( + submodel._supports_flex_attn + for submodel in model.modules() + if isinstance(submodel, PreTrainedModel) + ): + self.skipTest(reason="At least some parts of this model do not support flex attention") + # TODO: to change it in the future with other relevant auto classes fa_model = model_class._from_config( config, attn_implementation=attn_implementation, dtype=torch.bfloat16 @@ -3649,7 +3668,7 @@ def test_can_be_initialized_on_meta(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: # If it does not raise here, the test passes - with torch.device("meta"): + with init_empty_weights(include_buffers=True): _ = model_class(copy.deepcopy(config)) @require_torch_accelerator From 8d7ebfa79fb767170403aca394c6fddb92da278f Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Mon, 20 Oct 2025 20:46:13 +0200 Subject: [PATCH 23/31] Docs updated with another usage tip and an auto model --- docs/source/en/model_doc/fast_vlm.md | 14 ++++++++++++++ .../models/fast_vlm/modeling_fast_vlm.py | 4 ++-- .../models/fast_vlm/modular_fast_vlm.py | 4 ++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/fast_vlm.md b/docs/source/en/model_doc/fast_vlm.md index 48d3146c1d63..94f765792e2e 100644 --- a/docs/source/en/model_doc/fast_vlm.md +++ b/docs/source/en/model_doc/fast_vlm.md @@ -43,6 +43,20 @@ The original code can be found [here](https://github.com/apple/ml-fastvlm). - Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results. +**Important: ** + +Hugging Face models use SDPA by default; however, this model’s visual backbone supports only eager attention, so it automatically falls back to `"eager"`. + +If you want to use a different attention implementation in the language decoder, make sure to set it explicitly, for example: + +`model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", attn_implementation={"text_config": "flash_attention_2"})` + +Setting it for the entire model, e.g. + +`model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", attn_implementation="flash_attention_2")` + +will result in an error. + ### Formatting Prompts with Chat Templates Each **checkpoint** is trained with a specific prompt format, depending on the underlying large language model backbone. To ensure correct formatting, use the processor’s `apply_chat_template` method. diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index 1068ab1e1614..071e41e5457c 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -376,12 +376,12 @@ def forward( ```python >>> from PIL import Image >>> import requests - >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration + >>> from transformers import AutoProcessor, AutoModelForImageTextToText >>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" - >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B").to(device) + >>> model = AutoModelForImageTextToText.from_pretrained("KamilaMila/FastVLM-0.5B").to(device) >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") >>> conversation = [ diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index 77f5086f684f..e5b7c531f3a4 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -245,12 +245,12 @@ def forward(self, **super_kwargs): ```python >>> from PIL import Image >>> import requests - >>> from transformers import AutoProcessor, FastVlmForConditionalGeneration + >>> from transformers import AutoProcessor, AutoModelForImageTextToText >>> import torch >>> device = "cuda" if torch.cuda.is_available() else "cpu" - >>> model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B").to(device) + >>> model = AutoModelForImageTextToText.from_pretrained("KamilaMila/FastVLM-0.5B").to(device) >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") >>> conversation = [ From b3140d4114913aaaf523866f76ed74bb87ada564 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Tue, 21 Oct 2025 14:37:56 +0200 Subject: [PATCH 24/31] Reversed changes to test_can_intialize_on_meta becuase it's not fully compatible with one existing model --- src/transformers/integrations/accelerate.py | 4 ++++ tests/models/fast_vlm/test_modeling_fast_vlm.py | 5 ++++- tests/test_modeling_common.py | 3 +-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations/accelerate.py b/src/transformers/integrations/accelerate.py index 3e8b12976446..79ef98d8a4dc 100644 --- a/src/transformers/integrations/accelerate.py +++ b/src/transformers/integrations/accelerate.py @@ -105,6 +105,10 @@ def init_on_device(device: "torch.device", include_buffers: bool = False): tst = nn.Linear(100, 100) # on `cuda` device ``` """ + if include_buffers: + with device: + yield + return old_register_parameter = nn.Module.register_parameter if include_buffers: diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py index e9800feb9e85..7f30cd5c4ba1 100644 --- a/tests/models/fast_vlm/test_modeling_fast_vlm.py +++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py @@ -223,10 +223,13 @@ def test_mismatching_num_image_tokens(self): input_ids = torch.cat([input_ids, input_ids], dim=0) _ = model(input_ids=input_ids, pixel_values=pixel_values) - @unittest.skip("Timm wrapper and backbone don't currently support HF initialization") + @unittest.skip("Timm wrapper and backbone don't currently support full HF initialization") def test_can_init_all_missing_weights(self): pass + @unittest.skip("Timm can't be initialized on meta") + def test_can_be_initialized_on_meta(self): + pass @require_torch class FastVlmForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2546a53f04e7..159c401a51d9 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -41,7 +41,6 @@ set_seed, ) from transformers.integrations import HfDeepSpeedConfig -from transformers.integrations.accelerate import init_empty_weights from transformers.integrations.deepspeed import ( is_deepspeed_available, is_deepspeed_zero3_enabled, @@ -3668,7 +3667,7 @@ def test_can_be_initialized_on_meta(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: # If it does not raise here, the test passes - with init_empty_weights(include_buffers=True): + with torch.device("meta"): _ = model_class(copy.deepcopy(config)) @require_torch_accelerator From dd30f1facea6037372f1c9736d127e8707a5af11 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Tue, 21 Oct 2025 15:57:40 +0200 Subject: [PATCH 25/31] Some tweaks --- .../models/fast_vlm/test_modeling_fast_vlm.py | 1 + tests/test_modeling_common.py | 25 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py index 7f30cd5c4ba1..738e0e4092ad 100644 --- a/tests/models/fast_vlm/test_modeling_fast_vlm.py +++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py @@ -231,6 +231,7 @@ def test_can_init_all_missing_weights(self): def test_can_be_initialized_on_meta(self): pass + @require_torch class FastVlmForConditionalGenerationIntegrationTest(unittest.TestCase): def setUp(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 159c401a51d9..f34bdfe50e39 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1251,7 +1251,9 @@ def test_attention_outputs(self): del inputs_dict["output_attentions"] config.output_attentions = True for k in config.sub_configs: - if self._is_composite and k == "vision_config": + if ( + self._is_composite and k == "vision_config" + ): # skip because it's not needed and causes errors e.g with Timm continue if getattr(config, k) is not None: getattr(config, k).output_attentions = True @@ -1412,7 +1414,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_attentions = self.has_attentions for k in config.sub_configs: - if self._is_composite and k == "vision_config": # to be generalized + if ( + self._is_composite and k == "vision_config" + ): # # skip because it's not needed and causes errors e.g with Timm continue if getattr(config, k) is not None: getattr(config, k).output_attentions = self.has_attentions @@ -3196,9 +3200,9 @@ def test_flash_attn_2_fp32_ln(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) if not all( - submodel._supports_flex_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel) + submodel._supports_flash_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel) ): - self.skipTest(reason="At least some parts of this model do not support flex attention") + self.skipTest(reason="At least some parts of this model do not support flash attention") with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) @@ -3294,14 +3298,11 @@ def flash_attn_from_config(self, attn_implementation: str, test_fwd_in_train: bo self.skipTest(f"{model_class.__name__} does not support {attn_implementation}") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - with init_empty_weights(include_buffers=True): - model = model_class(config) # this model won't be used anywhere so we can initialize on meta - if not all( - submodel._supports_flex_attn - for submodel in model.modules() - if isinstance(submodel, PreTrainedModel) - ): - self.skipTest(reason="At least some parts of this model do not support flex attention") + model = model_class(config) # let's construct it here to see if any submodels can't support flash attn + if not all( + submodel._supports_flash_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel) + ): + self.skipTest(reason=f"At least some parts of this model do not support {attn_implementation}") # TODO: to change it in the future with other relevant auto classes fa_model = model_class._from_config( From d5b632943ad918797b4e15d9d0020f9b345e9582 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Tue, 21 Oct 2025 16:18:13 +0200 Subject: [PATCH 26/31] Typo fix --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f34bdfe50e39..a6cefb56548a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1416,7 +1416,7 @@ def test_retain_grad_hidden_states_attentions(self): for k in config.sub_configs: if ( self._is_composite and k == "vision_config" - ): # # skip because it's not needed and causes errors e.g with Timm + ): # skip because it's not needed and causes errors e.g with Timm continue if getattr(config, k) is not None: getattr(config, k).output_attentions = self.has_attentions From 3ee84e92292a599af16d3a7b6b271e1f064b4cbc Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Tue, 21 Oct 2025 16:48:36 +0200 Subject: [PATCH 27/31] Consistency fixed --- src/transformers/models/fast_vlm/modeling_fast_vlm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index 071e41e5457c..a427bf1b3d2b 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -59,6 +59,7 @@ def forward(self, image_features): class FastVlmPreTrainedModel(PreTrainedModel): config: FastVlmConfig base_model_prefix = "" + input_modalities = ["image", "text"] supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" From 46d401d1220f0def6457b2698232ced78051080f Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 22 Oct 2025 13:50:12 +0200 Subject: [PATCH 28/31] Review comment --- docs/source/en/model_doc/fast_vlm.md | 38 ---------------------------- 1 file changed, 38 deletions(-) diff --git a/docs/source/en/model_doc/fast_vlm.md b/docs/source/en/model_doc/fast_vlm.md index 94f765792e2e..25cbe3bff126 100644 --- a/docs/source/en/model_doc/fast_vlm.md +++ b/docs/source/en/model_doc/fast_vlm.md @@ -66,44 +66,6 @@ Each **checkpoint** is trained with a specific prompt format, depending on the u - Each message should be a dictionary with `"role"` and `"content"` keys. - The `"content"` should be a list of dictionaries for different modalities like `"text"` and `"image"`. - -Here’s an example of how to structure your input. -We will use a conversation history of text and image. Each content field has to be a list of dicts, as follows: - - -```python -from transformers import AutoProcessor - -processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") - -conversation = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What’s shown in this image?"}, - ], - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "This image shows a red stop sign."},] - }, - { - - "role": "user", - "content": [ - {"type": "text", "text": "Describe the image in more details."}, - ], - }, -] - -text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) - -# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images -print(text_prompt) ->>> "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nWhat’s shown in this image?<|im_end|>\n<|im_start|>assistant\n\nThis image shows a red stop sign.<|im_end|>\n<|im_start|>user\n\nDescribe the image in more details.<|im_end|>\n<|im_start|>assistant\n" -``` - ## Usage examples ### Single input inference From c6370772809ac964d3ee6d3fbd4d8c07a8e6e156 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Wed, 22 Oct 2025 15:03:03 +0200 Subject: [PATCH 29/31] Redundant config attr deleted --- src/transformers/models/fast_vlm/configuration_fast_vlm.py | 4 ---- src/transformers/models/fast_vlm/modular_fast_vlm.py | 4 ---- tests/models/fast_vlm/test_modeling_fast_vlm.py | 3 +-- 3 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py index 14967b083a62..46e5a6ccbf76 100644 --- a/src/transformers/models/fast_vlm/configuration_fast_vlm.py +++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py @@ -49,8 +49,6 @@ class FastVlmConfig(PreTrainedConfig): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. Only -1 supported. - image_seq_length (`int`, *optional*, defaults to 256): - Sequence length of one image embedding. multimodal_projector_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the multimodal projector. @@ -83,13 +81,11 @@ def __init__( projector_hidden_act="gelu", vision_feature_select_strategy="full", vision_feature_layer=-1, - image_seq_length=256, multimodal_projector_bias=True, **kwargs, ): self.image_token_id = image_token_id self.projector_hidden_act = projector_hidden_act - self.image_seq_length = image_seq_length if vision_feature_select_strategy != "full": raise ValueError( diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py index e5b7c531f3a4..e5d8c0908307 100644 --- a/src/transformers/models/fast_vlm/modular_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py @@ -57,8 +57,6 @@ class FastVlmConfig(LlavaConfig): The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the corresponding indices will be concatenated to form the vision features. Only -1 supported. - image_seq_length (`int`, *optional*, defaults to 256): - Sequence length of one image embedding. multimodal_projector_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the multimodal projector. @@ -87,13 +85,11 @@ def __init__( projector_hidden_act="gelu", vision_feature_select_strategy="full", vision_feature_layer=-1, - image_seq_length=256, multimodal_projector_bias=True, **kwargs, ): self.image_token_id = image_token_id self.projector_hidden_act = projector_hidden_act - self.image_seq_length = image_seq_length if vision_feature_select_strategy != "full": raise ValueError( diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py index 738e0e4092ad..60df8df6bc26 100644 --- a/tests/models/fast_vlm/test_modeling_fast_vlm.py +++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py @@ -128,7 +128,6 @@ def get_config(self): projector_hidden_act=self.projector_hidden_act, vision_feature_select_strategy=self.vision_feature_select_strategy, vision_feature_layer=self.vision_feature_layer, - image_seq_length=self.num_image_tokens, ) def prepare_config_and_inputs(self): @@ -184,7 +183,7 @@ class FastVlmForConditionalGenerationModelTest(ModelTesterMixin, GenerationTeste def setUp(self): self.model_tester = FastVlmVisionText2TextModelTester(self) - common_properties = ["image_token_id", "image_seq_length"] + common_properties = ["image_token_id"] self.config_tester = ConfigTester( self, config_class=FastVlmConfig, has_text_modality=False, common_properties=common_properties ) From 631553d2ed4ae3112b526230b3c2c262ee4f1b85 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Mon, 24 Nov 2025 18:58:05 +0100 Subject: [PATCH 30/31] Consistency fixed --- .../models/fast_vlm/modeling_fast_vlm.py | 29 ++----------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py index a427bf1b3d2b..82860e029f98 100644 --- a/src/transformers/models/fast_vlm/modeling_fast_vlm.py +++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py @@ -58,7 +58,7 @@ def forward(self, image_features): @auto_docstring class FastVlmPreTrainedModel(PreTrainedModel): config: FastVlmConfig - base_model_prefix = "" + base_model_prefix = "model" input_modalities = ["image", "text"] supports_gradient_checkpointing = True _skip_keys_device_placement = "past_key_values" @@ -114,12 +114,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def set_decoder(self, decoder): - self.language_model = decoder - - def get_decoder(self): - return self.language_model - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -291,7 +285,7 @@ class FastVlmCausalLMOutputWithPast(ModelOutput): ) class FastVlmForConditionalGeneration(FastVlmPreTrainedModel, GenerationMixin): _checkpoint_conversion_mapping = {} - _tied_weights_keys = ["lm_head.weight"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} def __init__(self, config: FastVlmConfig): super().__init__(config) @@ -308,12 +302,6 @@ def set_input_embeddings(self, value): def get_output_embeddings(self) -> nn.Module: return self.lm_head - def set_decoder(self, decoder): - self.model.set_decoder(decoder) - - def get_decoder(self): - return self.model.get_decoder() - def get_image_features( self, pixel_values: torch.FloatTensor, @@ -328,19 +316,6 @@ def get_image_features( **kwargs, ) - # Make modules available through conditional class for BC - @property - def language_model(self): - return self.model.language_model - - @property - def vision_tower(self): - return self.model.vision_tower - - @property - def multi_modal_projector(self): - return self.model.multi_modal_projector - @can_return_tuple @auto_docstring def forward( From 8e8c12a823cb138fe9e922bb056bf901a6822418 Mon Sep 17 00:00:00 2001 From: Kamila Luchay Date: Mon, 1 Dec 2025 14:45:45 +0100 Subject: [PATCH 31/31] Fixed integration tests after rebase --- .../models/fast_vlm/test_modeling_fast_vlm.py | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py index 60df8df6bc26..19971dfa1649 100644 --- a/tests/models/fast_vlm/test_modeling_fast_vlm.py +++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py @@ -20,7 +20,6 @@ from transformers import ( AutoProcessor, - BitsAndBytesConfig, FastVlmConfig, FastVlmForConditionalGeneration, FastVlmModel, @@ -29,7 +28,6 @@ ) from transformers.testing_utils import ( cleanup, - require_bitsandbytes, require_torch, require_vision, slow, @@ -232,6 +230,7 @@ def test_can_be_initialized_on_meta(self): @require_torch +@slow class FastVlmForConditionalGenerationIntegrationTest(unittest.TestCase): def setUp(self): self.processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B") @@ -239,12 +238,10 @@ def setUp(self): def tearDown(self): cleanup(torch_device, gc_collect=True) - @slow - @require_bitsandbytes @require_vision def test_small_model_integration_test(self): model = FastVlmForConditionalGeneration.from_pretrained( - "KamilaMila/FastVLM-0.5B", quantization_config=BitsAndBytesConfig(load_in_4bit=True) + "KamilaMila/FastVLM-0.5B", device_map=torch_device, dtype=torch.bfloat16 ) prompt = "user\n\nWhat are the things I should be cautious about when I visit this place?\nassistant" @@ -253,23 +250,19 @@ def test_small_model_integration_test(self): inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device) output = model.generate(**inputs, max_new_tokens=20) - expected_decoded_texts = """ -user\n\nWhat are the things I should be cautious about when I visit this place?\nassistant\nWhen visiting this place, you should be cautious about the following:\n\n1. Water safety: -""" # fmt: skip + expected_decoded_texts = "user\n\nWhat are the things I should be cautious about when I visit this place?\nassistant\n\nWhen visiting this place, there are a few things you should be cautious about:\n\n1. **" # fmt: skip - EXPECTED_DECODED_TEXT = expected_decoded_texts[1:-1] + EXPECTED_DECODED_TEXT = expected_decoded_texts self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) - @slow - @require_bitsandbytes @require_vision def test_small_model_integration_test_batch(self): model = FastVlmForConditionalGeneration.from_pretrained( - "KamilaMila/FastVLM-0.5B", quantization_config=BitsAndBytesConfig(load_in_4bit=True) + "KamilaMila/FastVLM-0.5B", device_map=torch_device, dtype=torch.bfloat16 ) prompts = [ @@ -286,8 +279,8 @@ def test_small_model_integration_test_batch(self): output = model.generate(**inputs, max_new_tokens=20) EXPECTED_DECODED_TEXT = [ - "user\n\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nassistant\n\nWhen visiting this place, you should be cautious of the following:\n\n1. **Weather Conditions**:", - "user\n\nWhat is this?\nassistant\nThe image depicts two cats lying on a pink surface, which appears to be a couch or" + "user\n\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nassistant\n\nWhen visiting this serene place, it's essential to be mindful of the following:\n\n1. **", + "user\n\nWhat is this?\nassistant\nThe image depicts two cats lying on a pink surface, which could be a couch or a" ] # fmt: skip self.assertEqual( @@ -295,12 +288,10 @@ def test_small_model_integration_test_batch(self): EXPECTED_DECODED_TEXT, ) - @slow - @require_bitsandbytes def test_generation_no_images(self): model_id = "KamilaMila/FastVLM-0.5B" model = FastVlmForConditionalGeneration.from_pretrained( - model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True) + model_id, device_map=torch_device, dtype=torch.bfloat16 ) processor = AutoProcessor.from_pretrained(model_id)