Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[show and tell] apple mps support #6

Closed
bghira opened this issue Apr 10, 2024 · 9 comments
Closed

[show and tell] apple mps support #6

bghira opened this issue Apr 10, 2024 · 9 comments

Comments

@bghira
Copy link

bghira commented Apr 10, 2024

with newer pytorch (2.4 nightly) we get bfloat16 support in MPS.

i tested this:

from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
import soundfile as sf
import torch

device = "mps:0"

model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler_tts_mini_v0.1").to(device=device, dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler_tts_mini_v0.1")

prompt = "welcome to huggingface"
description = "An old man."

input_ids = tokenizer(description, return_tensors="pt").input_ids.to(device=device)
prompt_input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device=device)

generation = model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.to(torch.float32).cpu().numpy().squeeze()
sf.write("parler_tts_out.wav", audio_arr, model.config.sampling_rate)
@sanchit-gandhi sanchit-gandhi changed the title apple mps support [show and tell] apple mps support Apr 11, 2024
@sanchit-gandhi
Copy link
Collaborator

That's awesome, thanks for sharing @bghira! How fast was inference on your local machine?

@bghira
Copy link
Author

bghira commented Apr 11, 2024

it gets slower as the sample size increases but this test script takes about 10 seconds to run on an M3 Max.

@maxtheman
Copy link

I got this working as well! Inference time seems to increase more than linearly with prompt size

  • 3 seconds of audio: 10 seconds of generation
  • 8s of audio: ~90 seconds of generation
  • 10 of audio: ~3min of generation

I think the reason is that itself takes a surprising amount of memory — loading the model takes the expected ~3GB of memory, but then inference takes 15 GB on top of that, which is probably what's slowing it down on my machine (16GB M2).

@QueryType
Copy link

I got this working as well! Inference time seems to increase more than linearly with prompt size

  • 3 seconds of audio: 10 seconds of generation
  • 8s of audio: ~90 seconds of generation
  • 10 of audio: ~3min of generation

I think the reason is that itself takes a surprising amount of memory — loading the model takes the expected ~3GB of memory, but then inference takes 15 GB on top of that, which is probably what's slowing it down on my machine (16GB M2).

Swapping activated? I will try on Mac Mini M2 (24GB). Do we know the performance on CUDA on similar machine?

@bghira
Copy link
Author

bghira commented Apr 12, 2024

on the 128gb M3 Max i can get pretty far into the output window before the time increases to 3 minutes.

it'll take about a minute for 30 seconds of audio.

@QueryType
Copy link

of

I am getting, 2s of audio: 11 seconds and 6s of audio: 36 seconds

@janewu77
Copy link

my data , on 64G M2 Max

seconds of audio cpu(seconds of generation) mps(seconds of generation)
1 7 10
3 13 17
7 30 44
9 41 194
18 71 308

@andimarafioti
Copy link
Member

I'm getting this error

NotImplementedError: Output channels > 65536 not supported at the MPS device. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Did something change or is it still working for you?

In [2]: torch.version
Out[2]: '2.5.0.dev20240726'

@bghira
Copy link
Author

bghira commented Jul 26, 2024

stick with pytorch 2.4 unless you want things blowing up constantly is my suggestion

@bghira bghira closed this as completed Jul 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants