Fix random validation errors for good (and restore torch.compile for the validation pipeline at the same time)#1131
Merged
bghira merged 2 commits intobghira:mainfrom Nov 10, 2024
Conversation
…dom validation errors with SD3.5 after commit 48cfc09 removed some earlier fixes. This also fixes torch.compile not getting called for the validation pipeline. Calling self.pipeline.to(self.inference_device) appears to have an unwanted side-effect: it moves additional text encoders to the accelerator device. In the case of SD3.5, I saw text_encoder_2 and text_encoder_3 getting moved to the GPU. This caused my RTX 3090 to go OOM when trying to generate validation images during training. Explicitly setting text_encoder_2 and text_encoder_3 to None in extra_pipeline_kwargs fixes this issue.
Owner
|
the reason this is done is because some of diffusers pipelines require the text encoder to be there simply so that its dtype can be checked 😮💨 i could copy those pipelines in and fix them, but they'll need to be checked for compatibility fully first. i don't want to move the whole pipeline to GPU when doing validations because the necessary components are already there - the text encoders then remained on CPU. but it'd be better for them to not be loaded at all. |
bghira
reviewed
Nov 9, 2024
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I saw the random validation errors were back after commit 48cfc09 removed some of my earlier fixes for the random validation issues. This time I decided to dig to the bottom of the issue of why the pipeline is randomly on the cpu.
The root cause turned out to be an early return in
setup_pipelinewhich means that this line at the end ofsetup_pipelineis never executed:And torch.compile before that is also never called.
I did notice that moving the entire pipeline to the accelerator device has some unwanted side-effects. It also moved text_encoder_2 and text_encoder_3 to the GPU when I was testing with SD3.5. Setting text_encoder_2 and text_encoder_3 to None in
extra_pipeline_kwargsprevents them from being loaded in the first place. I'm hoping other pipelines won't complain about text_encoder_2 and text_encoder_3 in kwargs. I tried throwing in text_encoder_4 there and the SD3 pipeline didn't complain.I tried testing validation_torch_compile but it seemed to be taking forever so I eventually just ended up killing it.