Conversation
bff4d6b to
9f9995d
Compare
7d2d617 to
d5aef88
Compare
- Create module directory structure - Add Gemma3nConfig dataclass with E4B defaults from HuggingFace - Add is_global_layer() helper for 4:1 local:global ratio - Create test directory with initial config tests - All tests passing (4/4) ~200 LOC implementation + 73 LOC tests
- Add interop.py with HF→fairseq2 state dict key mappings - Add checkpoint.py with stub loader (to be implemented in Phase 2-3) - Export conversion functions from __init__.py - Add tests for stub validation - All tests passing (6/6) ~130 LOC implementation + 17 LOC new tests
- Add hub.py with TODO for model/tokenizer hub accessors - Add tokenizer.py stub (will reuse Gemma tokenizer) - Add sharder.py with tensor parallelism sharding plan - Update __init__.py with stub comments - All tests still passing (6/6) ~90 LOC stubs
- Implement dual-frequency RoPE with theta=10,000 and dual_theta=100,000 - Split head dimension in half, apply different RoPE to each half - Add validation for encoding_dim divisibility by 4 - Add tests for functionality and input validation
- Implement soft-capped SDPA: tanh(logits / cap) * cap - Apply capping between scaled dot-product and softmax - Add tests for basic functionality and logit bounding - Default soft_cap=30.0 matching Gemma3n spec
- Implement GLU-style FFN with GELU activation for local layers - Simpler interface than GLUFeedForwardNetwork - Add tests for functionality and GELU activation verification
- Implement create_gemma3n_decoder_layer factory function - Local layers: DualRotaryEncoder + SoftCappedSDPA + AltUpFFN - Global layers: SoftCappedSDPA + GLU FFN with GELU - Add altup_hidden_dim config field (5376) - Add tests for local and global layer creation
- Removed 5 narrator comments from factory.py (SDPA, multi-head attention, FFN, layer norms, decoder layer) - Added docstring to AltUpFeedForwardNetwork.__init__ per FAIR standards - All tests passing (14/14)
- Implement create_gemma3n_model() factory function - Add Gemma3nFactory class with component creation methods - Add tests for model creation and forward pass - All tests passing (16/16) - ~150 LOC
- Add test_parity.sh for manual execution on compute nodes - Verified model creation and forward pass on H100 GPUs - Test confirms: 8 GPUs accessible, forward pass working, logits in valid range - Phase 3 basic validation complete
- Update get_gemma3n_e2b_config() to use 30 layers instead of 35 - All config values now match google/gemma-3n-E2B-it exactly - Verified: layers, hidden_dim, heads, vocab, max_seq_len, sliding_window, rms_eps - Model creates successfully with 2.05B parameters
- Fix key mappings: model.language_model.* instead of model.* - Add support for actual Gemma3n-E2B checkpoint structure - Note: LAuReL, PLE features deferred to Phase 4/5 (advanced features) - Basic text model conversion now working
- Verified HF→fairseq2 key conversion for language_model
- Tested with actual Gemma3n-E2B weights from HuggingFace
- Conversion correctly maps: input_layernorm, post_attention_layernorm, mlp.{down,up}_proj
- Note: LAuReL, PLE, AltUp features deferred to Phase 4/5 (advanced features)
- Basic text transformer conversion complete and validated
- Move all test scripts to scripts/gemma3n_validation/ - Add README documenting validation process and results - Clean up repository root
- Implement LAuReLResidualConnect with low-rank learned residual - linear_left: (rank, model_dim) downprojection - linear_right: (model_dim, rank) upprojection - post_laurel_norm: RMSNorm after residual connection - rank=64 for all layers - Add QK normalization to attention - RMSNorm applied to Q and K projections - Dimension: head_dim (256) - Replaces soft-capping from Gemma 2 - Update checkpoint conversion mappings - laurel.linear_left → self_attn_residual.linear_left - laurel.linear_right → self_attn_residual.linear_right - laurel.post_laurel_norm → self_attn_residual.layer_norm - self_attn.q_norm → self_attn.q_norm - self_attn.k_norm → self_attn.k_norm - Add validation scripts - test_laurel.sh: verify LAuReL implementation - test_laurel_conversion.sh: verify checkpoint conversion - inspect_qk_norm.sh: inspect QK norm parameters
- Create Gemma3nDecoderLayer inheriting from TransformerLMDecoderLayer - Clean implementation with explicit forward pass - Composition-based design for flexibility - Supports LAuReL residual, pre/post feedforward norms - Layer structure: 1. input_layernorm (pre-attention) 2. self-attention with QK norm 3. LAuReL residual (with post_laurel_norm) 4. post_attention_layernorm (pre-FFN) 5. pre_feedforward_layernorm (optional, before FFN) 6. FFN (AltUp or standard GLU) 7. post_feedforward_layernorm (optional, after FFN) 8. FFN residual - Update checkpoint conversion - Map all 4 normalization layers correctly - input_layernorm, post_attention_layernorm remain flat - pre_feedforward_layernorm, post_feedforward_layernorm added - Replace StandardTransformerLMDecoderLayer with Gemma3nDecoderLayer - Cleaner abstraction for Gemma3n-specific features - Easier to extend for PLE and AltUp routing
- Create PerLayerEmbedding module with gating mechanism - Shared embed_tokens_per_layer lookup table - Per-layer input gate for dynamic contribution - Per-layer projection to model dimension - Post-PLE normalization - Update Gemma3nDecoderLayer to support optional PLE - Add ple parameter (optional Module) - Register PLE module when provided - Structure matches HF checkpoint: - vocab_size_per_layer_input: 262_144 - hidden_size_per_layer_input: 256 - Projects to model_dim: 2048 Note: PLE forward integration deferred - needs token_ids threading through layer interface or alternative design.
Key fixes: - Set num_kv_shared_layers = 10 (was 15) to match HF config - Implement KV projection sharing with enum-based slots (LOCAL/GLOBAL) - SOURCE layers (18 local, 19 global) store K/V for consumers - CONSUMER layers (20-29) retrieve shared K/V instead of computing own Parity results: - 100% token prediction agreement - Max absolute diff: 1.39e-04 - Max relative diff: 2.67e-03 Components added: - KVProjectionType/KVProjectionRole enums for type-safe sharing - Gemma3nDecoder with AltUp 4D processing and PLE support - Gemma3nFrontend for embedding and PLE generation - Gemma3nLM model wrapper with softcapping - AltUp predict/correct mechanism - LAuReL augmented residual connections
- Removed 150+ debug and test scripts used during development - Updated PARITY_STATUS.md with comprehensive implementation docs - Kept single end-to-end parity test: test_parity.py - Retained KV sharing analysis documentation Final state: - scripts/gemma3n_validation/test_parity.py - Complete parity test - scripts/gemma3n_validation/PARITY_STATUS.md - Implementation guide - scripts/gemma3n_validation/HF_KV_SHARING_ANALYSIS.md - KV analysis - scripts/gemma3n_validation/KV_SHARING_IMPLEMENTATION.md - Design docs
Add 50 unit tests covering all parity-critical Gemma3n components: - test_kv_sharing.py (7 tests): KV projection sharing registry - Store/retrieve K/V tensors - Multiple source layers handling - Error handling and registry lifecycle - test_activation_sparsity.py (12 tests): Gaussian top-k sparsification - Zero/high sparsity behavior validation - Sparse vs dense output comparison - Various sparsity levels (0.0 to 0.99) - test_altup.py (10 tests): AltUp 4D predict/correct mechanism - 4D tensor shape handling [num_inputs, B, S, D] - Predict/correct cycle verification - Router modalities and coefficient clipping - test_laurel.py (9 tests): LAuReL low-rank residual layer - Low-rank bottleneck (D → rank → D) - Residual connection verification - Numerical stability across rank values - test_ple.py (12 tests): Per-layer embeddings - Embedding lookup and projection - Gating mechanism - Token-specific outputs All tests run in <1s without checkpoints using minimal dimensions. Bug fixes: - Fix StandardEmbedding parameter: embedding_dim → embed_dim - Move imports to module level in decoder_layer.py
- Changed Gemma3nModel to inherit from CausalLM instead of @Final TransformerLM - Added scale parameter to create_default_sdpa() for models with QK normalization - Created SoftcappedProjection wrapper to avoid subclassing @Final classes - Fixed all basedpyright type errors (0 errors, 123 warnings) - Removed dead code: KVSharedLayerRegistry, test_kv_sharing.py, get_kv_sharing_config() - Cleaned up stale docstring parameters in decoder_layer.py Parity verified: 100% token agreement maintained
Asset System Integration:
- Register gemma3n model family in composition/models.py
- Register gemma3n tokenizer family in composition/tokenizers.py
- Create hub.py with get_gemma3n_model_hub and get_gemma3n_tokenizer_hub
- Add asset cards for all 4 variants (E2B/E4B base and instruct)
- Register config variants (e2b, e4b) with ConfigRegistrar
Model Configuration:
- Fix E4B altup_hidden_dim: 5376 → 16384 to match checkpoint
- E2B uses 8192 for both ffn_inner_dim and altup_hidden_dim
- E4B uses 16384 for both (all layers use same dimension)
Checkpoint Loading:
- Filter multimodal components from HF checkpoints (vision_tower, audio_tower, embed_audio, embed_vision)
- Remove checkpoint.py stub (not needed - SafetensorsCheckpointLoader handles it)
- State dict conversion happens in interop.py via convert_gemma3n_state_dict
Tokenizer Implementation:
- Create Gemma3nTokenizer wrapping HuggingFaceTokenModel
- Add BOS prefix (not EOS suffix) matching HF behavior
- Implement apply_chat_template() for instruct formatting
- Expose chat_template property
- Support modes: default (BOS prefix), prompt (BOS prefix), as_is (no special tokens)
Validation Scripts:
- discover_variants.py: Find all Gemma3n variants on HuggingFace
- compare_tokenizers.py: Compare base vs instruct tokenizer differences
- test_asset_loading.py: Test tokenizer loading for all variants
- test_tokenizer_parity.py: Verify encoding/decoding matches HF exactly
- test_chat_template.py: Verify chat template formatting matches HF
- test_torchsdpa_parity.py: Existing SDPA parity validation
Testing:
- ✅ Model loading works for E2B and E4B (base and instruct)
- ✅ Tokenizer encoding matches HuggingFace exactly (BOS prefix)
- ✅ Special token IDs match HuggingFace
- ✅ Chat template formatting matches HuggingFace
- ✅ All 4 variants load successfully
API Usage:
from fairseq2.models import load_model
from fairseq2.data.tokenizers import load_tokenizer
model = load_model("gemma3n_e2b")
tokenizer = load_tokenizer("gemma3n_e2b_instruct")
# Chat template support
conversation = [
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"},
]
formatted = tokenizer.apply_chat_template(conversation, tokenize=False)
torchaudio.load requires torchcodec which isn't installed. fairseq2's AudioDecoder uses libsndfile via fairseq2n bindings.
Avoids separate tokenizer download that fails on gated repos. AutoTokenizer uses the same auth/cache path as AutoModelForCausalLM.
- Try fairseq2 hub (load_model/load_tokenizer) first - Fall back to direct safetensors loading from HF cache - Load tokenizer via AutoTokenizer only as last resort - Remove all direct HF model imports from main path
Uses fairseq2 hub for model and tokenizer loading, fairseq2 AudioDecoder for audio I/O, custom mel extraction. Zero HF dependencies in the inference path.
- HF Gemma3nTextConfig uses 15 for both E2B and E4B - Wrong value caused 5 layers to compute independent KV instead of sharing, breaking E4B inference (checkpoint weight role mismatch) - Also converts docstrings to :param:/:returns: style
- subsample.py: Replace GroupNorm with CumulativeGroupNorm (cumulative running stats along time), SiLU with ReLU, and symmetric conv padding with reverse-causal F.pad matching HF reference - sdpa.py: Always create validity mask for zero-padded block context positions (previously only when explicit mask was passed, leaving padding positions unmasked and diluting attention) - sdpa.py: Add reset_non_persistent_buffers() so q_scale, softcap, local_causal_valid_mask, and inv_timescales are re-initialized after checkpoint loading (fairseq2 zeros non-persistent buffers on load)
… KV sharing
- Add FSDP wrap policy (layer/stack granularity) following TransformerLM pattern
- Add activation checkpointing (layerwise) following TransformerLM pattern
- Implement HuggingFace export (reverse key map + config conversion)
- Add chat template with {%- generation %} blocks for SFT assistant mask
- Move KV projection slot management from decoder_layer to decoder forward
loop so dict mutation side effects work through FSDP-wrapped layers
- Register fsdp_applier, layerwise_ac_applier, hg_exporter in composition
- Add E2B GSM8K SFT config for smoke testing (batch_size=4, chat_mode=true)
- Fix isort, black, flake8, mypy issues in new code
- scripts/gemma3n_validation/ was development-only parity testing - tmp-docs/ was implementation planning artifacts
FSDP support is implemented via fsdp.py; sharder is not needed.
Gemma3n uses HTK-style mel extraction incompatible with fairseq2 defaults.
a5dba58 to
7909a0b
Compare
- Run isort/black on all gemma3n source and test files - Remove unused imports (PositionEncoder, Final, GELU) - Add type: ignore comments for buffer-as-tensor false positives - Fix missing Projection import in ffn.py - Add TYPE_CHECKING import for Gemma3nAudioTower in factory - Remove dead code block in GLUFeedForwardNetwork - Remove unused test variables
7909a0b to
abe6460
Compare
rsyue
reviewed
Mar 19, 2026
Contributor
There was a problem hiding this comment.
Really nice to see a clean, proper multimodal implementation via an HG model! Here are some suggestions:
Training config could be a bit more tuned
I do know this is for a quickstart, but wanted to give some input here for a realistic training scenario and for users who might want to start from a baseline.
- fp32 might be a bit overkill. Example: Olmo train config and default Gemma3 causal use bf16 with static mixed precision.
- Default learning rate (3e-4) is quite high for a pretrained model, catastrophic forgetting a possibility.
- LR warmup steps would be useful early in the training process (Olmo config uses 100)
- SDPA default (none) is torch_math, slower and more memory hungry. Consider FlashAttention2 here as a default
- For reproducibility,
common.seedwould be useful (if training reproducibility is the goal, ignore non-deterministic fa2 comment above) - Setting max_seq_len or max_num_tokens to correspond to the model defaults or thereabouts would be good in an experimental setting
- Not 100% sure if the checkpoint is still in native hf format, but export_hugging_face should be set to False if so. Gemma3 config disables this (defaults to True)
General questions
- Coordinating chat_mode, currently True. This is a good fairseq2-esque setting, but dataset.py:207 checks for HuggingFaceTokenEncoder and applies the chat template, assuming assistant_masks in output. --> Check if gemma3n tokenizer reliably produces the masks. If I recall, chat_mode is {"src": "Instruction here", "tgt": "Assistant answer here"}
- See comment in the config yaml: I got a KeyError. It seems like dataset.py:181 does a raw dict lookup without error handling. Perhaps chat with team about this, but in the meantime, error handling might be good
Minor but good to have in config
- keep_last_n_checkpoints: 10 with ckpt every 100 steps could miss out on better previous ckpts, if this is used in a baseline setting
- Run isort/black on all gemma3n source and test files - Remove unused imports (PositionEncoder, Final, GELU) - Add type: ignore comments for buffer-as-tensor false positives - Fix missing Projection import in ffn.py - Add TYPE_CHECKING import for Gemma3nAudioTower in factory - Remove dead code block in GLUFeedForwardNetwork - Remove unused test variables
artemru
approved these changes
Mar 24, 2026
zyaoj
approved these changes
Mar 24, 2026
- Removed dead code path from mha - Enabled huggingface export on checkpoint (regression + bug) - Changed interop to not modify caller vars - Residual str expr using correct members
Contributor
Author
|
Thanks to everybody for contributing! |
cirquit
added a commit
that referenced
this pull request
Mar 25, 2026
**Summary** - Fix test and lint regressions introduced in v0.8 development - Add backward compatibility shims for breaking API changes - Update CHANGELOG and README for v0.8 release **Regressions** - Fix `test_get_shard_dims_work` device mismatch when running with `--device cuda` - Replace deprecated `datetime.utcnow()` with `datetime.now(timezone.utc)` - Bump `black` to `~=26.3` (CVE fix) and reformat lines that the new parser rejects - Fix `Flash3SDPA` to support `flash-attn-3` v3.0.0 API (#1495) - Pin `pandas~=2.2` for Python 3.12 compatibility **Backward compatibility shims** - Add re-export shims for `fairseq2.recipe.validator` and `fairseq2.recipe.task` (#1417) - Add deprecated `resolve_optional()` on `DependencyResolver` (#1462) - Add deprecated `ModelCheckpointError` alias for `CorruptModelCheckpointError` (#1475) **Release prep** - Update CHANGELOG with missing entries, PR references, and new features (#1479, #1496) - Add v0.7 and v0.8 rows to README version matrix
YunchaoYang
pushed a commit
that referenced
this pull request
Mar 31, 2026
…T training (#1496) Full fairseq2 implementation of [Google's Gemma3n model family (E2B and E4B variants)](https://ai.google.dev/gemma/docs/gemma-3n), achieving inference parity with HuggingFace on audio and text. Vision tower is not included. **Text Tower** - Gemma3nModel (CausalLM) with AltUp 4D predict/correct, Per-Layer Embeddings (PLE), LAuReL low-rank residual connections, QK normalization, and KV projection sharing across layers - Local/global layer pattern with sliding window and full attention - Activation sparsity (Gaussian top-k) for local layers **Audio Tower** - USM Conformer encoder (12 layers) with chunked local attention, Shaw relative position embeddings, and per-dimension scaling - Subsample convolution projection with CumulativeGroupNorm (4x downsampling) - Multimodal embedder projecting audio features to text space **Integration** - Asset cards for all 4 variants (E2B/E4B, base/instruct) - HF checkpoint loading with full key mapping (text + audio) - HF export with nested Gemma3nTextConfig and `rope_parameters` mapping - Tokenizer with chat template support (generation blocks for SFT masking) - FSDP wrap policy and activation checkpointing - SFT recipe config (`gemma3n_e2b_gsm8k.yaml`) **Validation** - 100% token prediction agreement with HuggingFace (text inference) - HG export → transformers round-trip loading verified - 64 unit tests covering KV sharing, AltUp, LAuReL, PLE, activation sparsity, audio components
YunchaoYang
pushed a commit
that referenced
this pull request
Mar 31, 2026
**Summary** - Fix test and lint regressions introduced in v0.8 development - Add backward compatibility shims for breaking API changes - Update CHANGELOG and README for v0.8 release **Regressions** - Fix `test_get_shard_dims_work` device mismatch when running with `--device cuda` - Replace deprecated `datetime.utcnow()` with `datetime.now(timezone.utc)` - Bump `black` to `~=26.3` (CVE fix) and reformat lines that the new parser rejects - Fix `Flash3SDPA` to support `flash-attn-3` v3.0.0 API (#1495) - Pin `pandas~=2.2` for Python 3.12 compatibility **Backward compatibility shims** - Add re-export shims for `fairseq2.recipe.validator` and `fairseq2.recipe.task` (#1417) - Add deprecated `resolve_optional()` on `DependencyResolver` (#1462) - Add deprecated `ModelCheckpointError` alias for `CorruptModelCheckpointError` (#1475) **Release prep** - Update CHANGELOG with missing entries, PR references, and new features (#1479, #1496) - Add v0.7 and v0.8 rows to README version matrix
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do? Please describe:
Adds full Gemma3n (E2B/E4B) support: model architecture, audio tower, tokenizer, and training infrastructure. No vision.
Test:
torchrun --nproc-per-node=8 -m recipes.lm.sft /tmp/gemma3n_sft_test \ --config-file recipes/lm/sft/configs/gemma3n_e2b_gsm8k.yaml \ --config trainer.data_parallelism=fsdp \ --config regime.num_steps=200 \ --config regime.checkpoint_every_n_steps=100 \ --config regime.validate_every_n_steps=100Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.
Check list: