Skip to content

MPS fails with float16 and PyTorch 2.0 #2521

@mannby

Description

@mannby

Describe the bug

I don't know if this can be fixed in diffusers or if it's a PyTorch or MPS issue, but if you take the basic diffusers example from https://huggingface.co/docs/diffusers/optimization/mps and change the type to float16, like this:

# make sure you're logged in with `huggingface-cli login`
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe = pipe.to("mps")

# Recommended if your computer has < 64 GB of RAM
pipe.enable_attention_slicing()

prompt = "a photo of an astronaut riding a horse on mars"

# First-time "warmup" pass (see explanation above)
_ = pipe(prompt, num_inference_steps=1)

# Results match those from the CPU device after the warmup pass.
image = pipe(prompt).images[0]

you get what looks like an f16/f32 mismatch error:

loc("varianceEps"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/9e200cfa-7d96-11ed-886f-a23c4f261b56/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":228:0)):
error: input types 'tensor<1x77x1xf16>' and 'tensor<1xf32>' are not broadcast compatible

Reproduction

Run the script in the description with the latest diffusers (0.13.1) (and huggingface-hub-0.12.1 transformers-4.26.1) and a condo install of PyTorch from pytorch-nightly (2.0.0.dev20230228, py3.10_0)

Logs

loc("varianceEps"("(mpsFileLoc): /AppleInternal/Library/BuildRoots/9e200cfa-7d96-11ed-886f-a23c4f261b56/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm":228:0)):
error: input types 'tensor<1x77x1xf16>' and 'tensor<1xf32>' are not broadcast compatible

System Info

Intel Mac x64 w/ AMD Radeon Pro 5700 XT 16 GB (3.6 GHz 10-Core Intel Core i9)

diffusers (0.13.1) (and huggingface-hub-0.12.1 transformers-4.26.1) and a condo install of PyTorch from pytorch-nightly (2.0.0.dev20230228, py3.10_0)

# Name                    Version                   Build  Channel
absl-py                   1.3.0           py310hecd8cb5_0  
accelerate                0.16.0                   pypi_0    pypi
aiohttp                   3.8.3           py310h6c40b1e_0  
aiosignal                 1.2.0              pyhd3eb1b0_0  
async-timeout             4.0.2           py310hecd8cb5_0  
attrs                     22.1.0          py310hecd8cb5_0  
autopep8                  1.6.0              pyhd3eb1b0_1  
blas                      1.0                         mkl  
blinker                   1.4             py310hecd8cb5_0  
brotli                    1.0.9                hca72f7f_7  
brotli-bin                1.0.9                hca72f7f_7  
brotlipy                  0.7.0           py310hca72f7f_1002  
bzip2                     1.0.8                h1de35cc_0  
c-ares                    1.18.1               hca72f7f_0  
ca-certificates           2023.01.10           hecd8cb5_0  
cachetools                4.2.2              pyhd3eb1b0_0  
certifi                   2022.12.7       py310hecd8cb5_0  
cffi                      1.15.1          py310h6c40b1e_3  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
click                     8.0.4           py310hecd8cb5_0  
contourpy                 1.0.5           py310haf03e11_0  
cryptography              37.0.1          py310hf6deb26_0  
cycler                    0.11.0             pyhd3eb1b0_0  
diffusers                 0.13.1                   pypi_0    pypi
facenet-pytorch           2.5.2                    pypi_0    pypi
ffmpeg                    4.3                  h0a44026_0    pytorch
filelock                  3.9.0           py310hecd8cb5_0  
flit-core                 3.6.0              pyhd3eb1b0_0  
fonttools                 4.25.0             pyhd3eb1b0_0  
freetype                  2.12.1               hd8bbffd_0  
frozenlist                1.3.3           py310h6c40b1e_0  
gettext                   0.21.0               h7535e17_0  
giflib                    5.2.1                h6c40b1e_3  
gmp                       6.2.1                he9d5cce_3  
gmpy2                     2.1.2           py310hd5de756_0  
gnutls                    3.6.15               hed9c0bf_0  
google-auth               2.6.0              pyhd3eb1b0_0  
google-auth-oauthlib      0.4.4              pyhd3eb1b0_0  
grpcio                    1.42.0          py310ha29bfda_0  
huggingface-hub           0.12.1                   pypi_0    pypi
icu                       58.2                 h0a44026_3  
idna                      3.4             py310hecd8cb5_0  
importlib-metadata        6.0.0                    pypi_0    pypi
intel-openmp              2021.4.0          hecd8cb5_3538  
jpeg                      9e                   hca72f7f_0  
kiwisolver                1.4.4           py310hcec6c5f_0  
lame                      3.100                h1de35cc_0  
lcms2                     2.12                 hf1fd2bf_0  
lerc                      3.0                  he9d5cce_0  
libbrotlicommon           1.0.9                hca72f7f_7  
libbrotlidec              1.0.9                hca72f7f_7  
libbrotlienc              1.0.9                hca72f7f_7  
libcxx                    14.0.6               h9765a3e_0  
libdeflate                1.17                 hb664fd8_0  
libffi                    3.4.2                hecd8cb5_6  
libiconv                  1.16                 hca72f7f_2  
libidn2                   2.3.2                h9ed2024_0  
libpng                    1.6.37               ha441bb4_0  
libprotobuf               3.20.3               hfff2838_0  
libtasn1                  4.16.0               h9ed2024_0  
libtiff                   4.5.0                hcec6c5f_2  
libunistring              0.9.10               h9ed2024_0  
libwebp                   1.2.4                hf6ce154_1  
libwebp-base              1.2.4                h6c40b1e_1  
libxml2                   2.9.14               hbf8cd5e_0  
llvm-openmp               14.0.6               h0dcd299_0  
lz4-c                     1.9.4                hcec6c5f_0  
markdown                  3.4.1           py310hecd8cb5_0  
markupsafe                2.1.1           py310hca72f7f_0  
matplotlib                3.6.2           py310hecd8cb5_0  
matplotlib-base           3.6.2           py310h220de94_0  
mkl                       2021.4.0           hecd8cb5_637  
mkl-service               2.4.0           py310hca72f7f_0  
mkl_fft                   1.3.1           py310hf879493_0  
mkl_random                1.2.2           py310hc081a56_0  
mpc                       1.1.0                h6ef4df4_1  
mpfr                      4.0.2                h9066e36_1  
mpmath                    1.2.1                    pypi_0    pypi
multidict                 6.0.2           py310hca72f7f_0  
munkres                   1.1.4                      py_0  
ncurses                   6.4                  hcec6c5f_0  
nettle                    3.7.3                h230ac6f_1  
networkx                  3.0                      pypi_0    pypi
numpy                     1.23.5          py310h9638375_0  
numpy-base                1.23.5          py310ha98c3c9_0  
oauthlib                  3.2.1           py310hecd8cb5_0  
opencv-python-headless    4.7.0.72                 pypi_0    pypi
openh264                  2.1.1                h8346a28_0  
openssl                   1.1.1t               hca72f7f_0  
packaging                 22.0            py310hecd8cb5_0  
pillow                    9.4.0           py310hcec6c5f_0  
pip                       22.3.1          py310hecd8cb5_0  
protobuf                  3.20.3          py310hcec6c5f_0  
psutil                    5.9.4                    pypi_0    pypi
pyasn1                    0.4.8              pyhd3eb1b0_0  
pyasn1-modules            0.2.8                      py_0  
pycodestyle               2.10.0          py310hecd8cb5_0  
pycparser                 2.21               pyhd3eb1b0_0  
pyjwt                     2.4.0           py310hecd8cb5_0  
pyopenssl                 22.0.0             pyhd3eb1b0_0  
pyparsing                 3.0.9           py310hecd8cb5_0  
pysocks                   1.7.1           py310hecd8cb5_0  
python                    3.10.9               h218abb5_0  
python-dateutil           2.8.2              pyhd3eb1b0_0  
pytorch                   2.0.0.dev20230228        py3.10_0    pytorch-nightly
pyyaml                    6.0             py310h6c40b1e_1  
readline                  8.2                  hca72f7f_0  
regex                     2022.7.9        py310hca72f7f_0  
requests                  2.28.1          py310hecd8cb5_0  
requests-oauthlib         1.3.0                      py_0  
rsa                       4.7.2              pyhd3eb1b0_1  
setuptools                65.6.3          py310hecd8cb5_0  
six                       1.16.0             pyhd3eb1b0_1  
sqlite                    3.40.1               h880c91c_0  
sympy                     1.11.1          py310hecd8cb5_0  
tensorboard               2.10.0          py310hecd8cb5_0  
tensorboard-data-server   0.6.1           py310h7242b5c_0  
tensorboard-plugin-wit    1.6.0                      py_0  
tk                        8.6.12               h5d9f67b_0  
tokenizers                0.11.4          py310h8776b5c_1  
toml                      0.10.2             pyhd3eb1b0_0  
torchaudio                2.0.0.dev20230228       py310_cpu    pytorch-nightly
torchvision               0.15.0.dev20230228       py310_cpu    pytorch-nightly
tornado                   6.2             py310hca72f7f_0  
tqdm                      4.64.1          py310hecd8cb5_0  
transformers              4.26.1                   pypi_0    pypi
typing-extensions         4.4.0           py310hecd8cb5_0  
typing_extensions         4.4.0           py310hecd8cb5_0  
tzdata                    2022g                h04d1e81_0  
urllib3                   1.26.14         py310hecd8cb5_0  
werkzeug                  2.2.2           py310hecd8cb5_0  
wheel                     0.38.4          py310hecd8cb5_0  
xz                        5.2.10               h6c40b1e_1  
yaml                      0.2.5                haf1e3a3_0  
yarl                      1.8.1           py310hca72f7f_0  
zipp                      3.15.0                   pypi_0    pypi
zlib                      1.2.13               h4dc903c_0  
zstd                      1.5.2                hcb37349_0 

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions