Flux2: Tensor tuples can cause issues for checkpointing#12777
Flux2: Tensor tuples can cause issues for checkpointing#12777dg845 merged 11 commits intohuggingface:mainfrom
Conversation
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Hi @dxqb, thanks for opening this PR and thanks for your patience! This change looks good to me. As mentioned in #12776 (comment), it would be nice to have a small script to reproduce/test this behavior. |
no repro-code, but it's clear now why this happens. it's documented by pytorch: #12776 (comment) |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
dg845
left a comment
There was a problem hiding this comment.
Thanks for the changes! Can you solve the merge conflicts with main? I think they may be a result of #12524, which switches over to using Python 3.9+ style type hints without explicit typing imports, including in transformer_flux2.py.
done and tested using Nerogar/OneTrainer#1279 |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
Merging as the CI failure is unrelated. |
* upgrade diffusers for huggingface/diffusers#12777 * add preset
addresses #12776
What does this PR do?
This PR keeps the tuples, but moves the splitting from tensors into tuples of tensors to the transformer blocks, to avoid issues with checkpointing. By passing a tensor directly, torch.utils.checkpoint() identifies the tensor and saves it accordingly without running a backward through it multiple times.
This is a draft. If you agree with this change I can make it nicer. Among other things:
Who can review?
@yiyixuxu and @asomoza