-
Notifications
You must be signed in to change notification settings - Fork 30.7k
[torchao safetensors] integrate torchao safetensors support with transformers #40735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,17 @@ | |
import torch | ||
import torch.nn as nn | ||
|
||
if is_torchao_available(): | ||
import torchao | ||
|
||
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"): | ||
from torchao.prototype.safetensors.safetensors_support import ( | ||
flatten_tensor_state_dict, | ||
unflatten_tensor_state_dict, | ||
) | ||
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
|
@@ -81,6 +92,15 @@ def _linear_extra_repr(self): | |
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}" | ||
|
||
|
||
if is_torchao_available(): | ||
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [ | ||
torchao.quantization.Float8WeightOnlyConfig, | ||
torchao.quantization.Float8DynamicActivationFloat8WeightConfig, | ||
] | ||
|
||
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao")) | ||
|
||
|
||
class TorchAoHfQuantizer(HfQuantizer): | ||
""" | ||
Quantizer for torchao: https://github.com/pytorch/ao/ | ||
|
@@ -137,6 +157,21 @@ def update_dtype(self, dtype): | |
dtype = torch.float32 | ||
return dtype | ||
|
||
def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool] = False): | ||
""" | ||
If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with | ||
the safetensors format. | ||
""" | ||
if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization: | ||
if TORCHAO_VERSION >= version.parse("0.14.0"): | ||
return flatten_tensor_state_dict(model.state_dict()) | ||
else: | ||
raise RuntimeError( | ||
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}" | ||
) | ||
else: | ||
return super().get_state_dict_and_metadata(model) | ||
|
||
def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype": | ||
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): | ||
from accelerate.utils import CustomDtype | ||
|
@@ -279,6 +314,16 @@ def create_quantized_param( | |
|
||
quantize_(module, self.quantization_config.get_apply_tensor_subclass()) | ||
|
||
def update_state_dict_with_metadata(self, state_dict, metadata): | ||
""" | ||
If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict | ||
from the provided state_dict and metadata. | ||
""" | ||
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(metadata): | ||
return unflatten_tensor_state_dict(state_dict, metadata) | ||
else: | ||
return super().update_state_dict_with_metadata(state_dict, metadata) | ||
|
||
def _process_model_after_weight_loading(self, model, **kwargs): | ||
"""No process required for torchao quantized model""" | ||
if self.quantization_config.quant_type == "autoquant": | ||
|
@@ -297,10 +342,17 @@ def _process_model_after_weight_loading(self, model, **kwargs): | |
|
||
def is_serializable(self, safe_serialization=None) -> bool: | ||
if safe_serialization: | ||
logger.warning( | ||
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False" | ||
) | ||
return False | ||
_is_torchao_serializable = type( | ||
self.quantization_config.quant_type | ||
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also precise the version in the warning There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed! |
||
if not _is_torchao_serializable: | ||
logger.warning( | ||
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \ | ||
and torchao version >= 0.14.0, please set `safe_serialization` to False for \ | ||
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}." | ||
) | ||
return _is_torchao_serializable | ||
|
||
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( | ||
"0.25.0" | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import unittest | ||
|
||
from packaging import version | ||
from parameterized import parameterized | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig | ||
from transformers.testing_utils import ( | ||
|
@@ -37,6 +38,8 @@ | |
import torch | ||
|
||
if is_torchao_available(): | ||
import torchao | ||
|
||
# renamed in torchao 0.7.0, please install the latest torchao | ||
from torchao.dtypes import ( | ||
AffineQuantizedTensor, | ||
|
@@ -135,7 +138,7 @@ class TorchAoTest(unittest.TestCase): | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | ||
device = "cpu" | ||
quant_scheme_kwargs = ( | ||
{"group_size": 32, "layout": Int4CPULayout()} | ||
{"group_size": 32, "layout": Int4CPULayout(), "version": 1} | ||
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") | ||
else {"group_size": 32} | ||
) | ||
|
@@ -225,6 +228,7 @@ def test_include_input_output_embeddings(self): | |
weight_dtype=weight_dtype, | ||
granularity=granularity, | ||
mapping_type=mapping_type, | ||
version=1, | ||
) | ||
config = ModuleFqnToConfig( | ||
{"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config} | ||
|
@@ -277,7 +281,7 @@ def test_per_module_config_skip(self): | |
@require_torch_accelerator | ||
class TorchAoAcceleratorTest(TorchAoTest): | ||
device = torch_device | ||
quant_scheme_kwargs = {"group_size": 32} | ||
quant_scheme_kwargs = {"group_size": 32, "version": 1} | ||
|
||
# called only once for all test in this class | ||
@classmethod | ||
|
@@ -327,7 +331,7 @@ def test_int4wo_offload(self): | |
"lm_head": 0, | ||
} | ||
|
||
quant_config = TorchAoConfig("int4_weight_only", group_size=32) | ||
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs) | ||
|
||
quantized_model = AutoModelForCausalLM.from_pretrained( | ||
self.model_name, | ||
|
@@ -399,7 +403,7 @@ def test_autoquant(self): | |
|
||
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj) | ||
|
||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)" | ||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be reverted? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i double checked that this fails on main as well, so i just added the correction |
||
output = quantized_model.generate( | ||
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static" | ||
) | ||
|
@@ -414,7 +418,7 @@ class TorchAoSerializationTest(unittest.TestCase): | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | ||
quant_scheme = "int4_weight_only" | ||
quant_scheme_kwargs = ( | ||
{"group_size": 32, "layout": Int4CPULayout()} | ||
{"group_size": 32, "layout": Int4CPULayout(), "version": 1} | ||
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") | ||
else {"group_size": 32} | ||
) | ||
|
@@ -447,13 +451,13 @@ def test_original_model_expected_output(self): | |
|
||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) | ||
|
||
def check_serialization_expected_output(self, device, expected_output): | ||
def check_serialization_expected_output(self, device, expected_output, safe_serialization=False): | ||
""" | ||
Test if we can serialize and load/infer the model again on the same device | ||
""" | ||
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto" | ||
with tempfile.TemporaryDirectory() as tmpdirname: | ||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) | ||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization) | ||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, dtype=dtype, device_map=device) | ||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device) | ||
|
||
|
@@ -464,6 +468,48 @@ def test_serialization_expected_output(self): | |
self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT) | ||
|
||
|
||
@require_torchao | ||
@require_torchao_version_greater_or_equal("0.14.0") | ||
class TorchAoSafeSerializationTest(TorchAoSerializationTest): | ||
liangel-02 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# called only once for all test in this class | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) | ||
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" | ||
|
||
def tearDown(self): | ||
gc.collect() | ||
backend_empty_cache(torch_device) | ||
gc.collect() | ||
if hasattr(self, "quantized_model"): | ||
del self.quantized_model | ||
gc.collect() | ||
|
||
test_params = ( | ||
[ | ||
( | ||
torchao.quantization.Float8DynamicActivationFloat8WeightConfig(), | ||
"What are we having for dinner?\n\nJess: (smiling) I", | ||
), | ||
(torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"), | ||
] | ||
if is_torchao_available() | ||
else [] | ||
) | ||
|
||
@parameterized.expand(test_params, skip_on_empty=True) | ||
def test_serialization_expected_output(self, config, expected_output): | ||
device = "cuda" | ||
self.quant_config = TorchAoConfig(config) | ||
self.quantized_model = AutoModelForCausalLM.from_pretrained( | ||
self.model_name, | ||
dtype=torch.bfloat16, | ||
device_map=device, | ||
quantization_config=self.quant_config, | ||
) | ||
self.check_serialization_expected_output(device, expected_output, safe_serialization=True) | ||
|
||
|
||
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): | ||
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} | ||
|
||
|
@@ -500,7 +546,7 @@ def test_serialization_expected_output_on_accelerator(self): | |
|
||
@require_torch_accelerator | ||
class TorchAoSerializationAcceleratorTest(TorchAoSerializationTest): | ||
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32} | ||
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32, "version": 1} | ||
device = f"{torch_device}:0" | ||
|
||
# called only once for all test in this class | ||
|
Uh oh!
There was an error while loading. Please reload this page.