diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 4cd9918d93d..4c2df596282 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -27,7 +27,13 @@ import torch.nn as nn from tqdm import tqdm -from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module +from neural_compressor.torch.utils import ( + get_accelerator, + get_model_device, + is_transformers_imported, + logger, + set_module, +) from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator from .modules import WeightOnlyLinear @@ -995,6 +1001,7 @@ def prepare( if use_layer_wise: # pragma: no cover assert model_path is not None, "model_path should not be None when use layer wise mode" + self.model_device = get_model_device(model) # return model on the same device self.gptq_quantizer = RAWGPTQuantizer( model, weight_config=self.quant_config, @@ -1013,6 +1020,7 @@ def convert(self, model, *args, **kwargs): self.gptq_quantizer.model = model self.gptq_quantizer.remove_prepare_for_calibration() q_model, gptq_config = self.gptq_quantizer.execute_quantization() + q_model = q_model.to(self.model_device) q_model.gptq_config = gptq_config logger.info("GPTQ quantizing done.") return q_model diff --git a/neural_compressor/torch/algorithms/weight_only/modules.py b/neural_compressor/torch/algorithms/weight_only/modules.py index 18cf6e46e55..dcb2ff421f4 100644 --- a/neural_compressor/torch/algorithms/weight_only/modules.py +++ b/neural_compressor/torch/algorithms/weight_only/modules.py @@ -270,8 +270,8 @@ def recover(self): def pack_tensor_with_torch(self, raw_tensor): target_len = math.ceil(raw_tensor.shape[1] / self.n_pack) - packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(self.device) - mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device) + packed_tensor = torch.zeros(raw_tensor.shape[0], target_len, dtype=self.compression_dtype).to(raw_tensor.device) + mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(raw_tensor.device) for j in range(packed_tensor.shape[1]): start = self.n_pack * j end = self.n_pack * (j + 1) @@ -286,8 +286,8 @@ def pack_tensor_with_torch(self, raw_tensor): def unpack_tensor_with_torch(self, packed_tensor): target_dtype = torch.int8 if not hasattr(self, "qzeros") or "int" not in self.dtype else torch.uint8 target_len = packed_tensor.shape[1] * self.n_pack - unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(self.device) - mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(self.device) + unpacked_tensor = torch.zeros(packed_tensor.shape[0], target_len, dtype=target_dtype).to(packed_tensor.device) + mask = torch.tensor(2**self.bits - 1, dtype=self.compression_dtype).to(packed_tensor.device) for j in range(packed_tensor.shape[1]): for e in range(self.n_pack): index = j * self.n_pack + e @@ -338,13 +338,13 @@ def unpack_tensor_with_numpy(self, packed_tensor): return unpacked_tensor def pack_tensor(self, raw_tensor): - if "cuda" in self.device: + if "cuda" in raw_tensor.device.type: return self.pack_tensor_with_torch(raw_tensor) else: return self.pack_tensor_with_numpy(raw_tensor) def unpack_tensor(self, packed_tensor): - if "cuda" in self.device: + if "cuda" in packed_tensor.device.type: return self.unpack_tensor_with_torch(packed_tensor) else: return self.unpack_tensor_with_numpy(packed_tensor) diff --git a/neural_compressor/torch/algorithms/weight_only/rtn.py b/neural_compressor/torch/algorithms/weight_only/rtn.py index ca6f3e499e5..6a95bec4550 100644 --- a/neural_compressor/torch/algorithms/weight_only/rtn.py +++ b/neural_compressor/torch/algorithms/weight_only/rtn.py @@ -28,6 +28,7 @@ from neural_compressor.torch.utils import ( get_accelerator, get_attr, + get_model_device, is_transformers_imported, logger, set_attr, @@ -99,10 +100,7 @@ def convert( """ weight_config = self.quant_config device = get_accelerator(kwargs.pop("device", "auto")).current_device_name() - - # Put model on device explicitly - # TODO: refine it later, Put module on device one by one instead of the whole model - model.to(device) + model_device = get_model_device(model) # return model on the same device # for transformers model. If lm_head is tied from embedding, we deepcopy it. if quant_lm_head and getattr(getattr(model, "config", None), "tie_word_embeddings", False): @@ -132,6 +130,8 @@ def convert( dtype = weight_config[name].get("dtype", "int") if dtype == "fp32": continue + # Move modules to the accelerator device layer-by-layer + m.to(device) ### FP8 cast part if dtype in ["fp8_e5m2", "fp8_e5m2fnuz", "fp8_e4m3fn", "fp8_e4m3fnuz"]: logger.debug("Cast module {} to FP8 using qdq mode, no scaling".format(name)) @@ -223,4 +223,8 @@ def convert( return new_module else: set_module(model, name, new_module) + # Move modules back to the model device layer-by-layer + m.to(model_device) + new_module.to(model_device) + model.to(model_device) return model diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index 5c764cff7d3..e312a9c388b 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -265,3 +265,16 @@ def dump_model_op_stats(mode, tune_cfg): output_data.append(field_results) Statistics(output_data, header="Mixed Precision Statistics", field_names=field_names).print_stat() + + +def get_model_device(model: torch.nn.Module): + """Get the device. + + Args: + model (torch.nn.Module): the input model. + + Returns: + device (str): a string. + """ + for n, p in model.named_parameters(): + return p.data.device.type # p.data.device == device(type='cpu') diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index 7fbe0ad7737..dfbc39c25e7 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -36,6 +36,20 @@ def setup_class(self): def teardown_class(self): shutil.rmtree("saved_results", ignore_errors=True) + @pytest.mark.skipif(device == "cpu", reason="no available accelerator") + def test_auto_host2device(self): + # if model is on CPU, we move it to device layer-by-layer for acceleration, + # and then move it back to CPU after quantization. + model = copy.deepcopy(self.tiny_gptj).to("cpu") + example_inputs = copy.deepcopy(self.example_inputs).to("cpu") + quant_config = get_default_gptq_config() + model = prepare(model, quant_config) + run_fn(model) + model = convert(model) + gptq_label = model(example_inputs)[0] + gptq_atol = (gptq_label - self.label.to("cpu")).amax() + assert gptq_atol < 0.06, "GPTQ should have low atol." + def test_accuracy_improvement(self): # test_default_rtn_config model = copy.deepcopy(self.tiny_gptj) @@ -215,9 +229,9 @@ def test_conv1d(self): from transformers import GPT2Model, GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained("sshleifer/tiny-gpt2") - model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2") + model = GPT2Model.from_pretrained("sshleifer/tiny-gpt2").to(device) text = "Replace me by any text you'd like." - encoded_input = tokenizer(text, return_tensors="pt") + encoded_input = tokenizer(text, return_tensors="pt").to(device) def run_fn_conv1d(model): model(**encoded_input) diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index 206bd20aa10..04f6c444485 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -352,3 +352,16 @@ def mock_is_transformers_imported(): model = convert(model) out = model(self.example_inputs)[0] assert torch.allclose(out, self.label, atol=1e-1), "Accuracy gap atol > 0.1 is unexpected." + + @pytest.mark.skipif(device == "cpu", reason="no available accelerator") + def test_auto_host2device(self): + # if model is on CPU, we move it to device layer-by-layer for acceleration, + # and then move it back to CPU after quantization. + model = copy.deepcopy(self.tiny_gptj).to("cpu") + example_inputs = copy.deepcopy(self.example_inputs).to("cpu") + quant_config = get_default_rtn_config() + model = prepare(model, quant_config) + model = convert(model) + rtn_label = model(example_inputs)[0] + rtn_atol = (rtn_label - self.label.to("cpu")).amax() + assert rtn_atol < 0.08, "RTN should have low atol."