Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device support improvements (MPS) #1054

Merged
merged 3 commits into from
Jan 31, 2024
Merged

Conversation

akx
Copy link
Contributor

@akx akx commented Jan 16, 2024

This PR follows what #666 started to make things work better for MPS (Apple Silicon).

It turns out that accelerate already prefers MPS by default, but the third commit in the PR makes the accelerate device more obvious to the user.

Beyond that, this PR:

  • refactors the duplicated code that cleans up device memory (cuda.empty_cache(), gc.collect()) into a single function and teaches it about MPS
  • refactors the duplicated code to initialize IPEX (XPU) into a single function Moved to Deduplicate ipex initialization code #1060
  • makes it possible to control the non-Accelerate device selection via environment variables (documented in README)
    • the automatic determination here prefers cuda, then mps, then cpu.

I tested that training basically works on my machine, but I don't have an XPU or suitable CUDA machine to test the other changes on.

@akx akx mentioned this pull request Jan 16, 2024
@kohya-ss
Copy link
Owner

Thank you for this! This is very nice PR, but the modification is so large that it will take some time to verify. Please understand.

@kohya-ss
Copy link
Owner

Thank you for the great work! I have started my review and have a few concerns.

First, I do not like the use of environment variables. It also makes breaking changes to the existing environment. Is it possible to get the target device from accelerate?

Also, I do not want to change the IPEX part as it is maintained by another author and not by me. I would like to keep it separate from the rest of the scripts as much as possible.

Is it possible to take these into account?

@akx
Copy link
Contributor Author

akx commented Jan 18, 2024

First, I do not like the use of environment variables. It also makes breaking changes to the existing environment.

What breaking changes are you referring to? 🤔

If no SD_DEVICE or SD_INFERENCE_DEVICE environment variable is set, the logic is the same as before (with the addition of "mps" being preferred over "cpu").

Is it possible to get the target device from accelerate?

Sure, that could be used too, but that would certainly be a change from the previous behavior.

Also, I do not want to change the IPEX part as it is maintained by another author and not by me. I would like to keep it separate from the rest of the scripts as much as possible.

I can move the IPEX utility into a separate file, sure. The change itself (that files are less "littered" with the IPEX initialization code with its try-excepts and all) is for the better, though, I think?

EDIT: I moved the IPEX stuff to #1060 instead to keep this simple.

@akx akx force-pushed the mps branch 2 times, most recently from bf120b4 to 34d5260 Compare January 18, 2024 16:21
@akx akx changed the title Device support improvements (MPS, XPU) Device support improvements (MPS) Jan 18, 2024
@kohya-ss
Copy link
Owner

First, I do not like the use of environment variables. It also makes breaking changes to the existing environment.

What breaking changes are you referring to? 🤔

If no SD_DEVICE or SD_INFERENCE_DEVICE environment variable is set, the logic is the same as before (with the addition of "mps" being preferred over "cpu").

Oh, sorry, I didn't notice that. That's nice! However, I prefer not to use the environment variables... I would like to have all settings in one place.

Is it possible to get the target device from accelerate?

Sure, that could be used too, but that would certainly be a change from the previous behavior.

In my understanding, if we use the device from accelerate if needed, the behavior will be the same as now. Because the training is currently working on CUDA or MPS basically.

EDIT: I moved the IPEX stuff to #1060 instead to keep this simple.

Thank you! However, I just don't want to touch any part of IPEX to keep the responsibility separate😅

@akx
Copy link
Contributor Author

akx commented Jan 19, 2024

Oh, sorry, I didn't notice that. That's nice! However, I prefer not to use the environment variables... I would like to have all settings in one place.
In my understanding, if we use the device from accelerate if needed, the behavior will be the same as now. Because the training is currently working on CUDA or MPS basically.

I changed this to prefer the Accelerate device. It will be close enough, unless someone is actually using a distributed device type for Accelerate, and still wants to do inference for with a local device (CUDA, MPS, CPU). Those users are probably few and far between, and I'm sure they'll be technical enough to find a workaround 😄

As for environment variables, accelerate actually also reads some on its own (e.g. ACCELERATE_USE_CPU and so on), but I understand wanting to keep environment variables out of sd-scripts itself 😄

EDIT: I moved the IPEX stuff to #1060 instead to keep this simple.

Thank you! However, I just don't want to touch any part of IPEX to keep the responsibility separate😅

Aye, I pinged the IPEX maintainer in that PR. It's now a very simple refactoring of the repeated initialization code to a separate function, so it's functionally the same but keeps the IPEX stuff even better encapsulated out of the main training scripts.

Oh, and by the way, thank you for your work on sd-scripts!

@kohya-ss
Copy link
Owner

Thank you for updating!

It looks like accelerate is initializing to get the device. I wonder if that is not a problem.

In addition, I am very sorry that I did not notice this in the previous comment, it just doesn't seem to work correctly when using sd-scripts without setting up accelerate (using LoRA related utilities, or using model conversion only).

So, I think, in the training script, it is a redundant but simple and effective way to get the device from accelerate already initialized, and in other scripts to give the device as a command line argument.

@akx
Copy link
Contributor Author

akx commented Jan 19, 2024

In addition, I am very sorry that I did not notice this in the previous comment, it just doesn't seem to work correctly when using sd-scripts without setting up accelerate (using LoRA related utilities, or using model conversion only).

No worries – can you give me an example command line (and the resulting error/traceback) that doesn't work? I tried with a python sdxl_train_network.py ... call, and I don't have a .cache/huggingface/accelerate/default_config.yaml file, and it seems to start fine...

The code here currently does try to fall back to CUDA/MPS/CPU if there's any issue talking to Accelerate.

@kohya-ss
Copy link
Owner

No worries – can you give me an example command line (and the resulting error/traceback) that doesn't work? I tried with a python sdxl_train_network.py ... call, and I don't have a .cache/huggingface/accelerate/default_config.yaml file, and it seems to start fine...

Oh, sorry. Even if we remove default_config.yaml, accelerate.Accelerator().device returns some device ("cuda" in my env). But I don't think we can trust the device is appropriate.

So I think the simple way may be better.

@akx
Copy link
Contributor Author

akx commented Jan 19, 2024

@kohya-ss Understood – done! get_preferred_device() now simply does CUDA, MPS, CPU in that order of availability.

@kohya-ss
Copy link
Owner

Thank you for updating! I will review and merge this sooner!

@akx
Copy link
Contributor Author

akx commented Jan 23, 2024

I'll rebase this now that #1060 was merged.

@akx akx force-pushed the mps branch 2 times, most recently from 32dca12 to 04cef21 Compare January 23, 2024 12:26
@akx akx marked this pull request as draft January 23, 2024 12:27
@akx akx marked this pull request as ready for review January 23, 2024 12:31
@akx
Copy link
Contributor Author

akx commented Jan 23, 2024

@kohya-ss Rebased, ready for review again. By the way, would you prefer these PRs to target dev instead of main?

@kohya-ss
Copy link
Owner

Thank you! I prefer dev because I could test with other updates same time. But I can edit the target branch if needed.

@akx akx changed the base branch from main to dev January 25, 2024 10:10
@kohya-ss kohya-ss changed the base branch from dev to dev_device_support January 31, 2024 12:29
@kohya-ss kohya-ss merged commit 2ca4d0c into kohya-ss:dev_device_support Jan 31, 2024
1 check passed
@kohya-ss
Copy link
Owner

Thank you again for the PR. I have merged into the my new branch, and will add some little modification. I will merge the branch to dev (and main) sooner. I appreciate your understanding.

@akx
Copy link
Contributor Author

akx commented Jan 31, 2024

@kohya-ss Sure thing. Thank you for all of your work! 😄

@akx akx deleted the mps branch January 31, 2024 12:32
@Disty0
Copy link
Contributor

Disty0 commented Jan 31, 2024

Importing device_utils before running ipex_init() breaks ipex support.

torch.cuda.is_available() returns false without ipex_init and sd-scripts uses cpu.
I will merge ipex_interop with device_utils and add xpu device into it and then will open a PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants