diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 1a4464432425..a830d5bf8141 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -114,7 +114,26 @@ def __call__( if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - + if ( + torch.distributed.is_initialized() + and torch.distributed.get_world_size() > 1 + and not attn.training + and self._parallel_config is not None + and ( + self._parallel_config.context_parallel_config.ring_degree > 1 + or self._parallel_config.context_parallel_config.ulysses_degree > 1 + ) + ): + world_size = torch.distributed.get_world_size() + + seq_len = key.shape[1] + if seq_len > 100: + key_list = [torch.empty_like(key) for _ in range(world_size)] + value_list = [torch.empty_like(value) for _ in range(world_size)] + torch.distributed.all_gather(key_list, key.contiguous()) + torch.distributed.all_gather(value_list, value.contiguous()) + key = torch.cat(key_list, dim=1) + value = torch.cat(value_list, dim=1) hidden_states = dispatch_attention_fn( query, key,