-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The modulations calculated here...
| self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False) |
...return tuples of Tensors:
| return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets)) |
These tuples are passed from outside the transformer blocks into the checkpointed transformer blocks.
If the tensors inside the tuples require gradients, this can cause issues for the backward pass:
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
torch checkpointing doesn't identify the tuples as tensors. Only tensors are identified:
https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/utils/checkpoint.py#L252
Reproduction
isolated reproduction code is difficult because of the size of the model. but I'll post a draft PR in a minute.
Logs
System Info
torch 2.8, diffusers HEAD
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working