Fix Gemma 3 SFT training by detecting dual-registered VLM configs#695
Conversation
Gemma 3 (google/gemma-3-4b-it) is dual-registered in transformers: Gemma3Config maps to both MODEL_FOR_CAUSAL_LM_MAPPING and MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING. This caused is_vlm_with_causal_lm() to return False (because the config IS in the CausalLM mapping), so the model was loaded via AutoModelForCausalLM — which resolves to the full Gemma3ForConditionalGeneration VLM class, not a text-only CausalLM. The VLM forward pass then crashed during FSDP-wrapped distributed training because the text-only SFT training loop doesn't handle the vision tower. The fix checks what class MODEL_FOR_CAUSAL_LM_MAPPING actually resolves to. If it's a ForConditionalGeneration class (a VLM), the model is treated as needing backbone extraction, same as Ministral/Mistral3 models. Tested with model_validation.py: both gemma-3-4b-it and gemma-3n-E4B-it now train to loss 0.0000 on 1000-sample overfit dataset across 8x A100s. Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
📝 WalkthroughWalkthroughModified Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
src/instructlab/training/vlm_utils.py (2)
45-52: String-based class name check is a reasonable heuristic.The check
"ForConditionalGeneration" in resolved_cls.__name__works well given transformers' consistent naming conventions for VLM classes. However, consider adding a brief inline comment explaining that this pattern is relied upon for VLM detection, so future maintainers know the assumption.Minor: Lines 52 and 57 both evaluate
text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING. The duplication is acceptable here since the control flow is clearer with explicit returns in each branch.📝 Optional: Add a clarifying comment
resolved_cls = MODEL_FOR_CAUSAL_LM_MAPPING[config.__class__] + # VLM classes in transformers follow the naming convention *ForConditionalGeneration is_actually_vlm = "ForConditionalGeneration" in resolved_cls.__name__🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/instructlab/training/vlm_utils.py` around lines 45 - 52, The heuristic using "ForConditionalGeneration" in resolved_cls.__name__ to detect VLMs (inside the function using MODEL_FOR_CAUSAL_LM_MAPPING and the local variable is_actually_vlm) needs a short inline comment clarifying the assumption — add a one-line comment above the is_actually_vlm assignment that states we rely on Transformers' naming convention (classes for conditional generation include "ForConditionalGeneration") and that this is a heuristic used to identify VLMs; keep the existing control flow and returns (including subsequent checks on text_config and text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING) unchanged.
77-79: Consider safeguarding the implicit assumption that dual-registered VLMs always havetext_config.The check at line 78 returns
Falsefor any model inMODEL_FOR_CAUSAL_LM_MAPPING. Combined withis_vlm_with_causal_lm(which also returnsFalseat line 51 if a dual-registered VLM lackstext_config), this creates a gap: a hypothetical dual-registered VLM withouttext_configwould fall through toAutoModelForCausalLM.from_pretrained, which would load the full VLM instead of a text-only model.All known dual-registered VLMs (Gemma 3, Ministral, etc.) include
text_config, so this is unlikely to occur in practice. However, the assumption is implicit and worth documenting or guarding against for future models.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/instructlab/training/vlm_utils.py` around lines 77 - 79, The current check returns False for any config class in MODEL_FOR_CAUSAL_LM_MAPPING, implicitly assuming dual-registered VLMs have text_config; change the condition to verify the presence of text_config before short-circuiting. Concretely, in the function containing the check for MODEL_FOR_CAUSAL_LM_MAPPING, only return False if config.__class__ is in MODEL_FOR_CAUSAL_LM_MAPPING AND hasattr(config, "text_config") (or equivalent), otherwise fall through so is_vlm_with_causal_lm or the code path that avoids AutoModelForCausalLM.from_pretrained can handle the model; reference symbols: MODEL_FOR_CAUSAL_LM_MAPPING, is_vlm_with_causal_lm, text_config, and AutoModelForCausalLM.from_pretrained.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/instructlab/training/vlm_utils.py`:
- Around line 45-52: The heuristic using "ForConditionalGeneration" in
resolved_cls.__name__ to detect VLMs (inside the function using
MODEL_FOR_CAUSAL_LM_MAPPING and the local variable is_actually_vlm) needs a
short inline comment clarifying the assumption — add a one-line comment above
the is_actually_vlm assignment that states we rely on Transformers' naming
convention (classes for conditional generation include
"ForConditionalGeneration") and that this is a heuristic used to identify VLMs;
keep the existing control flow and returns (including subsequent checks on
text_config and text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING) unchanged.
- Around line 77-79: The current check returns False for any config class in
MODEL_FOR_CAUSAL_LM_MAPPING, implicitly assuming dual-registered VLMs have
text_config; change the condition to verify the presence of text_config before
short-circuiting. Concretely, in the function containing the check for
MODEL_FOR_CAUSAL_LM_MAPPING, only return False if config.__class__ is in
MODEL_FOR_CAUSAL_LM_MAPPING AND hasattr(config, "text_config") (or equivalent),
otherwise fall through so is_vlm_with_causal_lm or the code path that avoids
AutoModelForCausalLM.from_pretrained can handle the model; reference symbols:
MODEL_FOR_CAUSAL_LM_MAPPING, is_vlm_with_causal_lm, text_config, and
AutoModelForCausalLM.from_pretrained.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 729efb30-bca3-477e-a0bd-ba536431ab3b
📒 Files selected for processing (1)
src/instructlab/training/vlm_utils.py
Summary
Gemma 3 (
google/gemma-3-4b-it) fails during SFT distributed training becauseis_vlm_with_causal_lm()returnsFalsefor it, causing the model to be loaded viaAutoModelForCausalLM— which resolves to the fullGemma3ForConditionalGenerationVLM class instead of a text-only CausalLM.The root cause is that
Gemma3Configis dual-registered in transformers: it appears in bothMODEL_FOR_CAUSAL_LM_MAPPINGandMODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING. The existing check inis_vlm_with_causal_lm()bailed out early when the config was found inMODEL_FOR_CAUSAL_LM_MAPPING, assumingAutoModelForCausalLMwould load a text-only model. But for Gemma 3, it loads the full VLM, which then crashes during FSDP-wrapped training because the text-only SFT loop doesn't handle the vision tower.The fix adds a secondary check: when the config maps to CausalLM, inspect the resolved class name. If it's a
ForConditionalGenerationclass, treat the model as needing VLM backbone extraction (same path as Ministral/Mistral3).Test plan
google/gemma-3-4b-itSFT training — previously crashed during distributed training, now extracts text backbone and trains to loss 0.0000 on 1000-sample overfit dataset (8x A100)google/gemma-3n-E4B-itSFT training — still passes (loss 0.0000)AutoModelForCausalLM, VLMs with extractable backbones (Gemma 3, Gemma 3n, Ministral, Mistral3-VLM) useextract_causal_lm, VLMs without CausalLM variant (Qwen3-VL) usedirect_vlm_loadSummary by CodeRabbit