diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 1475e2ff07d2..2b29e3ae9eeb 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -99,14 +99,13 @@ def test_vae_slicing(self): assert np.abs(output_2[0].flatten() - output_1[0].flatten()).max() < 1e-2 def test_vae_tiling(self): - device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() # make sure here that pndm scheduler skips prk if "safety_checker" in components: components["safety_checker"] = None pipe = self.pipeline_class(**components) - pipe = pipe.to(device) + pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(torch_device) @@ -126,7 +125,7 @@ def test_vae_tiling(self): # test that tiled decode works with various shapes shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)] for shape in shapes: - zeros = torch.zeros(shape).to(device) + zeros = torch.zeros(shape).to(torch_device) pipe.vae.decode(zeros) def test_freeu_enabled(self):