diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 8809bac25f58..a5e38f931e09 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -379,7 +379,7 @@ def test_training(self): model_inputs.update({k: v for k, v in input_dict_for_transformer.items() if k not in model_inputs}) # Step 4: Check if the gradient is not None - with torch.amp.autocast("cuda", dtype=torch.float16): + with torch.amp.autocast(torch_device, dtype=torch.float16): out = self.model_8bit(**model_inputs)[0] out.norm().backward()