[model] support gemma4 mixed data#9180
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces the _post_encode method to Gemma4Template to handle mixed-modality training and prevent hangs in DeepSpeed ZeRO-2/3 environments by ensuring all model towers participate in the forward pass. Feedback includes fixing potential AttributeError exceptions when accessing .dtype on nn.Module objects, optimizing performance by caching dummy inputs and restricting dummy passes to DeepSpeed environments, and verifying the logic for ZeRO-3 participation.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request adds support for multi-modal token types in the base template and implements a _post_encode hook for Gemma templates to process image, video, and audio inputs during training. It also introduces a helper for dummy image forwarding. Feedback was provided regarding the use of boolean tensors in dist.all_reduce, which can cause runtime errors on certain distributed backends like NCCL; switching to torch.long is recommended.
| state = input_ids.new_tensor( | ||
| [pixel_values is not None, pixel_values_videos is not None, input_features is not None], dtype=torch.bool) | ||
| if dist.is_initialized(): | ||
| dist.all_reduce(state) | ||
| has_image, has_video, has_audio = state.tolist() |
There was a problem hiding this comment.
Using torch.bool as the data type for dist.all_reduce can lead to a RuntimeError on certain distributed backends (like NCCL), as they do not support the default SUM reduction operation on boolean tensors. It is safer to use torch.long for the state tensor. Since non-zero integers are truthy in Python, the subsequent unpacking and boolean checks will continue to work as expected.
| state = input_ids.new_tensor( | |
| [pixel_values is not None, pixel_values_videos is not None, input_features is not None], dtype=torch.bool) | |
| if dist.is_initialized(): | |
| dist.all_reduce(state) | |
| has_image, has_video, has_audio = state.tolist() | |
| state = input_ids.new_tensor( | |
| [pixel_values is not None, pixel_values_videos is not None, input_features is not None], dtype=torch.long) | |
| if dist.is_initialized(): | |
| dist.all_reduce(state) | |
| has_image, has_video, has_audio = state.tolist() |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request renames the Tencent Hunyuan 'hy3' model and template identifiers to 'hy_v3' across documentation and code. It introduces a forward pass patch for Gemma4 models to handle multimodal inputs and DeepSpeed compatibility, while also refining the template hook to include 'mm_token_type_ids' and cleaning up unused code and imports. Feedback focuses on optimizing the DeepSpeed dummy forward logic in the Gemma4 patch to avoid performance bottlenecks, correcting inconsistent error messages, and restoring the '*args, **kwargs' signature in the model loader to maintain interface compatibility.
| torch_compilable_check( | ||
| inputs_embeds[audio_mask].numel() == audio_features.numel(), | ||
| f'Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features:' | ||
| f' {audio_features.shape[0] * audio_features.shape[1]}', |
There was a problem hiding this comment.
The error message for audio features is inconsistent with the image features error message at line 281. It reports the total number of elements instead of the token count, making it confusing when compared against n_audio_tokens.
| f' {audio_features.shape[0] * audio_features.shape[1]}', | |
| f' {audio_features.shape[0]}', |
| class Gemma4Loader(ModelLoader): | ||
|
|
||
| def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: | ||
| def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel: |
There was a problem hiding this comment.
The signature of get_model has been changed to remove *args, **kwargs. This is a breaking change for the ModelLoader interface and may cause TypeError if the loader is called with additional keyword arguments. It is safer to retain *args, **kwargs to maintain compatibility with the base class.
| def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel: | |
| def get_model(self, model_dir: str, config, processor, model_kwargs, *args, **kwargs) -> PreTrainedModel: |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request renames Hunyuan Hy3 components to hy_v3 across the codebase and documentation for consistency, cleans up unused imports, and introduces a monkey-patch for the Gemma4 model to support multimodal inputs. Feedback identifies critical issues in the Gemma4 implementation, specifically regarding incorrect distributed synchronization gating that could cause hangs in standard DDP environments, a breaking change in the get_model method signature, and the need for a safety check on the processor object during patching.
|
|
||
| state = input_ids.new_tensor( | ||
| [pixel_values is not None or pixel_values_videos is not None, input_features is not None], dtype=torch.bool) | ||
| if dist.is_initialized() and is_deepspeed_enabled(): |
There was a problem hiding this comment.
The synchronization logic using all_reduce should not be gated by is_deepspeed_enabled(). In a standard DDP environment, if some ranks have multimodal data and others don't, they will diverge in their control flow if this synchronization is skipped. This can lead to "unused parameter" errors or hangs if collectives are called within the vision/audio towers. It should be enabled whenever dist.is_initialized() is true.
| if dist.is_initialized() and is_deepspeed_enabled(): | |
| if dist.is_initialized(): |
| def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel: | ||
| from transformers import Gemma4ForConditionalGeneration | ||
| self.auto_model_cls = self.auto_model_cls or Gemma4ForConditionalGeneration | ||
| return super().get_model(model_dir, *args, **kwargs) | ||
| model = super().get_model(model_dir, config, processor, model_kwargs) | ||
| _patch_gemma4_forward(model.model, processor) | ||
| return model |
There was a problem hiding this comment.
The signature of get_model has been changed from *args, **kwargs to specific positional arguments, which reduces flexibility and breaks compatibility with the base ModelLoader. Additionally, the call to super().get_model passes model_kwargs as a positional argument instead of unpacking it, and _patch_gemma4_forward is called on model.model instead of the top-level model. The patched forward function relies on attributes like language_model which are typically found on the ForConditionalGeneration model.
| def get_model(self, model_dir: str, config, processor, model_kwargs) -> PreTrainedModel: | |
| from transformers import Gemma4ForConditionalGeneration | |
| self.auto_model_cls = self.auto_model_cls or Gemma4ForConditionalGeneration | |
| return super().get_model(model_dir, *args, **kwargs) | |
| model = super().get_model(model_dir, config, processor, model_kwargs) | |
| _patch_gemma4_forward(model.model, processor) | |
| return model | |
| def get_model(self, model_dir: str, *args, **kwargs) -> PreTrainedModel: | |
| from transformers import Gemma4ForConditionalGeneration | |
| self.auto_model_cls = self.auto_model_cls or Gemma4ForConditionalGeneration | |
| model = super().get_model(model_dir, *args, **kwargs) | |
| _patch_gemma4_forward(model, kwargs.get('processor')) | |
| return model |
| if hasattr(model, 'origin_forward'): | ||
| return |
No description provided.