diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 344c9e3534ed..bb8b6673ac01 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -297,6 +297,8 @@ def create_quantized_param( # handle ModuleFqnToConfig, introduced in torchao 0.12.0+ if self.quantization_config._get_ao_version() >= version.Version("0.12.0"): + import re + from torchao.quantization import ModuleFqnToConfig config = self.quantization_config.get_apply_tensor_subclass() @@ -306,7 +308,16 @@ def create_quantized_param( if module_fqn in config.module_fqn_to_config: c = config.module_fqn_to_config[module_fqn] else: - c = config.module_fqn_to_config.get("_default", None) + for maybe_module_fqn_pattern in config.module_fqn_to_config: + if not maybe_module_fqn_pattern.startswith("re:"): + continue + elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn): + # we'll apply the config for first fully matched pattern + c = config.module_fqn_to_config[maybe_module_fqn_pattern] + break + else: + c = config.module_fqn_to_config.get("_default", None) + if c is not None: # filter_fn: not filtering out any modules quantize_(module, c, filter_fn=lambda x, fqn: True) diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 1ddc2de0801f..0d1ace7ede2a 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -46,6 +46,8 @@ TensorCoreTiledLayout, ) from torchao.quantization import ( + Float8Tensor, + Float8WeightOnlyConfig, Int8WeightOnlyConfig, IntxWeightOnlyConfig, MappingType, @@ -278,6 +280,89 @@ def test_per_module_config_skip(self): self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + @require_torchao_version_greater_or_equal("0.13.0") + def test_module_fqn_to_config_regex_basic(self): + linear_config = Int8WeightOnlyConfig() + config = ModuleFqnToConfig({"_default": linear_config, r"re:model\.layers\..+\.self_attn\.q_proj": None}) + quant_config = TorchAoConfig(quant_type=config) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device, + quantization_config=quant_config, + ) + # making sure `model.layers.0.self_attn.q_proj` is skipped + self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor)) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = [ + "What are we having for dinner?\n\nJessica: (smiling)", + "What are we having for dinner?\n\nJess: (smiling) I", + ] + self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + + @require_torchao_version_greater_or_equal("0.13.0") + def test_module_fqn_to_config_regex_fullmatch(self): + """Testing that we will only match the fqns that fully + matches the regex + """ + linear1_config = Int8WeightOnlyConfig() + linear2_config = Float8WeightOnlyConfig() + # intentially removing `j` after `q_proj` so it's not a full match + config = ModuleFqnToConfig({r"re:model\.layers\.+\.self_attn\.q_pro": linear1_config, "model.layers.3.self_attn.q_proj": linear2_config}) + quant_config = TorchAoConfig(quant_type=config) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device, + quantization_config=quant_config, + ) + # highest precedence is fully specified module fqn + self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) + # because regex `model\.layers\.+*\.self_attn\.q_pro` didin't fully match `model.layers.1.self_attn.q_proj` (missing last `j`) + # this layer is not expected to be quantized to int8 + self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = [ + "What are we having for dinner?\n\nJessica: (smiling)", + "What are we having for dinner?\n\nJess: (smiling) I", + ] + self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + + @require_torchao_version_greater_or_equal("0.13.0") + def test_module_fqn_to_config_regex_precedence(self): + linear1_config = Int8WeightOnlyConfig() + linear2_config = Float8WeightOnlyConfig() + config = ModuleFqnToConfig({r"re:model\.layers\..+\.self_attn\.q_proj": None, "model.layers.3.self_attn.q_proj": linear2_config, "_default": linear1_config}) + quant_config = TorchAoConfig(quant_type=config) + quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, + device_map=self.device, + quantization_config=quant_config, + ) + # highest precedence is fully specified module fqn + self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor)) + # second precedence: regex + self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor)) + # last precedence: _default + self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor)) + tokenizer = AutoTokenizer.from_pretrained(self.model_name) + + input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device) + + output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + EXPECTED_OUTPUT = [ + "What are we having for dinner?\n\nJessica: (smiling)", + "What are we having for dinner?\n\nJess: (smiling) I", + ] + self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT) + + @require_torch_accelerator class TorchAoAcceleratorTest(TorchAoTest): device = torch_device