From 2578d8235238730a44f465ff495883d806db2deb Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 25 Oct 2022 13:41:53 +0200 Subject: [PATCH 1/6] [WIP] Debugging mps DDIM tests --- tests/pipelines/ddim/test_ddim.py | 32 ++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 4445fe7feecf..06b2ee2977a4 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -52,7 +52,37 @@ def test_inference(self): # Warmup pass when using mps (see #372) if torch_device == "mps": - _ = ddpm(num_inference_steps=1) + _ = ddpm(num_inference_steps=5) + _ = ddpm(num_inference_steps=5) + + generator = torch.manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = torch.manual_seed(0) + image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array( + [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] + ) + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance + + def test_inference_2(self): + unet = self.dummy_uncond_unet + scheduler = DDIMScheduler() + + ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=5) generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images From 09090bcd1f57369b0bf952776444d157a3feee5c Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 25 Oct 2022 13:45:42 +0200 Subject: [PATCH 2/6] revert num_steps --- tests/pipelines/ddim/test_ddim.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 06b2ee2977a4..0ac4c299244e 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -52,8 +52,7 @@ def test_inference(self): # Warmup pass when using mps (see #372) if torch_device == "mps": - _ = ddpm(num_inference_steps=5) - _ = ddpm(num_inference_steps=5) + _ = ddpm(num_inference_steps=1) generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images @@ -82,7 +81,7 @@ def test_inference_2(self): # Warmup pass when using mps (see #372) if torch_device == "mps": - _ = ddpm(num_inference_steps=5) + _ = ddpm(num_inference_steps=1) generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images From b258635a40e7c92e09bd5fa585d4fa1831b43694 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 25 Oct 2022 14:20:14 +0200 Subject: [PATCH 3/6] check warmup with a generator --- tests/pipelines/ddim/test_ddim.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 0ac4c299244e..82ea76378333 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -52,7 +52,8 @@ def test_inference(self): # Warmup pass when using mps (see #372) if torch_device == "mps": - _ = ddpm(num_inference_steps=1) + generator = torch.manual_seed(42) + _ = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images From f5e3f14f95777cdc0bce087a90da84c934508838 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 25 Oct 2022 14:25:17 +0200 Subject: [PATCH 4/6] more warmup! --- tests/pipelines/ddim/test_ddim.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 82ea76378333..8d2447c015ac 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -43,6 +43,19 @@ def dummy_uncond_unet(self): return model def test_inference(self): + # Warmup pass when using mps (see #372) + if torch_device == "mps": + unet = self.dummy_uncond_unet + scheduler = DDIMScheduler() + + ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + generator = torch.manual_seed(0) + _ = ddpm(generator=generator, num_inference_steps=2) + del unet, scheduler, ddpm + unet = self.dummy_uncond_unet scheduler = DDIMScheduler() @@ -52,8 +65,7 @@ def test_inference(self): # Warmup pass when using mps (see #372) if torch_device == "mps": - generator = torch.manual_seed(42) - _ = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + _ = ddpm(num_inference_steps=1) generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images From 34998e712718184f7231f08ce6fa13837b292e44 Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 25 Oct 2022 15:08:59 +0200 Subject: [PATCH 5/6] remove xdist --- .github/workflows/pr_tests.yml | 3 ++- tests/pipelines/ddim/test_ddim.py | 42 ------------------------------- 2 files changed, 2 insertions(+), 43 deletions(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 493ff51484b0..ca32de4cdc4d 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -79,6 +79,7 @@ jobs: run: | ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] + ${CONDA_RUN} python -m pip uninstall -y pytest-xdist # reproducibility issues with mps + multiprocessing ${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu - name: Environment @@ -89,7 +90,7 @@ jobs: - name: Run all fast tests on MPS shell: arch -arch arm64 bash {0} run: | - ${CONDA_RUN} python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=tests_torch_mps tests/ + ${CONDA_RUN} python -m pytest -s -v --make-reports=tests_torch_mps tests/ - name: Failure short reports if: ${{ failure() }} diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 8d2447c015ac..4445fe7feecf 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -43,48 +43,6 @@ def dummy_uncond_unet(self): return model def test_inference(self): - # Warmup pass when using mps (see #372) - if torch_device == "mps": - unet = self.dummy_uncond_unet - scheduler = DDIMScheduler() - - ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - - generator = torch.manual_seed(0) - _ = ddpm(generator=generator, num_inference_steps=2) - del unet, scheduler, ddpm - - unet = self.dummy_uncond_unet - scheduler = DDIMScheduler() - - ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - - # Warmup pass when using mps (see #372) - if torch_device == "mps": - _ = ddpm(num_inference_steps=1) - - generator = torch.manual_seed(0) - image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images - - generator = torch.manual_seed(0) - image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] - - image_slice = image[0, -3:, -3:, -1] - image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - - assert image.shape == (1, 32, 32, 3) - expected_slice = np.array( - [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] - ) - tolerance = 1e-2 if torch_device != "mps" else 3e-2 - assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance - - def test_inference_2(self): unet = self.dummy_uncond_unet scheduler = DDIMScheduler() From 6150de253312eb1c49cf49a7e7810c3e1d0bdcfb Mon Sep 17 00:00:00 2001 From: anton-l Date: Tue, 25 Oct 2022 15:13:17 +0200 Subject: [PATCH 6/6] just use a single process --- .github/workflows/pr_tests.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index ca32de4cdc4d..cf21edf99165 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -79,7 +79,6 @@ jobs: run: | ${CONDA_RUN} python -m pip install --upgrade pip ${CONDA_RUN} python -m pip install -e .[quality,test] - ${CONDA_RUN} python -m pip uninstall -y pytest-xdist # reproducibility issues with mps + multiprocessing ${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu - name: Environment @@ -90,7 +89,7 @@ jobs: - name: Run all fast tests on MPS shell: arch -arch arm64 bash {0} run: | - ${CONDA_RUN} python -m pytest -s -v --make-reports=tests_torch_mps tests/ + ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps tests/ - name: Failure short reports if: ${{ failure() }}