Skip to content

[model] support gemma4 mixed data#9180

Merged
Jintao-Huang merged 19 commits into
modelscope:mainfrom
Jintao-Huang:support_gemma4_mixed_data
Apr 28, 2026
Merged

[model] support gemma4 mixed data#9180
Jintao-Huang merged 19 commits into
modelscope:mainfrom
Jintao-Huang:support_gemma4_mixed_data

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the _post_encode method to Gemma4Template to handle mixed-modality training and prevent hangs in DeepSpeed ZeRO-2/3 environments by ensuring all model towers participate in the forward pass. Feedback includes fixing potential AttributeError exceptions when accessing .dtype on nn.Module objects, optimizing performance by caching dummy inputs and restricting dummy passes to DeepSpeed environments, and verifying the logic for ZeRO-3 participation.

Comment thread swift/template/templates/gemma.py Outdated
Comment thread swift/template/templates/gemma.py Outdated
Comment thread swift/template/templates/gemma.py Outdated
Comment thread swift/template/templates/gemma.py Outdated
Comment thread swift/template/templates/gemma.py Outdated
@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for multi-modal token types in the base template and implements a _post_encode hook for Gemma templates to process image, video, and audio inputs during training. It also introduces a helper for dummy image forwarding. Feedback was provided regarding the use of boolean tensors in dist.all_reduce, which can cause runtime errors on certain distributed backends like NCCL; switching to torch.long is recommended.

Comment thread swift/template/templates/gemma.py Outdated
Comment on lines +348 to +352
state = input_ids.new_tensor(
[pixel_values is not None, pixel_values_videos is not None, input_features is not None], dtype=torch.bool)
if dist.is_initialized():
dist.all_reduce(state)
has_image, has_video, has_audio = state.tolist()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using torch.bool as the data type for dist.all_reduce can lead to a RuntimeError on certain distributed backends (like NCCL), as they do not support the default SUM reduction operation on boolean tensors. It is safer to use torch.long for the state tensor. Since non-zero integers are truthy in Python, the subsequent unpacking and boolean checks will continue to work as expected.

Suggested change
state = input_ids.new_tensor(
[pixel_values is not None, pixel_values_videos is not None, input_features is not None], dtype=torch.bool)
if dist.is_initialized():
dist.all_reduce(state)
has_image, has_video, has_audio = state.tolist()
state = input_ids.new_tensor(
[pixel_values is not None, pixel_values_videos is not None, input_features is not None], dtype=torch.long)
if dist.is_initialized():
dist.all_reduce(state)
has_image, has_video, has_audio = state.tolist()

@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request renames the Tencent Hunyuan 'hy3' model and template identifiers to 'hy_v3' across documentation and code. It introduces a forward pass patch for Gemma4 models to handle multimodal inputs and DeepSpeed compatibility, while also refining the template hook to include 'mm_token_type_ids' and cleaning up unused code and imports. Feedback focuses on optimizing the DeepSpeed dummy forward logic in the Gemma4 patch to avoid performance bottlenecks, correcting inconsistent error messages, and restoring the '*args, **kwargs' signature in the model loader to maintain interface compatibility.

Comment thread swift/model/models/gemma.py
torch_compilable_check(
inputs_embeds[audio_mask].numel() == audio_features.numel(),
f'Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features:'
f' {audio_features.shape[0] * audio_features.shape[1]}',
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error message for audio features is inconsistent with the image features error message at line 281. It reports the total number of elements instead of the token count, making it confusing when compared against n_audio_tokens.

Suggested change
f' {audio_features.shape[0] * audio_features.shape[1]}',
f' {audio_features.shape[0]}',

class Gemma4Loader(ModelLoader):

def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The signature of get_model has been changed to remove *args, **kwargs. This is a breaking change for the ModelLoader interface and may cause TypeError if the loader is called with additional keyword arguments. It is safer to retain *args, **kwargs to maintain compatibility with the base class.

Suggested change
def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
def get_model(self, model_dir: str, config, processor, model_kwargs, *args, **kwargs) -> PreTrainedModel:

@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request renames Hunyuan Hy3 components to hy_v3 across the codebase and documentation for consistency, cleans up unused imports, and introduces a monkey-patch for the Gemma4 model to support multimodal inputs. Feedback identifies critical issues in the Gemma4 implementation, specifically regarding incorrect distributed synchronization gating that could cause hangs in standard DDP environments, a breaking change in the get_model method signature, and the need for a safety check on the processor object during patching.


state = input_ids.new_tensor(
[pixel_values is not None or pixel_values_videos is not None, input_features is not None], dtype=torch.bool)
if dist.is_initialized() and is_deepspeed_enabled():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The synchronization logic using all_reduce should not be gated by is_deepspeed_enabled(). In a standard DDP environment, if some ranks have multimodal data and others don't, they will diverge in their control flow if this synchronization is skipped. This can lead to "unused parameter" errors or hangs if collectives are called within the vision/audio towers. It should be enabled whenever dist.is_initialized() is true.

Suggested change
if dist.is_initialized() and is_deepspeed_enabled():
if dist.is_initialized():

Comment on lines +397 to +402
def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
from transformers import Gemma4ForConditionalGeneration
self.auto_model_cls = self.auto_model_cls or Gemma4ForConditionalGeneration
return super().get_model(model_dir, *args, **kwargs)
model = super().get_model(model_dir, config, processor, model_kwargs)
_patch_gemma4_forward(model.model, processor)
return model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The signature of get_model has been changed from *args, **kwargs to specific positional arguments, which reduces flexibility and breaks compatibility with the base ModelLoader. Additionally, the call to super().get_model passes model_kwargs as a positional argument instead of unpacking it, and _patch_gemma4_forward is called on model.model instead of the top-level model. The patched forward function relies on attributes like language_model which are typically found on the ForConditionalGeneration model.

Suggested change
def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel:
from transformers import Gemma4ForConditionalGeneration
self.auto_model_cls = self.auto_model_cls or Gemma4ForConditionalGeneration
return super().get_model(model_dir, *args, **kwargs)
model = super().get_model(model_dir, config, processor, model_kwargs)
_patch_gemma4_forward(model.model, processor)
return model
def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel:
from transformers import Gemma4ForConditionalGeneration
self.auto_model_cls = self.auto_model_cls or Gemma4ForConditionalGeneration
model = super().get_model(model_dir, *args, **kwargs)
_patch_gemma4_forward(model, kwargs.get('processor'))
return model

Comment on lines +210 to +211
if hasattr(model, 'origin_forward'):
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Added a check to ensure processor is not None before proceeding with the patch, as the dummy forward logic depends on it.

Suggested change
if hasattr(model, 'origin_forward'):
return
if processor is None or hasattr(model, 'origin_forward'):
return

@Jintao-Huang Jintao-Huang merged commit 44c92c7 into modelscope:main Apr 28, 2026
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants