From 137e1d51452f548317caa54e5ee64a6629edbf31 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sun, 4 Sep 2022 21:31:31 +0200 Subject: [PATCH 01/39] Initial support for mps in Stable Diffusion pipeline. Required when classifier-free guidance is enabled. --- src/diffusers/models/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 8d52ee9bde92..f87c4fa7bbb7 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -137,6 +137,7 @@ def __init__(self, dim, n_heads, d_head, dropout=0.0, context_dim=None, gated_ff self.checkpoint = checkpoint def forward(self, x, context=None): + x = x.contiguous() if x.device.type == "mps" else x x = self.attn1(self.norm1(x)) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x From 0ef1d1e41865799e1c37c1d21ecb01d9c6195913 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sun, 4 Sep 2022 21:35:39 +0200 Subject: [PATCH 02/39] Initial "warmup" implementation when using mps. For some reason the first run produces results different than the rest. --- src/diffusers/models/unet_2d.py | 10 +++- src/diffusers/models/unet_2d_condition.py | 12 +++- src/diffusers/models/vae.py | 9 ++- src/diffusers/mps_warmup_utils.py | 67 +++++++++++++++++++++++ src/diffusers/pipeline_utils.py | 5 +- 5 files changed, 97 insertions(+), 6 deletions(-) create mode 100644 src/diffusers/mps_warmup_utils.py diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index db4c33690c9d..25be6f757ebb 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -1,15 +1,16 @@ -from typing import Dict, Union +from typing import Dict, Union, Tuple import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config +from ..mps_warmup_utils import MPSWarmupMixin from ..modeling_utils import ModelMixin from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block -class UNet2DModel(ModelMixin, ConfigMixin): +class UNet2DModel(ModelMixin, ConfigMixin, MPSWarmupMixin): @register_to_config def __init__( self, @@ -185,3 +186,8 @@ def forward( output = {"sample": sample} return output + + def warmup_inputs(self) -> Tuple: + w_sample = torch.randn((1, self.in_channels, 32, 32)) + t = torch.tensor([10], dtype=torch.long) + return (w_sample, t) \ No newline at end of file diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 25c4e37d8a6d..4e64ad8640b0 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1,15 +1,16 @@ -from typing import Dict, Union +from typing import Dict, Union, Tuple import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config +from ..mps_warmup_utils import MPSWarmupMixin from ..modeling_utils import ModelMixin from .embeddings import TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block -class UNet2DConditionModel(ModelMixin, ConfigMixin): +class UNet2DConditionModel(ModelMixin, ConfigMixin, MPSWarmupMixin): @register_to_config def __init__( self, @@ -184,3 +185,10 @@ def forward( output = {"sample": sample} return output + + def warmup_inputs(self) -> Tuple: + batch_size = 1 + w_sample = torch.randn((batch_size, self.in_channels, 64, 64)) + t = torch.tensor([10], dtype=torch.long) + w_encoded = torch.rand((batch_size, 77, 768)) + return (w_sample, t, w_encoded) \ No newline at end of file diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 009db1561d9e..2bcd90ad2594 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -2,7 +2,10 @@ import torch import torch.nn as nn +from typing import Dict, Union, Tuple + from ..configuration_utils import ConfigMixin, register_to_config +from ..mps_warmup_utils import MPSWarmupMixin from ..modeling_utils import ModelMixin from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -389,7 +392,7 @@ def forward(self, sample): return dec -class AutoencoderKL(ModelMixin, ConfigMixin): +class AutoencoderKL(ModelMixin, ConfigMixin, MPSWarmupMixin): @register_to_config def __init__( self, @@ -449,3 +452,7 @@ def forward(self, sample, sample_posterior=False): z = posterior.mode() dec = self.decode(z) return dec + + def warmup_inputs(self) -> Tuple: + w_sample = torch.randn((4, self.in_channels, self.sample_size, self.sample_size)) + return (w_sample,) \ No newline at end of file diff --git a/src/diffusers/mps_warmup_utils.py b/src/diffusers/mps_warmup_utils.py new file mode 100644 index 000000000000..97a43985fa4a --- /dev/null +++ b/src/diffusers/mps_warmup_utils.py @@ -0,0 +1,67 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from typing import Tuple + +class MPSWarmupMixin: + r""" + Temporary class to perform a 1-time warmup operation on some models, when they are moved to the `mps` device. + + It has been observed that the output of some models (`unet`, `vae`) is different the first time they run than the + rest, for the same inputs. We are investigating the root cause of the problem, but meanwhile this class will be + used to warmup those modules so their outputs are consistent. + TODO: link to issue when we open it. + + Classes that require warmup need to adhere to [`MPSWarmupMixin`] and implement the following: + + - **warmup_inputs** -- A method that returns a suitable set of inputs to use during a forward pass. + + IMPORTANT: + + Warmup will be automatically performed when moving a pipeline to the `mps` device. If you move a single module, no + warmup will be applied. + """ + + def warmup_inputs(self) -> Tuple: + r""" + Return inputs suitable for the forward pass of this module. + These will usually be a tuple of tensors. They will be automatically moved to the `mps` device on warmup. + """ + raise NotImplementedError( + """ + You declared conformance to `MPSWarmupMixin` but did not provide an implementation for `warmup_inputs`. + + Please, write a suitable implementation for `warmup_inputs` or remove conformance to `MPSWarmupMixin` + if it's not needed. + """ + ) + + def warmup(self): + r""" + Applies the warmup using `warmup_inputs`. + Assumes this class implements `__call__` and has a `device` property. + """ + if self.device.type != "mps": + return + + with torch.no_grad(): + w_inputs = self.warmup_inputs() + w_inputs = [w.to("mps") for w in w_inputs] + w_shapes = [w.shape for w in w_inputs] + print(f"Will perform warmup with shapes {w_shapes}") + self.__call__(*w_inputs) + print("Done") + \ No newline at end of file diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 214133bc5f17..36e502a112f7 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -26,6 +26,7 @@ from tqdm.auto import tqdm from .configuration_utils import ConfigMixin +from .mps_warmup_utils import MPSWarmupMixin from .utils import DIFFUSERS_CACHE, logging @@ -125,7 +126,9 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None): module = getattr(self, name) if isinstance(module, torch.nn.Module): module.to(torch_device) - return self + if isinstance(module, MPSWarmupMixin): + module.warmup() + return self @property def device(self) -> torch.device: From ae5ea469d55fc0984fd4789c0bafb14304ccf46a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sun, 4 Sep 2022 21:45:55 +0200 Subject: [PATCH 03/39] Make some deterministic tests pass with mps. --- src/diffusers/models/unet_2d.py | 5 +++-- src/diffusers/models/unet_2d_condition.py | 4 ++-- src/diffusers/models/vae.py | 5 +++-- src/diffusers/mps_warmup_utils.py | 6 +++--- src/diffusers/testing_utils.py | 1 + tests/test_modeling_common.py | 8 ++++++++ tests/test_models_vae.py | 3 +++ 7 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 25be6f757ebb..03372bd7be65 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -187,7 +187,8 @@ def forward( return output - def warmup_inputs(self) -> Tuple: - w_sample = torch.randn((1, self.in_channels, 32, 32)) + def warmup_inputs(self, batch_size) -> Tuple: + batch_size = 1 if batch_size is None else batch_size + w_sample = torch.randn((batch_size, self.in_channels, 32, 32)) t = torch.tensor([10], dtype=torch.long) return (w_sample, t) \ No newline at end of file diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 4e64ad8640b0..3c069be028dc 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -186,8 +186,8 @@ def forward( return output - def warmup_inputs(self) -> Tuple: - batch_size = 1 + def warmup_inputs(self, batch_size) -> Tuple: + batch_size = 1 if batch_size is None else batch_size w_sample = torch.randn((batch_size, self.in_channels, 64, 64)) t = torch.tensor([10], dtype=torch.long) w_encoded = torch.rand((batch_size, 77, 768)) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 2bcd90ad2594..85d24915ebce 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -453,6 +453,7 @@ def forward(self, sample, sample_posterior=False): dec = self.decode(z) return dec - def warmup_inputs(self) -> Tuple: - w_sample = torch.randn((4, self.in_channels, self.sample_size, self.sample_size)) + def warmup_inputs(self, batch_size) -> Tuple: + batch_size = 4 if batch_size is None else batch_size + w_sample = torch.randn((batch_size, self.in_channels, self.sample_size, self.sample_size)) return (w_sample,) \ No newline at end of file diff --git a/src/diffusers/mps_warmup_utils.py b/src/diffusers/mps_warmup_utils.py index 97a43985fa4a..15cffba9bc20 100644 --- a/src/diffusers/mps_warmup_utils.py +++ b/src/diffusers/mps_warmup_utils.py @@ -35,7 +35,7 @@ class MPSWarmupMixin: warmup will be applied. """ - def warmup_inputs(self) -> Tuple: + def warmup_inputs(self, batch_size=None) -> Tuple: r""" Return inputs suitable for the forward pass of this module. These will usually be a tuple of tensors. They will be automatically moved to the `mps` device on warmup. @@ -49,7 +49,7 @@ def warmup_inputs(self) -> Tuple: """ ) - def warmup(self): + def warmup(self, batch_size=None): r""" Applies the warmup using `warmup_inputs`. Assumes this class implements `__call__` and has a `device` property. @@ -58,7 +58,7 @@ def warmup(self): return with torch.no_grad(): - w_inputs = self.warmup_inputs() + w_inputs = self.warmup_inputs(batch_size) w_inputs = [w.to("mps") for w in w_inputs] w_shapes = [w.shape for w in w_inputs] print(f"Will perform warmup with shapes {w_shapes}") diff --git a/src/diffusers/testing_utils.py b/src/diffusers/testing_utils.py index 13f6332a9432..a1288b4edb3d 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/testing_utils.py @@ -8,6 +8,7 @@ global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" +torch_device = "mps" if torch.backends.mps.is_available() else torch_device def parse_flag_from_env(key, default=False): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e19238050559..eb64977f587d 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -19,6 +19,7 @@ import numpy as np import torch +from diffusers.mps_warmup_utils import MPSWarmupMixin from diffusers.testing_utils import torch_device from diffusers.training_utils import EMAModel @@ -29,12 +30,16 @@ def test_from_pretrained_save_pretrained(self): model = self.model_class(**init_dict) model.to(torch_device) + if isinstance(model, MPSWarmupMixin): + model.warmup() model.eval() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) + if isinstance(new_model, MPSWarmupMixin): + model.warmup(inputs_dict['sample'].shape[0]) with torch.no_grad(): image = model(**inputs_dict) @@ -53,6 +58,9 @@ def test_determinism(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) + if isinstance(model, MPSWarmupMixin): + model.warmup(inputs_dict['sample'].shape[0]) + model.eval() with torch.no_grad(): first = model(**inputs_dict) diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index 7df6b42bf796..33740020eeb8 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -18,6 +18,7 @@ import torch from diffusers import AutoencoderKL +from diffusers.mps_warmup_utils import MPSWarmupMixin from diffusers.testing_utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin @@ -78,6 +79,8 @@ def test_from_pretrained_hub(self): def test_output_pretrained(self): model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") model = model.to(torch_device) + if isinstance(model, MPSWarmupMixin): + model.warmup() model.eval() torch.manual_seed(0) From 4ed22c259ee2cf8eb578a07926f5808fb3ccbec8 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sun, 4 Sep 2022 21:47:11 +0200 Subject: [PATCH 04/39] Disable training tests when using mps. --- tests/test_modeling_common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index eb64977f587d..474bec8b62b8 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -140,6 +140,9 @@ def test_model_from_config(self): self.assertEqual(output_1.shape, output_2.shape) def test_training(self): + if torch_device == "mps": + return + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -155,6 +158,9 @@ def test_training(self): loss.backward() def test_ema_training(self): + if torch_device == "mps": + return + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) From bc93b51c7b1cb9554a48256376522f3ff8a40219 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sun, 4 Sep 2022 21:59:12 +0200 Subject: [PATCH 05/39] SD: generate latents in CPU then move to device. This is especially important when using the mps device, because generators are not supported there. See for example https://github.com/pytorch/pytorch/issues/84288. In addition, the other pipelines seem to use the same approach: generate the random samples then move to the appropriate device. After this change, generating an image in MPS produces the same result as when using the CPU, if the same seed is used. --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d4290da6f030..17a65aab1f07 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -104,12 +104,11 @@ def __call__( latents = torch.randn( latents_shape, generator=generator, - device=self.device, ) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) From 34c0effb184bf79702665772a131026e89d24727 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sun, 4 Sep 2022 22:16:52 +0200 Subject: [PATCH 06/39] Remove prints. --- src/diffusers/mps_warmup_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/mps_warmup_utils.py b/src/diffusers/mps_warmup_utils.py index 15cffba9bc20..4b16187abd38 100644 --- a/src/diffusers/mps_warmup_utils.py +++ b/src/diffusers/mps_warmup_utils.py @@ -60,8 +60,5 @@ def warmup(self, batch_size=None): with torch.no_grad(): w_inputs = self.warmup_inputs(batch_size) w_inputs = [w.to("mps") for w in w_inputs] - w_shapes = [w.shape for w in w_inputs] - print(f"Will perform warmup with shapes {w_shapes}") self.__call__(*w_inputs) - print("Done") \ No newline at end of file From 66b6752350c995fad61b6742b1b2b897eee8dc0d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sun, 4 Sep 2022 22:53:36 +0200 Subject: [PATCH 07/39] Pass AutoencoderKL test_output_pretrained with mps. Sampling from `posterior` must be done in CPU. --- src/diffusers/models/vae.py | 5 ++++- src/diffusers/mps_warmup_utils.py | 4 ++-- tests/test_models_vae.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index caa7b834c00b..f92c7db9c2fb 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -299,7 +299,10 @@ def __init__(self, parameters, deterministic=False): self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: - x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device) + device = self.parameters.device + sample_device = "cpu" if device.type == "mps" else device + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + x = self.mean + self.std * sample return x def kl(self, other=None): diff --git a/src/diffusers/mps_warmup_utils.py b/src/diffusers/mps_warmup_utils.py index 4b16187abd38..9e82e0379b64 100644 --- a/src/diffusers/mps_warmup_utils.py +++ b/src/diffusers/mps_warmup_utils.py @@ -49,7 +49,7 @@ def warmup_inputs(self, batch_size=None) -> Tuple: """ ) - def warmup(self, batch_size=None): + def warmup(self, batch_size=None, **kwargs): r""" Applies the warmup using `warmup_inputs`. Assumes this class implements `__call__` and has a `device` property. @@ -60,5 +60,5 @@ def warmup(self, batch_size=None): with torch.no_grad(): w_inputs = self.warmup_inputs(batch_size) w_inputs = [w.to("mps") for w in w_inputs] - self.__call__(*w_inputs) + self.__call__(*w_inputs, **kwargs) \ No newline at end of file diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index 33740020eeb8..34f609518bb5 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -79,9 +79,9 @@ def test_from_pretrained_hub(self): def test_output_pretrained(self): model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") model = model.to(torch_device) - if isinstance(model, MPSWarmupMixin): - model.warmup() model.eval() + if isinstance(model, MPSWarmupMixin): + model.warmup(1, sample_posterior=True) torch.manual_seed(0) if torch.cuda.is_available(): From db7da01ccad785d72f5d5091df4a6d1e490c3b6f Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 07:42:59 +0200 Subject: [PATCH 08/39] Style --- src/diffusers/models/vae.py | 2 +- src/diffusers/mps_warmup_utils.py | 2 +- src/diffusers/pipeline_utils.py | 3 +-- tests/test_modeling_common.py | 4 ++-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index f92c7db9c2fb..c89be18f6aaf 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -461,4 +461,4 @@ def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> def warmup_inputs(self, batch_size) -> Tuple: batch_size = 4 if batch_size is None else batch_size w_sample = torch.randn((batch_size, self.in_channels, self.sample_size, self.sample_size)) - return (w_sample,) \ No newline at end of file + return (w_sample,) diff --git a/src/diffusers/mps_warmup_utils.py b/src/diffusers/mps_warmup_utils.py index 9e82e0379b64..e2d0159aae13 100644 --- a/src/diffusers/mps_warmup_utils.py +++ b/src/diffusers/mps_warmup_utils.py @@ -16,6 +16,7 @@ import torch from typing import Tuple + class MPSWarmupMixin: r""" Temporary class to perform a 1-time warmup operation on some models, when they are moved to the `mps` device. @@ -61,4 +62,3 @@ def warmup(self, batch_size=None, **kwargs): w_inputs = self.warmup_inputs(batch_size) w_inputs = [w.to("mps") for w in w_inputs] self.__call__(*w_inputs, **kwargs) - \ No newline at end of file diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 36e502a112f7..a39bd3e64574 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -56,7 +56,6 @@ class DiffusionPipeline(ConfigMixin): - config_name = "model_index.json" def register_modules(self, **kwargs): @@ -128,7 +127,7 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None): module.to(torch_device) if isinstance(module, MPSWarmupMixin): module.warmup() - return self + return self @property def device(self) -> torch.device: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 474bec8b62b8..9a964169e37a 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -39,7 +39,7 @@ def test_from_pretrained_save_pretrained(self): new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) if isinstance(new_model, MPSWarmupMixin): - model.warmup(inputs_dict['sample'].shape[0]) + model.warmup(inputs_dict["sample"].shape[0]) with torch.no_grad(): image = model(**inputs_dict) @@ -59,7 +59,7 @@ def test_determinism(self): model = self.model_class(**init_dict) model.to(torch_device) if isinstance(model, MPSWarmupMixin): - model.warmup(inputs_dict['sample'].shape[0]) + model.warmup(inputs_dict["sample"].shape[0]) model.eval() with torch.no_grad(): From f20a0dda6f1d2b46ac872852be4ad58188e63a5e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 10:12:28 +0200 Subject: [PATCH 09/39] Do not use torch.long for log op in mps device. --- src/diffusers/models/unet_2d.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- tests/test_modeling_common.py | 2 +- tests/test_models_unet.py | 4 +++- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 9143b090e3ca..b7e359f29c62 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -189,5 +189,5 @@ def forward( def warmup_inputs(self, batch_size) -> Tuple: batch_size = 1 if batch_size is None else batch_size w_sample = torch.randn((batch_size, self.in_channels, 32, 32)) - t = torch.tensor([10], dtype=torch.long) + t = torch.tensor([10], dtype=torch.int32) return (w_sample, t) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index e3654a25dc3e..499a13580de5 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -191,6 +191,6 @@ def forward( def warmup_inputs(self, batch_size) -> Tuple: batch_size = 1 if batch_size is None else batch_size w_sample = torch.randn((batch_size, self.in_channels, 64, 64)) - t = torch.tensor([10], dtype=torch.long) + t = torch.tensor([10], dtype=torch.int32) w_encoded = torch.rand((batch_size, 77, 768)) return (w_sample, t, w_encoded) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 9a964169e37a..439ce958488d 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -58,10 +58,10 @@ def test_determinism(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) + model.eval() if isinstance(model, MPSWarmupMixin): model.warmup(inputs_dict["sample"].shape[0]) - model.eval() with torch.no_grad(): first = model(**inputs_dict) if isinstance(first, dict): diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 47f562dd5341..bf7a500d8109 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -191,7 +191,9 @@ def dummy_input(self, sizes=(32, 32)): num_channels = 3 noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(torch_device) + time_step = torch.tensor(batch_size * [10]).to(device=torch_device) + if torch_device == "mps": + time_step = time_step.to(dtype=torch.int32) return {"sample": noise, "timestep": time_step} From 7f40f24297e4d3dc0279fdfe38ae4015de8b2747 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 10:24:54 +0200 Subject: [PATCH 10/39] Perform incompatible padding ops in CPU. UNet tests now pass. See https://github.com/pytorch/pytorch/issues/84535 --- src/diffusers/models/resnet.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 50382bcab37d..f089b2e2cfb4 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -446,10 +446,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) + using_mps = out.device.type == "mps" + if using_mps: + out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + if using_mps: + out = out.to("mps") out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), From 71039936026095f429c6085ca3d8f53eb2dc2287 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 10:44:35 +0200 Subject: [PATCH 11/39] Style: fix import order. --- src/diffusers/models/unet_2d.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- src/diffusers/models/vae.py | 6 ++---- src/diffusers/mps_warmup_utils.py | 3 ++- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index b7e359f29c62..d874fc682fa1 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -4,8 +4,8 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..mps_warmup_utils import MPSWarmupMixin from ..modeling_utils import ModelMixin +from ..mps_warmup_utils import MPSWarmupMixin from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 499a13580de5..1b4e65273d3b 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -4,8 +4,8 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config -from ..mps_warmup_utils import MPSWarmupMixin from ..modeling_utils import ModelMixin +from ..mps_warmup_utils import MPSWarmupMixin from .embeddings import TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index c89be18f6aaf..893a65310290 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -1,14 +1,12 @@ -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn -from typing import Dict, Union, Tuple - from ..configuration_utils import ConfigMixin, register_to_config -from ..mps_warmup_utils import MPSWarmupMixin from ..modeling_utils import ModelMixin +from ..mps_warmup_utils import MPSWarmupMixin from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block diff --git a/src/diffusers/mps_warmup_utils.py b/src/diffusers/mps_warmup_utils.py index e2d0159aae13..65649afc5d2f 100644 --- a/src/diffusers/mps_warmup_utils.py +++ b/src/diffusers/mps_warmup_utils.py @@ -13,9 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch from typing import Tuple +import torch + class MPSWarmupMixin: r""" From c931d2ac050bf5992a9daaac860bcaa26487f310 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 11:31:39 +0200 Subject: [PATCH 12/39] Remove unused symbols. --- src/diffusers/models/vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 893a65310290..c50ffb4f0795 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple import numpy as np import torch From d0e85f3a4a0c3a1974049f7e561a580dd2bb9ce5 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 13:03:20 +0200 Subject: [PATCH 13/39] Remove MPSWarmupMixin, do not apply automatically. We do apply warmup in the tests, but not during normal use. This adopts some PR suggestions by @patrickvonplaten. --- src/diffusers/modeling_utils.py | 34 ++++++++++++ src/diffusers/models/unet_2d.py | 5 +- src/diffusers/models/unet_2d_condition.py | 5 +- src/diffusers/models/vae.py | 5 +- src/diffusers/mps_warmup_utils.py | 65 ----------------------- src/diffusers/pipeline_utils.py | 3 -- tests/test_modeling_common.py | 14 ++--- tests/test_models_vae.py | 6 +-- 8 files changed, 50 insertions(+), 87 deletions(-) delete mode 100644 src/diffusers/mps_warmup_utils.py diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index ec501e2ae1f8..ef366c0a9cf5 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -586,6 +586,40 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + def _mps_warmup_inputs(self, batch_size=None) -> Optional[Tuple]: + r""" + Temporary procedure to run a one-time forward pass on some models, when using the `mps` device. + + It has been observed that the output of some models (`unet`, `vae`) is different the first time they run than the + rest, for the same inputs. We are investigating the root cause of the problem, but meanwhile these methods cand be + used, if desired, to warmup those modules so their outputs are consistent. + + Return inputs suitable for the forward pass of this model. + These will usually be a tuple of tensors that will be automatically moved to the `mps` device on warmup. + + Return `None` if no warmup is required. + """ + return None + + + def _mps_warmup(self, batch_size=None, **kwargs): + r""" + Temporary procedure to run a one-time forward pass on some models, when using the `mps` device. + + Applies the warmup using `warmup_inputs`. + """ + if self.device.type != "mps": + return + + with torch.no_grad(): + w_inputs = self._mps_warmup_inputs(batch_size) + if w_inputs is None: + return + w_inputs = [w.to("mps") for w in w_inputs] + self.__call__(*w_inputs, **kwargs) + + + def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: """ Recursively unwraps a model from potential containers (as used in distributed training). diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index d874fc682fa1..ccb3b057066d 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -5,12 +5,11 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin -from ..mps_warmup_utils import MPSWarmupMixin from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block -class UNet2DModel(ModelMixin, ConfigMixin, MPSWarmupMixin): +class UNet2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, @@ -186,7 +185,7 @@ def forward( return output - def warmup_inputs(self, batch_size) -> Tuple: + def _mps_warmup_inputs(self, batch_size) -> Tuple: batch_size = 1 if batch_size is None else batch_size w_sample = torch.randn((batch_size, self.in_channels, 32, 32)) t = torch.tensor([10], dtype=torch.int32) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 1b4e65273d3b..a128f5657ffe 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -5,12 +5,11 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin -from ..mps_warmup_utils import MPSWarmupMixin from .embeddings import TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block -class UNet2DConditionModel(ModelMixin, ConfigMixin, MPSWarmupMixin): +class UNet2DConditionModel(ModelMixin, ConfigMixin): @register_to_config def __init__( self, @@ -188,7 +187,7 @@ def forward( return output - def warmup_inputs(self, batch_size) -> Tuple: + def _mps_warmup_inputs(self, batch_size) -> Tuple: batch_size = 1 if batch_size is None else batch_size w_sample = torch.randn((batch_size, self.in_channels, 64, 64)) t = torch.tensor([10], dtype=torch.int32) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index c50ffb4f0795..b342d2b65690 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -6,7 +6,6 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin -from ..mps_warmup_utils import MPSWarmupMixin from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -395,7 +394,7 @@ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: return dec -class AutoencoderKL(ModelMixin, ConfigMixin, MPSWarmupMixin): +class AutoencoderKL(ModelMixin, ConfigMixin): @register_to_config def __init__( self, @@ -456,7 +455,7 @@ def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> dec = self.decode(z) return dec - def warmup_inputs(self, batch_size) -> Tuple: + def _mps_warmup_inputs(self, batch_size) -> Tuple: batch_size = 4 if batch_size is None else batch_size w_sample = torch.randn((batch_size, self.in_channels, self.sample_size, self.sample_size)) return (w_sample,) diff --git a/src/diffusers/mps_warmup_utils.py b/src/diffusers/mps_warmup_utils.py deleted file mode 100644 index 65649afc5d2f..000000000000 --- a/src/diffusers/mps_warmup_utils.py +++ /dev/null @@ -1,65 +0,0 @@ -# coding=utf-8 -# Copyright 2022 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch - - -class MPSWarmupMixin: - r""" - Temporary class to perform a 1-time warmup operation on some models, when they are moved to the `mps` device. - - It has been observed that the output of some models (`unet`, `vae`) is different the first time they run than the - rest, for the same inputs. We are investigating the root cause of the problem, but meanwhile this class will be - used to warmup those modules so their outputs are consistent. - TODO: link to issue when we open it. - - Classes that require warmup need to adhere to [`MPSWarmupMixin`] and implement the following: - - - **warmup_inputs** -- A method that returns a suitable set of inputs to use during a forward pass. - - IMPORTANT: - - Warmup will be automatically performed when moving a pipeline to the `mps` device. If you move a single module, no - warmup will be applied. - """ - - def warmup_inputs(self, batch_size=None) -> Tuple: - r""" - Return inputs suitable for the forward pass of this module. - These will usually be a tuple of tensors. They will be automatically moved to the `mps` device on warmup. - """ - raise NotImplementedError( - """ - You declared conformance to `MPSWarmupMixin` but did not provide an implementation for `warmup_inputs`. - - Please, write a suitable implementation for `warmup_inputs` or remove conformance to `MPSWarmupMixin` - if it's not needed. - """ - ) - - def warmup(self, batch_size=None, **kwargs): - r""" - Applies the warmup using `warmup_inputs`. - Assumes this class implements `__call__` and has a `device` property. - """ - if self.device.type != "mps": - return - - with torch.no_grad(): - w_inputs = self.warmup_inputs(batch_size) - w_inputs = [w.to("mps") for w in w_inputs] - self.__call__(*w_inputs, **kwargs) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index a39bd3e64574..58f571d709df 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -26,7 +26,6 @@ from tqdm.auto import tqdm from .configuration_utils import ConfigMixin -from .mps_warmup_utils import MPSWarmupMixin from .utils import DIFFUSERS_CACHE, logging @@ -125,8 +124,6 @@ def to(self, torch_device: Optional[Union[str, torch.device]] = None): module = getattr(self, name) if isinstance(module, torch.nn.Module): module.to(torch_device) - if isinstance(module, MPSWarmupMixin): - module.warmup() return self @property diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 439ce958488d..6cd4961e12ae 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -18,8 +18,8 @@ import numpy as np import torch +from diffusers.modeling_utils import ModelMixin -from diffusers.mps_warmup_utils import MPSWarmupMixin from diffusers.testing_utils import torch_device from diffusers.training_utils import EMAModel @@ -30,16 +30,16 @@ def test_from_pretrained_save_pretrained(self): model = self.model_class(**init_dict) model.to(torch_device) - if isinstance(model, MPSWarmupMixin): - model.warmup() + if isinstance(model, ModelMixin): + model._mps_warmup() model.eval() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) - if isinstance(new_model, MPSWarmupMixin): - model.warmup(inputs_dict["sample"].shape[0]) + if isinstance(new_model, ModelMixin): + model._mps_warmup(inputs_dict["sample"].shape[0]) with torch.no_grad(): image = model(**inputs_dict) @@ -59,8 +59,8 @@ def test_determinism(self): model = self.model_class(**init_dict) model.to(torch_device) model.eval() - if isinstance(model, MPSWarmupMixin): - model.warmup(inputs_dict["sample"].shape[0]) + if isinstance(model, ModelMixin): + model._mps_warmup(inputs_dict["sample"].shape[0]) with torch.no_grad(): first = model(**inputs_dict) diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index 34f609518bb5..852b0b08e13f 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -18,7 +18,7 @@ import torch from diffusers import AutoencoderKL -from diffusers.mps_warmup_utils import MPSWarmupMixin +from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin @@ -80,8 +80,8 @@ def test_output_pretrained(self): model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") model = model.to(torch_device) model.eval() - if isinstance(model, MPSWarmupMixin): - model.warmup(1, sample_posterior=True) + if isinstance(model, ModelMixin): + model._mps_warmup(1, sample_posterior=True) torch.manual_seed(0) if torch.cuda.is_available(): From 692b1be1a0135aaf004c2c0bf5c0a2e01db8a6a5 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 13:09:15 +0200 Subject: [PATCH 14/39] Add comment for mps fallback to CPU step. --- src/diffusers/models/resnet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f089b2e2cfb4..f6b9daefc6a1 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -446,12 +446,14 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) + + # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 using_mps = out.device.type == "mps" if using_mps: out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) - + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) if using_mps: out = out.to("mps") From 36b6a46fd314c9c0b5009cad8485c76a203c2868 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 13:37:23 +0200 Subject: [PATCH 15/39] Add README_mps.md for mps installation and use. --- README.md | 4 ++++ README_mps.md | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 README_mps.md diff --git a/README.md b/README.md index 28e1be164bb5..918b787ef500 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4 conda install -c conda-forge diffusers ``` +**Apple Silicon (M1/M2) support** + +See [README_mps.md](README_mps.md). + ## Contributing We ❤️ contributions from the open-source community! diff --git a/README_mps.md b/README_mps.md new file mode 100644 index 000000000000..4cd4fbf46fb6 --- /dev/null +++ b/README_mps.md @@ -0,0 +1,44 @@ +## How to use Stable Diffusion in Apple Silicon (M1/M2) + +🤗 Diffusers is compatible with Apple silicon for inference, using the PyTorch `mps` device. These are the steps you need to follow to use your M1 or M2 computer with Stable Diffusion. + +### Requirements + +- Mac computer with Apple silicon (M1/M2) hardware. +- macOS 12.3 or later. +- arm64 version of Python. + +### Install PyTorch Nightly + +Install the latest [Preview (Nightly) build of PyTorch](https://pytorch.org/get-started/locally/) on your Apple silicon Mac. Please, make sure it meets the requirements above. + +We have tested inference using PyTorch version `1.13.0.dev20220830`. After installing, please verify that your version is at least that one. + +### Inference Pipeline + +The snippet shown below demonstrates how to use the `mps` backend using the familiar `to()` interface to move the Stable Diffusion pipeline to your M1 or M2 device. + +We recommend to "prime" the pipeline using an additional one-time pass through it. This is a temporary workaround for a weird issue we have detected: the first inference pass produces slightly different results than subsequent ones. You only need to do this pass once, and it's ok to use just one inference step and discard the result. + +```python +# make sure you're logged in with `huggingface-cli login` +from diffusers import StableDiffusionPipeline + +pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) +pipe = pipe.to("mps") + +prompt = "a photo of an astronaut riding a horse on mars" + +# First-time "warmup" pass (see explanation above) +_ = pipe(prompt, num_inference_steps=1) + +# Results match those from the CPU device after the warmup pass. +image = pipe(prompt)["sample"][0] +``` + +### Performance + +These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5. + +- CPU: 213.46s +- MPS: 30.81s From 261a78407322324ef131a65d6cc9dc996d156325 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 14:06:20 +0200 Subject: [PATCH 16/39] Apply `black` to modified files. --- src/diffusers/modeling_utils.py | 5 +---- src/diffusers/models/resnet.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index ef366c0a9cf5..891fab1a8b3b 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -585,7 +585,6 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) - def _mps_warmup_inputs(self, batch_size=None) -> Optional[Tuple]: r""" Temporary procedure to run a one-time forward pass on some models, when using the `mps` device. @@ -596,12 +595,11 @@ def _mps_warmup_inputs(self, batch_size=None) -> Optional[Tuple]: Return inputs suitable for the forward pass of this model. These will usually be a tuple of tensors that will be automatically moved to the `mps` device on warmup. - + Return `None` if no warmup is required. """ return None - def _mps_warmup(self, batch_size=None, **kwargs): r""" Temporary procedure to run a one-time forward pass on some models, when using the `mps` device. @@ -619,7 +617,6 @@ def _mps_warmup(self, batch_size=None, **kwargs): self.__call__(*w_inputs, **kwargs) - def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: """ Recursively unwraps a model from potential containers (as used in distributed training). diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index f6b9daefc6a1..9ed8adf37192 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -453,7 +453,7 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) - + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) if using_mps: out = out.to("mps") From 15d86ff6fcd4fe4eb6bc71b8fe36110867d92134 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 14:10:03 +0200 Subject: [PATCH 17/39] Restrict README_mps to SD, show measures in table. --- README_mps.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README_mps.md b/README_mps.md index 4cd4fbf46fb6..12f82f9711de 100644 --- a/README_mps.md +++ b/README_mps.md @@ -1,6 +1,6 @@ ## How to use Stable Diffusion in Apple Silicon (M1/M2) -🤗 Diffusers is compatible with Apple silicon for inference, using the PyTorch `mps` device. These are the steps you need to follow to use your M1 or M2 computer with Stable Diffusion. +🤗 Diffusers is compatible with Apple silicon for Stable Diffusion inference, using the PyTorch `mps` device. These are the steps you need to follow to use your M1 or M2 computer with Stable Diffusion. ### Requirements @@ -40,5 +40,7 @@ image = pipe(prompt)["sample"][0] These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5. -- CPU: 213.46s -- MPS: 30.81s +| Device | Steps | Time | +|--------|-------|---------| +| CPU | 50 | 213.46s | +| MPS | 50 | 30.81s | From 5ed888916d3966e87278d79da70765fdcc678b21 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 17:16:42 +0200 Subject: [PATCH 18/39] Make PNDM indexing compatible with mps. Addresses #239. --- src/diffusers/schedulers/scheduling_pndm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 57ed4fb7e369..e2fef9022583 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -270,7 +270,8 @@ def add_noise( noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> torch.Tensor: - + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 From 12f66700fa7116e24b3bbe42ec6b51b2dc6d7c12 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 5 Sep 2022 18:08:01 +0200 Subject: [PATCH 19/39] Do not use float64 when using LDMScheduler. Fixes #358. --- src/diffusers/models/unet_2d_condition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index a128f5657ffe..bb23350db965 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -135,7 +135,9 @@ def forward( if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + if sample.device.type == "mps": + timesteps = timesteps.to(dtype=torch.float32) + timesteps = timesteps[None].to(device=sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device) From ce1e8635f8dfd672e708a0a6a0a50209b5baee0e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 10:00:55 +0200 Subject: [PATCH 20/39] Fix typo identified by @patil-suraj Co-authored-by: Suraj Patil --- src/diffusers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 891fab1a8b3b..7ba331a87bb1 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -590,7 +590,7 @@ def _mps_warmup_inputs(self, batch_size=None) -> Optional[Tuple]: Temporary procedure to run a one-time forward pass on some models, when using the `mps` device. It has been observed that the output of some models (`unet`, `vae`) is different the first time they run than the - rest, for the same inputs. We are investigating the root cause of the problem, but meanwhile these methods cand be + rest, for the same inputs. We are investigating the root cause of the problem, but meanwhile these methods can be used, if desired, to warmup those modules so their outputs are consistent. Return inputs suitable for the forward pass of this model. From 3943fded76f082e989ff8b33c680678c82f7f66a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 10:26:12 +0200 Subject: [PATCH 21/39] Adapt example to new output style. --- README_mps.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README_mps.md b/README_mps.md index 12f82f9711de..44d8a74af3e1 100644 --- a/README_mps.md +++ b/README_mps.md @@ -33,7 +33,7 @@ prompt = "a photo of an astronaut riding a horse on mars" _ = pipe(prompt, num_inference_steps=1) # Results match those from the CPU device after the warmup pass. -image = pipe(prompt)["sample"][0] +image = pipe(prompt).images[0] ``` ### Performance From cec0c7e361e00f51ee1672d58e84550bbf8fb696 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 10:38:23 +0200 Subject: [PATCH 22/39] Restore 1:1 results reproducibility with CompVis. However, mps latents need to be generated in CPU because generators don't work in the mps device. --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9852aa8022e2..7c4d45295c6a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -101,11 +101,17 @@ def __call__( text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: latents = torch.randn( latents_shape, generator=generator, + device=latents_device, ) else: if latents.shape != latents_shape: From 1bf8c4c6b418c51130d2bfb561b1375a3acb2289 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 12:20:25 +0200 Subject: [PATCH 23/39] Move PyTorch nightly to requirements. --- README_mps.md | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/README_mps.md b/README_mps.md index 44d8a74af3e1..2deaefc8b032 100644 --- a/README_mps.md +++ b/README_mps.md @@ -7,12 +7,7 @@ - Mac computer with Apple silicon (M1/M2) hardware. - macOS 12.3 or later. - arm64 version of Python. - -### Install PyTorch Nightly - -Install the latest [Preview (Nightly) build of PyTorch](https://pytorch.org/get-started/locally/) on your Apple silicon Mac. Please, make sure it meets the requirements above. - -We have tested inference using PyTorch version `1.13.0.dev20220830`. After installing, please verify that your version is at least that one. +- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later. ### Inference Pipeline From 220999f6d351d3c0dd5b74826bdd111786b524bc Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 12:47:34 +0200 Subject: [PATCH 24/39] Adapt `test_scheduler_outputs_equivalence` ton MPS. --- tests/test_modeling_common.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c8ff93a5252e..c12bef30570d 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -181,8 +181,13 @@ def test_ema_training(self): def test_scheduler_outputs_equivalence(self): def set_nan_tensor_to_zero(t): + # Temporary fallback until `aten::_index_put_impl_` is implemented in mps + # Track progress in https://github.com/pytorch/pytorch/issues/77764 + device = t.device + if device.type == "mps": + t = t.to("cpu") t[t != t] = 0 - return t + return t.to(device) def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, (List, Tuple)): @@ -211,8 +216,11 @@ def recursive_check(tuple_object, dict_object): model = self.model_class(**init_dict) model.to(torch_device) model.eval() + if isinstance(model, ModelMixin): + model._mps_warmup(inputs_dict['sample'].shape[0]) - outputs_dict = model(**inputs_dict) - outputs_tuple = model(**inputs_dict, return_dict=False) + with torch.no_grad(): + outputs_dict = model(**inputs_dict) + outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) From 928091373b131484f807382c60a1a5a2f9c934f2 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 13:05:16 +0200 Subject: [PATCH 25/39] mps: skip training tests instead of ignoring silently. --- tests/test_modeling_common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index c12bef30570d..27899ecd32f1 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -18,6 +18,7 @@ from typing import Dict, List, Tuple import numpy as np +import pytest import torch from diffusers.modeling_utils import ModelMixin @@ -142,7 +143,7 @@ def test_model_from_config(self): def test_training(self): if torch_device == "mps": - return + pytest.skip("mps: unsupported training device") init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -160,7 +161,7 @@ def test_training(self): def test_ema_training(self): if torch_device == "mps": - return + pytest.skip("mps: unsupported training device") init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() From d8a093d645fc60ee9dcf3c4f60ed322160faffd7 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 13:18:14 +0200 Subject: [PATCH 26/39] Make VQModel tests pass on mps. --- src/diffusers/models/vae.py | 6 +++++- tests/test_models_vq.py | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index f77d124d6f98..e8a2f2e2f522 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -449,6 +449,11 @@ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[ return DecoderOutput(sample=dec) + def _mps_warmup_inputs(self, batch_size) -> Tuple: + batch_size = 4 if batch_size is None else batch_size + w_sample = torch.randn((batch_size, self.in_channels, self.sample_size, self.sample_size)) + return (w_sample,) + class AutoencoderKL(ModelMixin, ConfigMixin): @register_to_config @@ -529,4 +534,3 @@ def _mps_warmup_inputs(self, batch_size) -> Tuple: batch_size = 4 if batch_size is None else batch_size w_sample = torch.randn((batch_size, self.in_channels, self.sample_size, self.sample_size)) return (w_sample,) - diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py index c0acceccb492..9d03006aee86 100644 --- a/tests/test_models_vq.py +++ b/tests/test_models_vq.py @@ -18,6 +18,7 @@ import torch from diffusers import VQModel +from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin @@ -77,6 +78,8 @@ def test_from_pretrained_hub(self): def test_output_pretrained(self): model = VQModel.from_pretrained("fusing/vqgan-dummy") model.to(torch_device).eval() + if isinstance(model, ModelMixin): + model._mps_warmup(1) torch.manual_seed(0) if torch.cuda.is_available(): From 0f60435e869ad18f3c4da034c32ae1721d67d04c Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 13:37:40 +0200 Subject: [PATCH 27/39] mps ddim tests: warmup, increase tolerance. --- tests/test_pipelines.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index fbd0faf02c11..f713306a2147 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -194,6 +194,7 @@ def test_ddim(self): ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) + ddpm.unet._mps_warmup() generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images @@ -208,8 +209,9 @@ def test_ddim(self): expected_slice = np.array( [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + tolerance = 1e-2 if torch_device != "mps" else 2.5e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance def test_pndm_cifar10(self): unet = self.dummy_uncond_unet From 3c59b399c503331195ab3068a4d4a01cdcd312ee Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 14:18:25 +0200 Subject: [PATCH 28/39] ScoreSdeVeScheduler indexing made mps compatible. --- src/diffusers/schedulers/scheduling_sde_ve.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index f6b0ba936eea..4f9c4898b10f 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -111,7 +111,7 @@ def get_adjacent_sigma(self, timesteps, t): return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) elif tensor_format == "pt": return torch.where( - timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device) + timesteps == 0, torch.zeros_like(t.to(timesteps.device)), self.discrete_sigmas[timesteps - 1].to(timesteps.device) ) raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") @@ -155,8 +155,11 @@ def step_pred( ) # torch.repeat_interleave(timestep, sample.shape[0]) timesteps = (timestep * (len(self.timesteps) - 1)).long() + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.discrete_sigmas.device) + sigma = self.discrete_sigmas[timesteps].to(sample.device) - adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep) + adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) drift = self.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 From 1a2af524d2d4613cbc1aafad34203f8a89dd7b9e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 14:45:10 +0200 Subject: [PATCH 29/39] Make ldm pipeline tests pass using warmup. --- tests/test_pipelines.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index f713306a2147..ece7cf1f7a02 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -247,6 +247,14 @@ def test_ldm_text2img(self): ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" + + # Skip first when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[ + "sample" + ] + generator = torch.manual_seed(0) image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[ "sample" @@ -443,6 +451,11 @@ def test_ldm_uncond(self): ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler) ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) + + # Skip first when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images generator = torch.manual_seed(0) image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images From e4181f0469c506629e20d8674c375ee969b4ade9 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 18:14:46 +0200 Subject: [PATCH 30/39] Style --- src/diffusers/models/unet_2d.py | 1 - src/diffusers/models/unet_2d_condition.py | 1 - src/diffusers/schedulers/scheduling_sde_ve.py | 4 +++- tests/test_modeling_common.py | 2 +- tests/test_pipelines.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index a1f1b168f806..796c11edd095 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -207,4 +207,3 @@ def _mps_warmup_inputs(self, batch_size) -> Tuple: w_sample = torch.randn((batch_size, self.in_channels, 32, 32)) t = torch.tensor([10], dtype=torch.int32) return (w_sample, t) - diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 267d8b6fe2d8..3c93fdc242a9 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -210,4 +210,3 @@ def _mps_warmup_inputs(self, batch_size) -> Tuple: t = torch.tensor([10], dtype=torch.int32) w_encoded = torch.rand((batch_size, 77, 768)) return (w_sample, t, w_encoded) - diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 4f9c4898b10f..ad5c738d966e 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -111,7 +111,9 @@ def get_adjacent_sigma(self, timesteps, t): return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) elif tensor_format == "pt": return torch.where( - timesteps == 0, torch.zeros_like(t.to(timesteps.device)), self.discrete_sigmas[timesteps - 1].to(timesteps.device) + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), ) raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 27899ecd32f1..3c58bc0549cb 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -218,7 +218,7 @@ def recursive_check(tuple_object, dict_object): model.to(torch_device) model.eval() if isinstance(model, ModelMixin): - model._mps_warmup(inputs_dict['sample'].shape[0]) + model._mps_warmup(inputs_dict["sample"].shape[0]) with torch.no_grad(): outputs_dict = model(**inputs_dict) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 511d6f405b8a..e98f4c965a33 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -451,7 +451,7 @@ def test_ldm_uncond(self): ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler) ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) - + # Skip first when using mps (see #372) if torch_device == "mps": generator = torch.manual_seed(0) From 1b05d0f66b280a2f736ce74803ca97da1438c220 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 18:36:35 +0200 Subject: [PATCH 31/39] Simplify casting as suggested in PR. --- src/diffusers/models/unet_2d_condition.py | 3 +-- tests/test_models_unet.py | 4 +--- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 3c93fdc242a9..2fbd44ad43b9 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -149,8 +149,7 @@ def forward( if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - if sample.device.type == "mps": - timesteps = timesteps.to(dtype=torch.float32) + timesteps = timesteps.to(dtype=torch.float32) timesteps = timesteps[None].to(device=sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index b6c6b960410d..5e6269373973 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -191,9 +191,7 @@ def dummy_input(self, sizes=(32, 32)): num_channels = 3 noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(device=torch_device) - if torch_device == "mps": - time_step = time_step.to(dtype=torch.int32) + time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device) return {"sample": noise, "timestep": time_step} From df6683b00fc53afcb6f2c00ab7949a0cf65839bc Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 18:48:38 +0200 Subject: [PATCH 32/39] Add Known Issues to readme. --- README_mps.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README_mps.md b/README_mps.md index 2deaefc8b032..cb1df05de5d9 100644 --- a/README_mps.md +++ b/README_mps.md @@ -31,6 +31,11 @@ _ = pipe(prompt, num_inference_steps=1) image = pipe(prompt).images[0] ``` +### Known Issues + +- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372). +- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this might be related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039#issuecomment-1237735249), but we need to investigate in more depth. For now, we recommend to iterate instead of batching. + ### Performance These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5. From 4e4bd62945e02be6dbb6489da72cbe77d923fe5b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Sep 2022 18:54:53 +0200 Subject: [PATCH 33/39] `isort` import order. --- tests/test_modeling_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3c58bc0549cb..79cdfe97b7d6 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -18,10 +18,10 @@ from typing import Dict, List, Tuple import numpy as np -import pytest import torch -from diffusers.modeling_utils import ModelMixin +import pytest +from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import torch_device from diffusers.training_utils import EMAModel From c985b50e8f8acf49e2170d286512dfd15901c750 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 8 Sep 2022 12:14:10 +0200 Subject: [PATCH 34/39] Remove _mps_warmup helpers from ModelMixin. And just make changes to the tests. --- src/diffusers/modeling_utils.py | 31 ----------------------- src/diffusers/models/unet_2d.py | 6 ----- src/diffusers/models/unet_2d_condition.py | 7 ----- src/diffusers/models/vae.py | 10 -------- tests/test_modeling_common.py | 23 +++++++++++------ tests/test_models_vae.py | 9 +++++-- tests/test_models_vq.py | 5 ++-- tests/test_pipelines.py | 11 +++++--- 8 files changed, 32 insertions(+), 70 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 7ba331a87bb1..ec501e2ae1f8 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -585,37 +585,6 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) - def _mps_warmup_inputs(self, batch_size=None) -> Optional[Tuple]: - r""" - Temporary procedure to run a one-time forward pass on some models, when using the `mps` device. - - It has been observed that the output of some models (`unet`, `vae`) is different the first time they run than the - rest, for the same inputs. We are investigating the root cause of the problem, but meanwhile these methods can be - used, if desired, to warmup those modules so their outputs are consistent. - - Return inputs suitable for the forward pass of this model. - These will usually be a tuple of tensors that will be automatically moved to the `mps` device on warmup. - - Return `None` if no warmup is required. - """ - return None - - def _mps_warmup(self, batch_size=None, **kwargs): - r""" - Temporary procedure to run a one-time forward pass on some models, when using the `mps` device. - - Applies the warmup using `warmup_inputs`. - """ - if self.device.type != "mps": - return - - with torch.no_grad(): - w_inputs = self._mps_warmup_inputs(batch_size) - if w_inputs is None: - return - w_inputs = [w.to("mps") for w in w_inputs] - self.__call__(*w_inputs, **kwargs) - def unwrap_model(model: torch.nn.Module) -> torch.nn.Module: """ diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 796c11edd095..46d5ee532961 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -201,9 +201,3 @@ def forward( return (sample,) return UNet2DOutput(sample=sample) - - def _mps_warmup_inputs(self, batch_size) -> Tuple: - batch_size = 1 if batch_size is None else batch_size - w_sample = torch.randn((batch_size, self.in_channels, 32, 32)) - t = torch.tensor([10], dtype=torch.int32) - return (w_sample, t) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 2fbd44ad43b9..dbbcf3458313 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -202,10 +202,3 @@ def forward( return (sample,) return UNet2DConditionOutput(sample=sample) - - def _mps_warmup_inputs(self, batch_size) -> Tuple: - batch_size = 1 if batch_size is None else batch_size - w_sample = torch.randn((batch_size, self.in_channels, 64, 64)) - t = torch.tensor([10], dtype=torch.int32) - w_encoded = torch.rand((batch_size, 77, 768)) - return (w_sample, t, w_encoded) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index e8a2f2e2f522..b90a938aa81c 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -449,11 +449,6 @@ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[ return DecoderOutput(sample=dec) - def _mps_warmup_inputs(self, batch_size) -> Tuple: - batch_size = 4 if batch_size is None else batch_size - w_sample = torch.randn((batch_size, self.in_channels, self.sample_size, self.sample_size)) - return (w_sample,) - class AutoencoderKL(ModelMixin, ConfigMixin): @register_to_config @@ -529,8 +524,3 @@ def forward( return (dec,) return DecoderOutput(sample=dec) - - def _mps_warmup_inputs(self, batch_size) -> Tuple: - batch_size = 4 if batch_size is None else batch_size - w_sample = torch.randn((batch_size, self.in_channels, self.sample_size, self.sample_size)) - return (w_sample,) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 79cdfe97b7d6..98c843f4845b 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -32,18 +32,19 @@ def test_from_pretrained_save_pretrained(self): model = self.model_class(**init_dict) model.to(torch_device) - if isinstance(model, ModelMixin): - model._mps_warmup() model.eval() with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) - if isinstance(new_model, ModelMixin): - model._mps_warmup(inputs_dict["sample"].shape[0]) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + _ = model(**self.dummy_input) + _ = new_model(**self.dummy_input) + image = model(**inputs_dict) if isinstance(image, dict): image = image.sample @@ -61,10 +62,12 @@ def test_determinism(self): model = self.model_class(**init_dict) model.to(torch_device) model.eval() - if isinstance(model, ModelMixin): - model._mps_warmup(inputs_dict["sample"].shape[0]) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + first = model(**inputs_dict) if isinstance(first, dict): first = first.sample @@ -142,6 +145,7 @@ def test_model_from_config(self): self.assertEqual(output_1.shape, output_2.shape) def test_training(self): + # Warmup pass when using mps (see #372) if torch_device == "mps": pytest.skip("mps: unsupported training device") @@ -160,6 +164,7 @@ def test_training(self): loss.backward() def test_ema_training(self): + # Warmup pass when using mps (see #372) if torch_device == "mps": pytest.skip("mps: unsupported training device") @@ -217,10 +222,12 @@ def recursive_check(tuple_object, dict_object): model = self.model_class(**init_dict) model.to(torch_device) model.eval() - if isinstance(model, ModelMixin): - model._mps_warmup(inputs_dict["sample"].shape[0]) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + outputs_dict = model(**inputs_dict) outputs_tuple = model(**inputs_dict, return_dict=False) diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index c5dcdc510158..c772dc7f632d 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -80,8 +80,13 @@ def test_output_pretrained(self): model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") model = model.to(torch_device) model.eval() - if isinstance(model, ModelMixin): - model._mps_warmup(1, sample_posterior=True) + + # One-time warmup pass (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = image.to(torch_device) + with torch.no_grad(): + _ = model(image, sample_posterior=True).sample torch.manual_seed(0) if torch.cuda.is_available(): diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py index 9d03006aee86..95c79b836b47 100644 --- a/tests/test_models_vq.py +++ b/tests/test_models_vq.py @@ -78,8 +78,6 @@ def test_from_pretrained_hub(self): def test_output_pretrained(self): model = VQModel.from_pretrained("fusing/vqgan-dummy") model.to(torch_device).eval() - if isinstance(model, ModelMixin): - model._mps_warmup(1) torch.manual_seed(0) if torch.cuda.is_available(): @@ -88,6 +86,9 @@ def test_output_pretrained(self): image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = model(image) output = model(image).sample output_slice = output[0, -1, -3:, -3:].flatten().cpu() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index e98f4c965a33..61775e741afa 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -194,7 +194,10 @@ def test_ddim(self): ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) - ddpm.unet._mps_warmup() + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images @@ -209,7 +212,7 @@ def test_ddim(self): expected_slice = np.array( [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] ) - tolerance = 1e-2 if torch_device != "mps" else 2.5e-2 + tolerance = 1e-2 if torch_device != "mps" else 3e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance @@ -248,7 +251,7 @@ def test_ldm_text2img(self): prompt = "A painting of a squirrel eating a burger" - # Skip first when using mps (see #372) + # Warmup pass when using mps (see #372) if torch_device == "mps": generator = torch.manual_seed(0) _ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[ @@ -452,7 +455,7 @@ def test_ldm_uncond(self): ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) - # Skip first when using mps (see #372) + # Warmup pass when using mps (see #372) if torch_device == "mps": generator = torch.manual_seed(0) _ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images From cdd1c416525f6bc5a0adc3ebc83ed58b314a68c1 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 8 Sep 2022 12:16:14 +0200 Subject: [PATCH 35/39] Skip tests using unittest decorator for consistency. --- tests/test_modeling_common.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 98c843f4845b..2a39361e2958 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -15,12 +15,12 @@ import inspect import tempfile +import unittest from typing import Dict, List, Tuple import numpy as np import torch -import pytest from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import torch_device from diffusers.training_utils import EMAModel @@ -144,11 +144,8 @@ def test_model_from_config(self): self.assertEqual(output_1.shape, output_2.shape) + @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_training(self): - # Warmup pass when using mps (see #372) - if torch_device == "mps": - pytest.skip("mps: unsupported training device") - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -163,11 +160,8 @@ def test_training(self): loss = torch.nn.functional.mse_loss(output, noise) loss.backward() + @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_ema_training(self): - # Warmup pass when using mps (see #372) - if torch_device == "mps": - pytest.skip("mps: unsupported training device") - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) From 44f485bed1deb4b44cd4df22fc65d69d6dc2bbe7 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 8 Sep 2022 12:30:54 +0200 Subject: [PATCH 36/39] Remove temporary var. --- src/diffusers/models/resnet.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 9ed8adf37192..2bd0cbce77d8 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -448,15 +448,13 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): out = input.view(-1, in_h, 1, in_w, 1, minor) # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 - using_mps = out.device.type == "mps" - if using_mps: + if input.device.type == "mps": out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) - if using_mps: - out = out.to("mps") + out = out.to(input.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), From dfd5a6e631ba662a24f2ab7ddbe9aa1cf26186e2 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 8 Sep 2022 12:32:34 +0200 Subject: [PATCH 37/39] Remove spurious blank space. --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 2a39361e2958..7c098adbd87c 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -42,7 +42,7 @@ def test_from_pretrained_save_pretrained(self): with torch.no_grad(): # Warmup pass when using mps (see #372) if torch_device == "mps" and isinstance(model, ModelMixin): - _ = model(**self.dummy_input) + _ = model(**self.dummy_input) _ = new_model(**self.dummy_input) image = model(**inputs_dict) From b0579c2447eb8fd8dde28cc24e9d30bc72c9f536 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 8 Sep 2022 12:36:29 +0200 Subject: [PATCH 38/39] Remove unused symbol. --- tests/test_models_vq.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py index 95c79b836b47..69468efbb81a 100644 --- a/tests/test_models_vq.py +++ b/tests/test_models_vq.py @@ -18,7 +18,6 @@ import torch from diffusers import VQModel -from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin From 2e524575f12918069d64910c457e5c4f22d5883a Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 8 Sep 2022 12:43:39 +0200 Subject: [PATCH 39/39] Remove README_mps. Point to the documentation instead. --- README.md | 2 +- README_mps.md | 46 ---------------------------------------------- 2 files changed, 1 insertion(+), 47 deletions(-) delete mode 100644 README_mps.md diff --git a/README.md b/README.md index 459f99c88162..47b30f9e9272 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ conda install -c conda-forge diffusers **Apple Silicon (M1/M2) support** -See [README_mps.md](README_mps.md). +Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps). ## Contributing diff --git a/README_mps.md b/README_mps.md deleted file mode 100644 index cb1df05de5d9..000000000000 --- a/README_mps.md +++ /dev/null @@ -1,46 +0,0 @@ -## How to use Stable Diffusion in Apple Silicon (M1/M2) - -🤗 Diffusers is compatible with Apple silicon for Stable Diffusion inference, using the PyTorch `mps` device. These are the steps you need to follow to use your M1 or M2 computer with Stable Diffusion. - -### Requirements - -- Mac computer with Apple silicon (M1/M2) hardware. -- macOS 12.3 or later. -- arm64 version of Python. -- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.13.0.dev20220830` or later. - -### Inference Pipeline - -The snippet shown below demonstrates how to use the `mps` backend using the familiar `to()` interface to move the Stable Diffusion pipeline to your M1 or M2 device. - -We recommend to "prime" the pipeline using an additional one-time pass through it. This is a temporary workaround for a weird issue we have detected: the first inference pass produces slightly different results than subsequent ones. You only need to do this pass once, and it's ok to use just one inference step and discard the result. - -```python -# make sure you're logged in with `huggingface-cli login` -from diffusers import StableDiffusionPipeline - -pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True) -pipe = pipe.to("mps") - -prompt = "a photo of an astronaut riding a horse on mars" - -# First-time "warmup" pass (see explanation above) -_ = pipe(prompt, num_inference_steps=1) - -# Results match those from the CPU device after the warmup pass. -image = pipe(prompt).images[0] -``` - -### Known Issues - -- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372). -- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this might be related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039#issuecomment-1237735249), but we need to investigate in more depth. For now, we recommend to iterate instead of batching. - -### Performance - -These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5. - -| Device | Steps | Time | -|--------|-------|---------| -| CPU | 50 | 213.46s | -| MPS | 50 | 30.81s |