Skip to content

Commit 98445c3

Browse files
timesfm minicpmv46 canine colqwen2
1 parent b37d192 commit 98445c3

12 files changed

Lines changed: 218 additions & 215 deletions

File tree

src/transformers/exporters/exporter_dynamo.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _patch_chunked_vision_attention(module):
154154
)
155155
src = inspect.getsource(module.forward) if has_attention else ""
156156
if has_attention and "zip(*splits)" in src:
157-
returns_tuple = "return attn_output, attn_weight" in src
157+
returns_tuple = "return attn_output, attn_weight" in src or "return attn_output, None" in src
158158
return ("forward", functools.partial(_reshaped_vision_attention_forward, module, returns_tuple=returns_tuple))
159159

160160

@@ -169,6 +169,13 @@ def _reshaped_vision_attention_forward(
169169
):
170170
"""Export-safe vision attention: reshape segments into batch dim, single SDPA call."""
171171

172+
# Normalise NaViT-style `(1, T, D)` packing (minicpmv4_6) to the flat `(T, D)` layout
173+
# the rest of this wrapper assumes. The leading dim is always 1 — multi-image batches
174+
# are packed along the sequence dim.
175+
needs_batch_restore = hidden_states.ndim == 3
176+
if needs_batch_restore:
177+
hidden_states = hidden_states.squeeze(0)
178+
172179
seq_length = hidden_states.shape[0]
173180
num_segments = cu_seqlens.shape[0] - 1
174181
torch_compilable_check(
@@ -232,6 +239,10 @@ def _to_batched(t):
232239
attn_output = attn_output.transpose(1, 2).reshape(seq_length, -1).contiguous()
233240
out_proj = self.proj if hasattr(self, "proj") else self.out_proj
234241
attn_output = out_proj(attn_output)
242+
243+
if needs_batch_restore:
244+
attn_output = attn_output.unsqueeze(0)
245+
235246
return (attn_output, None) if returns_tuple else attn_output
236247

237248

src/transformers/exporters/utils.py

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -259,17 +259,19 @@ def prepare_for_export(
259259
"visual",
260260
)
261261
_MULTIMODAL_SUBMODULE_NAMES = _MULTIMODAL_ENCODER_NAMES + _MULTIMODAL_PROJECTOR_NAMES + _MULTIMODAL_LM_NAMES
262+
_WRAPPER_ATTRS = ("model", "vlm")
262263

263264

264265
def _find_multimodal_submodules(model: PreTrainedModel) -> dict[str, torch.nn.Module]:
265266
"""Return `{attr_name: module}` for all known multi-modal submodule names found on the model.
266267
267-
Checks `model` first, then `model.model` (common wrapper pattern).
268+
Checks `model` first, then known wrapper attributes (`model.model`, `model.vlm`, …).
268269
Only returns results when at least one modal encoder AND one language model are
269270
found — otherwise the model is not multi-modal and should be exported as a single unit.
270271
"""
272+
roots = [model] + [getattr(model, attr, None) for attr in _WRAPPER_ATTRS]
271273
found: dict[str, torch.nn.Module] = {}
272-
for root in (model, getattr(model, "model", None)):
274+
for root in roots:
273275
if root is None:
274276
continue
275277
for name in _MULTIMODAL_SUBMODULE_NAMES:
@@ -316,47 +318,69 @@ def _precompute_vision_inputs(model: torch.nn.Module, inputs: dict[str, Any]) ->
316318
position_ids, _ = model.get_rope_index(**rope_inputs)
317319
inputs["position_ids"] = position_ids
318320

319-
# Vision submodule level: precompute from grid_thw
321+
modeling_module = sys.modules[type(model).__module__]
322+
323+
# NaViT-style packed encoders carry per-image `(h, w)` as `target_sizes` instead of `grid_thw`.
324+
# Run the nearest-position-id / window-index / merged-shape helpers on the synthesised
325+
# `grid_thw = [1, h, w]` so the per-image Python loops move outside the traced graph.
326+
target_sizes = inputs.get("target_sizes")
327+
if target_sizes is not None:
328+
device = target_sizes.device
329+
num_patches_per_side = _find_submodule_attr(model, "num_patches_per_side")
330+
if hasattr(modeling_module, "get_vision_nearest_position_ids") and num_patches_per_side is not None:
331+
inputs["position_ids"] = modeling_module.get_vision_nearest_position_ids(
332+
target_sizes, num_patches_per_side
333+
).to(device)
334+
335+
window_kernel_size = _find_submodule_attr(model, "window_kernel_size")
336+
if hasattr(modeling_module, "get_vision_window_index") and window_kernel_size is not None:
337+
grid_thw = torch.nn.functional.pad(target_sizes, (1, 0), value=1)
338+
window_index, cu_window_seqlens = modeling_module.get_vision_window_index(
339+
grid_thw, spatial_merge_size=1, window_size=window_kernel_size[0], patch_size=1
340+
)
341+
inputs["window_index"] = window_index.to(device)
342+
inputs["cu_window_seqlens"] = cu_window_seqlens.to(device)
343+
inputs["merged_shape"] = modeling_module.get_vision_merged_shape(target_sizes, window_kernel_size)
344+
345+
# Vision submodule level: precompute from grid_thw. Vision config attributes can live
346+
# anywhere in the submodule tree (encoder, transformer, embeddings, …) — walk to find
347+
# them rather than asking models to mirror state on the outer module just so the
348+
# exporter can read it.
320349
grid_thw = inputs.get("grid_thw")
321-
if grid_thw is None:
322-
return
323-
324-
model_mod = sys.modules[type(model).__module__]
325-
326-
if hasattr(model_mod, "get_vision_cu_seqlens"):
327-
inputs["cu_seqlens"] = model_mod.get_vision_cu_seqlens(grid_thw)
328-
329-
# Vision config attributes can live anywhere in the submodule tree (encoder, transformer,
330-
# embeddings, …) — walk to find them rather than asking models to mirror state on the
331-
# outer module just so the exporter can read it.
332-
spatial_merge_size = _find_submodule_attr(model, "spatial_merge_size")
333-
if spatial_merge_size is None:
334-
# Video-Llama-3 carries per-image merge sizes as an input tensor; PaddleOCR-VL has none
335-
# (the encoder hard-codes `1` because spatial merging happens in the projector).
336-
spatial_merge_size = inputs.get("merge_sizes", 1)
337-
338-
if hasattr(model_mod, "get_vision_position_ids"):
339-
inputs["position_ids"] = model_mod.get_vision_position_ids(grid_thw, spatial_merge_size)
340-
341-
window_size = _find_submodule_attr(model, "window_size")
342-
patch_size = _find_submodule_attr(model, "patch_size")
343-
if hasattr(model_mod, "get_vision_window_index") and window_size is not None and patch_size is not None:
344-
inputs["window_index"], inputs["cu_window_seqlens"] = model_mod.get_vision_window_index(
345-
grid_thw, spatial_merge_size, window_size, patch_size
346-
)
350+
if grid_thw is not None:
351+
spatial_merge_size = _find_submodule_attr(model, "spatial_merge_size")
352+
if spatial_merge_size is None:
353+
# Video-Llama-3 carries per-image merge sizes as an input tensor; PaddleOCR-VL has
354+
# none (its encoder hard-codes `1` because spatial merging happens in the projector).
355+
spatial_merge_size = inputs.get("merge_sizes", 1)
356+
357+
if hasattr(modeling_module, "get_vision_cu_seqlens"):
358+
inputs["cu_seqlens"] = modeling_module.get_vision_cu_seqlens(grid_thw)
359+
360+
if hasattr(modeling_module, "get_vision_position_ids"):
361+
inputs["position_ids"] = modeling_module.get_vision_position_ids(grid_thw, spatial_merge_size)
362+
363+
window_size = _find_submodule_attr(model, "window_size")
364+
patch_size = _find_submodule_attr(model, "patch_size")
365+
if hasattr(modeling_module, "get_vision_window_index") and window_size is not None and patch_size is not None:
366+
inputs["window_index"], inputs["cu_window_seqlens"] = modeling_module.get_vision_window_index(
367+
grid_thw, spatial_merge_size, window_size, patch_size
368+
)
347369

348-
num_grid_per_side = _find_submodule_attr(model, "num_grid_per_side")
349-
if hasattr(model_mod, "get_vision_bilinear_indices_and_weights") and num_grid_per_side is not None:
350-
inputs["bilinear_indices"], inputs["bilinear_weights"] = model_mod.get_vision_bilinear_indices_and_weights(
351-
grid_thw, num_grid_per_side, spatial_merge_size
352-
)
370+
num_grid_per_side = _find_submodule_attr(model, "num_grid_per_side")
371+
if hasattr(modeling_module, "get_vision_bilinear_indices_and_weights") and num_grid_per_side is not None:
372+
inputs["bilinear_indices"], inputs["bilinear_weights"] = (
373+
modeling_module.get_vision_bilinear_indices_and_weights(
374+
grid_thw, num_grid_per_side, spatial_merge_size
375+
)
376+
)
353377

354378

355379
def _precompute_audio_inputs(model: torch.nn.Module, inputs: dict[str, Any]) -> None:
356380
"""Precompute audio encoder inputs that use untraceable ops (.tolist(), nonzero(), loops)."""
357-
model_mod = sys.modules[type(model).__module__]
381+
modeling_module = sys.modules[type(model).__module__]
358382

359-
if not hasattr(model_mod, "chunk_and_pad_features"):
383+
if not hasattr(modeling_module, "chunk_and_pad_features"):
360384
return
361385

362386
if "input_features" not in inputs or "feature_lens" not in inputs:
@@ -365,23 +389,25 @@ def _precompute_audio_inputs(model: torch.nn.Module, inputs: dict[str, Any]) ->
365389
feature_lens = inputs.pop("feature_lens")
366390
input_features = inputs.pop("input_features")
367391

368-
padded_feature, chunk_lengths = model_mod.chunk_and_pad_features(input_features, feature_lens, model.n_window)
392+
padded_feature, chunk_lengths = modeling_module.chunk_and_pad_features(
393+
input_features, feature_lens, model.n_window
394+
)
369395
inputs["padded_feature"] = padded_feature
370396
inputs["chunk_lengths"] = chunk_lengths
371397

372-
if hasattr(model_mod, "get_audio_cu_seqlens"):
373-
fn = model_mod.get_audio_cu_seqlens
398+
if hasattr(modeling_module, "get_audio_cu_seqlens"):
399+
fn = modeling_module.get_audio_cu_seqlens
374400
fn_params = set(inspect.signature(fn).parameters)
375401
if "feature_lens" in fn_params:
376402
inputs["cu_seqlens"] = fn(chunk_lengths, feature_lens, model.n_window_infer, model.n_window)
377403
else:
378404
inputs["cu_seqlens"] = fn(chunk_lengths)
379405

380-
if hasattr(model_mod, "get_valid_indices"):
381-
inputs["valid_indices"] = model_mod.get_valid_indices(chunk_lengths)
406+
if hasattr(modeling_module, "get_valid_indices"):
407+
inputs["valid_indices"] = modeling_module.get_valid_indices(chunk_lengths)
382408

383-
if hasattr(model_mod, "get_pool_indices"):
384-
inputs["pool_indices"] = model_mod.get_pool_indices(feature_lens)
409+
if hasattr(modeling_module, "get_pool_indices"):
410+
inputs["pool_indices"] = modeling_module.get_pool_indices(feature_lens)
385411

386412

387413
@contextlib.contextmanager

src/transformers/models/canine/modeling_canine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ def forward(
916916
molecule_attention_mask = create_bidirectional_mask(
917917
config=self.config,
918918
inputs_embeds=init_molecule_encoding[:, 0:1, :], # force q_len == 1
919-
attention_mask=molecule_attention_mask.squeeze(1), # 3D mask at times due to custom fn
919+
attention_mask=molecule_attention_mask,
920920
)
921921

922922
# Deep BERT encoder

src/transformers/models/minicpmv4_6/modeling_minicpmv4_6.py

Lines changed: 33 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from ...utils.generic import can_return_tuple, is_flash_attention_requested, merge_with_config_defaults
4343
from ...utils.import_utils import torch_compilable_check
4444
from ...utils.output_capturing import capture_outputs
45+
from ...vision_utils import get_vision_merged_shape, get_vision_nearest_position_ids, get_vision_window_index
4546
from ..auto import AutoModel
4647
from .configuration_minicpmv4_6 import MiniCPMV4_6Config, MiniCPMV4_6VisionConfig
4748

@@ -80,32 +81,14 @@ def forward(
8081
self,
8182
pixel_values: torch.FloatTensor,
8283
target_sizes: torch.IntTensor | None = None,
84+
**kwargs: Unpack[TransformersKwargs],
8385
) -> torch.Tensor:
8486
patch_embeds = self.patch_embedding(pixel_values)
8587
embeddings = patch_embeds.flatten(2).transpose(1, 2)
8688

87-
boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side)
88-
89-
position_embeddings = []
90-
for target_size in target_sizes:
91-
nb_patches_h = target_size[0]
92-
nb_patches_w = target_size[1]
93-
94-
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
95-
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
96-
97-
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
98-
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
99-
100-
pos_ids = (
101-
(bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w)
102-
.flatten()
103-
.to(self.position_embedding.weight.device)
104-
)
105-
106-
position_embeddings.append(self.position_embedding(pos_ids))
107-
108-
position_embeddings = torch.concat(position_embeddings, dim=0).unsqueeze(0)
89+
pos_ids = get_vision_nearest_position_ids(target_sizes, self.num_patches_per_side, kwargs=kwargs)
90+
pos_ids = pos_ids.to(self.position_embedding.weight.device)
91+
position_embeddings = self.position_embedding(pos_ids).unsqueeze(0)
10992
embeddings = embeddings + position_embeddings
11093
return embeddings
11194

@@ -358,55 +341,27 @@ def _init_weights(self):
358341
init.normal_(self.linear_2.weight, std=0.25)
359342
init.normal_(self.linear_2.bias, std=1e-6)
360343

361-
def get_window_index(self, target_sizes):
344+
def get_window_index(self, target_sizes, kwargs=None):
362345
window_h, window_w = self.window_kernel_size
363-
max_seqlens = window_h * window_w
364-
365-
window_index_list = []
366-
cu_seqlens = [0]
367-
token_offset = 0
368-
369-
for height, width in target_sizes:
370-
# Cast 0-d device tensors to Python ints so that the whole function
371-
# stays CPU-side integer arithmetic. `torch.arange` without `device=`
372-
# always returns on CPU; mixing with a device-bound `token_offset`
373-
# raises in strict PyTorch versions (2.10+).
374-
height, width = int(height), int(width)
375-
if height % window_h != 0 or width % window_w != 0:
376-
raise ValueError(
377-
f"height={height}, width={width} must be divisible by window size ({window_h}, {window_w})"
378-
)
379-
index = torch.arange(height * width).reshape(height, width)
380-
num_windows_h = height // window_h
381-
num_windows_w = width // window_w
382-
num_windows = num_windows_h * num_windows_w
383-
384-
index = index.reshape(num_windows_h, window_h, num_windows_w, window_w)
385-
index = index.permute(0, 2, 1, 3).reshape(num_windows, window_h * window_w)
386-
387-
window_index_list.append(index.reshape(-1) + token_offset)
388-
389-
cu_this = torch.arange(1, num_windows + 1) * (window_h * window_w) + cu_seqlens[-1]
390-
cu_seqlens.extend(cu_this.tolist())
391-
392-
token_offset += height * width
393-
394-
window_index = torch.cat(window_index_list)
395-
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32)
396-
397-
return window_index, cu_seqlens, max_seqlens
346+
if window_h != window_w:
347+
raise ValueError(f"window_kernel_size must be square; got ({window_h}, {window_w})")
348+
grid_thw = F.pad(target_sizes, (1, 0), value=1)
349+
window_index, cu_seqlens = get_vision_window_index(
350+
grid_thw, spatial_merge_size=1, window_size=window_h, patch_size=1, kwargs=kwargs
351+
)
352+
return window_index, cu_seqlens, window_h * window_w
398353

399354
def forward(
400355
self,
401356
hidden_states: torch.Tensor,
402357
target_sizes: torch.IntTensor,
403-
cu_seqlens: torch.Tensor | None = None,
358+
**kwargs: Unpack[TransformersKwargs],
404359
):
405360
residual = hidden_states
406361
hidden_states = self.layer_norm1(hidden_states)
407362
device = hidden_states.device
408363

409-
window_index, window_cu_seqlens, window_max_seqlens = self.get_window_index(target_sizes)
364+
window_index, window_cu_seqlens, window_max_seqlens = self.get_window_index(target_sizes, kwargs=kwargs)
410365
window_index = window_index.to(device)
411366

412367
hidden_states = hidden_states[:, window_index, :]
@@ -418,28 +373,26 @@ def forward(
418373
hidden_states = hidden_states[:, torch.argsort(window_index), :]
419374
hidden_states = residual + hidden_states
420375

421-
batch_size, _ = target_sizes.shape
376+
# Vectorised window merge: reshape (1, batch*seq_per_img, D) → (batch, seq_per_img, D)
377+
# and lift per-image (h, w) from target_sizes[0]. This assumes the input batch was
378+
# packed with uniform per-image sizes (the standard NaViT preprocessing output).
379+
batch_size = target_sizes.shape[0]
422380
window_h, window_w = self.window_kernel_size
423-
all_pixel_values = []
424-
for batch_idx in range(batch_size):
425-
height, width = target_sizes[batch_idx]
426-
patch = hidden_states[0, cu_seqlens[batch_idx] : cu_seqlens[batch_idx + 1], :].squeeze(0)
427-
428-
embed_dim = patch.shape[-1]
429-
merged_h, merged_w = height // window_h, width // window_w
430-
patch_5d = patch.view(merged_h, window_h, merged_w, window_w, embed_dim).permute(0, 2, 1, 3, 4)
431-
hidden_state = patch_5d.reshape(merged_h * merged_w, window_h * window_w * embed_dim)
432-
residual = patch_5d.reshape(merged_h * merged_w, window_h * window_w, embed_dim).mean(dim=1)
381+
embed_dim = hidden_states.shape[-1]
382+
seq_per_img = hidden_states.shape[1] // batch_size
383+
patch = hidden_states.view(batch_size, seq_per_img, embed_dim)
384+
merged_h, merged_w = get_vision_merged_shape(target_sizes, self.window_kernel_size, kwargs=kwargs)
433385

434-
hidden_state = self.pre_norm(hidden_state)
435-
hidden_state = self.linear_1(hidden_state)
436-
hidden_state = self.act(hidden_state)
437-
hidden_state = self.linear_2(hidden_state)
386+
patch_5d = patch.view(batch_size, merged_h, window_h, merged_w, window_w, embed_dim).permute(0, 1, 3, 2, 4, 5)
387+
flat = patch_5d.reshape(batch_size * merged_h * merged_w, window_h * window_w * embed_dim)
388+
residual = patch_5d.reshape(batch_size * merged_h * merged_w, window_h * window_w, embed_dim).mean(dim=1)
438389

439-
all_pixel_values.append(hidden_state + residual)
390+
hidden_state = self.pre_norm(flat)
391+
hidden_state = self.linear_1(hidden_state)
392+
hidden_state = self.act(hidden_state)
393+
hidden_state = self.linear_2(hidden_state)
440394

441-
new_hidden_states = torch.concat(all_pixel_values, dim=0).unsqueeze(0)
442-
return new_hidden_states
395+
return (hidden_state + residual).unsqueeze(0)
443396

444397

445398
class MiniCPMV4_6VisionPreTrainedModel(PreTrainedModel):
@@ -503,7 +456,7 @@ def forward(
503456
Whether to apply the ViT window-attention merger after the encoder.
504457
"""
505458

506-
hidden_states = self.embeddings(pixel_values, target_sizes=target_sizes)
459+
hidden_states = self.embeddings(pixel_values, target_sizes=target_sizes, **kwargs)
507460

508461
cu_seqlens = F.pad(
509462
torch.cumsum(target_sizes[:, 0] * target_sizes[:, 1], dim=0, dtype=torch.int32).to(hidden_states.device),
@@ -523,7 +476,7 @@ def forward(
523476
for layer_index, encoder_layer in enumerate(self.encoder.layers):
524477
hidden_states = encoder_layer(hidden_states, **attn_kwargs)
525478
if layer_index == insert_layer_id:
526-
hidden_states = self.vit_merger(hidden_states, target_sizes, cu_seqlens)
479+
hidden_states = self.vit_merger(hidden_states, target_sizes, **kwargs)
527480

528481
# NOTE: Downsampled hidden states, and therefore other kwargs should also!
529482
attn_kwargs, target_sizes, cu_seqlens = self.get_downsampled_inputs(

0 commit comments

Comments
 (0)