From 4c888fea8d1cb3e770dea338b5f788a64ba74d12 Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Mon, 7 Apr 2025 00:05:34 -0700 Subject: [PATCH] enable case on XPU: 1. tests/quantization/bnb/test_mixed_int8.py::BnB8bitTrainingTests::test_training Signed-off-by: YAO Matrix --- tests/quantization/bnb/test_mixed_int8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 8809bac25f58..a5e38f931e09 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -379,7 +379,7 @@ def test_training(self): model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) # Step 4: Check if the gradient is not None - with torch.amp.autocast("cuda", dtype=torch.float16): + with torch.amp.autocast(torch_device, dtype=torch.float16): out = self.model_8bit(**model_inputs)[0] out.norm().backward()