Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def create_quantized_param(

# handle ModuleFqnToConfig, introduced in torchao 0.12.0+
if self.quantization_config._get_ao_version() >= version.Version("0.12.0"):
import re

from torchao.quantization import ModuleFqnToConfig

config = self.quantization_config.get_apply_tensor_subclass()
Expand All @@ -306,7 +308,16 @@ def create_quantized_param(
if module_fqn in config.module_fqn_to_config:
c = config.module_fqn_to_config[module_fqn]
else:
c = config.module_fqn_to_config.get("_default", None)
for maybe_module_fqn_pattern in config.module_fqn_to_config:
if not maybe_module_fqn_pattern.startswith("re:"):
continue
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
# we'll apply the config for first fully matched pattern
c = config.module_fqn_to_config[maybe_module_fqn_pattern]
break
else:
c = config.module_fqn_to_config.get("_default", None)

if c is not None:
# filter_fn: not filtering out any modules
quantize_(module, c, filter_fn=lambda x, fqn: True)
Expand Down
85 changes: 85 additions & 0 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
TensorCoreTiledLayout,
)
from torchao.quantization import (
Float8Tensor,
Float8WeightOnlyConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
MappingType,
Expand Down Expand Up @@ -278,6 +280,89 @@ def test_per_module_config_skip(self):
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)


@require_torchao_version_greater_or_equal("0.13.0")
def test_module_fqn_to_config_regex_basic(self):
linear_config = Int8WeightOnlyConfig()
config = ModuleFqnToConfig({"_default": linear_config, r"re:model\.layers\..+\.self_attn\.q_proj": None})
quant_config = TorchAoConfig(quant_type=config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# making sure `model.layers.0.self_attn.q_proj` is skipped
self.assertTrue(not isinstance(quantized_model.model.layers[0].self_attn.q_proj.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)

@require_torchao_version_greater_or_equal("0.13.0")
def test_module_fqn_to_config_regex_fullmatch(self):
"""Testing that we will only match the fqns that fully
matches the regex
"""
linear1_config = Int8WeightOnlyConfig()
linear2_config = Float8WeightOnlyConfig()
# intentially removing `j` after `q_proj` so it's not a full match
config = ModuleFqnToConfig({r"re:model\.layers\.+\.self_attn\.q_pro": linear1_config, "model.layers.3.self_attn.q_proj": linear2_config})
quant_config = TorchAoConfig(quant_type=config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# highest precedence is fully specified module fqn
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor))
# because regex `model\.layers\.+*\.self_attn\.q_pro` didin't fully match `model.layers.1.self_attn.q_proj` (missing last `j`)
# this layer is not expected to be quantized to int8
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)

@require_torchao_version_greater_or_equal("0.13.0")
def test_module_fqn_to_config_regex_precedence(self):
linear1_config = Int8WeightOnlyConfig()
linear2_config = Float8WeightOnlyConfig()
config = ModuleFqnToConfig({r"re:model\.layers\..+\.self_attn\.q_proj": None, "model.layers.3.self_attn.q_proj": linear2_config, "_default": linear1_config})
quant_config = TorchAoConfig(quant_type=config)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device,
quantization_config=quant_config,
)
# highest precedence is fully specified module fqn
self.assertTrue(isinstance(quantized_model.model.layers[3].self_attn.q_proj.weight, Float8Tensor))
# second precedence: regex
self.assertTrue(not isinstance(quantized_model.model.layers[1].self_attn.q_proj.weight, AffineQuantizedTensor))
# last precedence: _default
self.assertTrue(isinstance(quantized_model.model.layers[1].self_attn.k_proj.weight, AffineQuantizedTensor))
tokenizer = AutoTokenizer.from_pretrained(self.model_name)

input_ids = tokenizer(self.input_text, return_tensors="pt").to(self.device)

output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
EXPECTED_OUTPUT = [
"What are we having for dinner?\n\nJessica: (smiling)",
"What are we having for dinner?\n\nJess: (smiling) I",
]
self.assertTrue(tokenizer.decode(output[0], skip_special_tokens=True) in EXPECTED_OUTPUT)


@require_torch_accelerator
class TorchAoAcceleratorTest(TorchAoTest):
device = torch_device
Expand Down