From cd221bceb3182a7483c427adb116d427a145c892 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 17 Oct 2022 20:22:35 +0200 Subject: [PATCH 1/8] Docs: refer to pre-RC version of PyTorch 1.13.0. --- docs/source/optimization/mps.mdx | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/source/optimization/mps.mdx b/docs/source/optimization/mps.mdx index ff9d614c870f..affb3bac6806 100644 --- a/docs/source/optimization/mps.mdx +++ b/docs/source/optimization/mps.mdx @@ -19,7 +19,11 @@ specific language governing permissions and limitations under the License. - Mac computer with Apple silicon (M1/M2) hardware. - macOS 12.3 or later. - arm64 version of Python. -- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later. +- PyTorch 1.13.0 test version (pre-RC). You can install it with `pip` using: + +``` +pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/test/cpu +``` ## Inference Pipeline From c3960edf5b8dfe8c6e3353b7d08dd5ce300af9f4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 17 Oct 2022 20:23:25 +0200 Subject: [PATCH 2/8] Remove temporary workaround for unavailable op. --- src/diffusers/models/resnet.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d4cb367ebc0b..fbd78b512a6b 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -492,10 +492,6 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): kernel_h, kernel_w = kernel.shape out = tensor.view(-1, in_h, 1, in_w, 1, minor) - - # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 - if tensor.device.type == "mps": - out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) From 8674cf76833fb81e25a5ead2c7b02bc953960ebf Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 17 Oct 2022 20:40:40 +0200 Subject: [PATCH 3/8] Update comment to make it less ambiguous. --- examples/community/clip_guided_stable_diffusion.py | 2 +- examples/community/interpolate_stable_diffusion.py | 2 +- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 974f4ab2e883..2c86e9130fdc 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -249,7 +249,7 @@ def __call__( latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": - # randn does not exist on mps + # randn does not work reproducibly on mps latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( self.device ) diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py index 97116bdc77b4..bbb1b0f9e633 100644 --- a/examples/community/interpolate_stable_diffusion.py +++ b/examples/community/interpolate_stable_diffusion.py @@ -324,7 +324,7 @@ def __call__( latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": - # randn does not exist on mps + # randn does not work reproducibly on mps latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( self.device ) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 8ae51999a7b3..30fe1d118b6d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -287,7 +287,7 @@ def __call__( latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": - # randn does not exist on mps + # randn does not work reproducibly on mps latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( self.device ) From 762650fac35d26b61b8dd6a5fda858e36b7a153a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 17 Oct 2022 20:41:03 +0200 Subject: [PATCH 4/8] Remove use of contiguous in mps. It appears to not longer be necessary. --- src/diffusers/models/attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 4906e10f27c4..33abe3aaa1e3 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -207,7 +207,6 @@ def _set_attention_slice(self, slice_size): self.attn2._slice_size = slice_size def forward(self, hidden_states, context=None): - hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states From ef41ddd5e6789a780e98e93a301563fb4de6b3a2 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 24 Oct 2022 12:37:48 +0200 Subject: [PATCH 5/8] Special case: use einsum for much better performance in mps --- src/diffusers/models/attention.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 33abe3aaa1e3..dce30c6a4aa6 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -287,10 +287,19 @@ def forward(self, hidden_states, context=None, mask=None): def _attention(self, query, key, value): # TODO: use baddbmm for better performance - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + if query.device.type == "mps": + # Better performance on mps (~20-25%) + attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale + else: + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_probs = attention_scores.softmax(dim=-1) # compute attention output - hidden_states = torch.matmul(attention_probs, value) + + if query.device.type == "mps": + hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value) + else: + hidden_states = torch.matmul(attention_probs, value) + # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states @@ -304,11 +313,21 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - attn_slice = ( - torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance + if query.device.type == "mps": + # Better performance on mps (~20-25%) + attn_slice = ( + torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) + * self.scale + ) + else: + attn_slice = ( + torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + ) # TODO: use baddbmm for better performance attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + if query.device.type == "mps": + attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) + else: + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice From 8d47740cf3ddfa75eaddff6572b526001afbae38 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 24 Oct 2022 12:40:46 +0200 Subject: [PATCH 6/8] Update mps docs. --- docs/source/optimization/mps.mdx | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/source/optimization/mps.mdx b/docs/source/optimization/mps.mdx index affb3bac6806..fd5156dee50e 100644 --- a/docs/source/optimization/mps.mdx +++ b/docs/source/optimization/mps.mdx @@ -17,7 +17,7 @@ specific language governing permissions and limitations under the License. ## Requirements - Mac computer with Apple silicon (M1/M2) hardware. -- macOS 12.3 or later. +- macOS 12.6 or later (13.0 or later recommended). - arm64 version of Python. - PyTorch 1.13.0 test version (pre-RC). You can install it with `pip` using: @@ -38,6 +38,9 @@ from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") pipe = pipe.to("mps") +# Recommended if your computer has < 64 GB of RAM +pipe.enable_attention_slicing() + prompt = "a photo of an astronaut riding a horse on mars" # First-time "warmup" pass (see explanation above) @@ -47,16 +50,13 @@ _ = pipe(prompt, num_inference_steps=1) image = pipe(prompt).images[0] ``` -## Known Issues +## Performance Recommendations -- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372). -- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this might be related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039#issuecomment-1237735249), but we need to investigate in more depth. For now, we recommend to iterate instead of batching. +M1/M2 performance is very sensitive to memory pressure. The system will automatically swap if it needs to, but performance will degrade significantly when it does. -## Performance +We recommend you use _attention slicing_ to reduce memory pressure during inference and prevent swapping, particularly if your computer has lass than 64 GB of system RAM, or if you generate images at non-standard resolutions larger than 512 × 512 pixels. Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually has a performance impact of ~20% in computers without universal memory, but we have observed _better performance_ in most Apple Silicon computers, unless you have 64 GB or more. -These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5. +## Known Issues -| Device | Steps | Time | -|--------|-------|---------| -| CPU | 50 | 213.46s | -| MPS | 50 | 30.81s | \ No newline at end of file +- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372). +- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). For now, we recommend to iterate instead of batching. From 421dcf45765c58ccaf4dcee265b4c8facaa73329 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 24 Oct 2022 13:05:53 +0200 Subject: [PATCH 7/8] Minor doc update. --- docs/source/optimization/mps.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/optimization/mps.mdx b/docs/source/optimization/mps.mdx index fd5156dee50e..267ec056a6af 100644 --- a/docs/source/optimization/mps.mdx +++ b/docs/source/optimization/mps.mdx @@ -19,7 +19,7 @@ specific language governing permissions and limitations under the License. - Mac computer with Apple silicon (M1/M2) hardware. - macOS 12.6 or later (13.0 or later recommended). - arm64 version of Python. -- PyTorch 1.13.0 test version (pre-RC). You can install it with `pip` using: +- PyTorch 1.13.0 RC (Release Candidate). You can install it with `pip` using: ``` pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/test/cpu From c60bc9a012627d911790b21a7474b90e3117c9bc Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 24 Oct 2022 18:38:34 +0200 Subject: [PATCH 8/8] Accept suggestion Co-authored-by: Anton Lozhkov --- docs/source/optimization/mps.mdx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/optimization/mps.mdx b/docs/source/optimization/mps.mdx index 267ec056a6af..3fe4428c87a8 100644 --- a/docs/source/optimization/mps.mdx +++ b/docs/source/optimization/mps.mdx @@ -56,6 +56,10 @@ M1/M2 performance is very sensitive to memory pressure. The system will automati We recommend you use _attention slicing_ to reduce memory pressure during inference and prevent swapping, particularly if your computer has lass than 64 GB of system RAM, or if you generate images at non-standard resolutions larger than 512 × 512 pixels. Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually has a performance impact of ~20% in computers without universal memory, but we have observed _better performance_ in most Apple Silicon computers, unless you have 64 GB or more. +```python +pipeline.enable_attention_slicing() +``` + ## Known Issues - As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).