diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 48cf0cd1a..8ceaefc00 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1422,7 +1422,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None: m.zp = None else: try: - m = m.to(self.device) + m = m.to(m.tuning_device if hasattr(m, "tuning_device") else self.device) m = WrapperLinear( m, enable_minmax_tuning=False, diff --git a/test/test_cuda/test_multiple_card.py b/test/test_cuda/test_multiple_card.py index ad33f071b..f2f1685be 100644 --- a/test/test_cuda/test_multiple_card.py +++ b/test/test_cuda/test_multiple_card.py @@ -242,6 +242,19 @@ def test_device_map_dict(self): ) autoround.quantize() + # test rtn + autoround = AutoRound( + model_name, + tokenizer, + bits=bits, + group_size=group_size, + sym=sym, + iters=0, + seqlen=2, + device_map=device_map, + ) + autoround.quantize() + @multi_card @require_greater_than_050 def test_device_map_for_triton(self):