4242from ...utils .generic import can_return_tuple , is_flash_attention_requested , merge_with_config_defaults
4343from ...utils .import_utils import torch_compilable_check
4444from ...utils .output_capturing import capture_outputs
45+ from ...vision_utils import get_vision_merged_shape , get_vision_nearest_position_ids , get_vision_window_index
4546from ..auto import AutoModel
4647from .configuration_minicpmv4_6 import MiniCPMV4_6Config , MiniCPMV4_6VisionConfig
4748
@@ -80,32 +81,14 @@ def forward(
8081 self ,
8182 pixel_values : torch .FloatTensor ,
8283 target_sizes : torch .IntTensor | None = None ,
84+ ** kwargs : Unpack [TransformersKwargs ],
8385 ) -> torch .Tensor :
8486 patch_embeds = self .patch_embedding (pixel_values )
8587 embeddings = patch_embeds .flatten (2 ).transpose (1 , 2 )
8688
87- boundaries = torch .arange (1 / self .num_patches_per_side , 1.0 , 1 / self .num_patches_per_side )
88-
89- position_embeddings = []
90- for target_size in target_sizes :
91- nb_patches_h = target_size [0 ]
92- nb_patches_w = target_size [1 ]
93-
94- fractional_coords_h = torch .arange (0 , 1 - 1e-6 , 1 / nb_patches_h )
95- fractional_coords_w = torch .arange (0 , 1 - 1e-6 , 1 / nb_patches_w )
96-
97- bucket_coords_h = torch .bucketize (fractional_coords_h , boundaries , right = True )
98- bucket_coords_w = torch .bucketize (fractional_coords_w , boundaries , right = True )
99-
100- pos_ids = (
101- (bucket_coords_h [:, None ] * self .num_patches_per_side + bucket_coords_w )
102- .flatten ()
103- .to (self .position_embedding .weight .device )
104- )
105-
106- position_embeddings .append (self .position_embedding (pos_ids ))
107-
108- position_embeddings = torch .concat (position_embeddings , dim = 0 ).unsqueeze (0 )
89+ pos_ids = get_vision_nearest_position_ids (target_sizes , self .num_patches_per_side , kwargs = kwargs )
90+ pos_ids = pos_ids .to (self .position_embedding .weight .device )
91+ position_embeddings = self .position_embedding (pos_ids ).unsqueeze (0 )
10992 embeddings = embeddings + position_embeddings
11093 return embeddings
11194
@@ -358,55 +341,27 @@ def _init_weights(self):
358341 init .normal_ (self .linear_2 .weight , std = 0.25 )
359342 init .normal_ (self .linear_2 .bias , std = 1e-6 )
360343
361- def get_window_index (self , target_sizes ):
344+ def get_window_index (self , target_sizes , kwargs = None ):
362345 window_h , window_w = self .window_kernel_size
363- max_seqlens = window_h * window_w
364-
365- window_index_list = []
366- cu_seqlens = [0 ]
367- token_offset = 0
368-
369- for height , width in target_sizes :
370- # Cast 0-d device tensors to Python ints so that the whole function
371- # stays CPU-side integer arithmetic. `torch.arange` without `device=`
372- # always returns on CPU; mixing with a device-bound `token_offset`
373- # raises in strict PyTorch versions (2.10+).
374- height , width = int (height ), int (width )
375- if height % window_h != 0 or width % window_w != 0 :
376- raise ValueError (
377- f"height={ height } , width={ width } must be divisible by window size ({ window_h } , { window_w } )"
378- )
379- index = torch .arange (height * width ).reshape (height , width )
380- num_windows_h = height // window_h
381- num_windows_w = width // window_w
382- num_windows = num_windows_h * num_windows_w
383-
384- index = index .reshape (num_windows_h , window_h , num_windows_w , window_w )
385- index = index .permute (0 , 2 , 1 , 3 ).reshape (num_windows , window_h * window_w )
386-
387- window_index_list .append (index .reshape (- 1 ) + token_offset )
388-
389- cu_this = torch .arange (1 , num_windows + 1 ) * (window_h * window_w ) + cu_seqlens [- 1 ]
390- cu_seqlens .extend (cu_this .tolist ())
391-
392- token_offset += height * width
393-
394- window_index = torch .cat (window_index_list )
395- cu_seqlens = torch .tensor (cu_seqlens , dtype = torch .int32 )
396-
397- return window_index , cu_seqlens , max_seqlens
346+ if window_h != window_w :
347+ raise ValueError (f"window_kernel_size must be square; got ({ window_h } , { window_w } )" )
348+ grid_thw = F .pad (target_sizes , (1 , 0 ), value = 1 )
349+ window_index , cu_seqlens = get_vision_window_index (
350+ grid_thw , spatial_merge_size = 1 , window_size = window_h , patch_size = 1 , kwargs = kwargs
351+ )
352+ return window_index , cu_seqlens , window_h * window_w
398353
399354 def forward (
400355 self ,
401356 hidden_states : torch .Tensor ,
402357 target_sizes : torch .IntTensor ,
403- cu_seqlens : torch . Tensor | None = None ,
358+ ** kwargs : Unpack [ TransformersKwargs ] ,
404359 ):
405360 residual = hidden_states
406361 hidden_states = self .layer_norm1 (hidden_states )
407362 device = hidden_states .device
408363
409- window_index , window_cu_seqlens , window_max_seqlens = self .get_window_index (target_sizes )
364+ window_index , window_cu_seqlens , window_max_seqlens = self .get_window_index (target_sizes , kwargs = kwargs )
410365 window_index = window_index .to (device )
411366
412367 hidden_states = hidden_states [:, window_index , :]
@@ -418,28 +373,26 @@ def forward(
418373 hidden_states = hidden_states [:, torch .argsort (window_index ), :]
419374 hidden_states = residual + hidden_states
420375
421- batch_size , _ = target_sizes .shape
376+ # Vectorised window merge: reshape (1, batch*seq_per_img, D) → (batch, seq_per_img, D)
377+ # and lift per-image (h, w) from target_sizes[0]. This assumes the input batch was
378+ # packed with uniform per-image sizes (the standard NaViT preprocessing output).
379+ batch_size = target_sizes .shape [0 ]
422380 window_h , window_w = self .window_kernel_size
423- all_pixel_values = []
424- for batch_idx in range (batch_size ):
425- height , width = target_sizes [batch_idx ]
426- patch = hidden_states [0 , cu_seqlens [batch_idx ] : cu_seqlens [batch_idx + 1 ], :].squeeze (0 )
427-
428- embed_dim = patch .shape [- 1 ]
429- merged_h , merged_w = height // window_h , width // window_w
430- patch_5d = patch .view (merged_h , window_h , merged_w , window_w , embed_dim ).permute (0 , 2 , 1 , 3 , 4 )
431- hidden_state = patch_5d .reshape (merged_h * merged_w , window_h * window_w * embed_dim )
432- residual = patch_5d .reshape (merged_h * merged_w , window_h * window_w , embed_dim ).mean (dim = 1 )
381+ embed_dim = hidden_states .shape [- 1 ]
382+ seq_per_img = hidden_states .shape [1 ] // batch_size
383+ patch = hidden_states .view (batch_size , seq_per_img , embed_dim )
384+ merged_h , merged_w = get_vision_merged_shape (target_sizes , self .window_kernel_size , kwargs = kwargs )
433385
434- hidden_state = self .pre_norm (hidden_state )
435- hidden_state = self .linear_1 (hidden_state )
436- hidden_state = self .act (hidden_state )
437- hidden_state = self .linear_2 (hidden_state )
386+ patch_5d = patch .view (batch_size , merged_h , window_h , merged_w , window_w , embed_dim ).permute (0 , 1 , 3 , 2 , 4 , 5 )
387+ flat = patch_5d .reshape (batch_size * merged_h * merged_w , window_h * window_w * embed_dim )
388+ residual = patch_5d .reshape (batch_size * merged_h * merged_w , window_h * window_w , embed_dim ).mean (dim = 1 )
438389
439- all_pixel_values .append (hidden_state + residual )
390+ hidden_state = self .pre_norm (flat )
391+ hidden_state = self .linear_1 (hidden_state )
392+ hidden_state = self .act (hidden_state )
393+ hidden_state = self .linear_2 (hidden_state )
440394
441- new_hidden_states = torch .concat (all_pixel_values , dim = 0 ).unsqueeze (0 )
442- return new_hidden_states
395+ return (hidden_state + residual ).unsqueeze (0 )
443396
444397
445398class MiniCPMV4_6VisionPreTrainedModel (PreTrainedModel ):
@@ -503,7 +456,7 @@ def forward(
503456 Whether to apply the ViT window-attention merger after the encoder.
504457 """
505458
506- hidden_states = self .embeddings (pixel_values , target_sizes = target_sizes )
459+ hidden_states = self .embeddings (pixel_values , target_sizes = target_sizes , ** kwargs )
507460
508461 cu_seqlens = F .pad (
509462 torch .cumsum (target_sizes [:, 0 ] * target_sizes [:, 1 ], dim = 0 , dtype = torch .int32 ).to (hidden_states .device ),
@@ -523,7 +476,7 @@ def forward(
523476 for layer_index , encoder_layer in enumerate (self .encoder .layers ):
524477 hidden_states = encoder_layer (hidden_states , ** attn_kwargs )
525478 if layer_index == insert_layer_id :
526- hidden_states = self .vit_merger (hidden_states , target_sizes , cu_seqlens )
479+ hidden_states = self .vit_merger (hidden_states , target_sizes , ** kwargs )
527480
528481 # NOTE: Downsampled hidden states, and therefore other kwargs should also!
529482 attn_kwargs , target_sizes , cu_seqlens = self .get_downsampled_inputs (
0 commit comments