Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions docs/source/optimization/mps.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ specific language governing permissions and limitations under the License.
## Requirements

- Mac computer with Apple silicon (M1/M2) hardware.
- macOS 12.3 or later.
- macOS 12.6 or later (13.0 or later recommended).
- arm64 version of Python.
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
- PyTorch 1.13.0 RC (Release Candidate). You can install it with `pip` using:

```
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/test/cpu
```

## Inference Pipeline

Expand All @@ -34,6 +38,9 @@ from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to("mps")

# Recommended if your computer has < 64 GB of RAM
pipe.enable_attention_slicing()

prompt = "a photo of an astronaut riding a horse on mars"

# First-time "warmup" pass (see explanation above)
Expand All @@ -43,16 +50,17 @@ _ = pipe(prompt, num_inference_steps=1)
image = pipe(prompt).images[0]
```

## Known Issues
## Performance Recommendations

- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this might be related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039#issuecomment-1237735249), but we need to investigate in more depth. For now, we recommend to iterate instead of batching.
M1/M2 performance is very sensitive to memory pressure. The system will automatically swap if it needs to, but performance will degrade significantly when it does.

## Performance
We recommend you use _attention slicing_ to reduce memory pressure during inference and prevent swapping, particularly if your computer has lass than 64 GB of system RAM, or if you generate images at non-standard resolutions larger than 512 × 512 pixels. Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually has a performance impact of ~20% in computers without universal memory, but we have observed _better performance_ in most Apple Silicon computers, unless you have 64 GB or more.

These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5.
```python
pipeline.enable_attention_slicing()
```

| Device | Steps | Time |
|--------|-------|---------|
| CPU | 50 | 213.46s |
| MPS | 50 | 30.81s |
## Known Issues

- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). For now, we recommend to iterate instead of batching.
2 changes: 1 addition & 1 deletion examples/community/clip_guided_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def __call__(
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
Expand Down
2 changes: 1 addition & 1 deletion examples/community/interpolate_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def __call__(
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
Expand Down
32 changes: 25 additions & 7 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def _set_attention_slice(self, slice_size):
self.attn2._slice_size = slice_size

def forward(self, hidden_states, context=None):
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
Expand Down Expand Up @@ -288,10 +287,19 @@ def forward(self, hidden_states, context=None, mask=None):

def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
else:
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)

if query.device.type == "mps":
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
else:
hidden_states = torch.matmul(attention_probs, value)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
Expand All @@ -305,11 +313,21 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
* self.scale
)
else:
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
if query.device.type == "mps":
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
else:
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])

hidden_states[start_idx:end_idx] = attn_slice

Expand Down
4 changes: 0 additions & 4 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,6 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
kernel_h, kernel_w = kernel.shape

out = tensor.view(-1, in_h, 1, in_w, 1, minor)

# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
if tensor.device.type == "mps":
out = out.to("cpu")
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def __call__(
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
Expand Down