From e255abb109dd778d1f7230bbbb2462c1e3575d95 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 10 Oct 2025 03:52:14 -0400 Subject: [PATCH 01/10] add gpt oss Signed-off-by: yiliu30 --- auto_round/modelling/__init__.py | 13 +++ auto_round/modelling/gpt_oss.py | 145 ++++++++++++++++++++++++++++ auto_round/special_model_handler.py | 7 +- auto_round/utils.py | 2 +- 4 files changed, 165 insertions(+), 2 deletions(-) create mode 100644 auto_round/modelling/__init__.py create mode 100644 auto_round/modelling/gpt_oss.py diff --git a/auto_round/modelling/__init__.py b/auto_round/modelling/__init__.py new file mode 100644 index 000000000..0d2740cdb --- /dev/null +++ b/auto_round/modelling/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/auto_round/modelling/gpt_oss.py b/auto_round/modelling/gpt_oss.py new file mode 100644 index 000000000..628ab2499 --- /dev/null +++ b/auto_round/modelling/gpt_oss.py @@ -0,0 +1,145 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import gc +import os + +import torch +import transformers.models.gpt_oss as transformers_gpt_oss +from torch import nn +from transformers import AutoModelForCausalLM, AutoTokenizer + + +@contextlib.contextmanager +def align_module_device(module: torch.nn.Module): + device = next(module.parameters()).device + # return with torch.device(device) + try: + yield device + except: + pass + + +from transformers.modeling_utils import no_init_weights as skip_weights_initialize + + +def update_offload_parameter( + module: torch.nn.Module, + name: str, + data: torch.Tensor, +) -> None: + param: torch.nn.Parameter = getattr(module, name) + param.data.copy_(data) + + +def _get_top_k(config): + # GPT-OSS MoE: experts per token + return getattr(config, "num_experts_per_tok", None) or getattr(config, "num_experts_per_token", 1) + + +class GPTOSSMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, dtype=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.alpha = 1.702 + self.limit = 7.0 + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True, dtype=dtype) + + def forward(self, x): + gate = self.gate_proj(x) + up = self.up_proj(x) + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + act = (up + 1) * glu + return self.down_proj(act) + + +class SequentialGPTOSSMoE(nn.Module): + """ + Replaces GPT-OSS fused-expert MoE with per-expert GPTOSSMLP modules. + Copies weights from fused tensors and reuses the original router and optional shared_expert. + """ + + def __init__(self, config, original): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + dtype_str = getattr(config, "torch_dtype", None) or getattr(config, "dtype", None) + dtype = torch.bfloat16 if str(dtype_str).endswith("bfloat16") else torch.float32 + top_k = _get_top_k(config) + self.hidden_size = hidden_size + self.intermediate = intermediate_size + self.top_k = top_k + self.router = original.router + self.shared_expert = getattr(original, "shared_expert", None) + + # Number of experts + E = original.experts.gate_up_proj.shape[0] + self.num_experts = E + + # Build per-expert MLPs + self.experts = nn.ModuleList() + with skip_weights_initialize(), align_module_device(original.experts): + for _ in range(E): + self.experts.append(GPTOSSMLP(hidden_size, intermediate_size, dtype=dtype)) + + gup = original.experts.gate_up_proj # [E, H, 2I] + gup_b = original.experts.gate_up_proj_bias # [E, 2I] + dwn = original.experts.down_proj # [E, I, H] + dwn_b = original.experts.down_proj_bias # [E, H] + + with align_module_device(self.experts): + for i, mlp in enumerate(self.experts): + update_offload_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T) + update_offload_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T) + update_offload_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T) + + update_offload_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2]) + update_offload_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2]) + update_offload_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H] + + def forward(self, hidden_states): + B, T, H = hidden_states.shape + x = hidden_states.reshape(-1, H) + + # Use the original router (it returns scores and indices already softmaxed over top-k) + router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k] + + out = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x) + + # Accumulate expert outputs for chosen experts only + for j in range(self.top_k): + idx = router_indices[:, j] + w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1) + unique_experts = torch.unique(idx) + for e in unique_experts: + mask = idx == e + out[mask] += self.experts[e](x[mask]) * w[mask] + + out = out.view(B, T, H) + router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder + return out, router_scores + + +def get_replacement_info(config): + return ( + SequentialGPTOSSMoE, + config.get_text_config(), + transformers_gpt_oss.modeling_gpt_oss.GptOssMLP.__name__, + ) diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 69e2932c7..f05c00941 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -36,13 +36,18 @@ } SPECIAL_SHARED_CACHE_KEYS["MiniMaxText01ForCausalLM"] = ("slope_rate",) -CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4"] +CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4", "gpt_oss"] def _get_moe_converter(config): import torch from transformers.modeling_utils import no_init_weights + if config.model_type == "gpt_oss": + from auto_round.modelling.gpt_oss import get_replacement_info + + return get_replacement_info(config) + # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py if config.model_type == "llama4": from transformers.models.llama4.modeling_llama4 import Llama4TextMLP diff --git a/auto_round/utils.py b/auto_round/utils.py index 26ec5f996..d4ac23b91 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -1083,7 +1083,7 @@ def get_fp_layer_names(model, fp_layers): for name in all_layer_names: if fp_layer in name: not_to_quantized_layers.append(name) - + logger.trace(f"not_to_quantized_layers: {not_to_quantized_layers}") return not_to_quantized_layers From 4340b35f378466c8c7e78082ccece0df415192e3 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 10 Oct 2025 03:56:43 -0400 Subject: [PATCH 02/10] refine code Signed-off-by: yiliu30 --- auto_round/modelling/gpt_oss.py | 38 ++++++++++++++++----------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/auto_round/modelling/gpt_oss.py b/auto_round/modelling/gpt_oss.py index 628ab2499..52a6f0bf9 100644 --- a/auto_round/modelling/gpt_oss.py +++ b/auto_round/modelling/gpt_oss.py @@ -13,13 +13,15 @@ # limitations under the License. import contextlib -import gc -import os import torch import transformers.models.gpt_oss as transformers_gpt_oss from torch import nn -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.modeling_utils import no_init_weights as skip_weights_initialize +from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP + +__all__ = ["get_replacement_info"] @contextlib.contextmanager @@ -32,15 +34,12 @@ def align_module_device(module: torch.nn.Module): pass -from transformers.modeling_utils import no_init_weights as skip_weights_initialize - - -def update_offload_parameter( +def _update_parameter( module: torch.nn.Module, name: str, data: torch.Tensor, ) -> None: - param: torch.nn.Parameter = getattr(module, name) + param = getattr(module, name) param.data.copy_(data) @@ -49,7 +48,7 @@ def _get_top_k(config): return getattr(config, "num_experts_per_tok", None) or getattr(config, "num_experts_per_token", 1) -class GPTOSSMLP(nn.Module): +class _GPTOSSMLP(nn.Module): def __init__(self, hidden_size, intermediate_size, dtype=None): super().__init__() self.hidden_size = hidden_size @@ -72,7 +71,7 @@ def forward(self, x): class SequentialGPTOSSMoE(nn.Module): """ - Replaces GPT-OSS fused-expert MoE with per-expert GPTOSSMLP modules. + Replaces GPT-OSS fused-expert MoE with per-expert _GPTOSSMLP modules. Copies weights from fused tensors and reuses the original router and optional shared_expert. """ @@ -97,22 +96,21 @@ def __init__(self, config, original): self.experts = nn.ModuleList() with skip_weights_initialize(), align_module_device(original.experts): for _ in range(E): - self.experts.append(GPTOSSMLP(hidden_size, intermediate_size, dtype=dtype)) + self.experts.append(_GPTOSSMLP(hidden_size, intermediate_size, dtype=dtype)) gup = original.experts.gate_up_proj # [E, H, 2I] gup_b = original.experts.gate_up_proj_bias # [E, 2I] dwn = original.experts.down_proj # [E, I, H] dwn_b = original.experts.down_proj_bias # [E, H] - with align_module_device(self.experts): - for i, mlp in enumerate(self.experts): - update_offload_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T) - update_offload_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T) - update_offload_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T) + for i, mlp in enumerate(self.experts): + _update_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T) + _update_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T) + _update_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T) - update_offload_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2]) - update_offload_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2]) - update_offload_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H] + _update_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2]) + _update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2]) + _update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H] def forward(self, hidden_states): B, T, H = hidden_states.shape @@ -141,5 +139,5 @@ def get_replacement_info(config): return ( SequentialGPTOSSMoE, config.get_text_config(), - transformers_gpt_oss.modeling_gpt_oss.GptOssMLP.__name__, + GptOssMLP.__name__, ) From 18827332141d3bb4bb927ef51faac50cad6a02da Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 10 Oct 2025 04:06:11 -0400 Subject: [PATCH 03/10] refator llama4 Signed-off-by: yiliu30 --- auto_round/modelling/gpt_oss.py | 15 ++---- auto_round/modelling/llama4.py | 73 +++++++++++++++++++++++++++++ auto_round/special_model_handler.py | 59 ++--------------------- 3 files changed, 81 insertions(+), 66 deletions(-) create mode 100644 auto_round/modelling/llama4.py diff --git a/auto_round/modelling/gpt_oss.py b/auto_round/modelling/gpt_oss.py index 52a6f0bf9..e576f7a8a 100644 --- a/auto_round/modelling/gpt_oss.py +++ b/auto_round/modelling/gpt_oss.py @@ -43,12 +43,7 @@ def _update_parameter( param.data.copy_(data) -def _get_top_k(config): - # GPT-OSS MoE: experts per token - return getattr(config, "num_experts_per_tok", None) or getattr(config, "num_experts_per_token", 1) - - -class _GPTOSSMLP(nn.Module): +class GPTOssSingleExpert(nn.Module): def __init__(self, hidden_size, intermediate_size, dtype=None): super().__init__() self.hidden_size = hidden_size @@ -71,17 +66,17 @@ def forward(self, x): class SequentialGPTOSSMoE(nn.Module): """ - Replaces GPT-OSS fused-expert MoE with per-expert _GPTOSSMLP modules. + Replaces GPT-OSS fused-expert MoE with per-expert `GPTOssSingleExpert` modules. Copies weights from fused tensors and reuses the original router and optional shared_expert. """ - def __init__(self, config, original): + def __init__(self, config: GptOssConfig, original: GptOssMLP): super().__init__() hidden_size = config.hidden_size intermediate_size = config.intermediate_size dtype_str = getattr(config, "torch_dtype", None) or getattr(config, "dtype", None) dtype = torch.bfloat16 if str(dtype_str).endswith("bfloat16") else torch.float32 - top_k = _get_top_k(config) + top_k = config.num_experts_per_tok self.hidden_size = hidden_size self.intermediate = intermediate_size self.top_k = top_k @@ -96,7 +91,7 @@ def __init__(self, config, original): self.experts = nn.ModuleList() with skip_weights_initialize(), align_module_device(original.experts): for _ in range(E): - self.experts.append(_GPTOSSMLP(hidden_size, intermediate_size, dtype=dtype)) + self.experts.append(GPTOssSingleExpert(hidden_size, intermediate_size, dtype=dtype)) gup = original.experts.gate_up_proj # [E, H, 2I] gup_b = original.experts.gate_up_proj_bias # [E, 2I] diff --git a/auto_round/modelling/llama4.py b/auto_round/modelling/llama4.py new file mode 100644 index 000000000..02cb97cc1 --- /dev/null +++ b/auto_round/modelling/llama4.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from transformers.modeling_utils import no_init_weights +from transformers.models.llama4.modeling_llama4 import Llama4TextMLP + + +class SequentialLlama4TextExperts(torch.nn.ModuleList): + def __init__(self, config, original): + self.num_experts = original.gate_up_proj.shape[0] + with no_init_weights(): + super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) + intermediate_size = original.down_proj.shape[1] + + for i in range(self.num_experts): + gate_up = original.gate_up_proj[i] + down = original.down_proj[i] + gate_proj = gate_up[:, :intermediate_size] + up_proj = gate_up[:, intermediate_size:] + + self[i].gate_proj.weight.data = gate_proj.t().contiguous() + self[i].up_proj.weight.data = up_proj.t().contiguous() + self[i].down_proj.weight.data = down.t().contiguous() + + +class SequentialLlama4TextMoe(torch.nn.Module): + def __init__(self, config, original): + super().__init__() + self.top_k = config.num_experts_per_tok + self.hidden_dim = config.hidden_size + self.num_experts = config.num_local_experts + self.experts = SequentialLlama4TextExperts(config, original.experts) + self.router = original.router + self.shared_expert = original.shared_expert + + def forward(self, hidden_states: torch.Tensor): + hidden_states = hidden_states.reshape(-1, self.hidden_dim) + router_logits = self.router(hidden_states) + if isinstance(router_logits, tuple): + router_scores, router_logits = router_logits + router_scores = router_scores.t() + else: + # transformers < 4.54.0 only returns router_logits + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + + router_scores = ( + torch.full_like(router_logits, float("-inf")) + .scatter_(1, router_indices, router_top_value) + .transpose(0, 1) + ) + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + + out = self.shared_expert(hidden_states) + for i in range(self.num_experts): + out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) + + return out, router_logits + + +def get_replacement_info(config): + return SequentialLlama4TextMoe, config.get_text_config(), "Llama4TextMoe" diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index f05c00941..7eae43b93 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -40,8 +40,6 @@ def _get_moe_converter(config): - import torch - from transformers.modeling_utils import no_init_weights if config.model_type == "gpt_oss": from auto_round.modelling.gpt_oss import get_replacement_info @@ -49,61 +47,10 @@ def _get_moe_converter(config): return get_replacement_info(config) # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py - if config.model_type == "llama4": - from transformers.models.llama4.modeling_llama4 import Llama4TextMLP - - class SequentialLlama4TextExperts(torch.nn.ModuleList): - def __init__(self, config, original): - self.num_experts = original.gate_up_proj.shape[0] - with no_init_weights(): - super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) - intermediate_size = original.down_proj.shape[1] - - for i in range(self.num_experts): - gate_up = original.gate_up_proj[i] - down = original.down_proj[i] - gate_proj = gate_up[:, :intermediate_size] - up_proj = gate_up[:, intermediate_size:] - - self[i].gate_proj.weight.data = gate_proj.t().contiguous() - self[i].up_proj.weight.data = up_proj.t().contiguous() - self[i].down_proj.weight.data = down.t().contiguous() - - class SequentialLlama4TextMoe(torch.nn.Module): - def __init__(self, config, original): - super().__init__() - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size - self.num_experts = config.num_local_experts - self.experts = SequentialLlama4TextExperts(config, original.experts) - self.router = original.router - self.shared_expert = original.shared_expert - - def forward(self, hidden_states: torch.Tensor): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) - if isinstance(router_logits, tuple): - router_scores, router_logits = router_logits - router_scores = router_scores.t() - else: - # transformers < 4.54.0 only returns router_logits - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) - - router_scores = ( - torch.full_like(router_logits, float("-inf")) - .scatter_(1, router_indices, router_top_value) - .transpose(0, 1) - ) - router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - - out = self.shared_expert(hidden_states) - for i in range(self.num_experts): - out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) - - return out, router_logits - - return SequentialLlama4TextMoe, config.get_text_config(), "Llama4TextMoe" + elif config.model_type == "llama4": + from auto_round.modelling.llama4 import get_replacement_info + return get_replacement_info(config) else: raise ValueError(f"Currently moe converter only supports llama4 model_type, but get {config.model_type}") From eb55c5481fe025831871fc80c727b99b04ea7779 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 10 Oct 2025 04:07:39 -0400 Subject: [PATCH 04/10] clean Signed-off-by: yiliu30 --- auto_round/modelling/gpt_oss.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/auto_round/modelling/gpt_oss.py b/auto_round/modelling/gpt_oss.py index e576f7a8a..f6f8ec2d4 100644 --- a/auto_round/modelling/gpt_oss.py +++ b/auto_round/modelling/gpt_oss.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib import torch -import transformers.models.gpt_oss as transformers_gpt_oss from torch import nn from transformers.modeling_utils import no_init_weights as skip_weights_initialize from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig @@ -24,16 +22,6 @@ __all__ = ["get_replacement_info"] -@contextlib.contextmanager -def align_module_device(module: torch.nn.Module): - device = next(module.parameters()).device - # return with torch.device(device) - try: - yield device - except: - pass - - def _update_parameter( module: torch.nn.Module, name: str, @@ -89,7 +77,8 @@ def __init__(self, config: GptOssConfig, original: GptOssMLP): # Build per-expert MLPs self.experts = nn.ModuleList() - with skip_weights_initialize(), align_module_device(original.experts): + target_device = next(original.experts.parameters()).device + with skip_weights_initialize(), torch.device(target_device): for _ in range(E): self.experts.append(GPTOssSingleExpert(hidden_size, intermediate_size, dtype=dtype)) From a4bd97fd80961655e80563cab5e1f42999fa8667 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 10 Oct 2025 04:08:31 -0400 Subject: [PATCH 05/10] fix Signed-off-by: yiliu30 --- auto_round/modelling/llama4.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/auto_round/modelling/llama4.py b/auto_round/modelling/llama4.py index 02cb97cc1..c6cc7ea91 100644 --- a/auto_round/modelling/llama4.py +++ b/auto_round/modelling/llama4.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +__all__ = ["get_replacement_info"] + + import torch from transformers.modeling_utils import no_init_weights from transformers.models.llama4.modeling_llama4 import Llama4TextMLP From 2b9c0150c1a19646dd95f7f9c5c99aeed6f7f9e7 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Fri, 10 Oct 2025 04:15:17 -0400 Subject: [PATCH 06/10] refine code Signed-off-by: yiliu30 --- auto_round/modelling/llama4.py | 1 + auto_round/special_model_handler.py | 28 +++++++++++++++------------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/auto_round/modelling/llama4.py b/auto_round/modelling/llama4.py index c6cc7ea91..f7e85f15b 100644 --- a/auto_round/modelling/llama4.py +++ b/auto_round/modelling/llama4.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Note: adapted from # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py __all__ = ["get_replacement_info"] diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index 7eae43b93..de5f8b2a4 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from auto_round.utils import logger +import auto_round.modelling as auto_round_modelling +from auto_round.utils import LazyImport, logger mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size @@ -40,19 +41,20 @@ def _get_moe_converter(config): - - if config.model_type == "gpt_oss": - from auto_round.modelling.gpt_oss import get_replacement_info - - return get_replacement_info(config) - - # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py - elif config.model_type == "llama4": - from auto_round.modelling.llama4 import get_replacement_info - - return get_replacement_info(config) + # Dispatch table for model_type to replacement_info functions + moe_converters = { + "gpt_oss": LazyImport("auto_round.modelling.gpt_oss.get_replacement_info"), + "llama4": LazyImport("auto_round.modelling.llama4.get_replacement_info"), + } + + # Retrieve the appropriate function based on model_type + if config.model_type in moe_converters: + return moe_converters[config.model_type](config) else: - raise ValueError(f"Currently moe converter only supports llama4 model_type, but get {config.model_type}") + raise ValueError( + f"Unsupported model_type '{config.model_type}'. " + f"Currently, MoE converter only supports: {', '.join(moe_converters.keys())}." + ) def _handle_special_model(model): From 30a560e41e2efe02944d3bc8dfbf10c47ab24e83 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 12 Oct 2025 22:21:26 -0400 Subject: [PATCH 07/10] add ut Signed-off-by: yiliu30 --- test/test_cpu/test_gpt_oss.py | 70 +++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 test/test_cpu/test_gpt_oss.py diff --git a/test/test_cpu/test_gpt_oss.py b/test/test_cpu/test_gpt_oss.py new file mode 100644 index 000000000..34031aa8f --- /dev/null +++ b/test/test_cpu/test_gpt_oss.py @@ -0,0 +1,70 @@ +import pytest +from transformers import AutoConfig, AutoTokenizer +from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + +from auto_round import AutoRound + + +@pytest.fixture +def setup_gpt_oss(): + """Fixture to set up the GPT-OSS model and tokenizer.""" + model_name = "/data5/yliu7/HF_HOME/unsloth/gpt-oss-20b-BF16/" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + config.num_hidden_layers = 1 # Reduce layers for testing + model = GptOssForCausalLM(config) + output_dir = "/tmp/test_quantized_gpt_oss" + return model, tokenizer, output_dir + + +def quantize_model(model, tokenizer, output_dir, scheme, iters=0): + """Helper function to quantize the model with the given scheme.""" + autoround = AutoRound( + model, + tokenizer, + scheme=scheme, + nsamples=2, + iters=iters, + fp_layers="self_attn,router,lm_head,mlp.gate", + ) + quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) + return quantized_model + + +def count_modules_by_type(model, target_module_name_or_class): + """Helper function to count modules of a specific type in the model.""" + cnt = 0 + for name, module in model.named_modules(): + if isinstance(target_module_name_or_class, str): + if target_module_name_or_class == module.__class__.__name__: + cnt += 1 + else: + if isinstance(module, target_module_name_or_class): + cnt += 1 + return cnt + + +@pytest.mark.parametrize("scheme", ["MXFP4", "MXFP8"]) +@pytest.mark.parametrize("quantize_model", [0, 4]) +def test_quantization_with_mxfp4(setup_gpt_oss, scheme): + """Test quantization with the scheme.""" + model, tokenizer, output_dir = setup_gpt_oss + quantized_model = quantize_model(model, tokenizer, output_dir, scheme) + + # Ensure the quantized model is not None + assert quantized_model is not None, "Quantized model should not be None." + + # Count specific modules + single_expert_cnt = count_modules_by_type(quantized_model, "GPTOssSingleExpert") + quant_linear_cnt = count_modules_by_type(quantized_model, "QuantLinear") + + # Assertions + assert single_expert_cnt >= 0, "GPTOssSingleExpert count should be non-negative." + assert quant_linear_cnt >= 0, "QuantLinear count should be non-negative." + + print(f"[{scheme}] Total GPTOssSingleExpert modules: {single_expert_cnt}") + print(f"[{scheme}] Total QuantLinear modules: {quant_linear_cnt}") + # clean the output directory after test + import shutil + + shutil.rmtree(output_dir, ignore_errors=True) From 6707c348afbbef869f8f77a2cea990e0ca0a3e27 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 12 Oct 2025 22:53:21 -0400 Subject: [PATCH 08/10] fix ut Signed-off-by: yiliu30 --- test/test_cpu/test_gpt_oss.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/test/test_cpu/test_gpt_oss.py b/test/test_cpu/test_gpt_oss.py index 34031aa8f..6e7ceef13 100644 --- a/test/test_cpu/test_gpt_oss.py +++ b/test/test_cpu/test_gpt_oss.py @@ -14,7 +14,7 @@ def setup_gpt_oss(): config.num_hidden_layers = 1 # Reduce layers for testing model = GptOssForCausalLM(config) output_dir = "/tmp/test_quantized_gpt_oss" - return model, tokenizer, output_dir + return model, tokenizer, output_dir, config def quantize_model(model, tokenizer, output_dir, scheme, iters=0): @@ -45,25 +45,27 @@ def count_modules_by_type(model, target_module_name_or_class): @pytest.mark.parametrize("scheme", ["MXFP4", "MXFP8"]) -@pytest.mark.parametrize("quantize_model", [0, 4]) -def test_quantization_with_mxfp4(setup_gpt_oss, scheme): +def test_quantization(setup_gpt_oss, scheme): """Test quantization with the scheme.""" - model, tokenizer, output_dir = setup_gpt_oss + model, tokenizer, output_dir, config = setup_gpt_oss quantized_model = quantize_model(model, tokenizer, output_dir, scheme) # Ensure the quantized model is not None assert quantized_model is not None, "Quantized model should not be None." + from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear + from auto_round.modelling.gpt_oss import GPTOssSingleExpert - # Count specific modules - single_expert_cnt = count_modules_by_type(quantized_model, "GPTOssSingleExpert") - quant_linear_cnt = count_modules_by_type(quantized_model, "QuantLinear") + single_expert_cnt = count_modules_by_type(quantized_model, GPTOssSingleExpert) + quant_linear_cnt = count_modules_by_type(quantized_model, QuantLinear) + assert ( + single_expert_cnt == config.num_local_experts + ), f"Expected {config.num_local_experts} GPTOssSingleExpert modules, found {single_expert_cnt}." + assert ( + quant_linear_cnt == config.num_hidden_layers * 3 * config.num_local_experts + ), f"Expected {config.num_hidden_layers * 3 * config.num_local_experts} QuantLinear modules, found {quant_linear_cnt}." - # Assertions - assert single_expert_cnt >= 0, "GPTOssSingleExpert count should be non-negative." - assert quant_linear_cnt >= 0, "QuantLinear count should be non-negative." - - print(f"[{scheme}] Total GPTOssSingleExpert modules: {single_expert_cnt}") - print(f"[{scheme}] Total QuantLinear modules: {quant_linear_cnt}") + print(f"[{scheme}] Total {GPTOssSingleExpert.__name__} modules: {single_expert_cnt}") + print(f"[{scheme}] Total {QuantLinear.__name__} modules: {quant_linear_cnt}") # clean the output directory after test import shutil From 03272f32ddf0c28d12e1eabd6665445618450d4e Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Sun, 12 Oct 2025 22:57:10 -0400 Subject: [PATCH 09/10] fix Signed-off-by: yiliu30 --- test/test_cpu/test_gpt_oss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_cpu/test_gpt_oss.py b/test/test_cpu/test_gpt_oss.py index 6e7ceef13..546818507 100644 --- a/test/test_cpu/test_gpt_oss.py +++ b/test/test_cpu/test_gpt_oss.py @@ -8,7 +8,8 @@ @pytest.fixture def setup_gpt_oss(): """Fixture to set up the GPT-OSS model and tokenizer.""" - model_name = "/data5/yliu7/HF_HOME/unsloth/gpt-oss-20b-BF16/" + model_name = "unsloth/gpt-oss-20b-BF16" + # model_name = "/data5/yliu7/HF_HOME/unsloth/gpt-oss-20b-BF16/" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) config.num_hidden_layers = 1 # Reduce layers for testing From 595ebfbfedbab01906032acfb4110313203bfd39 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Mon, 13 Oct 2025 21:40:54 -0400 Subject: [PATCH 10/10] fix Signed-off-by: yiliu30 --- auto_round/modelling/gpt_oss.py | 6 +++--- test/test_cpu/test_gpt_oss.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/auto_round/modelling/gpt_oss.py b/auto_round/modelling/gpt_oss.py index f6f8ec2d4..78f73075c 100644 --- a/auto_round/modelling/gpt_oss.py +++ b/auto_round/modelling/gpt_oss.py @@ -32,7 +32,7 @@ def _update_parameter( class GPTOssSingleExpert(nn.Module): - def __init__(self, hidden_size, intermediate_size, dtype=None): + def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype | None = None): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size @@ -42,7 +42,7 @@ def __init__(self, hidden_size, intermediate_size, dtype=None): self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True, dtype=dtype) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: gate = self.gate_proj(x) up = self.up_proj(x) gate = gate.clamp(max=self.limit) @@ -96,7 +96,7 @@ def __init__(self, config: GptOssConfig, original: GptOssMLP): _update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2]) _update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H] - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: B, T, H = hidden_states.shape x = hidden_states.reshape(-1, H) diff --git a/test/test_cpu/test_gpt_oss.py b/test/test_cpu/test_gpt_oss.py index 546818507..ccc997eba 100644 --- a/test/test_cpu/test_gpt_oss.py +++ b/test/test_cpu/test_gpt_oss.py @@ -8,8 +8,7 @@ @pytest.fixture def setup_gpt_oss(): """Fixture to set up the GPT-OSS model and tokenizer.""" - model_name = "unsloth/gpt-oss-20b-BF16" - # model_name = "/data5/yliu7/HF_HOME/unsloth/gpt-oss-20b-BF16/" + model_name = "/tf_dataset/auto_round/models/unsloth/gpt-oss-20b-BF16" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) config.num_hidden_layers = 1 # Reduce layers for testing