diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index cd5b09bee..3a7c1f823 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1651,9 +1651,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) if self.device_map is not None: accelerate.hooks.remove_hook_from_submodules(block) - if ( - is_nv_fp(self.act_data_type) and any("nv_fp" in format_ for format_ in self.formats) - ) or is_static_wfp8afp8(self): + if is_nv_fp(self.act_data_type) or is_static_wfp8afp8(self): # enable moe experts act_max automatic generation for Linear set_amax_for_all_moe_layers(block, attr_name="act_max") # Normalize imatrix and quantize layers @@ -2911,11 +2909,7 @@ def _quantize_block( with torch.no_grad(): unwrapper_block(block, best_params) - if ( - is_nv_fp(self.act_data_type) - and hasattr(self, "formats") - and any("nv_fp" in format_ for format_ in self.formats) - ): + if is_nv_fp(self.act_data_type): # enable moe experts act_max automatic generation for WrapperWALayer set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") diff --git a/auto_round/utils.py b/auto_round/utils.py index 82b426186..0b819e381 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -2481,7 +2481,7 @@ def set_nested_attr(module, attr_name: str, value): attrs = attr_name.split(".") for attr in attrs[:-1]: if not hasattr(module, attr): - raise AttributeError(f"{module} has no attribute '{attr}'") + return None # No need to set act_max for fp layers module = getattr(module, attr) setattr(module, attrs[-1], value) @@ -2546,7 +2546,7 @@ def set_amax_for_all_moe_layers(model: torch.nn.Module, layer_name=None, attr_na # For other MoE models (like Mixtral) with iterable experts try: set_amax_for_uncalibrated_experts( - [getattr(expert, linear_name) for expert in sub_module.experts], attr_name=attr_name + [getattr(expert, linear_name, None) for expert in sub_module.experts], attr_name=attr_name ) except AttributeError as e: # Provide more helpful debugging information diff --git a/test/test_cpu/test_export.py b/test/test_cpu/test_export.py index 382f5651b..180fd8f2f 100644 --- a/test/test_cpu/test_export.py +++ b/test/test_cpu/test_export.py @@ -532,6 +532,45 @@ def test_nvfp4_autoround_save_quantized(self): ), "Illegal NVFP4 packing name or data_type or shape" shutil.rmtree("./saved", ignore_errors=True) + def test_nvfp4_moe_actmax_rtn(self): + model_name = "/tf_dataset/auto_round/models/deepseek-ai/DeepSeek-V2-Lite" + layer_config = { + "self_attn": {"bits": 16, "act_bits": 16}, + "mlp.shared_experts": {"bits": 16, "act_bits": 16}, + } + scheme = "nvfp4" + autoround = AutoRound( + model_name, + scheme=scheme, + iters=0, + seqlen=2, + nsamples=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + compressed_model, _ = autoround.quantize() + assert hasattr(compressed_model.model.layers[1].mlp.experts[0].gate_proj.orig_layer, "act_max") + + def test_nvfp4_moe_actmax_ar(self): + model_name = "/tf_dataset/auto_round/models/deepseek-ai/DeepSeek-V2-Lite" + layer_config = { + "q_proj": {"bits": 16, "act_bits": 16}, + "mlp.shared_experts": {"bits": 16, "act_bits": 16}, + "experts.*2": {"bits": 16, "act_bits": 16}, + "experts.*5": {"bits": 16, "act_bits": 16}, + } + scheme = "nvfp4" + autoround = AutoRound( + model_name, + scheme=scheme, + iters=1, + seqlen=2, + nsamples=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + autoround.quantize_and_save(output_dir=self.save_dir, inplace=True, format="auto_round") + if __name__ == "__main__": unittest.main() diff --git a/test/test_cuda/test_export.py b/test/test_cuda/test_export.py index e9699da52..0ab05134f 100644 --- a/test/test_cuda/test_export.py +++ b/test/test_cuda/test_export.py @@ -402,6 +402,36 @@ def test_nvfp4_llmcompressor_format(self): # if "France" in prompt: # assert "Paris" in generated_text + def test_nvfp4_moe_actmax_rtn(self): + model_name = "/data0/deepseek-ai/DeepSeek-V2-Lite" + scheme = "nvfp4" + autoround = AutoRound( + model_name, + scheme=scheme, + iters=0, + seqlen=2, + nsamples=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + quantized_model_path = self.save_dir + autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="auto_round") + + def test_nvfp4_moe_actmax_ar(self): + model_name = "/data0/deepseek-ai/DeepSeek-V2-Lite" + scheme = "nvfp4" + autoround = AutoRound( + model_name, + scheme=scheme, + iters=1, + seqlen=2, + nsamples=2, + dataset=self.llm_dataloader, + ) + autoround.quantize() + quantized_model_path = self.save_dir + autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="auto_round") + if __name__ == "__main__": unittest.main()