Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7529 do not disable autocast for cuda devices #7530

Merged
merged 15 commits into from Apr 2, 2024

Conversation

bghira
Copy link
Contributor

@bghira bghira commented Mar 30, 2024

What does this PR do?

Fixes #7529

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@bghira
Copy link
Contributor Author

bghira commented Mar 30, 2024

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking care of it!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bghira
Copy link
Contributor Author

bghira commented Mar 31, 2024

i'm going to switch from this method to contextlib's nullcontext, as local testing with nightly pytorch indicates that'll avoid some new errors that will come down the pipe, and specifically the null context will allow us to entirely bypass all issues with autocast - sometimes, the platform is disabled incorrectly, or has partial support.

the nullcontext will allow us to fully decide when to commit to providing autocast on a new platform

@bghira
Copy link
Contributor Author

bghira commented Mar 31, 2024

@sayakpaul ready :-)

@bghira
Copy link
Contributor Author

bghira commented Mar 31, 2024

@pcuenca i tried to unify the behaviour of the context selection to use contextlib

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👌

@simbrams
Copy link

simbrams commented Apr 1, 2024

Can you also remove the error test in the instruct pix2pix xl pipeline 🙏 :

else:
raise ValueError(
"For the given accelerator, there seems to be an unexpected problem in type-casting. Please file an issue on the PyTorch GitHub repository. See also: https://github.com/huggingface/diffusers/pull/7446/."
)

@bghira
Copy link
Contributor Author

bghira commented Apr 1, 2024

@simbrams done

@bghira
Copy link
Contributor Author

bghira commented Apr 1, 2024

@sayakpaul @DN6 @yiyixuxu i've updated every pipeline that uses 🤗 Accelerate to disable AMP so that accelerator.prepare(...) does not try loading an autocast ctx.

Copy link

@christopher5106 christopher5106 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me now. I mainly checked the pipeline and the train script for sdxl lora.

@sayakpaul
Copy link
Member

@bghira thanks for the changes. A couple of comments:

  • GPUs are first-class citizens for the training scripts. In order to support Silicon compatibility, we cannot break that. I am sorry that I didn't check that earlier.
  • I see we're now making changes to the pipelines so that they use an inference context. I don't think we can have that level of change in our pipelines. The silicon support is experimental from PyTorch itself. So when that gets stabilized a bit, we might have to undo the changes we're currently making. So, I am not in favor of going via this route.
  • IMO, the changes should be kept as minimal as possible. If making changes to the training scripts and the schedulers (like we did before) to support Silicon compatibility isn't enough, that means this feature itself is a little too experimental to support as that requires non-trivial changes to the legacy core.

So, IMO, we should make sure:

  • Training works as before on GPUs.
  • Silicon training works but with minimal changes. If that prevents FP16 training, I would still be okay with that because it seems like it's quite experimental. If disabling the intermediate validation could work -- that could be another route too (albeit not a pragmatic one).

@pcuenca @yiyixuxu would like to get your opinions too.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 2, 2024

@sayakpaul

I see we're now making changes to the pipelines so that they use an inference context.

I didn't see it in this PR though

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Apr 2, 2024

@sayakpaul

So, IMO, we should make sure:

Training works as before on GPUs.
Silicon training works but with minimal changes. If that prevents FP16 training, I would still be okay with that because it >seems like it's quite experimental. If disabling the intermediate validation could work -- that could be another route too (albeit not a pragmatic one).

agree with this

@sayakpaul
Copy link
Member

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

that was already using a context manager, i simply updated it to be consistent

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

as an aside, i've finetuned the issue out of 2.1-v weights. so, a fixed model is available ... but it seems like the actual 2.1-v weights should be fixable somehow with a new or fixed extraction. (SDXL works great)

@sayakpaul
Copy link
Member

Okay good to know. Thanks!

What are our blockers now? I am happy to cook up a context manager in training_utils.py and then we can take it from there?

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

also, the instruct pix2pix script is where a lot of the original confusion came from! that you wrote :D there, the autocast disables for fp16 and pulls the device type str to chop :0 off - so i learned a few things along the way to fixing this. but now everything is at least consistent - there were some community pipelines i did not update, as it didn't seem appropriate to do so.

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

i've been trying to chase down that training issue. i want to try running SD 1.5 and see if the problem for legacy models is limited to 2.x, in which case it'll be best to limit training to SD 1.x.. will update shortly

@sayakpaul
Copy link
Member

the autocast disables for fp16 and pulls the device type str to chop :0 off

I see it's enabled when mixed precision is FP16:

str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"

But anyway, as mentioned I am down to writing a context manager to help ease things a bit. Maybe providing everything you have discovered so far regarding training on Silicon could be a very nice resource.

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

on mps,

  • sd 2.1 does not work in diffusers or simpletuner
  • sd 1.5 does not work in diffusers or simpletuner
    • both of these have MPS internal crashes
  • sdxl does not work in diffusers, but does work in simpletuner
    • in diffusers, fp16 training on sdxl causes the attention processor query to be in float32 while the weights are half

i've gone through them a lot today, and really not sure why it would cause the MPS crash, it's incredibly unfortunate.

the SDXL error does seem like it can be overcome, i just need to put a bit more time into hunting it down

so far, this patch works as advertised at least to resolve the cuda-specific regression and it should probably be merged

@sayakpaul
Copy link
Member

That's quite informative indeed.

sdxl does not work in diffusers, but does work in simpletuner
in diffusers, fp16 training on sdxl causes the attention processor query to be in float32 while the weights are half

Curious, what do you do on simpletuner to have that fixed?

@sayakpaul
Copy link
Member

I will wait for @pcuenca to give it a shot too, before merging.

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

in simpletuner i have a lot more brute-force general handling of dtypes and leave less responsibility for that on other parts of the stack.

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

good morning. after some more checking, fp32 training works fine for MPS on SD 1.5 and 2.1. but this feels disappointing, as dtype issues just shouldn't be the only thing holding back memory-efficient training. but i guess it's all we can do for now?

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

pytorch 2.3 supports MPS bf16:

Valid Types: [torch.float32, torch.float32, torch.float16, torch.float16, torch.bfloat16, torch.complex64, torch.uint8, torch.int8, torch.int16, torch.int16, torch.int32, torch.int32, torch.int64, torch.int64, torch.bool]

Invalid Types: [torch.float64, torch.float64, torch.complex128, torch.complex128, torch.quint8, torch.qint8, torch.quint4x2]

pytorch 2.2.1 does not:

Valid Types: [torch.float32, torch.float32, torch.float16, torch.float16, torch.uint8, torch.int8, torch.int16, torch.int16, torch.int32, torch.int32, torch.int64, torch.int64, torch.bool]
Invalid Types: [torch.float64, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.complex128, torch.quint8, torch.qint8, torch.quint4x2]
---
[bfloat16] BFloat16 is not supported on MPS

@sayakpaul
Copy link
Member

@bghira I am going to merge the PR in a while. But tomorrow, I will take a closer look into autocast related things that we're using in our training scripts and see if we can get rid of them. I can do these tests fast in a CUDA environment. Will look into an autocast context manager (to add to training_utils.py too).

But just so I understand the issues for M3 training correctly, w.r.t #7530 (comment), were you able to pinpoint a bug when training with FP16? Does it fail during intermediate inference?

@@ -752,6 +752,10 @@ def main(args):
project_config=accelerator_project_config,
)

# Disable AMP for MPS.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bghira possible to add a more descriptive comment here?

Copy link
Member

@sayakpaul sayakpaul Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I see this is not added to some scripts such as the advanced diffusion or consistency distillation scripts.

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

the fp16 mps bug happens on the third call to nn.Linear, the Parameter object is created with fp32 dtype for weight and bias but the query is fp16

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

its an actual training failure on step 1 when t=999

i saw that nn.Linear has a dtype parameter i tried setting but the Parameter class it gets passed into actually doesnt make use of the dtype parameter which seems like a torch bug

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

@sayakpaul
Copy link
Member

Thanks very much folks for the insightful discussions on autocast and bfloat16. I am under the impression that PyTorch, indeed, has some weird voodoo going under the hood for these.

Meanwhile, I am going to merge this PR to unblock our users. So, thanks to @bghira for the prompt action here.

This thread still remains open for discussions around autocast and bfloat16 which I believe will be valuable for the community.

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

without attention slicing, we see:

/AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:788: failed assertion `[MPSNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32'

but with attention slicing, this error disappears.

the resolution seems to be:

    # Base components to prepare
    if torch.backends.mps.is_available():
        accelerator.native_amp = False
    results = accelerator.prepare(unet, lr_scheduler, optimizer, *train_dataloaders)
    unet = results[0]
    if torch.backends.mps.is_available():
        unet.set_attention_slice()

however, the casting type issue is now there for bf16 and fp16, but bf16 no longer crashes. still digging into it.

@sayakpaul sayakpaul merged commit 8e963d1 into huggingface:main Apr 2, 2024
15 checks passed
@bghira bghira deleted the bugfix/7529-autocast-failure-cuda branch April 2, 2024 15:12
@AmericanPresidentJimmyCarter
Copy link
Contributor

Thanks very much folks for the insightful discussions on autocast and bfloat16. I am under the impression that PyTorch, indeed, has some weird voodoo going under the hood for these.

Meanwhile, I am going to merge this PR to unblock our users. So, thanks to @bghira for the prompt action here.

This thread still remains open for discussions around autocast and bfloat16 which I believe will be valuable for the community.

The issues seem to be scoped to bfloat16, which is unfortunate because most hardware has moved to it over the past several years.

The assert_close should fail. Interestingly, the assert_close passes if we change the autocast dtype to torch.float16.
pytorch/pytorch#120930
A model with PyTorch AMP produces different model results when used in conjunction with CUDA graphs.
In other words, it seems that conducting CUDA graph capture while wrapped with autocast leads to wrong outputs during training (specifically, after the first weight update).
pytorch/pytorch#71631

The second issue is from Jan 2022, so we may still be waiting for a while for a fix. It is clear that autocast was originally designed with fp16 and the fp16 grad scaler in mind and has issues with other types.

@bghira
Copy link
Contributor Author

bghira commented Apr 2, 2024

so disabling autocast on bf16 everywhere does seem like the way to go. honestly the new optimiser solves any need to use fp16 - pure bf16 is lighter weight and simpler code

noskill pushed a commit to noskill/diffusers that referenced this pull request Apr 5, 2024
* 7529 do not disable autocast for cuda devices

* Remove typecasting error check for non-mps platforms, as a correct autocast implementation makes it a non-issue

* add autocast fix to other training examples

* disable native_amp for dreambooth (sdxl)

* disable native_amp for pix2pix (sdxl)

* remove tests from remaining files

* disable native_amp on huggingface accelerator for every training example that uses it

* convert more usages of autocast to nullcontext, make style fixes

* make style fixes

* style.

* Empty-Commit

---------

Co-authored-by: bghira <bghira@users.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
8 participants