-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MPS support #472
Add MPS support #472
Conversation
Thanks a lot for this, this is awesome. I was thinking about simplifying the device completely, by using using Could then just check the device of the model, and avoiding specifying the decide completely. What do you think about that? |
That's a great idea, I was not aware of that function. |
|
Great! I'm thinking of merging this and then making a second pass later to modify the documentation. |
Fullgrad should be fixed, i overlooked it there. |
Thanks a lot for this contribution @soberhofer Merging this, and will modify the readme afterwards. Happy holidays! |
Thanks for merging :) And thanks a lot for maintaining this awesome project 👍 |
Instead of using the boolean parameter
use_cuda
, we take a parameter of typetorch.device
. This can becuda
but it can also betorch.device("mps")
for example. This adds support for GPU Acceleration on Apple Silicon.I ran all the tests switching the device to
mps
, and they all passed.For now i have only edited the code. Documentation and tutorials still have the current content.
This addresses #471