Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,11 +727,12 @@ def _load_state_dict_into_meta_model(
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])

is_quantized = hf_quantizer is not None
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
QuantizationMethod.TORCHAO,
}
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_ao
file_pointer = None
if is_meta_state_dict:
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
Expand Down Expand Up @@ -873,7 +874,7 @@ def load_shard_file(args):
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb,
is_hqq_or_bnb_or_ao,
is_quantized,
device_map,
hf_quantizer,
Expand All @@ -899,7 +900,7 @@ def load_shard_file(args):
map_location = "cpu"
if (
shard_file.endswith(".safetensors")
and not is_hqq_or_bnb
and not is_hqq_or_bnb_or_ao
and not (is_deepspeed_zero3_enabled() and not is_quantized)
):
map_location = "meta"
Expand All @@ -922,6 +923,13 @@ def load_shard_file(args):

# Fix the key names
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
metadata = None
if shard_file.endswith(".safetensors") and is_safetensors_available():
with safe_open(shard_file, framework="pt") as f:
metadata = f.metadata()

if hf_quantizer:
state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, metadata)

error_msgs = []

Expand Down Expand Up @@ -5277,9 +5285,10 @@ def _load_pretrained_model(
QuantizationMethod.HQQ,
QuantizationMethod.QUARK,
}
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
QuantizationMethod.HQQ,
QuantizationMethod.BITS_AND_BYTES,
QuantizationMethod.TORCHAO,
}

# Get all the keys of the state dicts that we have to initialize the model
Expand Down Expand Up @@ -5451,7 +5460,7 @@ def _load_pretrained_model(
shard_file,
state_dict,
disk_only_shard_files,
is_hqq_or_bnb,
is_hqq_or_bnb_or_ao,
is_quantized,
device_map,
hf_quantizer,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False):
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
return None, {}

def update_state_dict_with_metadata(self, state_dict, metadata):
"""Update state dict with metadata. Default behaviour returns state_dict"""
return state_dict

@abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs): ...

Expand Down
60 changes: 56 additions & 4 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@
import torch
import torch.nn as nn

if is_torchao_available():
import torchao

if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
from torchao.prototype.safetensors.safetensors_support import (
flatten_tensor_state_dict,
unflatten_tensor_state_dict,
)
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -81,6 +92,15 @@ def _linear_extra_repr(self):
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"


if is_torchao_available():
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
torchao.quantization.Float8WeightOnlyConfig,
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
]

TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))


class TorchAoHfQuantizer(HfQuantizer):
"""
Quantizer for torchao: https://github.com/pytorch/ao/
Expand Down Expand Up @@ -137,6 +157,21 @@ def update_dtype(self, dtype):
dtype = torch.float32
return dtype

def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool] = False):
"""
If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
the safetensors format.
"""
if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
if TORCHAO_VERSION >= version.parse("0.14.0"):
return flatten_tensor_state_dict(model.state_dict())
else:
raise RuntimeError(
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
)
else:
return super().get_state_dict_and_metadata(model)

def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
from accelerate.utils import CustomDtype
Expand Down Expand Up @@ -279,6 +314,16 @@ def create_quantized_param(

quantize_(module, self.quantization_config.get_apply_tensor_subclass())

def update_state_dict_with_metadata(self, state_dict, metadata):
"""
If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict
from the provided state_dict and metadata.
"""
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(metadata):
return unflatten_tensor_state_dict(state_dict, metadata)
else:
return super().update_state_dict_with_metadata(state_dict, metadata)

def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
if self.quantization_config.quant_type == "autoquant":
Expand All @@ -297,10 +342,17 @@ def _process_model_after_weight_loading(self, model, **kwargs):

def is_serializable(self, safe_serialization=None) -> bool:
if safe_serialization:
logger.warning(
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False"
)
return False
_is_torchao_serializable = type(
self.quantization_config.quant_type
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also precise the version in the warning

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

if not _is_torchao_serializable:
logger.warning(
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
and torchao version >= 0.14.0, please set `safe_serialization` to False for \
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
)
return _is_torchao_serializable

_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
"0.25.0"
)
Expand Down
62 changes: 54 additions & 8 deletions tests/quantization/torchao_integration/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest

from packaging import version
from parameterized import parameterized

from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from transformers.testing_utils import (
Expand All @@ -37,6 +38,8 @@
import torch

if is_torchao_available():
import torchao

# renamed in torchao 0.7.0, please install the latest torchao
from torchao.dtypes import (
AffineQuantizedTensor,
Expand Down Expand Up @@ -135,7 +138,7 @@ class TorchAoTest(unittest.TestCase):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
device = "cpu"
quant_scheme_kwargs = (
{"group_size": 32, "layout": Int4CPULayout()}
{"group_size": 32, "layout": Int4CPULayout(), "version": 1}
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
else {"group_size": 32}
)
Expand Down Expand Up @@ -225,6 +228,7 @@ def test_include_input_output_embeddings(self):
weight_dtype=weight_dtype,
granularity=granularity,
mapping_type=mapping_type,
version=1,
)
config = ModuleFqnToConfig(
{"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config}
Expand Down Expand Up @@ -277,7 +281,7 @@ def test_per_module_config_skip(self):
@require_torch_accelerator
class TorchAoAcceleratorTest(TorchAoTest):
device = torch_device
quant_scheme_kwargs = {"group_size": 32}
quant_scheme_kwargs = {"group_size": 32, "version": 1}

# called only once for all test in this class
@classmethod
Expand Down Expand Up @@ -327,7 +331,7 @@ def test_int4wo_offload(self):
"lm_head": 0,
}

quant_config = TorchAoConfig("int4_weight_only", group_size=32)
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)

quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
Expand Down Expand Up @@ -399,7 +403,7 @@ def test_autoquant(self):

check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)

EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)"
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be reverted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i double checked that this fails on main as well, so i just added the correction

output = quantized_model.generate(
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
)
Expand All @@ -414,7 +418,7 @@ class TorchAoSerializationTest(unittest.TestCase):
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
quant_scheme = "int4_weight_only"
quant_scheme_kwargs = (
{"group_size": 32, "layout": Int4CPULayout()}
{"group_size": 32, "layout": Int4CPULayout(), "version": 1}
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
else {"group_size": 32}
)
Expand Down Expand Up @@ -447,13 +451,13 @@ def test_original_model_expected_output(self):

self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def check_serialization_expected_output(self, device, expected_output):
def check_serialization_expected_output(self, device, expected_output, safe_serialization=False):
"""
Test if we can serialize and load/infer the model again on the same device
"""
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, dtype=dtype, device_map=device)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)

Expand All @@ -464,6 +468,48 @@ def test_serialization_expected_output(self):
self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT)


@require_torchao
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoSafeSerializationTest(TorchAoSerializationTest):
# called only once for all test in this class
@classmethod
def setUpClass(cls):
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"

def tearDown(self):
gc.collect()
backend_empty_cache(torch_device)
gc.collect()
if hasattr(self, "quantized_model"):
del self.quantized_model
gc.collect()

test_params = (
[
(
torchao.quantization.Float8DynamicActivationFloat8WeightConfig(),
"What are we having for dinner?\n\nJess: (smiling) I",
),
(torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
]
if is_torchao_available()
else []
)

@parameterized.expand(test_params, skip_on_empty=True)
def test_serialization_expected_output(self, config, expected_output):
device = "cuda"
self.quant_config = TorchAoConfig(config)
self.quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
dtype=torch.bfloat16,
device_map=device,
quantization_config=self.quant_config,
)
self.check_serialization_expected_output(device, expected_output, safe_serialization=True)


class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}

Expand Down Expand Up @@ -500,7 +546,7 @@ def test_serialization_expected_output_on_accelerator(self):

@require_torch_accelerator
class TorchAoSerializationAcceleratorTest(TorchAoSerializationTest):
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32, "version": 1}
device = f"{torch_device}:0"

# called only once for all test in this class
Expand Down