Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions auto_round/inference/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from auto_round.inference.utils import _expand_regex_config
from auto_round.logger import logger
from auto_round.schemes import QuantizationScheme
from auto_round.special_model_handler import _handle_moe_model
from auto_round.utils import (
SUPPORTED_LAYER_TYPES,
check_start_with_block_name,
Expand Down Expand Up @@ -582,6 +583,9 @@ def convert_hf_model(model: nn.Module, target_device: str = "cpu") -> tuple[nn.M
elif packing_format == "auto_round:gptq":
packing_format = "auto_round:auto_gptq"

# Preprocess model before replace layers
model = _handle_moe_model(model)

# Replace layers with quantized versions
layer_configs = get_layer_config(model, quantization_config)
used_backends = _replace_by_quant_layers(model, layer_configs, backend, target_device, packing_format)
Expand Down
22 changes: 10 additions & 12 deletions auto_round/modelling/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP

from auto_round.utils import unsupported_meta_device

__all__ = ["get_replacement_info"]


Expand Down Expand Up @@ -82,19 +84,15 @@ def __init__(self, config: GptOssConfig, original: GptOssMLP):
for _ in range(E):
self.experts.append(GPTOssSingleExpert(hidden_size, intermediate_size, dtype=dtype))

gup = original.experts.gate_up_proj # [E, H, 2I]
gup_b = original.experts.gate_up_proj_bias # [E, 2I]
dwn = original.experts.down_proj # [E, I, H]
dwn_b = original.experts.down_proj_bias # [E, H]

for i, mlp in enumerate(self.experts):
_update_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T)
_update_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T)
_update_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T)
if not unsupported_meta_device(original):
for i, mlp in enumerate(self.experts):
_update_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T)
_update_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T)
_update_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T)

_update_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2])
_update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2])
_update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H]
_update_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2])
_update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2])
_update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H]

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
B, T, H = hidden_states.shape
Expand Down
24 changes: 14 additions & 10 deletions auto_round/modelling/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,27 @@
from transformers.modeling_utils import no_init_weights
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP

from auto_round.utils import unsupported_meta_device


class SequentialLlama4TextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
self.num_experts = original.gate_up_proj.shape[0]
with no_init_weights():
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
intermediate_size = original.down_proj.shape[1]

for i in range(self.num_experts):
gate_up = original.gate_up_proj[i]
down = original.down_proj[i]
gate_proj = gate_up[:, :intermediate_size]
up_proj = gate_up[:, intermediate_size:]

self[i].gate_proj.weight.data = gate_proj.t().contiguous()
self[i].up_proj.weight.data = up_proj.t().contiguous()
self[i].down_proj.weight.data = down.t().contiguous()
if not unsupported_meta_device(original):
intermediate_size = original.down_proj.shape[1]

for i in range(self.num_experts):
gate_up = original.gate_up_proj[i]
down = original.down_proj[i]
gate_proj = gate_up[:, :intermediate_size]
up_proj = gate_up[:, intermediate_size:]

self[i].gate_proj.weight.data.copy_(gate_proj.t())
self[i].up_proj.weight.data.copy_(up_proj.t())
self[i].down_proj.weight.data.copy_(down.t())


class SequentialLlama4TextMoe(torch.nn.Module):
Expand Down
10 changes: 4 additions & 6 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import auto_round.modelling as auto_round_modelling
from auto_round.utils import LazyImport, logger
from auto_round.utils import LazyImport, logger, unsupported_meta_device

mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size

Expand Down Expand Up @@ -76,8 +76,9 @@ def _handle_moe_model(model, formats=None):
from auto_round.utils import clear_memory

new_moe_class, convert_config, orig_cls_name = _get_moe_converter(model.config)
model = model.to("cpu")
clear_memory()
if not unsupported_meta_device(model):
model = model.to("cpu")
clear_memory()

for name, module in tqdm(model.named_modules(), desc="Converting model"):
cls_name = module.__class__.__name__
Expand All @@ -87,9 +88,6 @@ def _handle_moe_model(model, formats=None):
parent = model.get_submodule(parent)
setattr(parent, child, new_module)

logger.warning(
f"{model.config.model_type} experts are converted, the quantized model can not run on transformers."
)
return model


Expand Down
90 changes: 90 additions & 0 deletions test/test_cpu/test_moe_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import shutil

import pytest
from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM

from auto_round import AutoRound


@pytest.fixture
def setup_gpt_oss():
"""Fixture to set up the GPT-OSS model and tokenizer."""
model_name = "/tf_dataset/auto_round/models/unsloth/gpt-oss-20b-BF16"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.num_hidden_layers = 1 # Reduce layers for testing
model = GptOssForCausalLM(config)
output_dir = "/tmp/test_quantized_gpt_oss"
return model, tokenizer, output_dir, config


@pytest.fixture
def setup_llama4():
"""Fixture to set up the llama4 model and tokenizer."""
model_name = "/tf_dataset/auto_round/models/meta-llama/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.vision_config.num_hidden_layers = 2 # Reduce layers for testing
config.text_config.num_hidden_layers = 2
model = Llama4ForConditionalGeneration(config)
output_dir = "/tmp/test_quantized_llama4"
return model, tokenizer, output_dir, config


def quantize_model(model, tokenizer, output_dir, scheme, iters=0):
"""Helper function to quantize the model with the given scheme."""
autoround = AutoRound(
model,
tokenizer,
scheme=scheme,
nsamples=2,
iters=iters,
fp_layers="self_attn,router,lm_head,mlp.gate",
)
quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir)
return quantized_model


def test_gptoss(setup_gpt_oss):
model, tokenizer, output_dir, config = setup_gpt_oss

# Below parameter is set to be same as the full model
# Remove it to avoid mismatch during quantized model loading
delattr(model.config, "layer_types")

quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4")

# Ensure the quantized model is not None
assert quantized_model is not None, "Quantized model should not be None."

loaded_model = GptOssForCausalLM.from_pretrained(output_dir)
for n, m in quantized_model.named_modules():
if m.__class__.__name__ == "QuantLinear":
loaded_m = loaded_model.get_submodule(n)
assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all()
# clean the output directory after test
shutil.rmtree(output_dir, ignore_errors=True)


def test_llama4(setup_llama4):
model, tokenizer, output_dir, config = setup_llama4

# Below parameters are set to be same as the full model
# Remove them to avoid mismatch during quantized model loading
model.config.text_config.no_rope_layers = []
delattr(model.config.text_config, "moe_layers")
delattr(model.config.text_config, "layer_types")

quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4")

# Ensure the quantized model is not None
assert quantized_model is not None, "Quantized model should not be None."

loaded_model = Llama4ForConditionalGeneration.from_pretrained(output_dir)
for n, m in quantized_model.named_modules():
if m.__class__.__name__ == "QuantLinear":
loaded_m = loaded_model.get_submodule(n)
assert (loaded_m.weight_packed.to("cpu") == m.weight_packed.to("cpu")).all()
# clean the output directory after test
shutil.rmtree(output_dir, ignore_errors=True)
105 changes: 105 additions & 0 deletions test/test_cuda/test_moe_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import shutil

import pytest
import torch
from transformers import AutoConfig, AutoTokenizer, Llama4ForConditionalGeneration
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM

from auto_round import AutoRound


@pytest.fixture
def setup_gpt_oss():
"""Fixture to set up the GPT-OSS model and tokenizer."""
model_name = "/models/gpt-oss-20b-BF16"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.num_hidden_layers = 1 # Reduce layers for testing
model = GptOssForCausalLM(config)
output_dir = "test_quantized_gpt_oss"
return model, tokenizer, output_dir, config


@pytest.fixture
def setup_llama4():
"""Fixture to set up the llama4 model and tokenizer."""
model_name = "/dataset/Llama-4-Scout-17B-16E-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.vision_config.num_hidden_layers = 2 # Reduce layers for testing
config.text_config.num_hidden_layers = 2
model = Llama4ForConditionalGeneration(config)
output_dir = "test_quantized_llama4"
return model, tokenizer, output_dir, config


def quantize_model(model, tokenizer, output_dir, scheme, iters=0):
"""Helper function to quantize the model with the given scheme."""
autoround = AutoRound(
model,
tokenizer,
scheme=scheme,
nsamples=2,
iters=iters,
fp_layers="self_attn,router,lm_head,mlp.gate",
)
quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir)
return quantized_model


def test_gptoss(setup_gpt_oss):
model, tokenizer, output_dir, config = setup_gpt_oss

# Below parameter is set to be same as the full model
# Remove it to avoid mismatch during quantized model loading
delattr(model.config, "layer_types")

quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4")

# Ensure the quantized model is not None
assert quantized_model is not None, "Quantized model should not be None."

loaded_model = GptOssForCausalLM.from_pretrained(output_dir)
quantized_model.to("cuda")
loaded_model.to("cuda")
for n, m in quantized_model.named_modules():
if m.__class__.__name__ == "QuantLinear":
loaded_m = loaded_model.get_submodule(n)
assert (loaded_m.weight_packed == m.weight_packed).all()

inp = torch.randint(0, 100, (1, 64)).to("cuda")
with torch.inference_mode():
loaded_out = loaded_model(inp)

# clean the output directory after test
shutil.rmtree(output_dir, ignore_errors=True)


def test_llama4(setup_llama4):
model, tokenizer, output_dir, config = setup_llama4

# Below parameters are set to be same as the full model
# Remove them to avoid mismatch during quantized model loading
model.config.text_config.no_rope_layers = []
delattr(model.config.text_config, "moe_layers")
delattr(model.config.text_config, "layer_types")

quantized_model = quantize_model(model, tokenizer, output_dir, "MXFP4")

# Ensure the quantized model is not None
assert quantized_model is not None, "Quantized model should not be None."

loaded_model = Llama4ForConditionalGeneration.from_pretrained(output_dir)
quantized_model.to("cuda")
loaded_model.to("cuda")
for n, m in quantized_model.named_modules():
if m.__class__.__name__ == "QuantLinear":
loaded_m = loaded_model.get_submodule(n)
assert (loaded_m.weight_packed == m.weight_packed).all()

inp = torch.randint(0, 100, (1, 64)).to("cuda")
with torch.inference_mode():
loaded_out = loaded_model(inp)

# clean the output directory after test
shutil.rmtree(output_dir, ignore_errors=True)