-
Notifications
You must be signed in to change notification settings - Fork 62
Support transformers loading quantized moe model #1067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
1095f74
Support transformers loading quantized moe model
mengniwang95 7196a91
add ut
mengniwang95 147afda
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 173806e
Update test_moe_model.py
mengniwang95 8f784e5
Update test_moe_model.py
mengniwang95 a800891
Update convert_model.py
mengniwang95 d87b64d
fix ut and add cuda ut
mengniwang95 af449c6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 575e76e
Merge branch 'main' into mengni/tf_load
mengniwang95 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| import shutil | ||
|
|
||
| import pytest | ||
| from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration | ||
| from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM | ||
|
|
||
| from auto_round import AutoRound | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def setup_gpt_oss(): | ||
| """Fixture to set up the GPT-OSS model and tokenizer.""" | ||
| model_name = "/tf_dataset/auto_round/models/unsloth/gpt-oss-20b-BF16" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | ||
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | ||
| config.num_hidden_layers = 1 # Reduce layers for testing | ||
| model = GptOssForCausalLM(config) | ||
| output_dir = "/tmp/test_quantized_gpt_oss" | ||
| return model, tokenizer, output_dir, config | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def setup_llama4(): | ||
| """Fixture to set up the llama4 model and tokenizer.""" | ||
| model_name = "/tf_dataset/auto_round/models/meta-llama/Llama-4-Scout-17B-16E-Instruct" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | ||
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | ||
| config.vision_config.num_hidden_layers = 2 # Reduce layers for testing | ||
| config.text_config.num_hidden_layers = 2 | ||
| model = Llama4ForConditionalGeneration(config) | ||
| output_dir = "/tmp/test_quantized_llama4" | ||
| return model, tokenizer, output_dir, config | ||
|
|
||
|
|
||
| def quantize_model(model, tokenizer, output_dir, scheme, iters=0): | ||
| """Helper function to quantize the model with the given scheme.""" | ||
| autoround = AutoRound( | ||
| model, | ||
| tokenizer, | ||
| scheme=scheme, | ||
| nsamples=2, | ||
| iters=iters, | ||
| fp_layers="self_attn,router,lm_head,mlp.gate", | ||
| ) | ||
| quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) | ||
| return quantized_model | ||
|
|
||
|
|
||
| def test_gptoss(setup_gpt_oss): | ||
yiliu30 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| model, tokenizer, output_dir, config = setup_gpt_oss | ||
|
|
||
| # Below parameter is set to be same as the full model | ||
| # Remove it to avoid mismatch during quantized model loading | ||
| delattr(model.config, "layer_types") | ||
|
|
||
| quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4") | ||
|
|
||
| # Ensure the quantized model is not None | ||
| assert quantized_model is not None, "Quantized model should not be None." | ||
|
|
||
| loaded_model = GptOssForCausalLM.from_pretrained(output_dir) | ||
| for n, m in quantized_model.named_modules(): | ||
| if m.__class__.__name__ == "QuantLinear": | ||
| loaded_m = loaded_model.get_submodule(n) | ||
| assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all() | ||
| # clean the output directory after test | ||
| shutil.rmtree(output_dir, ignore_errors=True) | ||
|
|
||
|
|
||
| def test_llama4(setup_llama4): | ||
| model, tokenizer, output_dir, config = setup_llama4 | ||
|
|
||
| # Below parameters are set to be same as the full model | ||
| # Remove them to avoid mismatch during quantized model loading | ||
| model.config.text_config.no_rope_layers = [] | ||
| delattr(model.config.text_config, "moe_layers") | ||
| delattr(model.config.text_config, "layer_types") | ||
|
|
||
| quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4") | ||
|
|
||
| # Ensure the quantized model is not None | ||
| assert quantized_model is not None, "Quantized model should not be None." | ||
|
|
||
| loaded_model = Llama4ForConditionalGeneration.from_pretrained(output_dir) | ||
| for n, m in quantized_model.named_modules(): | ||
| if m.__class__.__name__ == "QuantLinear": | ||
| loaded_m = loaded_model.get_submodule(n) | ||
| assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all() | ||
| # clean the output directory after test | ||
| shutil.rmtree(output_dir, ignore_errors=True) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| import shutil | ||
|
|
||
| import pytest | ||
| import torch | ||
| from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration | ||
| from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM | ||
|
|
||
| from auto_round import AutoRound | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def setup_gpt_oss(): | ||
| """Fixture to set up the GPT-OSS model and tokenizer.""" | ||
| model_name = "/models/gpt-oss-20b-BF16" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | ||
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | ||
| config.num_hidden_layers = 1 # Reduce layers for testing | ||
| model = GptOssForCausalLM(config) | ||
| output_dir = "test_quantized_gpt_oss" | ||
| return model, tokenizer, output_dir, config | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def setup_llama4(): | ||
| """Fixture to set up the llama4 model and tokenizer.""" | ||
| model_name = "/dataset/Llama-4-Scout-17B-16E-Instruct" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | ||
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | ||
| config.vision_config.num_hidden_layers = 2 # Reduce layers for testing | ||
| config.text_config.num_hidden_layers = 2 | ||
| model = Llama4ForConditionalGeneration(config) | ||
| output_dir = "test_quantized_llama4" | ||
| return model, tokenizer, output_dir, config | ||
|
|
||
|
|
||
| def quantize_model(model, tokenizer, output_dir, scheme, iters=0): | ||
| """Helper function to quantize the model with the given scheme.""" | ||
| autoround = AutoRound( | ||
| model, | ||
| tokenizer, | ||
| scheme=scheme, | ||
| nsamples=2, | ||
| iters=iters, | ||
| fp_layers="self_attn,router,lm_head,mlp.gate", | ||
| ) | ||
| quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) | ||
| return quantized_model | ||
|
|
||
|
|
||
| def test_gptoss(setup_gpt_oss): | ||
| model, tokenizer, output_dir, config = setup_gpt_oss | ||
|
|
||
| # Below parameter is set to be same as the full model | ||
| # Remove it to avoid mismatch during quantized model loading | ||
| delattr(model.config, "layer_types") | ||
|
|
||
| quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4") | ||
|
|
||
| # Ensure the quantized model is not None | ||
| assert quantized_model is not None, "Quantized model should not be None." | ||
|
|
||
| loaded_model = GptOssForCausalLM.from_pretrained(output_dir) | ||
| quantized_model.to("cuda") | ||
| loaded_model.to("cuda") | ||
| for n, m in quantized_model.named_modules(): | ||
| if m.__class__.__name__ == "QuantLinear": | ||
| loaded_m = loaded_model.get_submodule(n) | ||
| assert (loaded_m.weight_packed == m.weight_packed).all() | ||
|
|
||
| inp = torch.randint(0, 100, (1, 64)).to("cuda") | ||
| with torch.inference_mode(): | ||
| loaded_out = loaded_model(inp) | ||
|
|
||
| # clean the output directory after test | ||
| shutil.rmtree(output_dir, ignore_errors=True) | ||
|
|
||
|
|
||
| def test_llama4(setup_llama4): | ||
| model, tokenizer, output_dir, config = setup_llama4 | ||
|
|
||
| # Below parameters are set to be same as the full model | ||
| # Remove them to avoid mismatch during quantized model loading | ||
| model.config.text_config.no_rope_layers = [] | ||
| delattr(model.config.text_config, "moe_layers") | ||
| delattr(model.config.text_config, "layer_types") | ||
|
|
||
| quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4") | ||
|
|
||
| # Ensure the quantized model is not None | ||
| assert quantized_model is not None, "Quantized model should not be None." | ||
|
|
||
| loaded_model = Llama4ForConditionalGeneration.from_pretrained(output_dir) | ||
| quantized_model.to("cuda") | ||
| loaded_model.to("cuda") | ||
| for n, m in quantized_model.named_modules(): | ||
| if m.__class__.__name__ == "QuantLinear": | ||
| loaded_m = loaded_model.get_submodule(n) | ||
| assert (loaded_m.weight_packed == m.weight_packed).all() | ||
|
|
||
| inp = torch.randint(0, 100, (1, 64)).to("cuda") | ||
| with torch.inference_mode(): | ||
| loaded_out = loaded_model(inp) | ||
|
|
||
| # clean the output directory after test | ||
| shutil.rmtree(output_dir, ignore_errors=True) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.