Skip to content

Add SarvamMLA model (sarvamai/sarvam-105b)#44569

Open
aashay-sarvam wants to merge 8 commits intohuggingface:mainfrom
aashay-sarvam:add-sarvam-mla-model
Open

Add SarvamMLA model (sarvamai/sarvam-105b)#44569
aashay-sarvam wants to merge 8 commits intohuggingface:mainfrom
aashay-sarvam:add-sarvam-mla-model

Conversation

@aashay-sarvam
Copy link

What does this PR do?

Adds native support for the sarvam_mla model type (sarvamai/sarvam-105b) to HuggingFace Transformers using the modular pattern, inheriting from DeepSeek V3.

Model Architecture

SarvamMLA is a 105B parameter Mixture of Experts (MoE) language model developed by Sarvam AI. It uses:

  • Multi-head Latent Attention (MLA): Low-rank KV compression with decoupled RoPE
  • Sparse MoE: 128 routed experts, 8 active per token, plus 1 shared expert
  • First layer dense: first_k_dense_replace=1 (Layer 0 = dense MLP, Layer 1+ = MoE)
  • DeepSeek YaRN RoPE: Extended context up to 131K tokens
  • Sigmoid routing with group-based top-k

Files Added

File Description
src/transformers/models/sarvam_mla/__init__.py Lazy loading module
src/transformers/models/sarvam_mla/configuration_sarvam_mla.py Config with Hub compatibility (head_dim, rope_type normalization)
src/transformers/models/sarvam_mla/modular_sarvam_mla.py 48-line modular file inheriting DeepSeek V3
src/transformers/models/sarvam_mla/modeling_sarvam_mla.py Auto-generated from modular (736 lines)
tests/models/sarvam_mla/test_modeling_sarvam_mla.py Unit tests
docs/source/en/model_doc/sarvam_mla.md Documentation with usage examples

Files Modified

  • src/transformers/models/auto/configuration_auto.py — CONFIG_MAPPING_NAMES, MODEL_NAMES_MAPPING
  • src/transformers/models/auto/modeling_auto.py — MODEL_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM, SEQUENCE_CLASSIFICATION, TOKEN_CLASSIFICATION
  • src/transformers/models/__init__.py — import
  • src/transformers/conversion_mapping.py"sarvam_mla": "qwen2_moe" (per-expert → batched weight conversion)
  • docs/source/en/_toctree.yml — docs index

Hub Compatibility Fixes (in config)

  1. head_dim override: Hub config has head_dim: 576 (for vLLM MLA compat), but internally the model uses qk_rope_head_dim = 64 for RoPE. Popped from kwargs.
  2. deepseek_yarn rope type: Hub config uses "type": "deepseek_yarn" but ROPE_INIT_FUNCTIONS only has "yarn". Normalized in config __init__.
  3. Weight conversion: Per-expert ModuleList weights on Hub need conversion to batched format. Handled via qwen2_moe conversion pattern.

Test Results

  • Unit tests: 140 passed, 92 skipped (on GPU node)
  • End-to-end test: Full 105B model loaded in bf16 across 8× H100 80GB GPUs
    • Config loads as SarvamMLAConfig with model_type=sarvam_mla
    • Layer 0 MLP = SarvamMLAMLP (dense), Layer 1+ = SarvamMLAMoE
    • Generation produces coherent text ✓

Who can review?

@ArthurZucker

@Rocketknight1
Copy link
Member

Hi @aashay-sarvam, it looks like the architecture is identical to Deepseek V3! Can you just upload your checkpoints with that model type instead?

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey there 👋

I have left a few smaller comments but essentially we can already use the existing code. I assume that you want to keep the model type which is what my comments expect. It could also be updated to work with deepseek v3 directly completely (like Matt mentioned before me) but then it loses its identity I guess (no model type, no default config)

@vasqu
Copy link
Contributor

vasqu commented Mar 13, 2026

#44569 (comment) might've gotten lost

@aashay-sarvam
Copy link
Author

#44569 (comment) might've gotten lost

Missed this - will make the changes

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some new comments, because things changed a bit on main - sorry 😓

For the CI make fix-repo should fix most smaller things

Comment on lines +83 to +90
# Hub config.json uses num_experts/num_shared_experts; map to parent names
n_routed_experts = kwargs.pop("num_experts", n_routed_experts)
n_shared_experts = kwargs.pop("num_shared_experts", n_shared_experts)

# head_dim in Hub config.json is kv_lora_rank + qk_rope_head_dim (for vLLM
# MLA compat), but DeepseekV3Config computes it as qk_rope_head_dim.
kwargs.pop("head_dim", None)
kwargs.pop("q_head_dim", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to change the remote config instead of adding workarounds?

**kwargs,
)

def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: set | None = None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, we could properly change the remote config instead to have a proper attribute for this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have made the changes, though I still need to push the model config to the model repo (though I have tested locally)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, question - sglang uses sarvamMLA, will that break with this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not too familiar with sglangs integration here tbh. Imo, it can and should be able to use the (native) deepseek architecture - might need a nudge to respect the architecture 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @adarshxs if you have any insights re sglang

Add native support for the sarvam_mla model type using the modular
pattern, inheriting from DeepSeek V3. The model uses Multi-head Latent
Attention (MLA) with Mixture of Experts (MoE), supporting 105B parameters
with 128 routed experts and 8 active per token.

New files:
- configuration_sarvam_mla.py: Config with attribute mapping, rope
  normalization, and head_dim handling for Hub compatibility
- modular_sarvam_mla.py: 48-line modular file inheriting DeepSeek V3
- modeling_sarvam_mla.py: Auto-generated from modular (736 lines)
- test_modeling_sarvam_mla.py: 140 passing unit tests
- sarvam_mla.md: Documentation with usage examples

Modified files:
- Auto-registration in configuration_auto.py, modeling_auto.py
- Model import in models/__init__.py
- Weight conversion mapping (qwen2_moe pattern) in conversion_mapping.py
- Documentation index in _toctree.yml

Made-with: Cursor
Per vasqu's review:
- Remove modular_sarvam_mla.py and modeling_sarvam_mla.py (no need
  to re-implement identical DeepSeek V3 architecture)
- Point auto mappings directly to DeepseekV3 model classes
- Move rope type normalization (deepseek_yarn -> yarn) to
  convert_rope_params_to_dict override
- Remove test file (DeepseekV3 tests cover the architecture)
- Slim down docs to config-only autodoc

Made-with: Cursor
Move SarvamMLAConfig definition into modular_sarvam_mla.py and
auto-generate configuration_sarvam_mla.py from it, following the
canonical transformers modular pattern.

Made-with: Cursor
- Remove torch_dtype="auto" from docs (now default)
- Simplify modular_sarvam_mla.py to only override defaults that differ
  from DeepseekV3Config (no __init__, no workarounds)
- Add @strict(accept_kwargs=True) for config validation (huggingface#41250)
- Regenerate configuration_sarvam_mla.py with dataclass fields and
  __post_init__ pattern
- Hub config.json changes needed: remove head_dim/q_head_dim, change
  rope_scaling.type to "yarn", update architectures

Made-with: Cursor
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: auto

@github-actions
Copy link
Contributor

View the CircleCI Test Summary for this PR:

https://huggingface.co/spaces/transformers-community/circle-ci-viz?pr=44569&sha=3d9969

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants