diff --git a/docs/source/en/api/pipelines/cogvideox.md b/docs/source/en/api/pipelines/cogvideox.md
index 549666e60ebc..c7340eff40c4 100644
--- a/docs/source/en/api/pipelines/cogvideox.md
+++ b/docs/source/en/api/pipelines/cogvideox.md
@@ -29,6 +29,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
+There are two models available that can be used with the CogVideoX pipeline:
+- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
+- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
+
## Inference
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
@@ -68,7 +72,7 @@ With torch.compile(): Average inference time: 76.27 seconds.
### Memory optimization
-CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
+CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
- `pipe.enable_model_cpu_offload()`:
- Without enabling cpu offloading, memory usage is `33 GB`
diff --git a/scripts/convert_cogvideox_to_diffusers.py b/scripts/convert_cogvideox_to_diffusers.py
index c03013a7fff9..6448da7f1131 100644
--- a/scripts/convert_cogvideox_to_diffusers.py
+++ b/scripts/convert_cogvideox_to_diffusers.py
@@ -86,6 +86,9 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
"key_layernorm_list": reassign_query_key_layernorm_inplace,
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
"embed_tokens": remove_keys_inplace,
+ "freqs_sin": remove_keys_inplace,
+ "freqs_cos": remove_keys_inplace,
+ "position_embedding": remove_keys_inplace,
}
VAE_KEYS_RENAME_DICT = {
@@ -123,11 +126,21 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
state_dict[new_key] = state_dict.pop(old_key)
-def convert_transformer(ckpt_path: str):
+def convert_transformer(
+ ckpt_path: str,
+ num_layers: int,
+ num_attention_heads: int,
+ use_rotary_positional_embeddings: bool,
+ dtype: torch.dtype,
+):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
- transformer = CogVideoXTransformer3DModel()
+ transformer = CogVideoXTransformer3DModel(
+ num_layers=num_layers,
+ num_attention_heads=num_attention_heads,
+ use_rotary_positional_embeddings=use_rotary_positional_embeddings,
+ ).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[len(PREFIX_KEY) :]
@@ -145,9 +158,9 @@ def convert_transformer(ckpt_path: str):
return transformer
-def convert_vae(ckpt_path: str):
+def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
- vae = AutoencoderKLCogVideoX()
+ vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -172,13 +185,26 @@ def get_args():
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
- parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
+ parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
+ parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
parser.add_argument(
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
)
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
+ # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
+ parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
+ # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
+ parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
+ # For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
+ parser.add_argument(
+ "--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
+ )
+ # For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
+ parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
+ # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
+ parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
return parser.parse_args()
@@ -188,18 +214,33 @@ def get_args():
transformer = None
vae = None
+ if args.fp16 and args.bf16:
+ raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
+
+ dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
+
if args.transformer_ckpt_path is not None:
- transformer = convert_transformer(args.transformer_ckpt_path)
+ transformer = convert_transformer(
+ args.transformer_ckpt_path,
+ args.num_layers,
+ args.num_attention_heads,
+ args.use_rotary_positional_embeddings,
+ dtype,
+ )
if args.vae_ckpt_path is not None:
- vae = convert_vae(args.vae_ckpt_path)
+ vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
+ # Apparently, the conversion does not work any more without this :shrug:
+ for param in text_encoder.parameters():
+ param.data = param.data.contiguous()
+
scheduler = CogVideoXDDIMScheduler.from_config(
{
- "snr_shift_scale": 3.0,
+ "snr_shift_scale": args.snr_shift_scale,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
@@ -208,7 +249,7 @@ def get_args():
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
- "timestep_spacing": "linspace",
+ "timestep_spacing": "trailing",
}
)
@@ -218,5 +259,10 @@ def get_args():
if args.fp16:
pipe = pipe.to(dtype=torch.float16)
+ if args.bf16:
+ pipe = pipe.to(dtype=torch.bfloat16)
+ # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
+ # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
+ # is either fp16/bf16 here).
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index fc225567ddc1..75b4f164eb25 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -1783,6 +1783,148 @@ def __call__(
return hidden_states
+class CogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
+class FusedCogVideoXAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
+ query and key vectors, but does not include spatial normalization.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ qkv = attn.to_qkv(hidden_states)
+ split_size = qkv.shape[-1] // 3
+ query, key, value = torch.split(qkv, split_size, dim=-1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ from .embeddings import apply_rotary_emb
+
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
+ if not attn.is_cross_attention:
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ encoder_hidden_states, hidden_states = hidden_states.split(
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
+ )
+ return hidden_states, encoder_hidden_states
+
+
class XFormersAttnAddedKVProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
index 3bf6e68d2628..17fa2bbf40f6 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
@@ -902,7 +902,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
- scaling_factor (`float`, *optional*, defaults to 0.18215):
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index b2f496833176..d1366654c448 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -374,6 +374,90 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
return embeds
+def get_3d_rotary_pos_embed(
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ RoPE for video tokens with 3D structure.
+
+ Args:
+ embed_dim: (`int`):
+ The embedding dimension size, corresponding to hidden_size_head.
+ crops_coords (`Tuple[int]`):
+ The top-left and bottom-right coordinates of the crop.
+ grid_size (`Tuple[int]`):
+ The grid size of the spatial positional embedding (height, width).
+ temporal_size (`int`):
+ The size of the temporal dimension.
+ theta (`float`):
+ Scaling factor for frequency computation.
+ use_real (`bool`):
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
+
+ Returns:
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
+ """
+ start, stop = crops_coords
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
+
+ # Compute dimensions for each axis
+ dim_t = embed_dim // 4
+ dim_h = embed_dim // 8 * 3
+ dim_w = embed_dim // 8 * 3
+
+ # Temporal frequencies
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
+ grid_t = torch.from_numpy(grid_t).float()
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
+
+ # Spatial frequencies for height and width
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
+ grid_h = torch.from_numpy(grid_h).float()
+ grid_w = torch.from_numpy(grid_w).float()
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
+
+ # Broadcast and concatenate tensors along specified dimension
+ def broadcast(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = {len(t.shape) for t in tensors}
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+ shape_len = list(shape_lens)[0]
+ dim = (dim + shape_len) if dim < 0 else dim
+ dims = list(zip(*(list(t.shape) for t in tensors)))
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all(
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
+ ), "invalid dimensions for broadcastable concatenation"
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
+ return torch.cat(tensors, dim=dim)
+
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
+
+ t, h, w, d = freqs.shape
+ freqs = freqs.view(t * h * w, d)
+
+ # Generate sine and cosine components
+ sin = freqs.sin()
+ cos = freqs.cos()
+
+ if use_real:
+ return cos, sin
+ else:
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ return freqs_cis
+
+
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
"""
RoPE for image tokens with 2d structure.
diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
index 1030b0df04ff..c8d4b1896346 100644
--- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py
+++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
@@ -22,6 +22,7 @@
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
+from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -97,6 +98,7 @@ def __init__(
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
+ processor=CogVideoXAttnProcessor2_0(),
)
# 2. Feed Forward
@@ -116,24 +118,24 @@ def forward(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
- text_length = norm_encoder_hidden_states.size(1)
-
- # CogVideoX uses concatenated text + video embeddings with self-attention instead of using
- # them in cross-attention individually
- norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
- attn_output = self.attn1(
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
- encoder_hidden_states=None,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
)
- hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
- encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
@@ -144,8 +146,9 @@ def forward(
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
- hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
- encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
return hidden_states, encoder_hidden_states
@@ -231,6 +234,7 @@ def __init__(
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
+ use_rotary_positional_embeddings: bool = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
@@ -295,12 +299,113 @@ def __init__(
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
@@ -319,14 +424,16 @@ def forward(
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
- seq_length = height * width * num_frames // (self.config.patch_size**2)
+ text_seq_length = encoder_hidden_states.shape[1]
+ if not self.config.use_rotary_positional_embeddings:
+ seq_length = height * width * num_frames // (self.config.patch_size**2)
- pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
- hidden_states = hidden_states + pos_embeds
- hidden_states = self.embedding_dropout(hidden_states)
+ pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
+ hidden_states = hidden_states + pos_embeds
+ hidden_states = self.embedding_dropout(hidden_states)
- encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
- hidden_states = hidden_states[:, self.config.max_text_seq_length :]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
@@ -344,6 +451,7 @@ def custom_forward(*inputs):
hidden_states,
encoder_hidden_states,
emb,
+ image_rotary_emb,
**ckpt_kwargs,
)
else:
@@ -351,9 +459,17 @@ def custom_forward(*inputs):
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
+ image_rotary_emb=image_rotary_emb,
)
- hidden_states = self.norm_final(hidden_states)
+ if not self.config.use_rotary_positional_embeddings:
+ # CogVideoX-2B
+ hidden_states = self.norm_final(hidden_states)
+ else:
+ # CogVideoX-5B
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ hidden_states = self.norm_final(hidden_states)
+ hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index f43edab987fe..e100c1f11e20 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -23,6 +23,7 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
+from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
@@ -40,6 +41,7 @@
>>> from diffusers import CogVideoXPipeline
>>> from diffusers.utils import export_to_video
+ >>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
>>> prompt = (
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
@@ -55,6 +57,25 @@
"""
+# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
+def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
+ tw = tgt_width
+ th = tgt_height
+ h, w = src
+ r = h / w
+ if r > (th / tw):
+ resize_height = th
+ resize_width = int(round(th / h * w))
+ else:
+ resize_width = tw
+ resize_height = int(round(tw / w * h))
+
+ crop_top = int(round((th - resize_height) / 2.0))
+ crop_left = int(round((tw - resize_width) / 2.0))
+
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
+
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
@@ -409,6 +430,46 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)
+ def fuse_qkv_projections(self) -> None:
+ r"""Enables fused QKV projections."""
+ self.fusing_transformer = True
+ self.transformer.fuse_qkv_projections()
+
+ def unfuse_qkv_projections(self) -> None:
+ r"""Disable QKV projection fusion if enabled."""
+ if not self.fusing_transformer:
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.transformer.unfuse_qkv_projections()
+ self.fusing_transformer = False
+
+ def _prepare_rotary_positional_embeddings(
+ self,
+ height: int,
+ width: int,
+ num_frames: int,
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+ base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid(
+ (grid_height, grid_width), base_size_width, base_size_height
+ )
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=self.transformer.config.attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ use_real=True,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
@property
def guidance_scale(self):
return self._guidance_scale
@@ -599,7 +660,14 @@ def __call__(
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
- # 7. Denoising loop
+ # 7. Create rotary embeds if required
+ image_rotary_emb = (
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
+ if self.transformer.config.use_rotary_positional_embeddings
+ else None
+ )
+
+ # 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -620,6 +688,7 @@ def __call__(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
diff --git a/tests/pipelines/cogvideox/test_cogvideox.py b/tests/pipelines/cogvideox/test_cogvideox.py
index 17d0d8f21d5c..c69dcfda93c5 100644
--- a/tests/pipelines/cogvideox/test_cogvideox.py
+++ b/tests/pipelines/cogvideox/test_cogvideox.py
@@ -30,7 +30,12 @@
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, to_np
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+ check_qkv_fusion_matches_attn_procs_length,
+ check_qkv_fusion_processors_exist,
+ to_np,
+)
enable_full_determinism()
@@ -279,6 +284,44 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
def test_xformers_attention_forwardGenerator_pass(self):
pass
+ def test_fused_qkv_projections(self):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ frames = pipe(**inputs).frames # [B, F, C, H, W]
+ original_image_slice = frames[0, -2:, -1, -3:, -3:]
+
+ pipe.fuse_qkv_projections()
+ assert check_qkv_fusion_processors_exist(
+ pipe.transformer
+ ), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ assert check_qkv_fusion_matches_attn_procs_length(
+ pipe.transformer, pipe.transformer.original_attn_processors
+ ), "Something wrong with the attention processors concerning the fused QKV projections."
+
+ inputs = self.get_dummy_inputs(device)
+ frames = pipe(**inputs).frames
+ image_slice_fused = frames[0, -2:, -1, -3:, -3:]
+
+ pipe.transformer.unfuse_qkv_projections()
+ inputs = self.get_dummy_inputs(device)
+ frames = pipe(**inputs).frames
+ image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
+
+ assert np.allclose(
+ original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
+ ), "Fusion of QKV projections shouldn't affect the outputs."
+ assert np.allclose(
+ image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
+ ), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ assert np.allclose(
+ original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
+ ), "Original outputs should match when fused QKV projections are disabled."
+
@slow
@require_torch_gpu