Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c399691
feat: Initial implementation of GraniteDoclingHybrid
gabe-l-hart Mar 4, 2026
20a6df0
feat: Register granite_docling_hybrid in all auto registries
gabe-l-hart Mar 4, 2026
2c6094b
fix: Add missing config forwarding for underlying text model
gabe-l-hart Mar 4, 2026
11d8217
fix: Fix how image_hidden_states are extracted from get_image_feature…
gabe-l-hart Mar 4, 2026
c471581
fix: Update processing kwargs to not have return_row_col_info
gabe-l-hart Mar 4, 2026
f4b8f26
chore: regen modeling
gabe-l-hart Mar 4, 2026
f166660
style: Linting fixes
gabe-l-hart Mar 4, 2026
43b3c7b
fix: No inline imports
gabe-l-hart Mar 5, 2026
9629b3c
fix: Remove unnecessary entry in MODEL_FOR_PRETRAINING_MAPPING_NAMES
gabe-l-hart Mar 5, 2026
4c596a6
style: Fix copyright headers year
gabe-l-hart Mar 5, 2026
d4d327a
fix: Remove dead/unused code/docstrings from review
gabe-l-hart Mar 5, 2026
00ce088
feat: Remove return_dict plumbing
gabe-l-hart Mar 5, 2026
1848d83
feat: Consolidate config + processing into modular
gabe-l-hart Mar 5, 2026
4b73a3e
chore: regen from modular after consolidation
gabe-l-hart Mar 5, 2026
1e5f7bd
fix: Remove cache initialization logic in forward
gabe-l-hart Mar 5, 2026
1945128
fix: Remove unnecessary class-attrs for GraniteDoclingHybridProcessor
gabe-l-hart Mar 5, 2026
c99c53c
fix: Use auto_docstring for GraniteDoclingHybridProcessor
gabe-l-hart Mar 5, 2026
5122f17
feat: Use modern image fetching
gabe-l-hart Mar 9, 2026
2d72415
chore: redo codegen
gabe-l-hart Mar 9, 2026
b6d26d6
fix: Remove image_seq_len as a kwarg
gabe-l-hart Mar 9, 2026
1ed97ab
chore: Redo codegen
gabe-l-hart Mar 9, 2026
bee7b27
fix: Correctly update for no more cache_position
gabe-l-hart Mar 11, 2026
f82599a
chore: regen from modular
gabe-l-hart Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@
("gpt_oss", "GptOssConfig"),
("gptj", "GPTJConfig"),
("granite", "GraniteConfig"),
("granite_docling_hybrid", "GraniteDoclingHybridConfig"),
("granite_speech", "GraniteSpeechConfig"),
("granitemoe", "GraniteMoeConfig"),
("granitemoehybrid", "GraniteMoeHybridConfig"),
Expand Down Expand Up @@ -702,6 +703,7 @@
("gpt_oss", "GptOss"),
("gptj", "GPT-J"),
("granite", "Granite"),
("granite_docling_hybrid", "GraniteDoclingHybrid"),
("granite_speech", "GraniteSpeech"),
("granitemoe", "GraniteMoeMoe"),
("granitemoehybrid", "GraniteMoeHybrid"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("gpt_oss", "GptOssModel"),
("gptj", "GPTJModel"),
("granite", "GraniteModel"),
("granite_docling_hybrid", "GraniteDoclingHybridModel"),
("granitemoe", "GraniteMoeModel"),
("granitemoehybrid", "GraniteMoeHybridModel"),
("granitemoeshared", "GraniteMoeSharedModel"),
Expand Down Expand Up @@ -960,6 +961,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin):
("glm4v_moe", "Glm4vMoeForConditionalGeneration"),
("glm_ocr", "GlmOcrForConditionalGeneration"),
("got_ocr2", "GotOcr2ForConditionalGeneration"),
("granite_docling_hybrid", "GraniteDoclingHybridForConditionalGeneration"),
("idefics", "IdeficsForVisionText2Text"),
("idefics2", "Idefics2ForConditionalGeneration"),
("idefics3", "Idefics3ForConditionalGeneration"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
("glm_image", "Glm4vProcessor"),
("glmasr", "GlmAsrProcessor"),
("got_ocr2", "GotOcr2Processor"),
("granite_docling_hybrid", "GraniteDoclingHybridProcessor"),
("granite_speech", "GraniteSpeechProcessor"),
("grounding-dino", "GroundingDinoProcessor"),
("groupvit", "CLIPProcessor"),
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/models/granite_docling_hybrid/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_granite_docling_hybrid import *
from .modeling_granite_docling_hybrid import *
from .processing_granite_docling_hybrid import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/granite_docling_hybrid/modular_granite_docling_hybrid.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_granite_docling_hybrid.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2026 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ...configuration_utils import PreTrainedConfig, PretrainedConfig
from ...modeling_rope_utils import RopeParameters
from ...utils import auto_docstring, logging
from ..auto import CONFIG_MAPPING


logger = logging.get_logger(__name__)


@auto_docstring(checkpoint="HuggingFaceM4/GraniteDoclingHybrid-8B-Llama3")
class GraniteDoclingHybridVisionConfig(PreTrainedConfig):
r"""
Example:

```python
>>> from transformers.models.granite_docling_hybrid.modeling_granite_docling_hybrid import GraniteDoclingHybridVisionTransformer
>>> from transformers.models.granite_docling_hybrid.configuration_granite_docling_hybrid import GraniteDoclingHybridVisionConfig

>>> # Initializing a GraniteDoclingHybridVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = GraniteDoclingHybridVisionConfig()

>>> # Initializing a GraniteDoclingHybridVisionTransformer (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = GraniteDoclingHybridVisionTransformer(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "granite_docling_hybrid_vision"
base_config_key = "vision_config"

def __init__(
self,
hidden_size=1152,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=16,
num_channels=3,
image_size=224,
patch_size=32,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=0.02,
**kwargs,
):
super().__init__(**kwargs)

self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.initializer_range = initializer_range


@auto_docstring(checkpoint="ibm-granite/granite-speech-3.2-8b")
class GraniteDoclingHybridGraniteMoeHybridConfig(PreTrainedConfig):
r"""
embedding_multiplier (`float`, *optional*, defaults to 1.0): embedding multiplier.
logits_scaling (`float`, *optional*, defaults to 1.0): divisor for output logits.
residual_multiplier (`float`, *optional*, defaults to 1.0): residual multiplier.
attention_multiplier (`float`, *optional*, defaults to 1.0): attention multiplier.
position_embedding_type (`str`, *optional*): Positional embedding type to be used; defaults to None. Allowed options: `[None, "rope"]`
shared_intermediate_size (`int`, *optional*, defaults to 1024): intermediate size for shared experts.

Example:

```python
>>> from transformers import GraniteDoclingHybridGraniteMoeHybridModel, GraniteDoclingHybridGraniteMoeHybridConfig

>>> # Initializing a GraniteDoclingHybridGraniteMoeHybrid config
>>> configuration = GraniteDoclingHybridGraniteMoeHybridConfig()


>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "granite_docling_hybrid_granite_moe_hybrid"
attribute_map = {
"layers_block_type": "layer_types",
}
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size: int | None = 32000,
hidden_size: int | None = 4096,
intermediate_size: int | None = 11008,
num_hidden_layers: int | None = 32,
num_attention_heads: int | None = 32,
num_key_value_heads: int | None = None,
hidden_act: str | None = "silu",
max_position_embeddings: int | None = 2048,
initializer_range: float | None = 0.02,
rms_norm_eps: int | None = 1e-6,
use_cache: bool | None = True,
pad_token_id: int | None = None,
bos_token_id: int | None = 1,
eos_token_id: int | None = 2,
tie_word_embeddings: bool | None = False,
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
attention_bias: bool | None = False,
attention_dropout: float | None = 0.0,
embedding_multiplier: float | None = 1.0,
logits_scaling: float | None = 1.0,
residual_multiplier: float | None = 1.0,
attention_multiplier: float | None = 1.0,
num_local_experts: int | None = 8,
num_experts_per_tok: int | None = 2,
output_router_logits: bool | None = False,
router_aux_loss_coef: float | None = 0.001,
shared_intermediate_size: int | None = 1024,
position_embedding_type: str | None = None,
layer_types: list[str] | None = None,
mamba_n_heads: int | None = 128,
mamba_n_groups: int | None = 1,
mamba_d_state: int | None = 256,
mamba_d_head: str | None = "auto",
mamba_d_conv: int | None = 4,
mamba_expand: int | None = 2,
mamba_chunk_size: int | None = 256,
mamba_conv_bias: bool | None = True,
mamba_proj_bias: bool | None = False,
time_step_min: float | None = 0.001,
time_step_max: float | None = 0.1,
time_step_limit: tuple[float, float] | None = (0.0, float("inf")),
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads

# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.embedding_multiplier = embedding_multiplier
self.logits_scaling = logits_scaling
self.residual_multiplier = residual_multiplier
self.attention_multiplier = attention_multiplier
self.attention_dropout = attention_dropout
self.num_local_experts = num_local_experts
self.num_experts_per_tok = num_experts_per_tok
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.shared_intermediate_size = shared_intermediate_size
self.position_embedding_type = position_embedding_type
self.rope_parameters = rope_parameters

mamba_intermediate = mamba_expand * hidden_size

if layer_types is not None and any(layer_type not in ["mamba", "attention"] for layer_type in layer_types):
raise ValueError("layer_types must be a list strings in [`mamba` `attention`]")

if mamba_intermediate % mamba_n_heads != 0:
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")

# for the mamba_v2, must satisfy the following
if mamba_d_head == "auto":
mamba_d_head = mamba_intermediate // mamba_n_heads

if mamba_d_head * mamba_n_heads != mamba_intermediate:
raise ValueError("The dimensions for the Mamba head state do not match the model intermediate_size")

self.mamba_n_heads = mamba_n_heads
self.mamba_d_head = mamba_d_head
self.mamba_n_groups = mamba_n_groups
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
self.mamba_chunk_size = mamba_chunk_size
self.mamba_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.time_step_min = time_step_min
self.time_step_max = time_step_max
self.time_step_limit = tuple(time_step_limit) if time_step_limit is not None else None
self.mamba_expand = mamba_expand
self.layer_types = layer_types

self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
super().__init__(**kwargs)

# overwrite the function to use in `HybridMambaAttentionDynamicCache`
@property
def layers_block_type(self):
return self.layer_types if self.layer_types else ["mamba"] * self.num_hidden_layers


class GraniteDoclingHybridConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GraniteDoclingHybridModel`]. It is used to instantiate a
GraniteDoclingHybrid model according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the Idefics3 model architecture,
but with a GraniteMoeHybrid text model.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
image_token_id (`int`, *optional*, defaults to 128257):
The id of the "image" token.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether or not to tie the word embeddings with the token embeddings.
vision_config (`GraniteDoclingHybridVisionConfig` or `dict`, *optional*, defaults to `GraniteDoclingHybridVisionConfig`):
Custom vision config or dict for the vision tower
text_config (`PretrainedConfig` or `dict`, *optional*, defaults to `GraniteMoeHybridConfig`):
Custom text config or dict for the text model
scale_factor (`int`, *optional*, defaults to 2):
The scale factor for the image encoder.

Example:
```python
>>> from transformers import GraniteDoclingHybridModel, GraniteDoclingHybridConfig
>>> # Initializing configuration
>>> configuration = GraniteDoclingHybridConfig()
>>> # Initializing a model from the configuration
>>> model = GraniteDoclingHybridModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "granite_docling_hybrid"
sub_configs = {"text_config": CONFIG_MAPPING, "vision_config": GraniteDoclingHybridVisionConfig}

def __init__(
self,
image_token_id=128257,
tie_word_embeddings=False,
vision_config=None,
text_config=None,
scale_factor=2,
**kwargs,
):
self.image_token_id = image_token_id
self.tie_word_embeddings = tie_word_embeddings

if vision_config is None:
self.vision_config = GraniteDoclingHybridVisionConfig()
logger.info("vision_config is None, using default vision config")
elif isinstance(vision_config, dict):
self.vision_config = GraniteDoclingHybridVisionConfig(**vision_config)
elif isinstance(vision_config, GraniteDoclingHybridVisionConfig):
self.vision_config = vision_config

if isinstance(text_config, dict):
text_config["model_type"] = text_config.get("model_type", "granitemoehybrid")
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
logger.info("text_config is None, using default GraniteMoeHybrid text config")
text_config = CONFIG_MAPPING["granitemoehybrid"]()

self.text_config = text_config
self.scale_factor = scale_factor

super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings)


__all__ = ["GraniteDoclingHybridConfig"]
Loading
Loading