diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index c3036b8a3973..df64cfedd36c 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1020,6 +1020,8 @@
title: Emu3
- local: model_doc/evolla
title: Evolla
+ - local: model_doc/fast_vlm
+ title: FastVLM
- local: model_doc/flava
title: FLAVA
- local: model_doc/florence2
diff --git a/docs/source/en/model_doc/fast_vlm.md b/docs/source/en/model_doc/fast_vlm.md
new file mode 100644
index 000000000000..25cbe3bff126
--- /dev/null
+++ b/docs/source/en/model_doc/fast_vlm.md
@@ -0,0 +1,175 @@
+
+
+*This model was released on 2025-05-06 and added to Hugging Face Transformers on 2025-10-07.*
+
+# FastVLM
+
+
+
+## Overview
+
+FastVLM is an open-source vision-language model featuring a novel hybrid vision encoder, FastViTHD. Leveraging reparameterizable convolutional layers, scaled input resolution, and a reduced number of visual tokens, FastVLM delivers high accuracy with exceptional efficiency. Its optimized architecture enables deployment even on edge devices, achieving ultra-low TTFT (time to first token) without sacrificing performance.
+
+The model was proposed in [FastVLM: Efficient Vision Encoding for Vision Language Models](https://huggingface.co/papers/2412.13303) by Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel and Hadi Pouransari.
+
+The abstract from the paper is the following:
+
+*Scaling the input image resolution is essential for enhancing the performance of Vision Language Models (VLMs), particularly in text-rich image understanding tasks. However, popular visual encoders such as ViTs become inefficient at high resolutions due to the large number of tokens and high encoding latency. At different operational resolutions, the vision encoder of a VLM can be optimized along two axes: reducing encoding latency and minimizing the number of visual tokens passed to the LLM, thereby lowering overall latency. Based on a comprehensive efficiency analysis of the interplay between image resolution, vision latency, token count, and LLM size, we introduce FastVLM—a model that achieves an optimized trade-off between resolution, latency, and accuracy. FastVLM incorporates FastViTHD, a novel hybrid vision encoder designed to output fewer tokens and significantly reduce encoding time for high-resolution images. Unlike previous methods, FastVLM achieves the optimal balance between visual token count and image resolution solely by scaling the input image, eliminating the need for additional token pruning and simplifying the model design. In the LLaVA-1.5 setup, FastVLM achieves 3.2× improvement in time-to-first-token (TTFT) while maintaining similar performance on VLM benchmarks compared to prior works. Compared to LLaVa-OneVision at the highest resolution (1152×1152), FastVLM achieves better performance on key benchmarks like SeedBench, MMMU and DocVQA, using the same 0.5B LLM, but with 85× faster TTFT and a vision encoder that is 3.4× smaller.*
+
+This model was contributed by [Kamila](https://github.com/kamila-chay).
+The original code can be found [here](https://github.com/apple/ml-fastvlm).
+
+## Usage tips
+
+- We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating.
+
+- Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results.
+
+**Important: **
+
+Hugging Face models use SDPA by default; however, this model’s visual backbone supports only eager attention, so it automatically falls back to `"eager"`.
+
+If you want to use a different attention implementation in the language decoder, make sure to set it explicitly, for example:
+
+`model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", attn_implementation={"text_config": "flash_attention_2"})`
+
+Setting it for the entire model, e.g.
+
+`model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", attn_implementation="flash_attention_2")`
+
+will result in an error.
+
+### Formatting Prompts with Chat Templates
+
+Each **checkpoint** is trained with a specific prompt format, depending on the underlying large language model backbone. To ensure correct formatting, use the processor’s `apply_chat_template` method.
+
+**Important:**
+- You must construct a conversation history — passing a plain string won't work.
+- Each message should be a dictionary with `"role"` and `"content"` keys.
+- The `"content"` should be a list of dictionaries for different modalities like `"text"` and `"image"`.
+
+## Usage examples
+
+### Single input inference
+
+
+```python
+import torch
+from transformers import AutoProcessor, FastVlmForConditionalGeneration
+
+# Load the model in half-precision
+model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", dtype=torch.bfloat16, device_map="auto")
+processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B")
+
+conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+]
+
+inputs = processor.apply_chat_template(
+ conversation,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt"
+).to(model.device, torch.bfloat16)
+
+# Generate
+generate_ids = model.generate(**inputs, max_new_tokens=30)
+processor.batch_decode(generate_ids, skip_special_tokens=True)
+```
+
+
+### Batched inference
+
+FastVLM also supports batched inference. Here is how you can do it:
+
+```python
+import torch
+from transformers import AutoProcessor, FastVlmForConditionalGeneration
+
+# Load the model in half-precision
+model = FastVlmForConditionalGeneration.from_pretrained("KamilaMila/FastVLM-0.5B", dtype=torch.bfloat16, device_map="auto")
+processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B")
+
+
+# Prepare a batch of two prompts
+conversation_1 = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+]
+
+conversation_2 = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "image", "url": "http://images.cocodataset.org/val2017/000000039769.jpg"},
+ {"type": "text", "text": "What is shown in this image?"},
+ ],
+ },
+]
+
+inputs = processor.apply_chat_template(
+ [conversation_1, conversation_2],
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ padding=True,
+ return_tensors="pt"
+).to(model.device, torch.bfloat16)
+
+
+# Generate
+generate_ids = model.generate(**inputs, max_new_tokens=30)
+processor.batch_decode(generate_ids, skip_special_tokens=True)
+```
+
+
+## Note regarding reproducing original implementation
+
+In order to match the logits of the [original implementation](https://github.com/apple/ml-fastvlm), one needs to use float32. In half precision the logit difference is higher due to tiny differences in how some ops are implemented in timm.
+
+### Using Flash Attention 2
+
+Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [Flash Attention 2 section of performance docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one).
+
+## FastVlmConfig
+
+[[autodoc]] FastVlmConfig
+
+## FastVlmModel
+
+[[autodoc]] FastVlmModel
+
+## FastVlmForConditionalGeneration
+
+[[autodoc]] FastVlmForConditionalGeneration
+ - forward
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 75625aaff80f..d23425ad7c69 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -125,6 +125,7 @@
from .falcon import *
from .falcon_h1 import *
from .falcon_mamba import *
+ from .fast_vlm import *
from .fastspeech2_conformer import *
from .flaubert import *
from .flava import *
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index c55980e471c7..cb92aeb2ebfa 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -147,6 +147,7 @@
("falcon", "FalconConfig"),
("falcon_h1", "FalconH1Config"),
("falcon_mamba", "FalconMambaConfig"),
+ ("fast_vlm", "FastVlmConfig"),
("fastspeech2_conformer", "FastSpeech2ConformerConfig"),
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGanConfig"),
("flaubert", "FlaubertConfig"),
@@ -580,6 +581,7 @@
("falcon3", "Falcon3"),
("falcon_h1", "FalconH1"),
("falcon_mamba", "FalconMamba"),
+ ("fast_vlm", "FastVlm"),
("fastspeech2_conformer", "FastSpeech2Conformer"),
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
("flan-t5", "FLAN-T5"),
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index 22985f413341..3d2c49db6703 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -150,6 +150,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("falcon", "FalconModel"),
("falcon_h1", "FalconH1Model"),
("falcon_mamba", "FalconMambaModel"),
+ ("fast_vlm", "FastVlmModel"),
("fastspeech2_conformer", "FastSpeech2ConformerModel"),
("fastspeech2_conformer_with_hifigan", "FastSpeech2ConformerWithHifiGan"),
("flaubert", "FlaubertModel"),
@@ -986,6 +987,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("deepseek_vl_hybrid", "DeepseekVLHybridForConditionalGeneration"),
("emu3", "Emu3ForConditionalGeneration"),
("evolla", "EvollaForProteinText2Text"),
+ ("fast_vlm", "FastVlmForConditionalGeneration"),
("florence2", "Florence2ForConditionalGeneration"),
("fuyu", "FuyuForCausalLM"),
("gemma3", "Gemma3ForConditionalGeneration"),
diff --git a/src/transformers/models/fast_vlm/__init__.py b/src/transformers/models/fast_vlm/__init__.py
new file mode 100644
index 000000000000..949f087650dd
--- /dev/null
+++ b/src/transformers/models/fast_vlm/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_fast_vlm import *
+ from .modeling_fast_vlm import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/fast_vlm/configuration_fast_vlm.py b/src/transformers/models/fast_vlm/configuration_fast_vlm.py
new file mode 100644
index 000000000000..46e5a6ccbf76
--- /dev/null
+++ b/src/transformers/models/fast_vlm/configuration_fast_vlm.py
@@ -0,0 +1,137 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/fast_vlm/modular_fast_vlm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_fast_vlm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PreTrainedConfig
+from ..auto import CONFIG_MAPPING, AutoConfig
+
+
+class FastVlmConfig(PreTrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FastVlmForConditionalGeneration`]. It is used to instantiate a
+ FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield the same configuration as the one of FastVLM-7B.
+
+ e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `TimmWrapperConfig` for `fastvit_mci3`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
+ The config object or dictionary of the text backbone.
+ image_token_id (`int`, *optional*, defaults to 151646):
+ The image token index to encode the image prompt.
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The activation function used by the multimodal projector.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Only "full" supported.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features. Only -1 supported.
+ multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the multimodal projector.
+
+ Example:
+
+ ```python
+ >>> from transformers import FastVlmForConditionalGeneration, FastVlmConfig
+
+ >>> # Initializing a FastVLM-7B style configuration
+ >>> configuration = FastVlmConfig()
+
+ >>> # Initializing a model from the FastVLM-7B style configuration
+ >>> model = FastVlmForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "fast_vlm"
+ attribute_map = {
+ "image_token_id": "image_token_index",
+ }
+ sub_configs = {"text_config": AutoConfig, "vision_config": AutoConfig}
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ image_token_id=151646,
+ projector_hidden_act="gelu",
+ vision_feature_select_strategy="full",
+ vision_feature_layer=-1,
+ multimodal_projector_bias=True,
+ **kwargs,
+ ):
+ self.image_token_id = image_token_id
+ self.projector_hidden_act = projector_hidden_act
+
+ if vision_feature_select_strategy != "full":
+ raise ValueError(
+ f"Unexpected select feature strategy: {vision_feature_select_strategy}. Only 'full' is supported in FastVLM."
+ )
+
+ if vision_feature_layer != -1:
+ raise ValueError(
+ f"Unexpected vision feature layer: {vision_feature_layer}. Only -1 is supported in FastVLM."
+ )
+
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.vision_feature_layer = vision_feature_layer
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "timm_wrapper")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ vision_config = CONFIG_MAPPING["timm_wrapper"](
+ architecture="fastvit_mci3",
+ do_pooling=True,
+ global_pool="avg",
+ hidden_size=3072,
+ initializer_range=0.02,
+ model_args={"inference_mode": True},
+ )
+
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "qwen2")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["qwen2"](
+ hidden_size=3584,
+ vocab_size=152128,
+ intermediate_size=18944,
+ num_attention_heads=28,
+ num_key_value_heads=4,
+ num_hidden_layers=28,
+ )
+
+ self.text_config = text_config
+ self.multimodal_projector_bias = multimodal_projector_bias
+
+ super().__init__(**kwargs)
+
+
+__all__ = ["FastVlmConfig"]
diff --git a/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py
new file mode 100644
index 000000000000..70fbedb2e2d6
--- /dev/null
+++ b/src/transformers/models/fast_vlm/convert_fastvlm_weights_to_hf.py
@@ -0,0 +1,247 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import glob
+import os
+import re
+
+import requests
+import torch
+from huggingface_hub import snapshot_download
+from PIL import Image
+from safetensors import safe_open
+
+from transformers import (
+ AddedToken,
+ AutoConfig,
+ AutoTokenizer,
+ CLIPImageProcessor,
+ FastVlmConfig,
+ FastVlmForConditionalGeneration,
+ LlavaProcessor,
+)
+
+
+os.environ["TIMM_FUSED_ATTN"] = (
+ "0" # to avoid logits diverging, needed because the original implementation uses regular (not fused) atteniton
+)
+
+KEYS_TO_MODIFY_MAPPING = {
+ "model.vision_tower.vision_tower.model": "model.vision_tower.timm_model",
+ "patch_embed": "stem",
+ "layers": "language_model.layers",
+ "embed_tokens": "language_model.embed_tokens",
+ "layer_scale_1": "layer_scale_1.gamma",
+ "layer_scale_2": "layer_scale_2.gamma",
+ "mm_projector.0": "multi_modal_projector.linear_1",
+ "mm_projector.2": "multi_modal_projector.linear_2",
+ "conv_exp": "final_conv",
+ "se.reduce": "se.fc1",
+ "se.expand": "se.fc2",
+ "convffn": "mlp",
+ "lkb_reparam": "reparam_conv",
+}
+
+
+def map_to_stage(number):
+ number = int(number)
+ if number == 0:
+ return 0
+ if number in {1, 2}:
+ return 1
+ if number in {3, 4}:
+ return 2
+ if number in {5, 6, 7}:
+ return 3
+ if number in {8, 9, 10}:
+ return 4
+
+
+def load_original_state_dict(model_id):
+ directory_path = snapshot_download(repo_id=model_id, allow_patterns=["*.safetensors"])
+
+ original_state_dict = {}
+ for path in glob.glob(f"{directory_path}/*"):
+ if path.endswith(".safetensors"):
+ with safe_open(path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ original_state_dict[key] = f.get_tensor(key)
+
+ if "model.vision_tower.vision_tower.model.head.proj" in original_state_dict:
+ del original_state_dict["model.vision_tower.vision_tower.model.head.proj"]
+ return original_state_dict
+
+
+def convert_state_dict_to_hf(state_dict):
+ new_state_dict = {}
+
+ single_pattern = r"network\.(\d{1,2})"
+ double_pattern = r"network\.(\d{1,2})\.(\d{1,2})"
+ pos_embedding_pattern = r"stages\.(\d{1,2})\.reparam_conv"
+
+ for key, value in state_dict.items():
+ if key.endswith("layer_scale"):
+ key = key.replace("layer_scale", "layer_scale.gamma")
+ if key.startswith("model.norm"):
+ key = key.replace("model.norm", "model.language_model.norm")
+ if "token_mixer" not in key:
+ key = key.replace(".proj.", ".downsample.proj.")
+
+ matches = re.findall(double_pattern, key)
+ if len(matches) == 1:
+ match = matches[0]
+ key = key.replace(f"network.{match[0]}.{match[1]}", f"stages.{map_to_stage(match[0])}.blocks.{match[1]}")
+
+ matches = re.findall(single_pattern, key)
+ if len(matches) == 1:
+ match = matches[0]
+ key = key.replace(f"network.{match[0]}", f"stages.{map_to_stage(match[0])}")
+
+ matches = re.findall(pos_embedding_pattern, key)
+ if len(matches) == 1:
+ match = matches[0]
+ key = key.replace(f"stages.{match[0]}", f"stages.{match[0]}.pos_emb")
+
+ for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
+ if key_to_modify in key:
+ key = key.replace(key_to_modify, new_key)
+
+ new_state_dict[key] = value
+ return new_state_dict
+
+
+def convert_fastvlm_to_hf(text_model_id, vision_model_id, output_hub_path, old_state_dict_id):
+ torch.set_default_dtype(torch.bfloat16)
+
+ text_config = AutoConfig.from_pretrained(text_model_id)
+ vision_config = AutoConfig.from_pretrained(vision_model_id)
+ vision_config.model_args = {"inference_mode": True}
+ vision_config.hidden_size = vision_config.num_features
+ vision_config.label2id = {}
+ vision_config.id2label = {}
+ config = FastVlmConfig(
+ text_config=text_config,
+ vision_config=vision_config,
+ )
+ config.vision_feature_select_strategy = "full"
+ config.vision_feature_layer = -1
+ config.image_token_index = 151646
+ config.image_seq_length = 256
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ text_model_id,
+ chat_template="{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n'}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '' }}{% endfor %}{# Render all text next #}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ '\n' + content['text'] }}{% endfor %}{{'<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
+ )
+
+ tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True)
+ image_processor = CLIPImageProcessor(
+ crop_size={"height": 1024, "width": 1024},
+ image_mean=[0.0, 0.0, 0.0],
+ image_std=[1.0, 1.0, 1.0],
+ size={"shortest_edge": 1024},
+ )
+
+ processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor)
+ processor.patch_size = 64 # effective patch size (2^6)
+
+ model = FastVlmForConditionalGeneration(config)
+
+ state_dict = load_original_state_dict(old_state_dict_id)
+ state_dict = convert_state_dict_to_hf(state_dict)
+ model.load_state_dict(state_dict, strict=True, assign=True)
+
+ pre_expansion_embeddings = model.language_model.embed_tokens.weight.data
+ mu = torch.mean(pre_expansion_embeddings, dim=0).float()
+ n = pre_expansion_embeddings.size()[0]
+ sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
+ dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma)
+
+ # We add an image token so we resize the model and pad to 64 for performance reasons
+ pad_shape = 64
+ vocab_size = config.text_config.vocab_size
+ model.resize_token_embeddings(config.text_config.vocab_size + 1, pad_shape)
+ model.language_model.embed_tokens.weight.data[vocab_size:] = torch.stack(
+ tuple(dist.sample() for _ in range(model.language_model.embed_tokens.weight.data[vocab_size:].shape[0])),
+ dim=0,
+ )
+ model.lm_head.weight.data[vocab_size:] = torch.stack(
+ tuple(dist.sample() for _ in range(model.lm_head.weight.data[vocab_size:].shape[0])),
+ dim=0,
+ )
+
+ conversation = [{"role": "user", "content": [{"type": "text", "text": "What are these?"}, {"type": "image"}]}]
+ prompt = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
+
+ image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
+ raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ inputs = processor(images=raw_image, text=prompt, return_tensors="pt").to("cuda")
+ inputs = {k: (v.to(torch.bfloat16) if v.dtype == torch.float32 else v) for k, v in inputs.items()}
+
+ model = model.cuda()
+ model.eval()
+ with torch.no_grad():
+ logits = model(**inputs).logits
+
+ # in order to get the same logits as in the Apple repo, we need to manually replace the original (Apple) LayerNorm2D with Timm's LayerNorm2D or vice versa
+ # otherwise numerical errors accumulate
+ if output_hub_path == "KamilaMila/FastVLM-0.5B":
+ expected_shape = torch.Size([1, 280, 152000])
+ expected_slice = torch.tensor([4.1250, 9.6875, 11.1875], device="cuda")
+ elif output_hub_path == "KamilaMila/FastVLM-1.5B":
+ expected_shape = torch.Size([1, 280, 152000])
+ expected_slice = torch.tensor([3.3750, 11.5000, 11.8125], device="cuda")
+ elif output_hub_path == "KamilaMila/FastVLM-7B":
+ expected_shape = torch.Size([1, 280, 152128])
+ expected_slice = torch.tensor([3.8281, 9.0625, 7.9062], device="cuda")
+
+ logits_slice = logits[0, -1, :3]
+ assert torch.allclose(expected_slice, logits_slice, atol=1e-8)
+ assert logits.shape == expected_shape
+
+ model.push_to_hub(output_hub_path)
+ processor.push_to_hub(output_hub_path)
+ print("Successfully pushed to hub!")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--text_model_id",
+ default="Qwen/Qwen2-1.5B",
+ help="Hub location of the text model",
+ )
+ parser.add_argument(
+ "--vision_model_id",
+ default="timm/fastvit_mci3.apple_mclip2_dfndr2b",
+ help="Hub location of the vision model",
+ )
+ parser.add_argument(
+ "--output_hub_path",
+ default="KamilaMila/FastVLM-1.5B",
+ help="Location on the hub of the converted model",
+ )
+ parser.add_argument(
+ "--old_state_dict_id",
+ default="apple/FastVLM-1.5B",
+ help="Location on the hub of the raw state dict of the original model.",
+ )
+ args = parser.parse_args()
+ convert_fastvlm_to_hf(args.text_model_id, args.vision_model_id, args.output_hub_path, args.old_state_dict_id)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/transformers/models/fast_vlm/modeling_fast_vlm.py b/src/transformers/models/fast_vlm/modeling_fast_vlm.py
new file mode 100644
index 000000000000..82860e029f98
--- /dev/null
+++ b/src/transformers/models/fast_vlm/modeling_fast_vlm.py
@@ -0,0 +1,458 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/fast_vlm/modular_fast_vlm.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_fast_vlm.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
+from ..auto import AutoModel
+from .configuration_fast_vlm import FastVlmConfig
+
+
+class FastVlmMultiModalProjector(nn.Module):
+ def __init__(self, config: FastVlmConfig):
+ super().__init__()
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size,
+ config.text_config.hidden_size,
+ bias=config.multimodal_projector_bias,
+ )
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
+ )
+
+ def forward(self, image_features):
+ hidden_states = self.linear_1(image_features)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ return hidden_states
+
+
+@auto_docstring
+class FastVlmPreTrainedModel(PreTrainedModel):
+ config: FastVlmConfig
+ base_model_prefix = "model"
+ input_modalities = ["image", "text"]
+ supports_gradient_checkpointing = True
+ _skip_keys_device_placement = "past_key_values"
+
+ _supports_flash_attn = True
+ _supports_sdpa = True
+
+ _can_compile_fullgraph = True
+ _supports_flex_attn = True
+ _supports_attention_backend = True
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for FastVlm outputs, with hidden states and attentions.
+ """
+)
+class FastVlmModelOutputWithPast(BaseModelOutputWithPast):
+ r"""
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The FastVlm model which consists of a vision backbone and a language model, without a language modeling head.
+ """
+)
+class FastVlmModel(FastVlmPreTrainedModel):
+ _checkpoint_conversion_mapping = {}
+
+ def __init__(self, config: FastVlmConfig):
+ super().__init__(config)
+ self.vision_tower = AutoModel.from_config(config.vision_config)
+
+ self.multi_modal_projector = FastVlmMultiModalProjector(config)
+ self.language_model = AutoModel.from_config(config.text_config)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index/indices of the layer to select the vision feature. Only -1 supported.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Only "full" supported.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+ image_outputs = self.vision_tower(pixel_values, **kwargs)
+
+ # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava
+ selected_image_feature = image_outputs.last_hidden_state
+ selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1)
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = list(image_features)
+ return image_features
+
+ def get_placeholder_mask(
+ self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
+ ):
+ """
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
+ """
+ if input_ids is None:
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
+ )
+ special_image_mask = special_image_mask.all(-1)
+ else:
+ special_image_mask = input_ids == self.config.image_token_id
+
+ n_image_tokens = special_image_mask.sum()
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ n_image_features = image_features.shape[0] * image_features.shape[1]
+ if inputs_embeds[special_image_mask].numel() != image_features.numel():
+ raise ValueError(
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
+ )
+ return special_image_mask
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, FastVlmModelOutputWithPast]:
+ r"""
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported.
+
+ vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the
+ corresponding indices will be concatenated to form the vision features. Only -1 supported.
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ if pixel_values is not None:
+ image_features = self.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ image_sizes=image_sizes,
+ )
+ image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
+ special_image_mask = self.get_placeholder_mask(
+ input_ids, inputs_embeds=inputs_embeds, image_features=image_features
+ )
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ return FastVlmModelOutputWithPast(
+ last_hidden_state=outputs.last_hidden_state,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=image_features if pixel_values is not None else None,
+ )
+
+
+@dataclass
+@auto_docstring(
+ custom_intro="""
+ Base class for FastVlm causal language model (or autoregressive) outputs.
+ """
+)
+class FastVlmCausalLMOutputWithPast(ModelOutput):
+ r"""
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ image_hidden_states (`torch.FloatTensor`, *optional*):
+ A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
+ image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
+ """
+
+ loss: Optional[torch.FloatTensor] = None
+ logits: Optional[torch.FloatTensor] = None
+ past_key_values: Optional[Cache] = None
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
+ attentions: Optional[tuple[torch.FloatTensor]] = None
+ image_hidden_states: Optional[torch.FloatTensor] = None
+
+
+@auto_docstring(
+ custom_intro="""
+ The FastVlm model which consists of a vision backbone and a language model.
+ """
+)
+class FastVlmForConditionalGeneration(FastVlmPreTrainedModel, GenerationMixin):
+ _checkpoint_conversion_mapping = {}
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
+
+ def __init__(self, config: FastVlmConfig):
+ super().__init__(config)
+ self.model = FastVlmModel(config)
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.model.set_input_embeddings(value)
+
+ def get_output_embeddings(self) -> nn.Module:
+ return self.lm_head
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ return self.model.get_image_features(
+ pixel_values=pixel_values,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ **kwargs,
+ )
+
+ @can_return_tuple
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ pixel_values: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ labels: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0,
+ image_sizes: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[TransformersKwargs],
+ ) -> Union[tuple, FastVlmCausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported.
+
+ vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the
+ corresponding indices will be concatenated to form the vision features. Only -1 supported.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModelForImageTextToText
+ >>> import torch
+
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ >>> model = AutoModelForImageTextToText.from_pretrained("KamilaMila/FastVLM-0.5B").to(device)
+ >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B")
+
+ >>> conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What are these?"},
+ {"type": "image"}
+ ]
+ }
+ ]
+
+ >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=15)
+ >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
+ system\n You are a helpful assistant.\n user\n What are these?\n assistant\n The image depicts a traditional Chinese street...
+ ```"""
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ outputs = self.model(
+ input_ids=input_ids,
+ pixel_values=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ vision_feature_layer=vision_feature_layer,
+ vision_feature_select_strategy=vision_feature_select_strategy,
+ cache_position=cache_position,
+ image_sizes=image_sizes,
+ **kwargs,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(
+ logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
+ )
+
+ return FastVlmCausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ image_hidden_states=outputs.image_hidden_states,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ inputs_embeds=None,
+ pixel_values=None,
+ attention_mask=None,
+ cache_position=None,
+ logits_to_keep=None,
+ **kwargs,
+ ):
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
+
+ model_inputs = super().prepare_inputs_for_generation(
+ input_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ attention_mask=attention_mask,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs,
+ )
+
+ if cache_position[0] == 0:
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
+ # Otherwise we need pixel values to be passed to model
+ model_inputs["pixel_values"] = pixel_values
+
+ return model_inputs
+
+
+__all__ = ["FastVlmForConditionalGeneration", "FastVlmModel", "FastVlmPreTrainedModel"]
diff --git a/src/transformers/models/fast_vlm/modular_fast_vlm.py b/src/transformers/models/fast_vlm/modular_fast_vlm.py
new file mode 100644
index 000000000000..e5d8c0908307
--- /dev/null
+++ b/src/transformers/models/fast_vlm/modular_fast_vlm.py
@@ -0,0 +1,276 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+
+from ...activations import ACT2FN
+from ...configuration_utils import PreTrainedConfig
+from ...utils import auto_docstring
+from ..auto import CONFIG_MAPPING
+from ..llava.configuration_llava import LlavaConfig
+from ..llava.modeling_llava import (
+ LlavaForConditionalGeneration,
+ LlavaModel,
+ LlavaMultiModalProjector,
+ LlavaPreTrainedModel,
+)
+
+
+class FastVlmConfig(LlavaConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`FastVlmForConditionalGeneration`]. It is used to instantiate a
+ FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield the same configuration as the one of FastVLM-7B.
+
+ e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `TimmWrapperConfig` for `fastvit_mci3`):
+ The config object or dictionary of the vision backbone.
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
+ The config object or dictionary of the text backbone.
+ image_token_id (`int`, *optional*, defaults to 151646):
+ The image token index to encode the image prompt.
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The activation function used by the multimodal projector.
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Only "full" supported.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1):
+ The index of the layer to select the vision feature. If multiple indices are provided,
+ the vision feature of the corresponding indices will be concatenated to form the
+ vision features. Only -1 supported.
+ multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
+ Whether to use bias in the multimodal projector.
+
+ Example:
+
+ ```python
+ >>> from transformers import FastVlmForConditionalGeneration, FastVlmConfig
+
+ >>> # Initializing a FastVLM-7B style configuration
+ >>> configuration = FastVlmConfig()
+
+ >>> # Initializing a model from the FastVLM-7B style configuration
+ >>> model = FastVlmForConditionalGeneration(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "fast_vlm"
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ image_token_id=151646,
+ projector_hidden_act="gelu",
+ vision_feature_select_strategy="full",
+ vision_feature_layer=-1,
+ multimodal_projector_bias=True,
+ **kwargs,
+ ):
+ self.image_token_id = image_token_id
+ self.projector_hidden_act = projector_hidden_act
+
+ if vision_feature_select_strategy != "full":
+ raise ValueError(
+ f"Unexpected select feature strategy: {vision_feature_select_strategy}. Only 'full' is supported in FastVLM."
+ )
+
+ if vision_feature_layer != -1:
+ raise ValueError(
+ f"Unexpected vision feature layer: {vision_feature_layer}. Only -1 is supported in FastVLM."
+ )
+
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.vision_feature_layer = vision_feature_layer
+
+ if isinstance(vision_config, dict):
+ vision_config["model_type"] = vision_config.get("model_type", "timm_wrapper")
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
+ elif vision_config is None:
+ vision_config = CONFIG_MAPPING["timm_wrapper"](
+ architecture="fastvit_mci3",
+ do_pooling=True,
+ global_pool="avg",
+ hidden_size=3072,
+ initializer_range=0.02,
+ model_args={"inference_mode": True},
+ )
+
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = text_config.get("model_type", "qwen2")
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["qwen2"](
+ hidden_size=3584,
+ vocab_size=152128,
+ intermediate_size=18944,
+ num_attention_heads=28,
+ num_key_value_heads=4,
+ num_hidden_layers=28,
+ )
+
+ self.text_config = text_config
+ self.multimodal_projector_bias = multimodal_projector_bias
+
+ PreTrainedConfig.__init__(**kwargs)
+
+
+class FastVlmMultiModalProjector(LlavaMultiModalProjector):
+ def __init__(self, config: FastVlmConfig):
+ nn.Module.__init__()
+ self.linear_1 = nn.Linear(
+ config.vision_config.hidden_size,
+ config.text_config.hidden_size,
+ bias=config.multimodal_projector_bias,
+ )
+ self.act = ACT2FN[config.projector_hidden_act]
+ self.linear_2 = nn.Linear(
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
+ )
+
+
+class FastVlmPreTrainedModel(LlavaPreTrainedModel):
+ pass
+
+
+class FastVlmModel(LlavaModel):
+ _checkpoint_conversion_mapping = {}
+
+ def __init__(self, config: FastVlmConfig):
+ super().__init__(config)
+
+ def get_image_features(
+ self,
+ pixel_values: torch.FloatTensor,
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
+ vision_feature_select_strategy: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
+
+ Args:
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
+ The tensors corresponding to the input images.
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
+ The index/indices of the layer to select the vision feature. Only -1 supported.
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone.
+ Only "full" supported.
+ Returns:
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
+ """
+ vision_feature_layer = (
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
+ )
+ vision_feature_select_strategy = (
+ vision_feature_select_strategy
+ if vision_feature_select_strategy is not None
+ else self.config.vision_feature_select_strategy
+ )
+
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+ image_outputs = self.vision_tower(pixel_values, **kwargs)
+
+ # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava
+ selected_image_feature = image_outputs.last_hidden_state
+ selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1)
+ image_features = self.multi_modal_projector(selected_image_feature)
+ image_features = list(image_features)
+ return image_features
+
+ def forward(self, **super_kwargs):
+ r"""
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported.
+
+ vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the
+ corresponding indices will be concatenated to form the vision features. Only -1 supported.
+ """
+ super().forward(**super_kwargs)
+
+
+@auto_docstring(
+ custom_intro="""
+ The FastVlm model which consists of a vision backbone and a language model.
+ """
+)
+class FastVlmForConditionalGeneration(LlavaForConditionalGeneration):
+ _checkpoint_conversion_mapping = {}
+
+ def forward(self, **super_kwargs):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ vision_feature_select_strategy (`str`, *optional*):
+ The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported.
+
+ vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*):
+ The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the
+ corresponding indices will be concatenated to form the vision features. Only -1 supported.
+
+ Example:
+
+ ```python
+ >>> from PIL import Image
+ >>> import requests
+ >>> from transformers import AutoProcessor, AutoModelForImageTextToText
+ >>> import torch
+
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ >>> model = AutoModelForImageTextToText.from_pretrained("KamilaMila/FastVLM-0.5B").to(device)
+ >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B")
+
+ >>> conversation = [
+ {
+ "role": "user",
+ "content": [
+ {"type": "text", "text": "What are these?"},
+ {"type": "image"}
+ ]
+ }
+ ]
+
+ >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
+ >>> image = Image.open(requests.get(url, stream=True).raw)
+
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
+
+ >>> # Generate
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=15)
+ >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
+ system\n You are a helpful assistant.\n user\n What are these?\n assistant\n The image depicts a traditional Chinese street...
+ ```"""
+ super().forward(**super_kwargs)
+
+
+__all__ = ["FastVlmForConditionalGeneration", "FastVlmModel", "FastVlmPreTrainedModel", "FastVlmConfig"]
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index d1f14ad0c9c8..82d198f8ddbb 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -1894,6 +1894,13 @@ def test_flash_attention_2_continue_generate_with_position_ids(self):
config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1
model = model_class(config)
+ if not all(
+ getattr(submodel, "_supports_flash_attn")
+ for submodel in model.modules()
+ if isinstance(submodel, PreTrainedModel)
+ ):
+ self.skipTest(f"At least some parts of {model_class.__name__} don't support flash attention")
+
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
@@ -1994,6 +2001,14 @@ def attention_mask_padding_matches_padding_free_with_position_ids(
config.max_position_embeddings = max_new_tokens + dummy_input_ids.shape[1] + 1
model = model_class(config)
+ if attn_implementation != "eager":
+ if not all(
+ getattr(submodel, support_flag[attn_implementation])
+ for submodel in model.modules()
+ if isinstance(submodel, PreTrainedModel)
+ ):
+ self.skipTest(f"At least some parts of {model_class.__name__} don't support {attn_implementation}")
+
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")
diff --git a/tests/models/fast_vlm/__init__.py b/tests/models/fast_vlm/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/models/fast_vlm/test_modeling_fast_vlm.py b/tests/models/fast_vlm/test_modeling_fast_vlm.py
new file mode 100644
index 000000000000..19971dfa1649
--- /dev/null
+++ b/tests/models/fast_vlm/test_modeling_fast_vlm.py
@@ -0,0 +1,302 @@
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the FastVLM model."""
+
+import copy
+import unittest
+
+import requests
+
+from transformers import (
+ AutoProcessor,
+ FastVlmConfig,
+ FastVlmForConditionalGeneration,
+ FastVlmModel,
+ is_torch_available,
+ is_vision_available,
+)
+from transformers.testing_utils import (
+ cleanup,
+ require_torch,
+ require_vision,
+ slow,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
+
+
+if is_torch_available():
+ import torch
+
+
+if is_vision_available():
+ from PIL import Image
+
+
+class FastVlmVisionText2TextModelTester:
+ def __init__(
+ self,
+ parent,
+ ignore_index=-100,
+ image_token_id=0,
+ projector_hidden_act="gelu",
+ seq_length=7,
+ vision_feature_select_strategy="full",
+ vision_feature_layer=-1,
+ text_config={
+ "model_type": "qwen2",
+ "is_training": True,
+ "vocab_size": 99,
+ "hidden_size": 32,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 4,
+ "num_key_value_heads": 4,
+ "intermediate_size": 37,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "attention_probs_dropout_prob": 0.1,
+ "max_position_embeddings": 512,
+ "initializer_range": 0.02,
+ "pad_token_id": 1,
+ },
+ is_training=True,
+ vision_config={
+ "image_size": 16,
+ "patch_size": 8,
+ "num_channels": 3,
+ "hidden_size": 32,
+ "initializer_range": 0.02,
+ "architecture": "fastvit_mci3",
+ "do_pooling": True,
+ "global_pool": "avg",
+ "model_args": {
+ "inference_mode": True,
+ "layers": (2, 2),
+ "embed_dims": (8, 16),
+ "mlp_ratios": (4, 4),
+ "se_downsamples": (False, False),
+ "downsamples": (False, True),
+ "pos_embs": (None, None),
+ "token_mixers": ("repmixer", "repmixer"),
+ "lkc_use_act": True,
+ "stem_use_scale_branch": False,
+ },
+ },
+ ):
+ self.parent = parent
+ self.ignore_index = ignore_index
+ self.image_token_id = image_token_id
+ self.projector_hidden_act = projector_hidden_act
+ self.vision_feature_select_strategy = vision_feature_select_strategy
+ self.vision_feature_layer = vision_feature_layer
+ self.text_config = text_config
+ self.vision_config = vision_config
+ self.pad_token_id = text_config["pad_token_id"]
+
+ self.num_hidden_layers = text_config["num_hidden_layers"]
+ self.vocab_size = text_config["vocab_size"]
+ self.hidden_size = text_config["hidden_size"]
+ self.num_attention_heads = text_config["num_attention_heads"]
+ self.is_training = is_training
+
+ self.batch_size = 3
+ self.num_image_tokens = (self.vision_config["image_size"] // self.vision_config["patch_size"]) ** 2
+ self.seq_length = seq_length + self.num_image_tokens
+
+ def get_config(self):
+ return FastVlmConfig(
+ text_config=self.text_config,
+ vision_config=self.vision_config,
+ ignore_index=self.ignore_index,
+ image_token_id=self.image_token_id,
+ projector_hidden_act=self.projector_hidden_act,
+ vision_feature_select_strategy=self.vision_feature_select_strategy,
+ vision_feature_layer=self.vision_feature_layer,
+ )
+
+ def prepare_config_and_inputs(self):
+ pixel_values = floats_tensor(
+ [
+ self.batch_size,
+ self.vision_config["num_channels"],
+ self.vision_config["image_size"],
+ self.vision_config["image_size"],
+ ]
+ )
+ config = self.get_config()
+
+ return config, pixel_values
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ config, pixel_values = config_and_inputs
+ input_ids = ids_tensor([self.batch_size, self.seq_length], config.text_config.vocab_size - 1) + 1
+ input_ids[input_ids == config.image_token_index] = self.pad_token_id
+ input_ids[:, : self.num_image_tokens] = config.image_token_index
+ attention_mask = input_ids.ne(1).to(torch_device)
+
+ inputs_dict = {
+ "pixel_values": pixel_values,
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ }
+ return config, inputs_dict
+
+
+@require_torch
+class FastVlmForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
+ """
+ Model tester for `FastVlmForConditionalGeneration`.
+ """
+
+ all_model_classes = (
+ (
+ FastVlmModel,
+ FastVlmForConditionalGeneration,
+ )
+ if is_torch_available()
+ else ()
+ )
+ pipeline_model_mapping = (
+ {"image-to-text": FastVlmForConditionalGeneration, "image-text-to-text": FastVlmForConditionalGeneration}
+ if is_torch_available()
+ else {}
+ )
+
+ _is_composite = True
+
+ def setUp(self):
+ self.model_tester = FastVlmVisionText2TextModelTester(self)
+ common_properties = ["image_token_id"]
+ self.config_tester = ConfigTester(
+ self, config_class=FastVlmConfig, has_text_modality=False, common_properties=common_properties
+ )
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_mismatching_num_image_tokens(self):
+ """
+ Tests that an explicit error is thrown when the number of image tokens
+ doesn't match the number of image placeholders in the text.
+ We also test multi-image cases when one prompt has multiple image tokens.
+ """
+ config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ for model_class in self.all_model_classes:
+ model = model_class(config).to(torch_device)
+ model.eval()
+ curr_input_dict = copy.deepcopy(input_dict) # in-place modifications further
+ _ = model(**curr_input_dict) # successful forward with no modifications
+
+ # remove one image but leave all the image tokens in text
+ curr_input_dict["pixel_values"] = curr_input_dict["pixel_values"][-2:, ...]
+ with self.assertRaises(ValueError):
+ _ = model(**curr_input_dict)
+
+ # simulate the multi-image/single set of placeholders case by concatenating
+ input_ids = curr_input_dict["input_ids"][:1]
+ pixel_values = curr_input_dict["pixel_values"][:1]
+ pixel_values = torch.cat([pixel_values, pixel_values], dim=0)
+
+ # two images and one set of image tokens raise an error
+ with self.assertRaises(ValueError):
+ _ = model(input_ids=input_ids, pixel_values=pixel_values)
+
+ # two images and two sets of image tokens don't raise an error
+ input_ids = torch.cat([input_ids, input_ids], dim=0)
+ _ = model(input_ids=input_ids, pixel_values=pixel_values)
+
+ @unittest.skip("Timm wrapper and backbone don't currently support full HF initialization")
+ def test_can_init_all_missing_weights(self):
+ pass
+
+ @unittest.skip("Timm can't be initialized on meta")
+ def test_can_be_initialized_on_meta(self):
+ pass
+
+
+@require_torch
+@slow
+class FastVlmForConditionalGenerationIntegrationTest(unittest.TestCase):
+ def setUp(self):
+ self.processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B")
+
+ def tearDown(self):
+ cleanup(torch_device, gc_collect=True)
+
+ @require_vision
+ def test_small_model_integration_test(self):
+ model = FastVlmForConditionalGeneration.from_pretrained(
+ "KamilaMila/FastVLM-0.5B", device_map=torch_device, dtype=torch.bfloat16
+ )
+
+ prompt = "user\n\nWhat are the things I should be cautious about when I visit this place?\nassistant"
+ image_file = "https://llava-vl.github.io/static/images/view.jpg"
+ raw_image = Image.open(requests.get(image_file, stream=True).raw)
+ inputs = self.processor(images=raw_image, text=prompt, return_tensors="pt").to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=20)
+ expected_decoded_texts = "user\n\nWhat are the things I should be cautious about when I visit this place?\nassistant\n\nWhen visiting this place, there are a few things you should be cautious about:\n\n1. **" # fmt: skip
+
+ EXPECTED_DECODED_TEXT = expected_decoded_texts
+
+ self.assertEqual(
+ self.processor.decode(output[0], skip_special_tokens=True),
+ EXPECTED_DECODED_TEXT,
+ )
+
+ @require_vision
+ def test_small_model_integration_test_batch(self):
+ model = FastVlmForConditionalGeneration.from_pretrained(
+ "KamilaMila/FastVLM-0.5B", device_map=torch_device, dtype=torch.bfloat16
+ )
+
+ prompts = [
+ "user\n\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nassistant",
+ "user\n\nWhat is this?\nassistant",
+ ]
+ image1 = Image.open(requests.get("https://llava-vl.github.io/static/images/view.jpg", stream=True).raw)
+ image2 = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
+
+ inputs = self.processor(images=[image1, image2], text=prompts, return_tensors="pt", padding=True).to(
+ torch_device
+ )
+
+ output = model.generate(**inputs, max_new_tokens=20)
+
+ EXPECTED_DECODED_TEXT = [
+ "user\n\nWhat are the things I should be cautious about when I visit this place? What should I bring with me?\nassistant\n\nWhen visiting this serene place, it's essential to be mindful of the following:\n\n1. **",
+ "user\n\nWhat is this?\nassistant\nThe image depicts two cats lying on a pink surface, which could be a couch or a"
+ ] # fmt: skip
+
+ self.assertEqual(
+ self.processor.batch_decode(output, skip_special_tokens=True),
+ EXPECTED_DECODED_TEXT,
+ )
+
+ def test_generation_no_images(self):
+ model_id = "KamilaMila/FastVLM-0.5B"
+ model = FastVlmForConditionalGeneration.from_pretrained(
+ model_id, device_map=torch_device, dtype=torch.bfloat16
+ )
+ processor = AutoProcessor.from_pretrained(model_id)
+
+ # Prepare inputs with no images
+ inputs = processor(text="Hello, I am", return_tensors="pt").to(torch_device)
+
+ # Make sure that `generate` works
+ _ = model.generate(**inputs, max_new_tokens=20)
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 297c8cfc2192..11949f740336 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -1309,6 +1309,10 @@ def test_attention_outputs(self):
del inputs_dict["output_attentions"]
config.output_attentions = True
for k in config.sub_configs:
+ if (
+ self._is_composite and k == "vision_config"
+ ): # skip because it's not needed and causes errors e.g with Timm
+ continue
if getattr(config, k) is not None:
getattr(config, k).output_attentions = True
@@ -1468,6 +1472,10 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_attentions = self.has_attentions
for k in config.sub_configs:
+ if (
+ self._is_composite and k == "vision_config"
+ ): # skip because it's not needed and causes errors e.g with Timm
+ continue
if getattr(config, k) is not None:
getattr(config, k).output_attentions = self.has_attentions
@@ -3309,6 +3317,11 @@ def test_flash_attn_2_fp32_ln(self):
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
+ if not all(
+ submodel._supports_flash_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel)
+ ):
+ self.skipTest(reason="At least some parts of this model do not support flash attention")
+
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
@@ -3403,6 +3416,12 @@ def flash_attn_from_config(self, attn_implementation: str, test_fwd_in_train: bo
self.skipTest(f"{model_class.__name__} does not support {attn_implementation}")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config) # let's construct it here to see if any submodels can't support flash attn
+ if not all(
+ submodel._supports_flash_attn for submodel in model.modules() if isinstance(submodel, PreTrainedModel)
+ ):
+ self.skipTest(reason=f"At least some parts of this model do not support {attn_implementation}")
+
# TODO: to change it in the future with other relevant auto classes
fa_model = model_class._from_config(
config, attn_implementation=attn_implementation, dtype=torch.bfloat16