Skip to content

Flux2: Tensor tuples can cause issues for checkpointing #12776

@dxqb

Description

@dxqb

Describe the bug

The modulations calculated here...

self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)

...return tuples of Tensors:

return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))

These tuples are passed from outside the transformer blocks into the checkpointed transformer blocks.
If the tensors inside the tuples require gradients, this can cause issues for the backward pass:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

torch checkpointing doesn't identify the tuples as tensors. Only tensors are identified:
https://github.com/pytorch/pytorch/blob/d38164a545b4a4e4e0cf73ce67173f70574890b6/torch/utils/checkpoint.py#L252

Reproduction

isolated reproduction code is difficult because of the size of the model. but I'll post a draft PR in a minute.

Logs

System Info

torch 2.8, diffusers HEAD

Who can help?

@DN6 @yiyixuxu @sayakpaul

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions