Skip to content

Fix Gemma 3 SFT training by detecting dual-registered VLM configs#695

Merged
RobotSail merged 1 commit intoinstructlab:mainfrom
RobotSail:fix/gemma3-vlm-extraction
Mar 24, 2026
Merged

Fix Gemma 3 SFT training by detecting dual-registered VLM configs#695
RobotSail merged 1 commit intoinstructlab:mainfrom
RobotSail:fix/gemma3-vlm-extraction

Conversation

@RobotSail
Copy link
Member

@RobotSail RobotSail commented Mar 20, 2026

Summary

Gemma 3 (google/gemma-3-4b-it) fails during SFT distributed training because is_vlm_with_causal_lm() returns False for it, causing the model to be loaded via AutoModelForCausalLM — which resolves to the full Gemma3ForConditionalGeneration VLM class instead of a text-only CausalLM.

The root cause is that Gemma3Config is dual-registered in transformers: it appears in both MODEL_FOR_CAUSAL_LM_MAPPING and MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING. The existing check in is_vlm_with_causal_lm() bailed out early when the config was found in MODEL_FOR_CAUSAL_LM_MAPPING, assuming AutoModelForCausalLM would load a text-only model. But for Gemma 3, it loads the full VLM, which then crashes during FSDP-wrapped training because the text-only SFT loop doesn't handle the vision tower.

The fix adds a secondary check: when the config maps to CausalLM, inspect the resolved class name. If it's a ForConditionalGeneration class, treat the model as needing VLM backbone extraction (same path as Ministral/Mistral3).

Test plan

  • google/gemma-3-4b-it SFT training — previously crashed during distributed training, now extracts text backbone and trains to loss 0.0000 on 1000-sample overfit dataset (8x A100)
  • google/gemma-3n-E4B-it SFT training — still passes (loss 0.0000)
  • Verified routing for all models in validation suite: text-only models (Qwen2, Llama, Granite, Mistral, Qwen3.5) still use AutoModelForCausalLM, VLMs with extractable backbones (Gemma 3, Gemma 3n, Ministral, Mistral3-VLM) use extract_causal_lm, VLMs without CausalLM variant (Qwen3-VL) use direct_vlm_load

Summary by CodeRabbit

  • Bug Fixes
    • Enhanced model detection logic to accurately identify vision-language models with causal language components, improving model configuration recognition during training initialization.

Gemma 3 (google/gemma-3-4b-it) is dual-registered in transformers:
Gemma3Config maps to both MODEL_FOR_CAUSAL_LM_MAPPING and
MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING. This caused is_vlm_with_causal_lm()
to return False (because the config IS in the CausalLM mapping), so the
model was loaded via AutoModelForCausalLM — which resolves to the full
Gemma3ForConditionalGeneration VLM class, not a text-only CausalLM.

The VLM forward pass then crashed during FSDP-wrapped distributed training
because the text-only SFT training loop doesn't handle the vision tower.

The fix checks what class MODEL_FOR_CAUSAL_LM_MAPPING actually resolves to.
If it's a ForConditionalGeneration class (a VLM), the model is treated as
needing backbone extraction, same as Ministral/Mistral3 models.

Tested with model_validation.py: both gemma-3-4b-it and gemma-3n-E4B-it
now train to loss 0.0000 on 1000-sample overfit dataset across 8x A100s.

Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com>
@coderabbitai
Copy link

coderabbitai bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

Modified is_vlm_with_causal_lm() function in vlm_utils.py to refine VLM detection logic. When a config class maps to MODEL_FOR_CAUSAL_LM_MAPPING, the function now resolves the mapped class, checks if it's a ForConditionalGeneration VLM, and conditionally evaluates text_config instead of immediately returning False.

Changes

Cohort / File(s) Summary
VLM Detection Logic Refinement
src/instructlab/training/vlm_utils.py
Enhanced is_vlm_with_causal_lm() to resolve mapped classes and inspect for ForConditionalGeneration VLMs before evaluating text config. Relocated text_config assignment for improved control flow.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

🐰 A VLM detection tale refined,
Where mapped classes now align,
We check for generation's gentle way,
And text configs come to play!
Logic flows where once it didn't—
The model's secrets now are written.

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly describes the main fix: detecting dual-registered VLM configs to resolve Gemma 3 SFT training issues, which aligns with the primary change in the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
src/instructlab/training/vlm_utils.py (2)

45-52: String-based class name check is a reasonable heuristic.

The check "ForConditionalGeneration" in resolved_cls.__name__ works well given transformers' consistent naming conventions for VLM classes. However, consider adding a brief inline comment explaining that this pattern is relied upon for VLM detection, so future maintainers know the assumption.

Minor: Lines 52 and 57 both evaluate text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING. The duplication is acceptable here since the control flow is clearer with explicit returns in each branch.

📝 Optional: Add a clarifying comment
         resolved_cls = MODEL_FOR_CAUSAL_LM_MAPPING[config.__class__]
+        # VLM classes in transformers follow the naming convention *ForConditionalGeneration
         is_actually_vlm = "ForConditionalGeneration" in resolved_cls.__name__
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/instructlab/training/vlm_utils.py` around lines 45 - 52, The heuristic
using "ForConditionalGeneration" in resolved_cls.__name__ to detect VLMs (inside
the function using MODEL_FOR_CAUSAL_LM_MAPPING and the local variable
is_actually_vlm) needs a short inline comment clarifying the assumption — add a
one-line comment above the is_actually_vlm assignment that states we rely on
Transformers' naming convention (classes for conditional generation include
"ForConditionalGeneration") and that this is a heuristic used to identify VLMs;
keep the existing control flow and returns (including subsequent checks on
text_config and text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING) unchanged.

77-79: Consider safeguarding the implicit assumption that dual-registered VLMs always have text_config.

The check at line 78 returns False for any model in MODEL_FOR_CAUSAL_LM_MAPPING. Combined with is_vlm_with_causal_lm (which also returns False at line 51 if a dual-registered VLM lacks text_config), this creates a gap: a hypothetical dual-registered VLM without text_config would fall through to AutoModelForCausalLM.from_pretrained, which would load the full VLM instead of a text-only model.

All known dual-registered VLMs (Gemma 3, Ministral, etc.) include text_config, so this is unlikely to occur in practice. However, the assumption is implicit and worth documenting or guarding against for future models.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/instructlab/training/vlm_utils.py` around lines 77 - 79, The current
check returns False for any config class in MODEL_FOR_CAUSAL_LM_MAPPING,
implicitly assuming dual-registered VLMs have text_config; change the condition
to verify the presence of text_config before short-circuiting. Concretely, in
the function containing the check for MODEL_FOR_CAUSAL_LM_MAPPING, only return
False if config.__class__ is in MODEL_FOR_CAUSAL_LM_MAPPING AND hasattr(config,
"text_config") (or equivalent), otherwise fall through so is_vlm_with_causal_lm
or the code path that avoids AutoModelForCausalLM.from_pretrained can handle the
model; reference symbols: MODEL_FOR_CAUSAL_LM_MAPPING, is_vlm_with_causal_lm,
text_config, and AutoModelForCausalLM.from_pretrained.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/instructlab/training/vlm_utils.py`:
- Around line 45-52: The heuristic using "ForConditionalGeneration" in
resolved_cls.__name__ to detect VLMs (inside the function using
MODEL_FOR_CAUSAL_LM_MAPPING and the local variable is_actually_vlm) needs a
short inline comment clarifying the assumption — add a one-line comment above
the is_actually_vlm assignment that states we rely on Transformers' naming
convention (classes for conditional generation include
"ForConditionalGeneration") and that this is a heuristic used to identify VLMs;
keep the existing control flow and returns (including subsequent checks on
text_config and text_config.__class__ in MODEL_FOR_CAUSAL_LM_MAPPING) unchanged.
- Around line 77-79: The current check returns False for any config class in
MODEL_FOR_CAUSAL_LM_MAPPING, implicitly assuming dual-registered VLMs have
text_config; change the condition to verify the presence of text_config before
short-circuiting. Concretely, in the function containing the check for
MODEL_FOR_CAUSAL_LM_MAPPING, only return False if config.__class__ is in
MODEL_FOR_CAUSAL_LM_MAPPING AND hasattr(config, "text_config") (or equivalent),
otherwise fall through so is_vlm_with_causal_lm or the code path that avoids
AutoModelForCausalLM.from_pretrained can handle the model; reference symbols:
MODEL_FOR_CAUSAL_LM_MAPPING, is_vlm_with_causal_lm, text_config, and
AutoModelForCausalLM.from_pretrained.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 729efb30-bca3-477e-a0bd-ba536431ab3b

📥 Commits

Reviewing files that changed from the base of the PR and between 0c6614a and c3a9222.

📒 Files selected for processing (1)
  • src/instructlab/training/vlm_utils.py

@mergify mergify bot added the one-approval label Mar 21, 2026
@RobotSail RobotSail merged commit 75425ca into instructlab:main Mar 24, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants