diff --git a/auto_round/alg_ext.abi3.so b/auto_round/alg_ext.abi3.so old mode 100644 new mode 100755 index b6f3cbfe1..d172b6a27 Binary files a/auto_round/alg_ext.abi3.so and b/auto_round/alg_ext.abi3.so differ diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index 73191c40f..0f2b00e06 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -19,13 +19,13 @@ from accelerate import dispatch_model, infer_auto_device_map from accelerate.utils import get_balanced_memory -from auto_round.low_cpu_mem import get_module from auto_round.schemes import QuantizationScheme, preset_name_to_scheme from auto_round.utils import ( SUPPORTED_LAYER_TYPES, check_to_quantized, get_block_names, get_layer_features, + get_module, is_hpex_available, ) diff --git a/auto_round/autoround.py b/auto_round/autoround.py index a335b343c..4716f173b 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -118,7 +118,7 @@ def __new__( **kwargs: Backward compatible options: - enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap, super_group_size, super_bits, scale_dtype ("fp16" etc.), - nblocks, low_cpu_mem_usage, to_quant_block_names, + nblocks, to_quant_block_names, enable_norm_bias_tuning, enable_quanted_input, disable_deterministic_algorithms, vlm, static_kv_dtype Raises: @@ -271,7 +271,7 @@ class AutoRoundLLM(LLMCompressor): **kwargs: Backward compatible options: - enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap, super_group_size, super_bits, scale_dtype ("fp16" etc.), - nblocks, low_cpu_mem_usage, to_quant_block_names, + nblocks, to_quant_block_names, enable_norm_bias_tuning, enable_quanted_input, disable_deterministic_algorithms, mllm, static_kv_dtype Raises: @@ -366,7 +366,6 @@ class AutoRoundAdam(AdamCompressor): lr (float): The learning rate (default is 0.005). minmax_lr (float): The learning rate for min-max tuning (default is None). low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). - low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). iters (int): Number of iterations (default is 200). seqlen (int): Length of the sequence. nsamples (int): Number of samples (default is 128). @@ -473,7 +472,6 @@ class AutoRoundMLLM(MLLMCompressor): lr (float): The learning rate (default is 0.005). minmax_lr (float): The learning rate for min-max tuning (default is None). low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). - low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). iters (int): Number of iterations (default is 200). seqlen (int): Length of the sequence. nsamples (int): Number of samples (default is 128). diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 1ab233953..d6df02bc1 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -36,7 +36,6 @@ from auto_round.export.export_to_autoround import AutoRoundFormat from auto_round.export.export_to_gguf.config import GGUF_CONFIG, GGUF_INNER_CONFIG, ModelType from auto_round.logger import logger -from auto_round.low_cpu_mem.utils import get_layers_before_block from auto_round.schemes import AutoScheme, QuantizationScheme, get_gguf_scheme, preset_name_to_scheme from auto_round.sign_sgd import SignSGD from auto_round.special_model_handler import _handle_moe_model @@ -181,7 +180,7 @@ def __init__( **kwargs: Backward compatible options: - enable_alg_ext, quant_lm_head, lr, lr_scheduler, sampler, not_use_best_mse, dynamic_max_gap, super_group_size, super_bits, scale_dtype ("fp16" etc.), - nblocks, low_cpu_mem_usage, to_quant_block_names, + nblocks, to_quant_block_names, enable_norm_bias_tuning, enable_quanted_input, disable_deterministic_algorithms, mllm, static_kv_dtype Raises: @@ -254,7 +253,6 @@ def __init__( not_use_best_mse = kwargs.pop("not_use_best_mse", False) dynamic_max_gap = kwargs.pop("dynamic_max_gap", -1) nblocks = kwargs.pop("nblocks", 1) - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) to_quant_block_names: Union[str, list, None] = kwargs.pop("to_quant_block_names", None) enable_norm_bias_tuning: bool = kwargs.pop("enable_norm_bias_tuning", False) enable_quanted_input: bool = kwargs.pop("enable_quanted_input", True) @@ -292,14 +290,12 @@ def __init__( # Model related self.quantized = False if isinstance(model, str): - model, tokenizer, low_cpu_mem_usage = llm_load_model( + model, tokenizer = llm_load_model( model, device="cpu", # always load cpu first - low_cpu_mem_mode=low_cpu_mem_usage, ) elif tokenizer is None and not self.diffusion and iters > 0: raise ValueError("A tokenizer must be set for non-str model input") - self.low_cpu_mem_usage = bool(low_cpu_mem_usage) if unsupported_meta_device(model): raise RuntimeError( "AutoRound does not support parameters on meta device. " @@ -572,7 +568,6 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() - and not self.low_cpu_mem_usage and "fp8" not in self.data_type and "fp8" not in self.act_data_type ): @@ -581,10 +576,6 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: "Enabling it can reduce tuning cost by 20%, but it might throw an exception." ) - if self.low_cpu_mem_usage and self.enable_torch_compile: - self.enable_torch_compile = False - logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled") - # if is_debug_mode() and self.enable_torch_compile: # self.enable_torch_compile = False # logger.warning("reset enable_torch_compile to `False` as debug mode is enabled") @@ -1605,7 +1596,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) self._quantize_layer_via_rtn(m.tmp_name) all_to_quantized_module_names.remove(m.tmp_name) - mv_module_from_gpu(block, self.low_cpu_mem_usage) + mv_module_from_gpu(block) pbar.update(1) pbar.close() @@ -1717,11 +1708,11 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: all_q_inputs = self.try_cache_inter_data_gpucpu( all_first_block_names, self.nsamples, layer_names=layer_names ) - self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + self.model = mv_module_from_gpu(self.model) clear_memory() if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules(self.model) # self.model.hf_device_map has not been changed - self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + self.model = mv_module_from_gpu(self.model) logger.info("caching done") if len(all_blocks) > 1: pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks)) @@ -1855,7 +1846,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: self.model ) ##self.model.hf_device_map has not been changed - self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + self.model = mv_module_from_gpu(self.model) clear_memory() quant_layer = self._quantize_layer for layer_name in layer_names: @@ -1945,12 +1936,6 @@ def calib(self, nsamples, bs): self.dataloader = self.dataset total_cnt = 0 - # load embed weight if use low_cpu_mem_usage - if self.low_cpu_mem_usage: - embed_layers = get_layers_before_block(self.model) - for n, m in embed_layers: - m = m.to(self.device) - for data in self.dataloader: if data is None: continue @@ -2013,11 +1998,6 @@ def calib(self, nsamples, bs): f"Target samples count is {nsamples}, while valid samples count is {total_cnt}" ) - # clean embed weight to save memory - if self.low_cpu_mem_usage: - for n, m in embed_layers: - m = m.to("meta") - @torch.no_grad() def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, last_cache_name=None): """Attempts to cache intermediate data on GPU, if failed, then using CPU. @@ -2091,7 +2071,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l accelerate.hooks.remove_hook_from_submodules( self.model ) ##self.model.hf_device_map has not been changed - self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + self.model = mv_module_from_gpu(self.model) clear_memory() ## Important change after v0.51, on cpu, we use rtn mode for layers in layer_names all_inputs = self.cache_inter_data( @@ -2348,7 +2328,7 @@ def _quantize_layer( logger.info(dump_info) with torch.no_grad(): unwrapper_layer(self.model, wrapper_linear, layer_name, {}) - mv_module_from_gpu(layer, self.low_cpu_mem_usage) + mv_module_from_gpu(layer) lr = torch.tensor(self.lr) minmax_lr = torch.tensor(self.minmax_lr) @@ -2441,7 +2421,7 @@ def _quantize_layer( best_iter = last_best_iter with torch.no_grad(): unwrapper_layer(self.model, wrapper_linear, layer_name, best_params) - mv_module_from_gpu(layer, self.low_cpu_mem_usage) + mv_module_from_gpu(layer) dump_info = f"quantized {layer_name}, loss iter 0: {init_loss:.6f} -> iter {best_iter}: {last_loss:.6f}" logger.info(dump_info) @@ -2636,7 +2616,7 @@ def _quantize_block( ) logger.info(dump_info) unwrapper_block(block, {}) # TODO Quant layer should change - mv_module_from_gpu(block, self.low_cpu_mem_usage) + mv_module_from_gpu(block) return output, output if self.lr_scheduler is None: @@ -2731,8 +2711,6 @@ def _quantize_block( set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") if self.enable_quanted_input: - if self.low_cpu_mem_usage: - block = block.to(device) clear_memory() q_outputs = self._get_block_outputs( block, @@ -2744,7 +2722,7 @@ def _quantize_block( ) if self.device_map is not None: accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block, self.low_cpu_mem_usage) + mv_module_from_gpu(block) clear_memory(input_ids) return q_outputs, output @@ -2752,7 +2730,7 @@ def _quantize_block( else: if self.device_map is not None: accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block, self.low_cpu_mem_usage) + mv_module_from_gpu(block) clear_memory(input_ids) return None, output @@ -2848,9 +2826,7 @@ def _quantize_blocks( modules = [get_module(model, n) for n in names] m = WrapperMultiblock(modules) - if not self.model.device.type == "meta" or self.low_cpu_mem_usage: - m = m.to(device) - + m = m.to(device) q_input, input_ids = quantize_block( m, input_ids, @@ -2889,7 +2865,7 @@ def _quantize_blocks( if pbar is not None: pbar.update(1) - self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) + self.model = mv_module_from_gpu(self.model) for n, m in self.model.named_modules(): if hasattr(m, "name"): delattr(m, "name") @@ -2917,9 +2893,6 @@ def save_quantized( """ format = self._check_supported_format(format) - if self.low_cpu_mem_usage: - self.model = self.model.to("cpu") - if not self.quantized: logger.warning("please run autoround.quantize first") return @@ -3207,7 +3180,6 @@ class AdamCompressor(BaseCompressor): lr (float): The learning rate (default is 0.005). minmax_lr (float): The learning rate for min-max tuning (default is None). low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). - low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). iters (int): Number of iterations (default is 200). seqlen (int): Length of the sequence. nsamples (int): Number of samples (default is 128). diff --git a/auto_round/compressors/diffusion/compressor.py b/auto_round/compressors/diffusion/compressor.py index 5441d00b5..8508ef291 100644 --- a/auto_round/compressors/diffusion/compressor.py +++ b/auto_round/compressors/diffusion/compressor.py @@ -22,7 +22,6 @@ from auto_round.compressors.base import BaseCompressor from auto_round.compressors.diffusion.dataset import get_diffusion_dataloader from auto_round.logger import logger -from auto_round.low_cpu_mem.utils import get_layers_before_block from auto_round.schemes import QuantizationScheme from auto_round.utils import ( LazyImport, @@ -302,11 +301,6 @@ def calib(self, nsamples, bs): self.dataloader = self.dataset total_cnt = 0 - if self.low_cpu_mem_usage: - embed_layers = get_layers_before_block(self.model) - for n, m in embed_layers: - m = m.to(self.device) - total = nsamples if not hasattr(self.dataloader, "len") else min(nsamples, len(self.dataloader)) if self.pipe.dtype != self.model.dtype: self.pipe.to(self.model.dtype) @@ -358,10 +352,6 @@ def calib(self, nsamples, bs): if isinstance(v[key], list) and len(v[key]) == total_cnt: self.inputs[k][key] = v[key][:max_len] - # clean embed weight to save memory - if self.low_cpu_mem_usage: - for n, m in embed_layers: - m = m.to("meta") # torch.cuda.empty_cache() def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): diff --git a/auto_round/compressors/mllm/compressor.py b/auto_round/compressors/mllm/compressor.py index 6fcb50c2c..402a643dc 100644 --- a/auto_round/compressors/mllm/compressor.py +++ b/auto_round/compressors/mllm/compressor.py @@ -25,7 +25,6 @@ from auto_round.compressors.mllm.dataset import get_mllm_dataloader from auto_round.compressors.mllm.template import Template, get_template from auto_round.logger import logger -from auto_round.low_cpu_mem.utils import get_layers_before_block from auto_round.schemes import QuantizationScheme from auto_round.special_model_handler import ( NOT_SUPPORT_ONLY_TEXT_MODELS, @@ -108,7 +107,6 @@ class MLLMCompressor(BaseCompressor): lr (float): The learning rate (default is 0.005). minmax_lr (float): The learning rate for min-max tuning (default is None). low_gpu_mem_usage (bool): Whether to use low GPU memory (default is False). - low_cpu_mem_usage (bool): Whether to use low CPU memory (default is False). iters (int): Number of iterations (default is 200). seqlen (int): Length of the sequence. nsamples (int): Number of samples (default is 128). @@ -315,11 +313,6 @@ def calib(self, nsamples, bs): self.dataloader = self.dataset total_cnt = 0 - if self.low_cpu_mem_usage: - embed_layers = get_layers_before_block(self.model) - for n, m in embed_layers: - m = m.to(self.device) - total = nsamples if not hasattr(self.dataloader, "len") else min(nsamples, len(self.dataloader)) with tqdm(range(1, total + 1), desc="cache block inputs") as pbar: for data in self.dataloader: @@ -417,10 +410,6 @@ def calib(self, nsamples, bs): if isinstance(v[key], list) and len(v[key]) == total_cnt: self.inputs[k][key] = v[key][:max_len] - # clean embed weight to save memory - if self.low_cpu_mem_usage: - for n, m in embed_layers: - m = m.to("meta") # torch.cuda.empty_cache() def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **kwargs): diff --git a/auto_round/low_cpu_mem/__init__.py b/auto_round/low_cpu_mem/__init__.py deleted file mode 100644 index 453a608c4..000000000 --- a/auto_round/low_cpu_mem/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Torch layer-wise quantization module.""" -from .utils import * diff --git a/auto_round/low_cpu_mem/load.py b/auto_round/low_cpu_mem/load.py deleted file mode 100644 index 44a8f2b0b..000000000 --- a/auto_round/low_cpu_mem/load.py +++ /dev/null @@ -1,271 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Load one specify tensor from a bin file.""" - -import io -import os -import warnings -from typing import IO, Any, BinaryIO, Callable, Dict, Optional, Union - -from packaging.version import Version -from torch.serialization import ( - StorageType, - _get_restore_location, - _is_torchscript_zip, - _is_zipfile, - _maybe_decode_ascii, - _open_file_like, - _open_zipfile_reader, -) - -from ..low_cpu_mem import modified_pickle as pickle -from .utils import torch - -torch_version = torch.__version__.split("+")[0] -version = Version(torch_version) - -FILE_LIKE = Union[str, os.PathLike, BinaryIO, IO[bytes]] -MAP_LOCATION = Optional[Union[Callable[[torch.Tensor, str], torch.Tensor], torch.device, str, Dict[str, str]]] - -if version.release < Version("1.13.0").release: - UntypedStorage = torch._UntypedStorage -else: - UntypedStorage = torch.UntypedStorage - - -def _load(zip_file, tensor_name, prefix, map_location, pickle_module, pickle_file="data.pkl", **pickle_load_args): - restore_location = _get_restore_location(map_location) - - loaded_storages = {} - - def load_tensor(dtype, numel, key, location): - name = f"data/{key}" - - if version.release < Version("1.13.0").release: - storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped() - typed_storage = torch.storage._TypedStorage(wrap_storage=restore_location(storage, location), dtype=dtype) - loaded_storages[key] = typed_storage - elif version.release < Version("2.0.0").release: # pragma: no cover - storage = zip_file.get_storage_from_record(name, numel, UntypedStorage).storage().untyped() - typed_storage = torch.storage.TypedStorage(wrap_storage=restore_location(storage, location), dtype=dtype) - loaded_storages[key] = typed_storage - else: - storage = zip_file.get_storage_from_record(name, numel, UntypedStorage)._typed_storage()._untyped_storage - typed_storage = torch.storage.TypedStorage( - wrap_storage=restore_location(storage, location), dtype=dtype, _internal=True - ) - - if typed_storage._data_ptr() != 0: - loaded_storages[key] = typed_storage - - return typed_storage - - load_module_mapping: Dict[str, str] = {"torch.tensor": "torch._tensor"} - - # Need to subclass Unpickler instead of directly monkey-patching the find_class method - # because it's marked readonly in pickle. - # The type: ignore is because mypy can't statically determine the type of this class. - class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] - def find_class(self, mod_name, name): - if type(name) is str and "Storage" in name: - try: - return StorageType(name) - except KeyError: # pragma: no cover - pass - mod_name = load_module_mapping.get(mod_name, mod_name) - return super().find_class(mod_name, name) - - def persistent_load(self, saved_id): - assert isinstance(saved_id, tuple) - typename = _maybe_decode_ascii(saved_id[0]) - data = saved_id[1:] - - assert ( - typename == "storage" - ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" - storage_type, key, location, numel = data - - if storage_type is UntypedStorage: # pragma: no cover - dtype = torch.uint8 - else: - dtype = storage_type.dtype - - if key in loaded_storages: - typed_storage = loaded_storages[key] - else: - name_list = [self.tensor_name] - if prefix: - no_prefix_name = self.tensor_name.split(".") - if prefix in no_prefix_name: - no_prefix_name.remove(prefix) - no_prefix_name = ".".join(no_prefix_name) - name_list.append(no_prefix_name) - if self.tensor_name and self.metastack[-1][-2] not in name_list: - # typed_storage = None - # loaded_storages[key] = typed_storage - # nbytes = numel * torch._utils._element_size(dtype) - # typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) - typed_storage = None - else: - nbytes = numel * torch._utils._element_size(dtype) - typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location)) - - return typed_storage - - # Load the data (which may in turn use `persistent_load` to load tensors) - data_file = io.BytesIO(zip_file.get_record(pickle_file)) - - unpickler = UnpicklerWrapper(data_file, **pickle_load_args) - # unpickler.persistent_load = persistent_load - result = unpickler.load(tensor_name) - - torch._utils._validate_loaded_sparse_tensors() - return result - - -def load( - f: FILE_LIKE, - tensor_name: str = None, - prefix: str = None, - map_location: MAP_LOCATION = None, - pickle_module: Any = None, - *, - weights_only: bool = False, - **pickle_load_args: Any, -) -> Any: - # Reference: https://github.com/pytorch/pytorch/issues/54354 - # The first line of this docstring overrides the one Sphinx generates for the - # documentation. We need it so that Sphinx doesn't leak `pickle`s path from - # the build environment (e.g. `>> # xdoctest: +SKIP("undefined filepaths") - >>> torch.load('tensors.pt') - # Load all tensors onto the CPU - >>> torch.load('tensors.pt', map_location=torch.device('cpu')) - # Load all tensors onto the CPU, using a function - >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage) - # Load all tensors onto GPU 1 - >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) - # Map tensors from GPU 1 to GPU 0 - >>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}) - # Load tensor from io.BytesIO object - >>> with open('tensor.pt', 'rb') as f: - ... buffer = io.BytesIO(f.read()) - >>> torch.load(buffer) - # Load a module with 'ascii' encoding for unpickling - >>> torch.load('module.pt', encoding='ascii') - """ - torch._C._log_api_usage_once("torch.load") - # Add ability to force safe only weight loads via environment variable - if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ["1", "y", "yes", "true"]: # pragma: no cover - weights_only = True - - if weights_only: # pragma: no cover - if pickle_module is not None: - raise RuntimeError("Can not safely load weights when explicit pickle_module is specified") - else: - if pickle_module is None: - pickle_module = pickle - - if "encoding" not in pickle_load_args.keys(): - pickle_load_args["encoding"] = "utf-8" - - with _open_file_like(f, "rb") as opened_file: - if _is_zipfile(opened_file): - # The zipfile reader is going to advance the current file position. - # If we want to actually tail call to torch.jit.load, we need to - # reset back to the original position. - orig_position = opened_file.tell() - with _open_zipfile_reader(opened_file) as opened_zipfile: - if _is_torchscript_zip(opened_zipfile): # pragma: no cover - warnings.warn( - "'torch.load' received a zip file that looks like a TorchScript archive" - " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to" - " silence this warning)", - UserWarning, - ) - opened_file.seek(orig_position) - return torch.jit.load(opened_file, map_location=map_location) - return _load(opened_zipfile, tensor_name, prefix, map_location, pickle_module, **pickle_load_args) diff --git a/auto_round/low_cpu_mem/modified_pickle.py b/auto_round/low_cpu_mem/modified_pickle.py deleted file mode 100644 index 22e6ab5c0..000000000 --- a/auto_round/low_cpu_mem/modified_pickle.py +++ /dev/null @@ -1,1840 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Create portable serialized representations of Python objects. - -See module copyreg for a mechanism for registering custom picklers. -See module pickletools source for extensive comments. - -Classes: - - Pickler - Unpickler - -Functions: - - dump(object, file) - dumps(object) -> string - load(file) -> object - loads(string) -> object - -Misc variables: - - __version__ - format_version - compatible_formats -""" - -import codecs -import io -import re -import sys -from copyreg import _extension_cache, _extension_registry, _inverted_registry, dispatch_table -from functools import partial -from itertools import islice -from struct import pack, unpack -from sys import maxsize -from types import FunctionType - -import _compat_pickle - -__all__ = ["PickleError", "PicklingError", "UnpicklingError", "Pickler", "Unpickler", "dump", "dumps", "load", "loads"] - -try: - from _pickle import PickleBuffer - - __all__.append("PickleBuffer") - _HAVE_PICKLE_BUFFER = True -except ImportError: - _HAVE_PICKLE_BUFFER = False - - -# Shortcut for use in isinstance testing -bytes_types = (bytes, bytearray) - -# These are purely informational; no code uses these. -format_version = "4.0" # File format version we write -compatible_formats = [ - "1.0", # Original protocol 0 - "1.1", # Protocol 0 with INST added - "1.2", # Original protocol 1 - "1.3", # Protocol 1 with BINFLOAT added - "2.0", # Protocol 2 - "3.0", # Protocol 3 - "4.0", # Protocol 4 - "5.0", # Protocol 5 -] # Old format versions we can read - -# This is the highest protocol number we know how to read. -HIGHEST_PROTOCOL = 5 - -# The protocol we write by default. May be less than HIGHEST_PROTOCOL. -# Only bump this if the oldest still supported version of Python already -# includes it. -DEFAULT_PROTOCOL = 4 - - -class PickleError(Exception): - """A common base class for the other pickling exceptions.""" - - pass - - -class PicklingError(PickleError): - """This exception is raised when an unpicklable object is passed to the - dump() method.""" - - pass - - -class UnpicklingError(PickleError): - """This exception is raised when there is a problem unpickling an object, - such as a security violation. - - Note that other exceptions may also be raised during unpickling, including - (but not necessarily limited to) AttributeError, EOFError, ImportError, - and IndexError. - """ - - pass - - -# An instance of _Stop is raised by Unpickler.load_stop() in response to -# the STOP opcode, passing the object that is the result of unpickling. -class _Stop(Exception): - def __init__(self, value): # pragma: no cover - self.value = value - - -# Jython has PyStringMap; it's a dict subclass with string keys -try: - from org.python.core import PyStringMap -except ImportError: - PyStringMap = None - -# Pickle opcodes. See pickletools.py for extensive docs. The listing -# here is in kind-of alphabetical order of 1-character pickle code. -# pickletools groups them by purpose. -# fmt: off -MARK = b'(' # push special markobject on stack -STOP = b'.' # every pickle ends with STOP -POP = b'0' # discard topmost stack item -POP_MARK = b'1' # discard stack top through topmost markobject -DUP = b'2' # duplicate top stack item -FLOAT = b'F' # push float object; decimal string argument -INT = b'I' # push integer or bool; decimal string argument -BININT = b'J' # push four-byte signed int -BININT1 = b'K' # push 1-byte unsigned int -LONG = b'L' # push long; decimal string argument -BININT2 = b'M' # push 2-byte unsigned int -NONE = b'N' # push None -PERSID = b'P' # push persistent object; id is taken from string arg -BINPERSID = b'Q' # " " " ; " " " " stack -REDUCE = b'R' # apply callable to argtuple, both on stack -STRING = b'S' # push string; NL-terminated string argument -BINSTRING = b'T' # push string; counted binary string argument -SHORT_BINSTRING= b'U' # " " ; " " " " < 256 bytes -UNICODE = b'V' # push Unicode string; raw-unicode-escaped'd argument -BINUNICODE = b'X' # " " " ; counted UTF-8 string argument -APPEND = b'a' # append stack top to list below it -BUILD = b'b' # call __setstate__ or __dict__.update() -GLOBAL = b'c' # push self.find_class(modname, name); 2 string args -DICT = b'd' # build a dict from stack items -EMPTY_DICT = b'}' # push empty dict -APPENDS = b'e' # extend list on stack by topmost stack slice -GET = b'g' # push item from memo on stack; index is string arg -BINGET = b'h' # " " " " " " ; " " 1-byte arg -INST = b'i' # build & push class instance -LONG_BINGET = b'j' # push item from memo on stack; index is 4-byte arg -LIST = b'l' # build list from topmost stack items -EMPTY_LIST = b']' # push empty list -OBJ = b'o' # build & push class instance -PUT = b'p' # store stack top in memo; index is string arg -BINPUT = b'q' # " " " " " ; " " 1-byte arg -LONG_BINPUT = b'r' # " " " " " ; " " 4-byte arg -SETITEM = b's' # add key+value pair to dict -TUPLE = b't' # build tuple from topmost stack items -EMPTY_TUPLE = b')' # push empty tuple -SETITEMS = b'u' # modify dict by adding topmost key+value pairs -BINFLOAT = b'G' # push float; arg is 8-byte float encoding - -TRUE = b'I01\n' # not an opcode; see INT docs in pickletools.py -FALSE = b'I00\n' # not an opcode; see INT docs in pickletools.py - -# Protocol 2 - -PROTO = b'\x80' # identify pickle protocol -NEWOBJ = b'\x81' # build object by applying cls.__new__ to argtuple -EXT1 = b'\x82' # push object from extension registry; 1-byte index -EXT2 = b'\x83' # ditto, but 2-byte index -EXT4 = b'\x84' # ditto, but 4-byte index -TUPLE1 = b'\x85' # build 1-tuple from stack top -TUPLE2 = b'\x86' # build 2-tuple from two topmost stack items -TUPLE3 = b'\x87' # build 3-tuple from three topmost stack items -NEWTRUE = b'\x88' # push True -NEWFALSE = b'\x89' # push False -LONG1 = b'\x8a' # push long from < 256 bytes -LONG4 = b'\x8b' # push really big long - -_tuplesize2code = [EMPTY_TUPLE, TUPLE1, TUPLE2, TUPLE3] - -# Protocol 3 (Python 3.x) - -BINBYTES = b'B' # push bytes; counted binary string argument -SHORT_BINBYTES = b'C' # " " ; " " " " < 256 bytes - -# Protocol 4 - -SHORT_BINUNICODE = b'\x8c' # push short string; UTF-8 length < 256 bytes -BINUNICODE8 = b'\x8d' # push very long string -BINBYTES8 = b'\x8e' # push very long bytes string -EMPTY_SET = b'\x8f' # push empty set on the stack -ADDITEMS = b'\x90' # modify set by adding topmost stack items -FROZENSET = b'\x91' # build frozenset from topmost stack items -NEWOBJ_EX = b'\x92' # like NEWOBJ but work with keyword only arguments -STACK_GLOBAL = b'\x93' # same as GLOBAL but using names on the stacks -MEMOIZE = b'\x94' # store top of the stack in memo -FRAME = b'\x95' # indicate the beginning of a new frame - -# Protocol 5 - -BYTEARRAY8 = b'\x96' # push bytearray -NEXT_BUFFER = b'\x97' # push next out-of-band buffer -READONLY_BUFFER = b'\x98' # make top of stack readonly -# fmt: on -__all__.extend([x for x in dir() if re.match("[A-Z][A-Z0-9_]+$", x)]) - - -class _Framer: # pragma: no cover - _FRAME_SIZE_MIN = 4 - _FRAME_SIZE_TARGET = 64 * 1024 - - def __init__(self, file_write): - self.file_write = file_write - self.current_frame = None - - def start_framing(self): - self.current_frame = io.BytesIO() - - def end_framing(self): - if self.current_frame and self.current_frame.tell() > 0: - self.commit_frame(force=True) - self.current_frame = None - - def commit_frame(self, force=False): - if self.current_frame: - f = self.current_frame - if f.tell() >= self._FRAME_SIZE_TARGET or force: - data = f.getbuffer() - write = self.file_write - if len(data) >= self._FRAME_SIZE_MIN: - # Issue a single call to the write method of the underlying - # file object for the frame opcode with the size of the - # frame. The concatenation is expected to be less expensive - # than issuing an additional call to write. - write(FRAME + pack("": - raise AttributeError("Can't get local attribute {!r} on {!r}".format(name, obj)) - try: - parent = obj - obj = getattr(obj, subpath) - except AttributeError: - raise AttributeError("Can't get attribute {!r} on {!r}".format(name, obj)) from None - return obj, parent - - -def whichmodule(obj, name): # pragma: no cover - """Find the module an object belong to.""" - module_name = getattr(obj, "__module__", None) - if module_name is not None: - return module_name - # Protect the iteration by using a list copy of sys.modules against dynamic - # modules that trigger imports of other modules upon calls to getattr. - for module_name, module in sys.modules.copy().items(): - if module_name == "__main__" or module_name == "__mp_main__" or module is None: # bpo-42406 - continue - try: - if _getattribute(module, name)[0] is obj: - return module_name - except AttributeError: - pass - return "__main__" - - -def encode_long(x): # pragma: no cover - r"""Encode a long to a two's complement little-endian binary string. - Note that 0 is a special case, returning an empty string, to save a - byte in the LONG1 pickling context. - - >>> encode_long(0) - b'' - >>> encode_long(255) - b'\xff\x00' - >>> encode_long(32767) - b'\xff\x7f' - >>> encode_long(-256) - b'\x00\xff' - >>> encode_long(-32768) - b'\x00\x80' - >>> encode_long(-128) - b'\x80' - >>> encode_long(127) - b'\x7f' - >>> - """ - if x == 0: - return b"" - nbytes = (x.bit_length() >> 3) + 1 - result = x.to_bytes(nbytes, byteorder="little", signed=True) - if x < 0 and nbytes > 1: - if result[-1] == 0xFF and (result[-2] & 0x80) != 0: - result = result[:-1] - return result - - -def decode_long(data): # pragma: no cover - r"""Decode a long from a two's complement little-endian binary string. - - >>> decode_long(b'') - 0 - >>> decode_long(b"\xff\x00") - 255 - >>> decode_long(b"\xff\x7f") - 32767 - >>> decode_long(b"\x00\xff") - -256 - >>> decode_long(b"\x00\x80") - -32768 - >>> decode_long(b"\x80") - -128 - >>> decode_long(b"\x7f") - 127 - """ - return int.from_bytes(data, byteorder="little", signed=True) - - -# Pickling machinery - - -class _Pickler: # pragma: no cover - def __init__(self, file, protocol=None, *, fix_imports=True, buffer_callback=None): - """This takes a binary file for writing a pickle data stream. - - The optional *protocol* argument tells the pickler to use the - given protocol; supported protocols are 0, 1, 2, 3, 4 and 5. - The default protocol is 4. It was introduced in Python 3.4, and - is incompatible with previous versions. - - Specifying a negative protocol version selects the highest - protocol version supported. The higher the protocol used, the - more recent the version of Python needed to read the pickle - produced. - - The *file* argument must have a write() method that accepts a - single bytes argument. It can thus be a file object opened for - binary writing, an io.BytesIO instance, or any other custom - object that meets this interface. - - If *fix_imports* is True and *protocol* is less than 3, pickle - will try to map the new Python 3 names to the old module names - used in Python 2, so that the pickle data stream is readable - with Python 2. - - If *buffer_callback* is None (the default), buffer views are - serialized into *file* as part of the pickle stream. - - If *buffer_callback* is not None, then it can be called any number - of times with a buffer view. If the callback returns a false value - (such as None), the given buffer is out-of-band; otherwise the - buffer is serialized in-band, i.e. inside the pickle stream. - - It is an error if *buffer_callback* is not None and *protocol* - is None or smaller than 5. - """ - if protocol is None: - protocol = DEFAULT_PROTOCOL - if protocol < 0: - protocol = HIGHEST_PROTOCOL - elif not 0 <= protocol <= HIGHEST_PROTOCOL: - raise ValueError("pickle protocol must be <= %d" % HIGHEST_PROTOCOL) - if buffer_callback is not None and protocol < 5: - raise ValueError("buffer_callback needs protocol >= 5") - self._buffer_callback = buffer_callback - try: - self._file_write = file.write - except AttributeError: - raise TypeError("file must have a 'write' attribute") - self.framer = _Framer(self._file_write) - self.write = self.framer.write - self._write_large_bytes = self.framer.write_large_bytes - self.memo = {} - self.proto = int(protocol) - self.bin = protocol >= 1 - self.fast = 0 - self.fix_imports = fix_imports and protocol < 3 - - def clear_memo(self): - """Clears the pickler's "memo". - - The memo is the data structure that remembers which objects the - pickler has already seen, so that shared or recursive objects - are pickled by reference and not by value. This method is - useful when reusing picklers. - """ - self.memo.clear() - - def dump(self, obj): - """Write a pickled representation of obj to the open file.""" - # Check whether Pickler was initialized correctly. This is - # only needed to mimic the behavior of _pickle.Pickler.dump(). - if not hasattr(self, "_file_write"): - raise PicklingError("Pickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,)) - if self.proto >= 2: - self.write(PROTO + pack("= 4: - self.framer.start_framing() - self.save(obj) - self.write(STOP) - self.framer.end_framing() - - def memoize(self, obj): - """Store an object in the memo.""" - - # The Pickler memo is a dictionary mapping object ids to 2-tuples - # that contain the Unpickler memo key and the object being memoized. - # The memo key is written to the pickle and will become - # the key in the Unpickler's memo. The object is stored in the - # Pickler memo so that transient objects are kept alive during - # pickling. - - # The use of the Unpickler memo length as the memo key is just a - # convention. The only requirement is that the memo values be unique. - # But there appears no advantage to any other scheme, and this - # scheme allows the Unpickler memo to be implemented as a plain (but - # growable) array, indexed by memo key. - if self.fast: - return - assert id(obj) not in self.memo - idx = len(self.memo) - self.write(self.put(idx)) - self.memo[id(obj)] = idx, obj - - # Return a PUT (BINPUT, LONG_BINPUT) opcode string, with argument i. - def put(self, idx): - if self.proto >= 4: - return MEMOIZE - elif self.bin: - if idx < 256: - return BINPUT + pack("= 2 and func_name == "__newobj_ex__": - cls, args, kwargs = args - if not hasattr(cls, "__new__"): - raise PicklingError("args[0] from {} args has no __new__".format(func_name)) - if obj is not None and cls is not obj.__class__: - raise PicklingError("args[0] from {} args has the wrong class".format(func_name)) - if self.proto >= 4: - save(cls) - save(args) - save(kwargs) - write(NEWOBJ_EX) - else: - func = partial(cls.__new__, cls, *args, **kwargs) - save(func) - save(()) - write(REDUCE) - elif self.proto >= 2 and func_name == "__newobj__": - # A __reduce__ implementation can direct protocol 2 or newer to - # use the more efficient NEWOBJ opcode, while still - # allowing protocol 0 and 1 to work normally. For this to - # work, the function returned by __reduce__ should be - # called __newobj__, and its first argument should be a - # class. The implementation for __newobj__ - # should be as follows, although pickle has no way to - # verify this: - # - # def __newobj__(cls, *args): - # return cls.__new__(cls, *args) - # - # Protocols 0 and 1 will pickle a reference to __newobj__, - # while protocol 2 (and above) will pickle a reference to - # cls, the remaining args tuple, and the NEWOBJ code, - # which calls cls.__new__(cls, *args) at unpickling time - # (see load_newobj below). If __reduce__ returns a - # three-tuple, the state from the third tuple item will be - # pickled regardless of the protocol, calling __setstate__ - # at unpickling time (see load_build below). - # - # Note that no standard __newobj__ implementation exists; - # you have to provide your own. This is to enforce - # compatibility with Python 2.2 (pickles written using - # protocol 0 or 1 in Python 2.3 should be unpicklable by - # Python 2.2). - cls = args[0] - if not hasattr(cls, "__new__"): - raise PicklingError("args[0] from __newobj__ args has no __new__") - if obj is not None and cls is not obj.__class__: - raise PicklingError("args[0] from __newobj__ args has the wrong class") - args = args[1:] - save(cls) - save(args) - write(NEWOBJ) - else: - save(func) - save(args) - write(REDUCE) - - if obj is not None: - # If the object is already in the memo, this means it is - # recursive. In this case, throw away everything we put on the - # stack, and fetch the object back from the memo. - if id(obj) in self.memo: - write(POP + self.get(self.memo[id(obj)][0])) - else: - self.memoize(obj) - - # More new special cases (that work with older protocols as - # well): when __reduce__ returns a tuple with 4 or 5 items, - # the 4th and 5th item should be iterators that provide list - # items and dict items (as (key, value) tuples), or None. - - if listitems is not None: - self._batch_appends(listitems) - - if dictitems is not None: - self._batch_setitems(dictitems) - - if state is not None: - if state_setter is None: - save(state) - write(BUILD) - else: - # If a state_setter is specified, call it instead of load_build - # to update obj's with its previous state. - # First, push state_setter and its tuple of expected arguments - # (obj, state) onto the stack. - save(state_setter) - save(obj) # simple BINGET opcode as obj is already memoized. - save(state) - write(TUPLE2) - # Trigger a state_setter(obj, state) function call. - write(REDUCE) - # The purpose of state_setter is to carry-out an - # inplace modification of obj. We do not care about what the - # method might return, so its output is eventually removed from - # the stack. - write(POP) - - # Methods below this point are dispatched through the dispatch table - - dispatch = {} - - def save_none(self, obj): - self.write(NONE) - - dispatch[type(None)] = save_none - - def save_bool(self, obj): - if self.proto >= 2: - self.write(NEWTRUE if obj else NEWFALSE) - else: - self.write(TRUE if obj else FALSE) - - dispatch[bool] = save_bool - - def save_long(self, obj): - if self.bin: - # If the int is small enough to fit in a signed 4-byte 2's-comp - # format, we can store it more efficiently than the general - # case. - # First one- and two-byte unsigned ints: - if obj >= 0: - if obj <= 0xFF: - self.write(BININT1 + pack("= 2: - encoded = encode_long(obj) - n = len(encoded) - if n < 256: - self.write(LONG1 + pack("d", obj)) - else: - self.write(FLOAT + repr(obj).encode("ascii") + b"\n") - - dispatch[float] = save_float - - def save_bytes(self, obj): - if self.proto < 3: - if not obj: # bytes object is empty - self.save_reduce(bytes, (), obj=obj) - else: - self.save_reduce(codecs.encode, (str(obj, "latin1"), "latin1"), obj=obj) - return - n = len(obj) - if n <= 0xFF: - self.write(SHORT_BINBYTES + pack(" 0xFFFFFFFF and self.proto >= 4: - self._write_large_bytes(BINBYTES8 + pack("= self.framer._FRAME_SIZE_TARGET: - self._write_large_bytes(BINBYTES + pack("= self.framer._FRAME_SIZE_TARGET: - self._write_large_bytes(BYTEARRAY8 + pack("= 5") - with obj.raw() as m: - if not m.contiguous: - raise PicklingError("PickleBuffer can not be pickled when " "pointing to a non-contiguous buffer") - in_band = True - if self._buffer_callback is not None: - in_band = bool(self._buffer_callback(obj)) - if in_band: - # Write data in-band - # XXX The C implementation avoids a copy here - if m.readonly: - self.save_bytes(m.tobytes()) - else: - self.save_bytearray(m.tobytes()) - else: - # Write data out-of-band - self.write(NEXT_BUFFER) - if m.readonly: - self.write(READONLY_BUFFER) - - dispatch[PickleBuffer] = save_picklebuffer - - def save_str(self, obj): - if self.bin: - encoded = obj.encode("utf-8", "surrogatepass") - n = len(encoded) - if n <= 0xFF and self.proto >= 4: - self.write(SHORT_BINUNICODE + pack(" 0xFFFFFFFF and self.proto >= 4: - self._write_large_bytes(BINUNICODE8 + pack("= self.framer._FRAME_SIZE_TARGET: - self._write_large_bytes(BINUNICODE + pack("= 2: - for element in obj: - save(element) - # Subtle. Same as in the big comment below. - if id(obj) in memo: - get = self.get(memo[id(obj)][0]) - self.write(POP * n + get) - else: - self.write(_tuplesize2code[n]) - self.memoize(obj) - return - - # proto 0 or proto 1 and tuple isn't empty, or proto > 1 and tuple - # has more than 3 elements. - write = self.write - write(MARK) - for element in obj: - save(element) - - if id(obj) in memo: - # Subtle. d was not in memo when we entered save_tuple(), so - # the process of saving the tuple's elements must have saved - # the tuple itself: the tuple is recursive. The proper action - # now is to throw away everything we put on the stack, and - # simply GET the tuple (it's already constructed). This check - # could have been done in the "for element" loop instead, but - # recursive tuples are a rare thing. - get = self.get(memo[id(obj)][0]) - if self.bin: - write(POP_MARK + get) - else: # proto 0 -- POP_MARK not available - write(POP * (n + 1) + get) - return - - # No recursion. - write(TUPLE) - self.memoize(obj) - - dispatch[tuple] = save_tuple - - def save_list(self, obj): - if self.bin: - self.write(EMPTY_LIST) - else: # proto 0 -- can't use EMPTY_LIST - self.write(MARK + LIST) - - self.memoize(obj) - self._batch_appends(obj) - - dispatch[list] = save_list - - _BATCHSIZE = 1000 - - def _batch_appends(self, items): - # Helper to batch up APPENDS sequences - save = self.save - write = self.write - - if not self.bin: - for x in items: - save(x) - write(APPEND) - return - - it = iter(items) - while True: - tmp = list(islice(it, self._BATCHSIZE)) - n = len(tmp) - if n > 1: - write(MARK) - for x in tmp: - save(x) - write(APPENDS) - elif n: - save(tmp[0]) - write(APPEND) - # else tmp is empty, and we're done - if n < self._BATCHSIZE: - return - - def save_dict(self, obj): - if self.bin: - self.write(EMPTY_DICT) - else: # proto 0 -- can't use EMPTY_DICT - self.write(MARK + DICT) - - self.memoize(obj) - self._batch_setitems(obj.items()) - - dispatch[dict] = save_dict - if PyStringMap is not None: - dispatch[PyStringMap] = save_dict - - def _batch_setitems(self, items): - # Helper to batch up SETITEMS sequences; proto >= 1 only - save = self.save - write = self.write - - if not self.bin: - for k, v in items: - save(k) - save(v) - write(SETITEM) - return - - it = iter(items) - while True: - tmp = list(islice(it, self._BATCHSIZE)) - n = len(tmp) - if n > 1: - write(MARK) - for k, v in tmp: - save(k) - save(v) - write(SETITEMS) - elif n: - k, v = tmp[0] - save(k) - save(v) - write(SETITEM) - # else tmp is empty, and we're done - if n < self._BATCHSIZE: - return - - def save_set(self, obj): - save = self.save - write = self.write - - if self.proto < 4: - self.save_reduce(set, (list(obj),), obj=obj) - return - - write(EMPTY_SET) - self.memoize(obj) - - it = iter(obj) - while True: - batch = list(islice(it, self._BATCHSIZE)) - n = len(batch) - if n > 0: - write(MARK) - for item in batch: - save(item) - write(ADDITEMS) - if n < self._BATCHSIZE: - return - - dispatch[set] = save_set - - def save_frozenset(self, obj): - save = self.save - write = self.write - - if self.proto < 4: - self.save_reduce(frozenset, (list(obj),), obj=obj) - return - - write(MARK) - for item in obj: - save(item) - - if id(obj) in self.memo: - # If the object is already in the memo, this means it is - # recursive. In this case, throw away everything we put on the - # stack, and fetch the object back from the memo. - write(POP_MARK + self.get(self.memo[id(obj)][0])) - return - - write(FROZENSET) - self.memoize(obj) - - dispatch[frozenset] = save_frozenset - - def save_global(self, obj, name=None): - write = self.write - memo = self.memo - - if name is None: - name = getattr(obj, "__qualname__", None) - if name is None: - name = obj.__name__ - - module_name = whichmodule(obj, name) - try: - __import__(module_name, level=0) - module = sys.modules[module_name] - obj2, parent = _getattribute(module, name) - except (ImportError, KeyError, AttributeError): - raise PicklingError("Can't pickle %r: it's not found as %s.%s" % (obj, module_name, name)) from None - else: - if obj2 is not obj: - raise PicklingError("Can't pickle %r: it's not the same object as %s.%s" % (obj, module_name, name)) - - if self.proto >= 2: - code = _extension_registry.get((module_name, name)) - if code: - assert code > 0 - if code <= 0xFF: - write(EXT1 + pack("= 3. - if self.proto >= 4: - self.save(module_name) - self.save(name) - write(STACK_GLOBAL) - elif parent is not module: - self.save_reduce(getattr, (parent, lastname)) - elif self.proto >= 3: - write(GLOBAL + bytes(module_name, "utf-8") + b"\n" + bytes(name, "utf-8") + b"\n") - else: - if self.fix_imports: - r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING - r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING - if (module_name, name) in r_name_mapping: - module_name, name = r_name_mapping[(module_name, name)] - elif module_name in r_import_mapping: - module_name = r_import_mapping[module_name] - try: - write(GLOBAL + bytes(module_name, "ascii") + b"\n" + bytes(name, "ascii") + b"\n") - except UnicodeEncodeError: - raise PicklingError( - "can't pickle global identifier '%s.%s' using " "pickle protocol %i" % (module, name, self.proto) - ) from None - - self.memoize(obj) - - def save_type(self, obj): - if obj is type(None): - return self.save_reduce(type, (None,), obj=obj) - elif obj is type(NotImplemented): - return self.save_reduce(type, (NotImplemented,), obj=obj) - elif obj is type(...): - return self.save_reduce(type, (...,), obj=obj) - return self.save_global(obj) - - dispatch[FunctionType] = save_global - dispatch[type] = save_type - - -# Unpickling machinery - - -class _Unpickler: # pragma: no cover - def __init__(self, file, *, fix_imports=True, encoding="ASCII", errors="strict", buffers=None): - """This takes a binary file for reading a pickle data stream. - - The protocol version of the pickle is detected automatically, so - no proto argument is needed. - - The argument *file* must have two methods, a read() method that - takes an integer argument, and a readline() method that requires - no arguments. Both methods should return bytes. Thus *file* - can be a binary file object opened for reading, an io.BytesIO - object, or any other custom object that meets this interface. - - The file-like object must have two methods, a read() method - that takes an integer argument, and a readline() method that - requires no arguments. Both methods should return bytes. - Thus file-like object can be a binary file object opened for - reading, a BytesIO object, or any other custom object that - meets this interface. - - If *buffers* is not None, it should be an iterable of buffer-enabled - objects that is consumed each time the pickle stream references - an out-of-band buffer view. Such buffers have been given in order - to the *buffer_callback* of a Pickler object. - - If *buffers* is None (the default), then the buffers are taken - from the pickle stream, assuming they are serialized there. - It is an error for *buffers* to be None if the pickle stream - was produced with a non-None *buffer_callback*. - - Other optional arguments are *fix_imports*, *encoding* and - *errors*, which are used to control compatibility support for - pickle stream generated by Python 2. If *fix_imports* is True, - pickle will try to map the old Python 2 names to the new names - used in Python 3. The *encoding* and *errors* tell pickle how - to decode 8-bit string instances pickled by Python 2; these - default to 'ASCII' and 'strict', respectively. *encoding* can be - 'bytes' to read these 8-bit string instances as bytes objects. - """ - self._buffers = iter(buffers) if buffers is not None else None - self._file_readline = file.readline - self._file_read = file.read - self.memo = {} - self.encoding = encoding - self.errors = errors - self.proto = 0 - self.fix_imports = fix_imports - - def load(self, tensor_name=None): - """Read a pickled object representation from the open file. - - Return the reconstituted object hierarchy specified in the file. - """ - # Check whether Unpickler was initialized correctly. This is - # only needed to mimic the behavior of _pickle.Unpickler.dump(). - - if not hasattr(self, "_file_read"): - raise UnpicklingError( - "Unpickler.__init__() was not called by " "%s.__init__()" % (self.__class__.__name__,) - ) - self.tensor_name = tensor_name - self._unframer = _Unframer(self._file_read, self._file_readline) - self.read = self._unframer.read - self.readinto = self._unframer.readinto - self.readline = self._unframer.readline - self.metastack = [] - self.stack = [] - self.append = self.stack.append - self.proto = 0 - read = self.read - dispatch = self.dispatch - try: - while True: - key = read(1) - if not key: - raise EOFError - assert isinstance(key, bytes_types) - dispatch[key[0]](self) - except _Stop as stopinst: - return stopinst.value - - # Return a list of items pushed in the stack after last MARK instruction. - def pop_mark(self): - items = self.stack - self.stack = self.metastack.pop() - self.append = self.stack.append - return items - - def persistent_load(self, pid): - raise UnpicklingError("unsupported persistent id encountered") - - dispatch = {} - - def load_proto(self): - proto = self.read(1)[0] - if not 0 <= proto <= HIGHEST_PROTOCOL: - raise ValueError("unsupported pickle protocol: %d" % proto) - self.proto = proto - - dispatch[PROTO[0]] = load_proto - - def load_frame(self): - (frame_size,) = unpack(" sys.maxsize: - raise ValueError("frame size > sys.maxsize: %d" % frame_size) - self._unframer.load_frame(frame_size) - - dispatch[FRAME[0]] = load_frame - - def load_persid(self): - try: - pid = self.readline()[:-1].decode("ascii") - except UnicodeDecodeError: - raise UnpicklingError("persistent IDs in protocol 0 must be ASCII strings") - self.append(self.persistent_load(pid)) - - dispatch[PERSID[0]] = load_persid - - def load_binpersid(self): - pid = self.stack.pop() - self.append(self.persistent_load(pid)) - - dispatch[BINPERSID[0]] = load_binpersid - - def load_none(self): - self.append(None) - - dispatch[NONE[0]] = load_none - - def load_false(self): - self.append(False) - - dispatch[NEWFALSE[0]] = load_false - - def load_true(self): - self.append(True) - - dispatch[NEWTRUE[0]] = load_true - - def load_int(self): - data = self.readline() - if data == FALSE[1:]: - val = False - elif data == TRUE[1:]: - val = True - else: - val = int(data, 0) - self.append(val) - - dispatch[INT[0]] = load_int - - def load_binint(self): - self.append(unpack("d", self.read(8))[0]) - - dispatch[BINFLOAT[0]] = load_binfloat - - def _decode_string(self, value): - # Used to allow strings from Python 2 to be decoded either as - # bytes or Unicode strings. This should be used only with the - # STRING, BINSTRING and SHORT_BINSTRING opcodes. - if self.encoding == "bytes": - return value - else: - return value.decode(self.encoding, self.errors) - - def load_string(self): - data = self.readline()[:-1] - # Strip outermost quotes - if len(data) >= 2 and data[0] == data[-1] and data[0] in b"\"'": - data = data[1:-1] - else: - raise UnpicklingError("the STRING opcode argument must be quoted") - self.append(self._decode_string(codecs.escape_decode(data)[0])) - - dispatch[STRING[0]] = load_string - - def load_binstring(self): - # Deprecated BINSTRING uses signed 32-bit length - (len,) = unpack(" maxsize: - raise UnpicklingError("BINBYTES exceeds system's maximum size " "of %d bytes" % maxsize) - self.append(self.read(len)) - - dispatch[BINBYTES[0]] = load_binbytes - - def load_unicode(self): - self.append(str(self.readline()[:-1], "raw-unicode-escape")) - - dispatch[UNICODE[0]] = load_unicode - - def load_binunicode(self): - (len,) = unpack(" maxsize: - raise UnpicklingError("BINUNICODE exceeds system's maximum size " "of %d bytes" % maxsize) - self.append(str(self.read(len), "utf-8", "surrogatepass")) - - dispatch[BINUNICODE[0]] = load_binunicode - - def load_binunicode8(self): - (len,) = unpack(" maxsize: - raise UnpicklingError("BINUNICODE8 exceeds system's maximum size " "of %d bytes" % maxsize) - self.append(str(self.read(len), "utf-8", "surrogatepass")) - - dispatch[BINUNICODE8[0]] = load_binunicode8 - - def load_binbytes8(self): - (len,) = unpack(" maxsize: - raise UnpicklingError("BINBYTES8 exceeds system's maximum size " "of %d bytes" % maxsize) - self.append(self.read(len)) - - dispatch[BINBYTES8[0]] = load_binbytes8 - - def load_bytearray8(self): - (len,) = unpack(" maxsize: - raise UnpicklingError("BYTEARRAY8 exceeds system's maximum size " "of %d bytes" % maxsize) - b = bytearray(len) - self.readinto(b) - self.append(b) - - dispatch[BYTEARRAY8[0]] = load_bytearray8 - - def load_next_buffer(self): - if self._buffers is None: - raise UnpicklingError("pickle stream refers to out-of-band data " "but no *buffers* argument was given") - try: - buf = next(self._buffers) - except StopIteration: - raise UnpicklingError("not enough out-of-band buffers") - self.append(buf) - - dispatch[NEXT_BUFFER[0]] = load_next_buffer - - def load_readonly_buffer(self): - buf = self.stack[-1] - with memoryview(buf) as m: - if not m.readonly: - self.stack[-1] = m.toreadonly() - - dispatch[READONLY_BUFFER[0]] = load_readonly_buffer - - def load_short_binstring(self): - len = self.read(1)[0] - data = self.read(len) - self.append(self._decode_string(data)) - - dispatch[SHORT_BINSTRING[0]] = load_short_binstring - - def load_short_binbytes(self): - len = self.read(1)[0] - self.append(self.read(len)) - - dispatch[SHORT_BINBYTES[0]] = load_short_binbytes - - def load_short_binunicode(self): - len = self.read(1)[0] - self.append(str(self.read(len), "utf-8", "surrogatepass")) - - dispatch[SHORT_BINUNICODE[0]] = load_short_binunicode - - def load_tuple(self): - items = self.pop_mark() - self.append(tuple(items)) - - dispatch[TUPLE[0]] = load_tuple - - def load_empty_tuple(self): - self.append(()) - - dispatch[EMPTY_TUPLE[0]] = load_empty_tuple - - def load_tuple1(self): - self.stack[-1] = (self.stack[-1],) - - dispatch[TUPLE1[0]] = load_tuple1 - - def load_tuple2(self): - self.stack[-2:] = [(self.stack[-2], self.stack[-1])] - - dispatch[TUPLE2[0]] = load_tuple2 - - def load_tuple3(self): - self.stack[-3:] = [(self.stack[-3], self.stack[-2], self.stack[-1])] - - dispatch[TUPLE3[0]] = load_tuple3 - - def load_empty_list(self): - self.append([]) - - dispatch[EMPTY_LIST[0]] = load_empty_list - - def load_empty_dictionary(self): - self.append({}) - - dispatch[EMPTY_DICT[0]] = load_empty_dictionary - - def load_empty_set(self): - self.append(set()) - - dispatch[EMPTY_SET[0]] = load_empty_set - - def load_frozenset(self): - items = self.pop_mark() - self.append(frozenset(items)) - - dispatch[FROZENSET[0]] = load_frozenset - - def load_list(self): - items = self.pop_mark() - self.append(items) - - dispatch[LIST[0]] = load_list - - def load_dict(self): - items = self.pop_mark() - d = {items[i]: items[i + 1] for i in range(0, len(items), 2)} - self.append(d) - - dispatch[DICT[0]] = load_dict - - # INST and OBJ differ only in how they get a class object. It's not - # only sensible to do the rest in a common routine, the two routines - # previously diverged and grew different bugs. - # klass is the class to instantiate, and k points to the topmost mark - # object, following which are the arguments for klass.__init__. - def _instantiate(self, klass, args): - if args or not isinstance(klass, type) or hasattr(klass, "__getinitargs__"): - try: - value = klass(*args) - except TypeError as err: - raise TypeError("in constructor for %s: %s" % (klass.__name__, str(err)), sys.exc_info()[2]) - else: - value = klass.__new__(klass) - self.append(value) - - def load_inst(self): - module = self.readline()[:-1].decode("ascii") - name = self.readline()[:-1].decode("ascii") - klass = self.find_class(module, name) - self._instantiate(klass, self.pop_mark()) - - dispatch[INST[0]] = load_inst - - def load_obj(self): - # Stack is ... markobject classobject arg1 arg2 ... - args = self.pop_mark() - cls = args.pop(0) - self._instantiate(cls, args) - - dispatch[OBJ[0]] = load_obj - - def load_newobj(self): - args = self.stack.pop() - cls = self.stack.pop() - obj = cls.__new__(cls, *args) - self.append(obj) - - dispatch[NEWOBJ[0]] = load_newobj - - def load_newobj_ex(self): - kwargs = self.stack.pop() - args = self.stack.pop() - cls = self.stack.pop() - obj = cls.__new__(cls, *args, **kwargs) - self.append(obj) - - dispatch[NEWOBJ_EX[0]] = load_newobj_ex - - def load_global(self): - module = self.readline()[:-1].decode("utf-8") - name = self.readline()[:-1].decode("utf-8") - klass = self.find_class(module, name) - self.append(klass) - - dispatch[GLOBAL[0]] = load_global - - def load_stack_global(self): - name = self.stack.pop() - module = self.stack.pop() - if type(name) is not str or type(module) is not str: - raise UnpicklingError("STACK_GLOBAL requires str") - self.append(self.find_class(module, name)) - - dispatch[STACK_GLOBAL[0]] = load_stack_global - - def load_ext1(self): - code = self.read(1)[0] - self.get_extension(code) - - dispatch[EXT1[0]] = load_ext1 - - def load_ext2(self): - (code,) = unpack("= 4: - return _getattribute(sys.modules[module], name)[0] - else: - return getattr(sys.modules[module], name) - - def load_reduce(self): - stack = self.stack - args = stack.pop() - func = stack[-1] - if len(args) > 0 and args[0] is None: - stack[-1] = None - else: - stack[-1] = func(*args) - # stack[-1] = func(*args) - - dispatch[REDUCE[0]] = load_reduce - - def load_pop(self): - if self.stack: - del self.stack[-1] - else: - self.pop_mark() - - dispatch[POP[0]] = load_pop - - def load_pop_mark(self): - self.pop_mark() - - dispatch[POP_MARK[0]] = load_pop_mark - - def load_dup(self): - self.append(self.stack[-1]) - - dispatch[DUP[0]] = load_dup - - def load_get(self): - i = int(self.readline()[:-1]) - self.append(self.memo[i]) - - dispatch[GET[0]] = load_get - - def load_binget(self): - i = self.read(1)[0] - self.append(self.memo[i]) - - dispatch[BINGET[0]] = load_binget - - def load_long_binget(self): - (i,) = unpack(" maxsize: - raise ValueError("negative LONG_BINPUT argument") - self.memo[i] = self.stack[-1] - - dispatch[LONG_BINPUT[0]] = load_long_binput - - def load_memoize(self): - memo = self.memo - memo[len(memo)] = self.stack[-1] - - dispatch[MEMOIZE[0]] = load_memoize - - def load_append(self): - stack = self.stack - value = stack.pop() - list = stack[-1] - list.append(value) - - dispatch[APPEND[0]] = load_append - - def load_appends(self): - items = self.pop_mark() - list_obj = self.stack[-1] - try: - extend = list_obj.extend - except AttributeError: - pass - else: - extend(items) - return - # Even if the PEP 307 requires extend() and append() methods, - # fall back on append() if the object has no extend() method - # for backward compatibility. - append = list_obj.append - for item in items: - append(item) - - dispatch[APPENDS[0]] = load_appends - - def load_setitem(self): - stack = self.stack - value = stack.pop() - key = stack.pop() - dict = stack[-1] - dict[key] = value - - dispatch[SETITEM[0]] = load_setitem - - def load_setitems(self): - items = self.pop_mark() - dict = self.stack[-1] - for i in range(0, len(items), 2): - dict[items[i]] = items[i + 1] - - dispatch[SETITEMS[0]] = load_setitems - - def load_additems(self): - items = self.pop_mark() - set_obj = self.stack[-1] - if isinstance(set_obj, set): - set_obj.update(items) - else: - add = set_obj.add - for item in items: - add(item) - - dispatch[ADDITEMS[0]] = load_additems - - def load_build(self): - stack = self.stack - state = stack.pop() - inst = stack[-1] - setstate = getattr(inst, "__setstate__", None) - if setstate is not None: - setstate(state) - return - slotstate = None - if isinstance(state, tuple) and len(state) == 2: - state, slotstate = state - if state: - inst_dict = inst.__dict__ - intern = sys.intern - for k, v in state.items(): - if type(k) is str: - inst_dict[intern(k)] = v - else: - inst_dict[k] = v - if slotstate: - for k, v in slotstate.items(): - setattr(inst, k, v) - - dispatch[BUILD[0]] = load_build - - def load_mark(self): - self.metastack.append(self.stack) - self.stack = [] - self.append = self.stack.append - - dispatch[MARK[0]] = load_mark - - def load_stop(self): - value = self.stack.pop() - raise _Stop(value) - - dispatch[STOP[0]] = load_stop - - -# Shorthands - - -def _dump(obj, file, protocol=None, *, fix_imports=True, buffer_callback=None): # pragma: no cover - _Pickler(file, protocol, fix_imports=fix_imports, buffer_callback=buffer_callback).dump(obj) - - -def _dumps(obj, protocol=None, *, fix_imports=True, buffer_callback=None): # pragma: no cover - f = io.BytesIO() - _Pickler(f, protocol, fix_imports=fix_imports, buffer_callback=buffer_callback).dump(obj) - res = f.getvalue() - assert isinstance(res, bytes_types) - return res - - -def _load(file, *, fix_imports=True, encoding="ASCII", errors="strict", buffers=None): # pragma: no cover - return _Unpickler(file, fix_imports=fix_imports, buffers=buffers, encoding=encoding, errors=errors).load() - - -def _loads(s, *, fix_imports=True, encoding="ASCII", errors="strict", buffers=None): # pragma: no cover - if isinstance(s, str): - raise TypeError("Can't load pickle from unicode string") - file = io.BytesIO(s) - return _Unpickler(file, fix_imports=fix_imports, buffers=buffers, encoding=encoding, errors=errors).load() - - -# Use the faster _pickle if possible -Pickler, Unpickler = _Pickler, _Unpickler -dump, dumps, load, loads = _dump, _dumps, _load, _loads - - -# Doctest -def _test(): # pragma: no cover - import doctest - - return doctest.testmod() diff --git a/auto_round/low_cpu_mem/utils.py b/auto_round/low_cpu_mem/utils.py deleted file mode 100644 index b54db7f95..000000000 --- a/auto_round/low_cpu_mem/utils.py +++ /dev/null @@ -1,474 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utils for layer wise quantization.""" - -import gc -import json -import logging -import os -import pickle -from collections import OrderedDict -from functools import partial - -import torch -from accelerate import init_empty_weights -from accelerate.utils import set_module_tensor_to_device -from transformers import AutoConfig, AutoModelForCausalLM -from transformers.models.auto.auto_factory import _BaseAutoModelClass - -from auto_round.utils import detect_device - -from .load import load - -logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(filename)s L%(lineno)d: %(message)s") -logger = logging.getLogger("low_cpu_mem_tools") - -LWQ_WORKSPACE = os.path.join("low_cpu_mem_tmp") - - -def get_module(model, key): - """Get module from model by key name. - - Args: - model (torch.nn.Module): original model - key (str): module name to be replaced - """ - attrs = key.split(".") - module = model - for attr in attrs: - try: - attr = int(attr) - module = module[attr] - except: - module = getattr(module, attr) - return module - - -def get_children(model): - """Get all the children of given model.""" - module_list = [] - children = list(model.children()) - if len(children) == 0: - return [model] - for child in children: - module_list += get_children(child) - return module_list - - -def get_named_children(model, pre=[]): - """Get all the name and children of given model.""" - module_list = [] - if len(list(model.children())) == 0: - return [(".".join(pre), model)] - for name, module in model.named_children(): - module_list += get_named_children(module, pre=pre + [name]) - return module_list - - -def download_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None): - """Download hugging face model from hf hub.""" - from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE - from huggingface_hub.file_download import REGEX_COMMIT_HASH, repo_folder_name - from huggingface_hub.utils import EntryNotFoundError - - if cache_dir is None: - cache_dir = HUGGINGFACE_HUB_CACHE - if revision is None: - revision = DEFAULT_REVISION - if repo_type is None: - repo_type = "model" - storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) - commit_hash = None - if REGEX_COMMIT_HASH.match(revision): - commit_hash = revision - else: - ref_path = os.path.join(storage_folder, "refs", revision) - if os.path.exists(ref_path): - with open(ref_path) as f: - commit_hash = f.read() - if storage_folder and commit_hash: - pointer_path = os.path.join(storage_folder, "snapshots", commit_hash) - if os.path.isdir(pointer_path): - return pointer_path - else: # pragma: no cover - from huggingface_hub import snapshot_download - - file_path = snapshot_download(repo_id) - return file_path - - -def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, saved_path=None, **kwargs): - """Load a empty model.""" - is_local = os.path.isdir(pretrained_model_name_or_path) - if is_local: # pragma: no cover - path = pretrained_model_name_or_path - else: - path = download_hf_model(pretrained_model_name_or_path) - torch_dtype = kwargs.pop("torch_dtype", None) - if cls.__base__ == _BaseAutoModelClass: - config = AutoConfig.from_pretrained(path, **kwargs) - if str(torch_dtype) == "auto": - if hasattr(config, "torch_dtype") and config.torch_dtype is not None: - torch_dtype = config.torch_dtype - else: - torch_dtype = torch.float32 - with init_empty_weights(): - model = cls.from_config(config, torch_dtype=torch_dtype) - else: # pragma: no cover - config = cls.config_class.from_pretrained(path, **kwargs) - if hasattr(config, "torch_dtype") and config.torch_dtype is not None: - torch_dtype = config.torch_dtype - else: - torch_dtype = torch.float32 - with init_empty_weights(): - model = cls(config, torch_dtype=torch_dtype) - model.tie_weights() - model.eval() - model.path = path - - if saved_path is None: - logger.warning(f"saved_path is not set, use default working space: {LWQ_WORKSPACE}") - saved_path = LWQ_WORKSPACE - convert_model(model, saved_path) - return model - - -def get_super_module_by_name(model, module_name): - """Get the father module with given name of child module.""" - name_list = module_name.split(".") - for name in name_list[:-1]: - if hasattr(model, name): - model = getattr(model, name) - else: # pragma: no cover - return None - if hasattr(model, name_list[-1]): - return model - else: # pragma: no cover - return None - - -def update_module(model, module_name, new_module): - """Update module.""" - super_module = get_super_module_by_name(model, module_name) - if super_module: - setattr(super_module, module_name.split(".")[-1], new_module) - - -def get_layers_before_block(model): - """get the embed layers before blocks. - validate on: Llama, opt, bloom, gpt-j, Qwen, Baichuan, Mistral, Mixtral - not work for: phi - """ - return_layers = [] - block_name = None - - def _forward(module, name, *args, **kwargs): - if name == block_name: - raise NotImplementedError - if len(module._modules) == 0: - return_layers.append((name, module)) - return module.ori_forward(*args, **kwargs) - - for n, m in model.named_modules(): - if isinstance(m, torch.nn.ModuleList): - block_name = n + "." + m.named_children().__next__()[0] - m.ori_forward = m.forward - m.forward = partial(_forward, m, n) - - try: - if model.device.type == "meta": - target_device = "cpu" - else: - target_device = model.device - input = { - "input_ids": torch.zeros((1, 1), device=target_device, dtype=torch.int), - "attention_mask": torch.zeros((1, 1), device=target_device, dtype=torch.int), - } - model.forward(**input) - except NotImplementedError: - pass - - for n, m in model.named_modules(): - m.forward = m.ori_forward - del m.ori_forward - - return return_layers - - -def load_layer_wise_quantized_model(path): # pragma: no cover - """Load layer wise quantized model.""" - model = torch.load(os.path.join(path, "model_arch.pt")) - for name, _ in model.named_modules(): - if name + ".pt" in os.listdir(path): - update_module(model, name, torch.load(os.path.join(path, name + ".pt"))) - model.eval() - return model - - -def load_tensor_from_shard(pretrained_model_name_or_path, tensor_name, prefix=None): # pragma: no cover - """Load tensor from shard.""" - path = _get_path(pretrained_model_name_or_path) - idx_dict = json.load(open(os.path.join(path, "pytorch_model.bin.index.json"), "r"))["weight_map"] - if tensor_name not in idx_dict.keys(): - if tensor_name.replace(f"{prefix}.", "") in idx_dict.keys(): - tensor_name = tensor_name.replace(f"{prefix}.", "") - else: - assert False, "{} not in the index.json".format(tensor_name) - return load_tensor(os.path.join(path, idx_dict[tensor_name]), tensor_name, None) - - -def load_tensor(path, tensor_name=None, prefix=None): - """Load a tensor from bin file with given tensor name.""" - # transformers.modeling_utils - if tensor_name: - if "gamma" in tensor_name: # pragma: no cover - tensor_name = tensor_name.replace("gamma", "weight") - if "beta" in tensor_name: # pragma: no cover - tensor_name = tensor_name.replace("beta", "bias") - - if os.path.isdir(path): - path = os.path.join(path, "pytorch_model.bin") - state_dict = load(path, tensor_name, prefix) - if tensor_name: - if tensor_name in state_dict: - return state_dict[tensor_name] - else: # pragma: no cover - return state_dict[tensor_name.replace(f"{prefix}.", "")] - else: # pragma: no cover - return state_dict - - -def _get_path(pretrained_model_name_or_path): - if pretrained_model_name_or_path is None: - return None - is_local = os.path.isdir(pretrained_model_name_or_path) - if is_local: # pragma: no cover - path = pretrained_model_name_or_path - else: - path = download_hf_model(pretrained_model_name_or_path) - return path - - -def load_value(model, param_name, path): - logger.debug(f"load value for layer: {param_name}") - if "lm_head" in param_name and getattr(model.config, "tie_word_embeddings", True): - input_embeddings = model.get_input_embeddings() - modules = get_named_children(model) - for name, module in modules: - if module == input_embeddings: - param_name = name + "." + param_name.split(".")[-1] - prefix = model.base_model_prefix - if "pytorch_model.bin.index.json" in os.listdir(path): - value = load_tensor_from_shard(path, param_name, prefix) - else: - value = load_tensor(os.path.join(path, "pytorch_model.bin"), param_name, prefix) - return value - - -def load_module(model, module_name, path, device="cpu"): - module = get_module(model, module_name) - for n, p in module.named_parameters(): - param_name = module_name + "." + n - value = load_value(model, param_name, path) - set_module_tensor_to_device(model, param_name, device, value) - - -def register_weight_hooks(model, path, device="cpu", clean_weight=True, saved_path=None): - if saved_path: - os.makedirs(saved_path, exist_ok=True) - - def forward_pre_hook(name): - def hook(module, input): - logger.debug(f"{name} forward hood load value") - state_dict = None - if os.path.exists(os.path.join(saved_path, f"{name}.pt")): - state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) - for n, p in module.named_parameters(): - param_name = name + "." + n - if state_dict: - value = state_dict[n] - else: - value = load_value(model, param_name, path) - set_module_tensor_to_device(model, param_name, device, value) - module = module.to(device) - - return hook - - def forward_hook(name): - def hook(module, input, output): - logger.debug(f"{name} forward hood clean value") - if saved_path: - file_path = os.path.join(saved_path, f"{name}.pt") - torch.save(module.state_dict(), file_path) - clean_module_weight(module) - - return hook - - handle = {} - modules = get_named_children(model) - for name, module in modules: - handle[name] = [module.register_forward_pre_hook(forward_pre_hook(name))] - if clean_weight: - handle[name] += [module.register_forward_hook(forward_hook(name))] - return handle - - -def clean_module_weight(submodule): # pragma: no cover - for n, m in submodule.named_parameters(): - is_buffer = n in submodule._buffers - old_value = getattr(submodule, n) - with torch.no_grad(): - if is_buffer: - submodule._buffers[n] = torch.zeros(old_value.shape, device="meta") - else: - param_cls = type(submodule._parameters[n]) - kwargs = submodule._parameters[n].__dict__ - new_value = torch.zeros(old_value.shape, device="meta") - new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to("meta") - submodule._parameters[n] = new_value - gc.collect() - - -def convert_model(empty_model, saved_path=LWQ_WORKSPACE): - os.makedirs(saved_path, exist_ok=True) - - def _get_value(name, n): - state_dict = None - if os.path.exists(os.path.join(saved_path, f"{name}.pt")): - state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) - param_name = name + "." + n - if state_dict: - value = state_dict[n] - else: - value = load_value(empty_model, param_name, empty_model.path) - return value - - def _update(name, module): - state_dict = None - if os.path.exists(os.path.join(saved_path, f"{name}.pt")): - state_dict = torch.load(os.path.join(saved_path, f"{name}.pt")) - for n, p in module.named_parameters(): - if str(p.device) != "meta": - continue - param_name = name + "." + n - if state_dict: - value = state_dict[n] - else: - value = load_value(empty_model, param_name, saved_path) - set_module_tensor_to_device(empty_model, param_name, "cpu", value) - file_path = os.path.join(saved_path, f"{name}.pt") - torch.save(module.state_dict(), file_path) - - # save quant_info - quant_info = {} - if hasattr(module, "scale"): - quant_info["scale"] = module.scale - if hasattr(module, "zp"): - quant_info["zp"] = module.zp - logger.debug(f"save quant info for layer: {name}") - f = open(os.path.join(saved_path, f"{name}_quant_info.pkl"), "wb") - pickle.dump(quant_info, f) - f.close() - - def _layer_wise_to(module, name, device_or_dtype): - if isinstance(device_or_dtype, torch.dtype): - return module.ori_to(device_or_dtype) - elif len(module._modules) == 0: - # skip method type - if len(module._parameters) == 0 or module.weight.device.type != "meta": - return module.ori_to(device_or_dtype) - else: - for n, _ in module.named_parameters(): - param_name = name + "." + n - value = load_value(empty_model, param_name, empty_model.path) - dtype = None - if hasattr(module, "dtype"): - dtype = module.dtype - set_module_tensor_to_device(module, n, device_or_dtype, value, dtype=dtype) - - if hasattr(module, "scale"): - f = open(os.path.join(saved_path, f"{name}_quant_info.pkl"), "rb") - quant_info = pickle.load(f) - f.close() - module.scale = quant_info["scale"].to(device_or_dtype) - if "zp" in quant_info: - module.zp = quant_info["zp"].to(device_or_dtype) - return module.ori_to(device_or_dtype) - else: - for n, m in module.named_children(): - m.to(device_or_dtype) - return module - - modules = get_named_children(empty_model) - for name, module in modules: - if hasattr(module, "weight"): - # delattr(module, 'weight') - # module.weight = partial(_get_value, name, 'weight')() - module.get_weight = partial(_get_value, name, "weight") - if hasattr(module, "bias") and module.bias is not None: - module.get_bias = partial(_get_value, name, "bias") - module.update = partial(_update, name, module) - - def _replace_to(module, name): - if len(module._modules) > 0: - for n, m in module.named_children(): - if len(name) > 0: - n = name + "." + n - _replace_to(m, n) - module.ori_to = module.to - module.to = partial(_layer_wise_to, module, name) - - _replace_to(empty_model, "") - - -def load_model_with_hooks( - pretrained_model_name_or_path, cls=AutoModelForCausalLM, device=None, clean_weight=True, saved_path=None, **kwargs -): - if saved_path is None: - logger.warning(f"saved_path is not set, use default working space: {LWQ_WORKSPACE}") - saved_path = LWQ_WORKSPACE - device = detect_device(device) - empty_model = load_empty_model(pretrained_model_name_or_path, cls=cls, saved_path=saved_path, **kwargs) - register_weight_hooks(empty_model, empty_model.path, device, clean_weight, saved_path) - return empty_model - - -def layer_wise_save(model, path): - os.makedirs(path, exist_ok=True) - file_path = os.path.join(path, "layer_wise_model.bin") - modules = get_named_children(model) - with open(file_path, "wb") as f: - for name, module in modules: - output = OrderedDict() - if hasattr(module, "get_weight"): - output[f"{name}.weight"] = module.get_weight() - if hasattr(module, "get_bias"): - output[f"{name}.bias"] = module.get_bias() - output = pickle.dumps(output) - f.write(output + b"split_tag") - - -def layer_wise_load(path): - file_path = os.path.join(path, "layer_wise_model.bin") - state_dict = OrderedDict() - with open(file_path, "rb") as f: - data = f.read().split(b"split_tag") - for d in data: - if len(d) > 0: - d = pickle.loads(d) - state_dict.update(d) - return state_dict diff --git a/auto_round/utils.py b/auto_round/utils.py index 84ecadd76..29d1474ee 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -277,27 +277,23 @@ def to_device(input, device=torch.device("cpu")): return input -def mv_module_from_gpu(module, low_cpu_mem_usage=False): - """Moves module from gpu to cpu or meta if low_cpu_mem_usage is true. +def mv_module_from_gpu(module): + """Moves module from gpu to cpu. Args: module: The module to be moved. - low_cpu_mem_usage: Whether to use low CPU memory. If true, move module to meta. Returns: The module on the specified device. """ if hasattr(module, "device"): - target_device = "meta" if low_cpu_mem_usage else "cpu" + target_device = "cpu" if module.device.type == target_device: return module else: return module.to(target_device) else: - if low_cpu_mem_usage: - return module.to("meta") - else: - return module.to("cpu") + return module.to("cpu") def to_dtype(input, dtype=torch.float32): @@ -1420,8 +1416,6 @@ def llm_load_model( trust_remote_code=True, model_dtype=None, device="cpu", - low_cpu_mem_mode=0, - low_cpu_mem_tmp_dir=None, **kwargs, ): from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer @@ -1432,7 +1426,6 @@ def llm_load_model( torch_dtype = torch.bfloat16 is_glm = bool(re.search("chatglm", pretrained_model_name_or_path.lower())) - low_cpu_mem_usage = False tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) @@ -1440,79 +1433,50 @@ def llm_load_model( if "deepseek" in pretrained_model_name_or_path.lower() and trust_remote_code: logger.warning("trust_remote_code is enabled by default, please ensure its correctness.") - if low_cpu_mem_tmp_dir is None: - low_cpu_mem_tmp_dir = "low_cpu_mem_tmp" - if low_cpu_mem_mode == 2: - from auto_round.low_cpu_mem.utils import load_model_with_hooks - - model = load_model_with_hooks( - pretrained_model_name_or_path, - model_cls, - device=device, - clean_weight=True, - saved_path=low_cpu_mem_tmp_dir, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - ) - elif low_cpu_mem_mode == 1: - from auto_round.low_cpu_mem.utils import load_empty_model - - low_cpu_mem_usage = True - model = load_empty_model( + if _use_hpu_compile_mode(): + model = model_cls.from_pretrained( pretrained_model_name_or_path, - model_cls, - device=device, - saved_path=low_cpu_mem_tmp_dir, torch_dtype=torch_dtype, + attn_implementation="eager", trust_remote_code=trust_remote_code, + device_map="auto" if use_auto_mapping else None, ) else: - if _use_hpu_compile_mode(): + try: model = model_cls.from_pretrained( pretrained_model_name_or_path, torch_dtype=torch_dtype, - attn_implementation="eager", trust_remote_code=trust_remote_code, device_map="auto" if use_auto_mapping else None, ) - else: - try: + except ValueError as e: + if "FP8 quantized" in str(e): + orig_func = set_fake_cuda_device_capability() model = model_cls.from_pretrained( pretrained_model_name_or_path, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code, device_map="auto" if use_auto_mapping else None, ) - except ValueError as e: - if "FP8 quantized" in str(e): - orig_func = set_fake_cuda_device_capability() - model = model_cls.from_pretrained( - pretrained_model_name_or_path, - torch_dtype=torch_dtype, - trust_remote_code=trust_remote_code, - device_map="auto" if use_auto_mapping else None, - ) - torch.cuda.get_device_capability = orig_func - logger.warning("the support for fp8 model as input is experimental, please use with caution.") - else: - raise + torch.cuda.get_device_capability = orig_func + logger.warning("the support for fp8 model as input is experimental, please use with caution.") + else: + raise - except OSError as e: - logger.warning( - f"fail to load {pretrained_model_name_or_path}, set trust_remote_code to False and retry." - ) - model = model_cls.from_pretrained( - pretrained_model_name_or_path, - torch_dtype=torch_dtype, - trust_remote_code=False, - device_map="auto" if use_auto_mapping else None, - ) + except OSError as e: + logger.warning(f"fail to load {pretrained_model_name_or_path}, set trust_remote_code to False and retry.") + model = model_cls.from_pretrained( + pretrained_model_name_or_path, + torch_dtype=torch_dtype, + trust_remote_code=False, + device_map="auto" if use_auto_mapping else None, + ) model = model.eval() check_and_mark_fp8_model(model) model = _to_model_dtype(model, model_dtype) - return model, tokenizer, low_cpu_mem_usage + return model, tokenizer def mllm_load_model( diff --git a/docs/step_by_step.md b/docs/step_by_step.md index 9a518e31f..c63df50d9 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -543,15 +543,6 @@ autoround.save_quantized(format="auto_awq", output_dir="tmp_autoround") - Trigger immediate packing: Packing will be triggered immediately when using the command-line interface or the quantize_and_save API, as long as only one export format is specified. - - (only available for .bin file currently) set "--low_cpu_mem_mode 1" to use block-wise mode, load the weights from - disk of each block when tuning and - release the memory of the block after tuning. (more tuning cost) - - - (only available for .bin file currently) set "--low_cpu_mem_mode 2" to use layer-wise mode, load the weights of - each layer from disk when tuning, minimum - memory consumption and also the slowest running speed. - - - **Speedup the tuning:** - set `enable_torch_compile` to True diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index cbd0583df..7a54dc34a 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -13,7 +13,7 @@ from auto_round import AutoRound from auto_round.eval.evaluation import simple_evaluate_user_model -from auto_round.low_cpu_mem import get_module +from auto_round.utils import get_module class LLMDataLoader: diff --git a/test/test_cpu/test_low_cpu_mem.py b/test/test_cpu/test_low_cpu_mem.py deleted file mode 100644 index 582b5e47b..000000000 --- a/test/test_cpu/test_low_cpu_mem.py +++ /dev/null @@ -1,122 +0,0 @@ -import os -import shutil -import sys -import unittest - -sys.path.insert(0, "../..") - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from auto_round import AutoRound -from auto_round.low_cpu_mem.utils import ( - get_layers_before_block, - layer_wise_load, - layer_wise_save, - load_empty_model, - load_model_with_hooks, -) - - -class LLMDataLoader: - def __init__(self): - self.batch_size = 1 - - def __iter__(self): - for i in range(2): - yield torch.ones([1, 10], dtype=torch.long) - - -class TestLowCPUMem(unittest.TestCase): - @classmethod - def setUpClass(self): - self.model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - self.saved_path = "./test_tmp_saved" - self.ori_model = AutoModelForCausalLM.from_pretrained(self.model_name, trust_remote_code=True) - self.model = load_model_with_hooks( - self.model_name, AutoModelForCausalLM, saved_path=self.saved_path, device="cpu" - ) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) - self.llm_dataloader = LLMDataLoader() - - @classmethod - def tearDownClass(self): - shutil.rmtree(self.saved_path, ignore_errors=True) - - def test_default(self): - self.assertTrue(self.model.device.type, "meta") - - # TODO: change this func - # layers = get_layers_before_block(self.model) - # self.assertEqual(layers[0][0], "model.decoder.embed_tokens") - - # test get_weight bias - self.assertTrue( - torch.equal( - self.model.model.decoder.layers[0].self_attn.k_proj.get_weight(), - self.ori_model.model.decoder.layers[0].self_attn.k_proj.weight, - ) - ) - self.assertTrue( - torch.equal( - self.model.model.decoder.layers[0].self_attn.k_proj.get_bias(), - self.ori_model.model.decoder.layers[0].self_attn.k_proj.bias, - ) - ) - - # test hooks - text = ["Hello, my dog is cute"] - input = self.tokenizer(text) - for key in input: - input[key] = torch.tensor(input[key]) - ori_output = self.ori_model.generate(**input, max_new_tokens=5, do_sample=False) - ori_result = self.tokenizer.decode(ori_output[0]) - print(ori_result) - self.model.to("cpu") - output = self.model.generate(**input, max_new_tokens=5, do_sample=False) - result = self.tokenizer.decode(output[0]) - print(result) - self.assertEqual(ori_result, result) - self.model.to("meta") - - # test save and load - layer_wise_save(self.model, self.saved_path) - state_dict = layer_wise_load(self.saved_path) - self.assertTrue(torch.equal(state_dict["lm_head.weight"], self.ori_model.lm_head.weight)) - - # test layer-wise auto_round - bits, group_size, sym = 4, 128, False - autoround = AutoRound( - self.model, - self.tokenizer, - device="cpu", - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - enable_torch_compile=False, - ) - autoround.quantize() - - # test block-wise auto_round - self.model = load_empty_model(self.model_name, AutoModelForCausalLM, saved_path=self.saved_path, device="cpu") - bits, group_size, sym = 4, 128, False - autoround = AutoRound( - self.model, - self.tokenizer, - device="cpu", - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - low_cpu_mem_usage=True, - ) - autoround.quantize() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_cuda/test_vlms.py b/test/test_cuda/test_vlms.py index b3bf0f0fd..eee7c055a 100644 --- a/test/test_cuda/test_vlms.py +++ b/test/test_cuda/test_vlms.py @@ -154,7 +154,7 @@ def test_mllm_detect(self): for model_name in ["/models/glm-4-9b-chat", "/models/Qwen2.5-1.5B-Instruct/"]: self.assertFalse(is_mllm_model(model_name)) - model, _, _ = llm_load_model(model_name) + model, _ = llm_load_model(model_name) self.assertFalse(is_mllm_model(model))