From d53f848720a03423bb9998e75a30b4c3cd04e96d Mon Sep 17 00:00:00 2001 From: leffff Date: Sat, 4 Oct 2025 10:10:23 +0000 Subject: [PATCH 01/77] add transformer pipeline first version --- src/diffusers/__init__.py | 4 + src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 288 +++++++- src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_kandinsky.py | 630 ++++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/kandinsky5/__init__.py | 48 ++ .../kandinsky5/pipeline_kandinsky.py | 545 +++++++++++++++ .../pipelines/kandinsky5/pipeline_output.py | 20 + 10 files changed, 1541 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/transformers/transformer_kandinsky.py create mode 100644 src/diffusers/pipelines/kandinsky5/__init__.py create mode 100644 src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py create mode 100644 src/diffusers/pipelines/kandinsky5/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8867250deda8..19670053a3c5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -260,6 +260,7 @@ "VQModel", "WanTransformer3DModel", "WanVACETransformer3DModel", + "Kandinsky5Transformer3DModel", "attention_backend", ] ) @@ -618,6 +619,7 @@ "WanPipeline", "WanVACEPipeline", "WanVideoToVideoPipeline", + "Kandinsky5T2VPipeline", "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", @@ -947,6 +949,7 @@ VQModel, WanTransformer3DModel, WanVACETransformer3DModel, + Kandinsky5Transformer3DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1275,6 +1278,7 @@ WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline, + Kandinsky5T2VPipeline, WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 742548653800..6a48ac1b0deb 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", + "KandinskyLoraLoaderMixin", "HiDreamImageLoraLoaderMixin", "SkyReelsV2LoraLoaderMixin", "QwenImageLoraLoaderMixin", @@ -126,6 +127,7 @@ def text_encoder_attn_modules(text_encoder): StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, + KandinskyLoraLoaderMixin ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e25a29e1c00e..ea1b92c68b59 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3638,6 +3638,292 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components, **kwargs) + +class KandinskyLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Kandinsky5Transformer3DModel`], + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + - A string, the *model id* of a pretrained model hosted on the Hub. + - A path to a *directory* containing the model weights. + - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + use_safetensors (`bool`, *optional*): + Whether to use safetensors for loading. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata. + """ + # Load the main state dict first which has the LoRA layers + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. + hotswap (`bool`, *optional*): + Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + kwargs (`dict`, *optional*): + See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + # Load LoRA into transformer + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + Load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. + transformer (`Kandinsky5Transformer3DModel`): + The transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights. + hotswap (`bool`, *optional*): + See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. + """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + ): + r""" + Save the LoRA parameters corresponding to the transformer and text encoders. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process. + save_function (`Callable`): + The function to use to save the state dictionary. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError( + "You must pass at least one of `transformer_lora_layers`" + ) + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. + + Example: + ```py + from diffusers import Kandinsky5T2VPipeline + + pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + pipeline.load_lora_weights("path/to/lora.safetensors") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of [`pipe.fuse_lora()`]. + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components, **kwargs) + class WanLoraLoaderMixin(LoraBaseMixin): r""" @@ -4802,4 +5088,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 457f70448af3..89ca9d39774b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -101,6 +101,7 @@ _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] + _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -200,6 +201,7 @@ TransformerTemporalModel, WanTransformer3DModel, WanVACETransformer3DModel, + Kandinsky5Transformer3DModel, ) from .unets import ( I2VGenXLUNet, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..4b9911f9cb5d 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -37,3 +37,4 @@ from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel + from .transformer_kandinsky import Kandinsky5Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py new file mode 100644 index 000000000000..a057cc13cc0f --- /dev/null +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -0,0 +1,630 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union, List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + +logger = logging.get_logger(__name__) + + +# @torch.compile() +@torch.autocast(device_type="cuda", dtype=torch.float32) +def apply_scale_shift_norm(norm, x, scale, shift): + return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) + +# @torch.compile() +@torch.autocast(device_type="cuda", dtype=torch.float32) +def apply_gate_sum(x, out, gate): + return (x + gate * out).to(torch.bfloat16) + +# @torch.compile() +@torch.autocast(device_type="cuda", enabled=False) +def apply_rotary(x, rope): + x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) + x_out = (rope * x_).sum(dim=-1) + return x_out.reshape(*x.shape).to(torch.bfloat16) + + +@torch.autocast(device_type="cuda", enabled=False) +def get_freqs(dim, max_period=10000.0): + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=dim, dtype=torch.float32) + / dim + ) + return freqs + + +class TimeEmbeddings(nn.Module): + def __init__(self, model_dim, time_dim, max_period=10000.0): + super().__init__() + assert model_dim % 2 == 0 + self.model_dim = model_dim + self.max_period = max_period + self.register_buffer( + "freqs", get_freqs(model_dim // 2, max_period), persistent=False + ) + self.in_layer = nn.Linear(model_dim, time_dim, bias=True) + self.activation = nn.SiLU() + self.out_layer = nn.Linear(time_dim, time_dim, bias=True) + + def forward(self, time): + args = torch.outer(time, self.freqs.to(device=time.device)) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + return time_embed + + +class TextEmbeddings(nn.Module): + def __init__(self, text_dim, model_dim): + super().__init__() + self.in_layer = nn.Linear(text_dim, model_dim, bias=True) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=True) + + def forward(self, text_embed): + text_embed = self.in_layer(text_embed) + return self.norm(text_embed).type_as(text_embed) + + +class VisualEmbeddings(nn.Module): + def __init__(self, visual_dim, model_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) + + def forward(self, x): + batch_size, duration, height, width, dim = x.shape + x = ( + x.view( + batch_size, + duration // self.patch_size[0], + self.patch_size[0], + height // self.patch_size[1], + self.patch_size[1], + width // self.patch_size[2], + self.patch_size[2], + dim, + ) + .permute(0, 1, 3, 5, 2, 4, 6, 7) + .flatten(4, 7) + ) + return self.in_layer(x) + + +class RoPE1D(nn.Module): + """ + 1D Rotary Positional Embeddings for text sequences. + + Args: + dim: Dimension of the rotary embeddings + max_pos: Maximum sequence length + max_period: Maximum period for sinusoidal embeddings + """ + + def __init__(self, dim, max_pos=1024, max_period=10000.0): + super().__init__() + self.max_period = max_period + self.dim = dim + self.max_pos = max_pos + freq = get_freqs(dim // 2, max_period) + pos = torch.arange(max_pos, dtype=freq.dtype) + self.register_buffer("args", torch.outer(pos, freq), persistent=False) + + def forward(self, pos): + """ + Args: + pos: Position indices of shape [seq_len] or [batch_size, seq_len] + + Returns: + Rotary embeddings of shape [seq_len, 1, 2, 2] + """ + args = self.args[pos] + cosine = torch.cos(args) + sine = torch.sin(args) + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) + rope = rope.view(*rope.shape[:-1], 2, 2) + return rope.unsqueeze(-4) + + +class RoPE3D(nn.Module): + def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): + super().__init__() + self.axes_dims = axes_dims + self.max_pos = max_pos + self.max_period = max_period + + for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)): + freq = get_freqs(axes_dim // 2, max_period) + pos = torch.arange(ax_max_pos, dtype=freq.dtype) + self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False) + + @torch.autocast(device_type="cuda", enabled=False) + def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): + batch_size, duration, height, width = shape + args_t = self.args_0[pos[0]] / scale_factor[0] + args_h = self.args_1[pos[1]] / scale_factor[1] + args_w = self.args_2[pos[2]] / scale_factor[2] + + # Replicate the original logic with batch dimension + args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1) + args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1) + args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1) + + # Concatenate along the last dimension + args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F] + + cosine = torch.cos(args) + sine = torch.sin(args) + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4] + rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2] + return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2] + + +class Modulation(nn.Module): + def __init__(self, time_dim, model_dim, num_params): + super().__init__() + self.activation = nn.SiLU() + self.out_layer = nn.Linear(time_dim, num_params * model_dim) + self.out_layer.weight.data.zero_() + self.out_layer.bias.data.zero_() + + def forward(self, x): + return self.out_layer(self.activation(x)) + + +class MultiheadSelfAttentionEnc(nn.Module): + def __init__(self, num_channels, head_dim): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + def forward(self, x, rope): + query = self.to_query(x) + key = self.to_key(x) + value = self.to_value(x) + + shape = query.shape[:-1] + query = query.reshape(*shape, self.num_heads, -1) + key = key.reshape(*shape, self.num_heads, -1) + value = value.reshape(*shape, self.num_heads, -1) + + query = self.query_norm(query.float()).type_as(query) + key = self.key_norm(key.float()).type_as(key) + + query = apply_rotary(query, rope).type_as(query) + key = apply_rotary(key, rope).type_as(key) + + # Use torch's scaled_dot_product_attention + out = F.scaled_dot_product_attention( + query, + key, + value, + ).flatten(-2, -1) + + out = self.out_layer(out) + return out + + +class MultiheadSelfAttentionDec(nn.Module): + def __init__(self, num_channels, head_dim): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + def forward(self, x, rope, sparse_params=None): + query = self.to_query(x) + key = self.to_key(x) + value = self.to_value(x) + + shape = query.shape[:-1] + query = query.reshape(*shape, self.num_heads, -1) + key = key.reshape(*shape, self.num_heads, -1) + value = value.reshape(*shape, self.num_heads, -1) + + query = self.query_norm(query.float()).type_as(query) + key = self.key_norm(key.float()).type_as(key) + + query = apply_rotary(query, rope).type_as(query) + key = apply_rotary(key, rope).type_as(key) + + # Use standard attention (can be extended with sparse attention) + out = F.scaled_dot_product_attention( + query, + key, + value, + ).flatten(-2, -1) + + out = self.out_layer(out) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__(self, num_channels, head_dim): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + def forward(self, x, cond): + query = self.to_query(x) + key = self.to_key(cond) + value = self.to_value(cond) + + shape, cond_shape = query.shape[:-1], key.shape[:-1] + query = query.reshape(*shape, self.num_heads, -1) + key = key.reshape(*cond_shape, self.num_heads, -1) + value = value.reshape(*cond_shape, self.num_heads, -1) + + query = self.query_norm(query.float()).type_as(query) + key = self.key_norm(key.float()).type_as(key) + + out = F.scaled_dot_product_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + ).permute(0, 2, 1, 3).flatten(-2, -1) + + out = self.out_layer(out) + return out + + +class FeedForward(nn.Module): + def __init__(self, dim, ff_dim): + super().__init__() + self.in_layer = nn.Linear(dim, ff_dim, bias=False) + self.activation = nn.GELU() + self.out_layer = nn.Linear(ff_dim, dim, bias=False) + + def forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim): + super().__init__() + self.text_modulation = Modulation(time_dim, model_dim, 6) + + self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.self_attention = MultiheadSelfAttentionEnc(model_dim, head_dim) + + self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.feed_forward = FeedForward(model_dim, ff_dim) + + def forward(self, x, time_embed, rope): + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) + + out = self.self_attention_norm(x) + out = out * (scale + 1.0) + shift + out = self.self_attention(out, rope) + x = x + gate * out + + shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) + out = self.feed_forward_norm(x) + out = out * (scale + 1.0) + shift + out = self.feed_forward(out) + x = x + gate * out + return x + + +class TransformerDecoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim): + super().__init__() + self.visual_modulation = Modulation(time_dim, model_dim, 9) + + self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.self_attention = MultiheadSelfAttentionDec(model_dim, head_dim) + + self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.cross_attention = MultiheadCrossAttention(model_dim, head_dim) + + self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.feed_forward = FeedForward(model_dim, ff_dim) + + def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): + self_attn_params, cross_attn_params, ff_params = torch.chunk( + self.visual_modulation(time_embed), 3, dim=-1 + ) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) + + visual_out = self.self_attention_norm(visual_embed) + visual_out = visual_out * (scale + 1.0) + shift + visual_out = self.self_attention(visual_out, rope, sparse_params) + visual_embed = visual_embed + gate * visual_out + + shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) + visual_out = self.cross_attention_norm(visual_embed) + visual_out = visual_out * (scale + 1.0) + shift + visual_out = self.cross_attention(visual_out, text_embed) + visual_embed = visual_embed + gate * visual_out + + shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) + visual_out = self.feed_forward_norm(visual_embed) + visual_out = visual_out * (scale + 1.0) + shift + visual_out = self.feed_forward(visual_out) + visual_embed = visual_embed + gate * visual_out + return visual_embed + + +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.out_layer = nn.Linear( + model_dim, math.prod(patch_size) * visual_dim, bias=True + ) + + def forward(self, visual_embed, text_embed, time_embed): + # Handle the new batch dimension: [batch, duration, height, width, model_dim] + batch_size, duration, height, width, _ = visual_embed.shape + + shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) + + # Apply modulation with proper broadcasting for the new shape + visual_embed = apply_scale_shift_norm( + self.norm, + visual_embed, + scale[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] + shift[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] + ).type_as(visual_embed) + + x = self.out_layer(visual_embed) + + # Now x has shape [batch, duration, height, width, patch_prod * visual_dim] + x = ( + x.view( + batch_size, + duration, + height, + width, + -1, + self.patch_size[0], + self.patch_size[1], + self.patch_size[2], + ) + .permute(0, 5, 1, 6, 2, 7, 3, 4) # [batch, patch_t, duration, patch_h, height, patch_w, width, features] + .flatten(1, 2) # [batch, patch_t * duration, height, patch_w, width, features] + .flatten(2, 3) # [batch, patch_t * duration, patch_h * height, width, features] + .flatten(3, 4) # [batch, patch_t * duration, patch_h * height, patch_w * width] + ) + return x + + +@maybe_allow_in_graph +class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): + r""" + A 3D Transformer model for video generation used in Kandinsky 5.0. + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods implemented for all models (such as downloading or saving). + + Args: + in_visual_dim (`int`, defaults to 16): + Number of channels in the input visual latent. + out_visual_dim (`int`, defaults to 16): + Number of channels in the output visual latent. + time_dim (`int`, defaults to 512): + Dimension of the time embeddings. + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + Patch size for the visual embeddings (temporal, height, width). + model_dim (`int`, defaults to 1792): + Hidden dimension of the transformer model. + ff_dim (`int`, defaults to 7168): + Intermediate dimension of the feed-forward networks. + num_text_blocks (`int`, defaults to 2): + Number of transformer blocks in the text encoder. + num_visual_blocks (`int`, defaults to 32): + Number of transformer blocks in the visual decoder. + axes_dims (`Tuple[int]`, defaults to `(16, 24, 24)`): + Dimensions for the rotary positional embeddings (temporal, height, width). + visual_cond (`bool`, defaults to `True`): + Whether to use visual conditioning (for image/video conditioning). + in_text_dim (`int`, defaults to 3584): + Dimension of the text embeddings from Qwen2.5-VL. + in_text_dim2 (`int`, defaults to 768): + Dimension of the pooled text embeddings from CLIP. + """ + + @register_to_config + def __init__( + self, + in_visual_dim: int = 16, + out_visual_dim: int = 16, + time_dim: int = 512, + patch_size: Tuple[int, int, int] = (1, 2, 2), + model_dim: int = 1792, + ff_dim: int = 7168, + num_text_blocks: int = 2, + num_visual_blocks: int = 32, + axes_dims: Tuple[int, int, int] = (16, 24, 24), + visual_cond: bool = True, + in_text_dim: int = 3584, + in_text_dim2: int = 768, + ): + super().__init__() + + self.in_visual_dim = in_visual_dim + self.model_dim = model_dim + self.patch_size = patch_size + self.visual_cond = visual_cond + + # Calculate head dimension for attention + head_dim = sum(axes_dims) + + # Determine visual embedding dimension based on conditioning + visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim + + # 1. Embedding layers + self.time_embeddings = TimeEmbeddings(model_dim, time_dim) + self.text_embeddings = TextEmbeddings(in_text_dim, model_dim) + self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim) + self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size) + + # 2. Rotary positional embeddings + self.text_rope_embeddings = RoPE1D(head_dim) + self.visual_rope_embeddings = RoPE3D(axes_dims) + + # 3. Transformer blocks + self.text_transformer_blocks = nn.ModuleList([ + TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_text_blocks) + ]) + + self.visual_transformer_blocks = nn.ModuleList([ + TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_visual_blocks) + ]) + + # 4. Output layer + self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + pooled_text_embed: torch.Tensor, + timestep: torch.Tensor, + visual_rope_pos: List[torch.Tensor], + text_rope_pos: torch.Tensor, + scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), + sparse_params: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + Forward pass of the Kandinsky 5.0 3D Transformer. + + Args: + hidden_states (`torch.Tensor`): + Input visual latent tensor of shape `(batch_size, num_frames, height, width, channels)`. + encoder_hidden_states (`torch.Tensor`): + Text embeddings from Qwen2.5-VL of shape `(batch_size, sequence_length, text_dim)`. + pooled_text_embed (`torch.Tensor`): + Pooled text embeddings from CLIP of shape `(batch_size, pooled_text_dim)`. + timestep (`torch.Tensor`): + Timestep tensor of shape `(batch_size,)` or `(batch_size * num_frames,)`. + visual_rope_pos (`List[torch.Tensor]`): + List of tensors for visual rotary positional embeddings [temporal, height, width]. + text_rope_pos (`torch.Tensor`): + Tensor for text rotary positional embeddings. + scale_factor (`Tuple[float, float, float]`, defaults to `(1.0, 1.0, 1.0)`): + Scale factors for rotary positional embeddings. + sparse_params (`Dict[str, Any]`, *optional*): + Parameters for sparse attention. + return_dict (`bool`, defaults to `True`): + Whether to return a dictionary or a tensor. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + If `return_dict` is `True`, a [`~models.transformer_2d.Transformer2DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. + """ + batch_size, num_frames, height, width, channels = hidden_states.shape + + # 1. Process text embeddings + text_embed = self.text_embeddings(encoder_hidden_states) + time_embed = self.time_embeddings(timestep) + + # Add pooled text embedding to time embedding + pooled_embed = self.pooled_text_embeddings(pooled_text_embed) + time_embed = time_embed + pooled_embed + + # visual_embed shape: [batch_size, seq_len, model_dim] + visual_embed = self.visual_embeddings(hidden_states) + + # 3. Text rotary embeddings + text_rope = self.text_rope_embeddings(text_rope_pos) + + # 4. Text transformer blocks + for text_block in self.text_transformer_blocks: + if self.gradient_checkpointing and self.training: + text_embed = torch.utils.checkpoint.checkpoint( + text_block, text_embed, time_embed, text_rope, use_reentrant=False + ) + else: + text_embed = text_block(text_embed, time_embed, text_rope) + + # 5. Prepare visual rope + visual_shape = visual_embed.shape[:-1] + visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + + visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) + visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) + + # 6. Visual transformer blocks + for visual_block in self.visual_transformer_blocks: + if self.gradient_checkpointing and self.training: + visual_embed = torch.utils.checkpoint.checkpoint( + visual_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + # visual_rope_flat, + sparse_params, + use_reentrant=False, + ) + else: + visual_embed = visual_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) + + # 7. Output projection + visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1) + output = self.out_layer(visual_embed, text_embed, time_embed) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 190c7871d270..201d92afb07c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -382,6 +382,7 @@ "WuerstchenPriorPipeline", ] _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"] + _import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -787,6 +788,7 @@ ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline + from .kandinsky5 import Kandinsky5T2VPipeline from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, diff --git a/src/diffusers/pipelines/kandinsky5/__init__.py b/src/diffusers/pipelines/kandinsky5/__init__.py new file mode 100644 index 000000000000..af8e12421740 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_kandinsky import Kandinsky5T2VPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py new file mode 100644 index 000000000000..02eae1363303 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -0,0 +1,545 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer +import torchvision +from torchvision.transforms import ToPILImage + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel + >>> from diffusers.utils import export_to_video + + >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=512, + ... width=768, + ... num_frames=25, + ... num_inference_steps=50, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=6) + ``` +""" + + +class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using Kandinsky 5.0. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`Kandinsky5Transformer3DModel`]): + Conditional Transformer to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder (Qwen2.5-VL). + tokenizer ([`AutoProcessor`]): + Tokenizer for Qwen2.5-VL. + text_encoder_2 ([`CLIPTextModel`]): + Frozen CLIP text encoder. + tokenizer_2 ([`CLIPTokenizer`]): + Tokenizer for CLIP. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + transformer: Kandinsky5Transformer3DModel, + vae: AutoencoderKLHunyuanVideo, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2VLProcessor, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio + self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio + + def _encode_prompt_qwen( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 256, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Kandinsky specific prompt template + prompt_template = "\n".join([ + "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>", + ]) + crop_start = 129 + + full_texts = [prompt_template.format(p) for p in prompt] + + inputs = self.tokenizer( + text=full_texts, + images=None, + videos=None, + max_length=max_sequence_length + crop_start, + truncation=True, + return_tensors="pt", + padding=True, + ).to(device) + + with torch.no_grad(): + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, crop_start:] + + attention_mask = inputs["attention_mask"][:, crop_start:] + embeds = embeds[attention_mask.bool()] + cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) + cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + + # duplicate for each generation per prompt + batch_size = len(prompt) + seq_len = embeds.shape[0] // batch_size + embeds = embeds.reshape(batch_size, seq_len, -1) + embeds = embeds.repeat(1, num_videos_per_prompt, 1) + embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return embeds, cu_seqlens + + def _encode_prompt_clip( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + inputs = self.tokenizer_2( + prompt, + max_length=77, + truncation=True, + add_special_tokens=True, + padding="max_length", + return_tensors="pt", + ).to(device) + + with torch.no_grad(): + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + + # duplicate for each generation per prompt + batch_size = len(prompt) + pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) + pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) + + return pooled_embed + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + # Encode with Qwen2.5-VL + prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt, device, num_videos_per_prompt + ) + pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt) + + if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen( + negative_prompt, device, num_videos_per_prompt + ) + negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt) + else: + negative_prompt_embeds = None + negative_pooled_embed = None + negative_cu_seqlens = None + + text_embeds = { + "text_embeds": prompt_embeds, + "pooled_embed": pooled_embed, + } + negative_text_embeds = { + "text_embeds": negative_prompt_embeds, + "pooled_embed": negative_pooled_embed, + } if do_classifier_free_guidance else None + + return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + visual_cond: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if visual_cond: + # For visual conditioning, concatenate with zeros and mask + visual_cond = torch.zeros_like(latents) + visual_cond_mask = torch.zeros( + [batch_size, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, 1], + dtype=latents.dtype, + device=latents.device + ) + latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1) + + return latents + + def get_velocity( + self, + latents: torch.Tensor, + timestep: torch.Tensor, + text_embeds: Dict[str, torch.Tensor], + negative_text_embeds: Optional[Dict[str, torch.Tensor]], + visual_rope_pos: List[torch.Tensor], + text_rope_pos: torch.Tensor, + negative_text_rope_pos: torch.Tensor, + guidance_scale: float, + sparse_params: Optional[Dict] = None, + ): + # print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params) + + pred_velocity = self.transformer( + latents, + text_embeds["text_embeds"], + text_embeds["pooled_embed"], + timestep * 1000, # Scale to match training + visual_rope_pos, + text_rope_pos, + scale_factor=(1, 2, 2), # From Kandinsky config + sparse_params=sparse_params, + return_dict=False + )[0] + + if guidance_scale > 1.0 and negative_text_embeds is not None: + uncond_pred_velocity = self.transformer( + latents, + negative_text_embeds["text_embeds"], + negative_text_embeds["pooled_embed"], + timestep * 1000, + visual_rope_pos, + negative_text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=sparse_params, + return_dict=False + )[0] + + pred_velocity = uncond_pred_velocity + guidance_scale * ( + pred_velocity - uncond_pred_velocity + ) + + return pred_velocity + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 25, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + scheduler_scale: float = 10.0, + num_videos_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. + height (`int`, defaults to `512`): + The height in pixels of the generated video. + width (`int`, defaults to `768`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `25`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in classifier-free guidance. + scheduler_scale (`float`, defaults to `10.0`): + Scale factor for the custom flow matching scheduler. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator`, *optional*): + A torch generator to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`KandinskyPipelineOutput`]. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step. + + Examples: + + Returns: + [`~KandinskyPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + # 2. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + + # 3. Encode input prompt + text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + ) + + # 4. Prepare timesteps (Kandinsky uses custom flow matching) + timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device) + timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps) + + # 5. Prepare latent variables + num_channels_latents = 16 + latents = self.prepare_latents( + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=16, + height=height, + width=width, + num_frames=num_frames, + visual_cond=self.transformer.visual_cond, + dtype=self.transformer.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 6. Prepare rope positions + visual_rope_pos = [ + torch.arange(num_frames // 4 + 1, device=device), + torch.arange(height // 8 // 2, device=device), # patch size 2 + torch.arange(width // 8 // 2, device=device), + ] + + text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device) + + negative_text_rope_pos = ( + torch.arange(negative_cu_seqlens[-1].item(), device=device) + if negative_cu_seqlens is not None + else None + ) + + # 7. Prepare sparse attention params if needed + sparse_params = None # Can be extended based on Kandinsky attention config + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))): + # Expand timestep to match batch size + time = timestep.unsqueeze(0) + + pred_velocity = self.get_velocity( + latents, + time, + text_embeds, + negative_text_embeds, + visual_rope_pos, + text_rope_pos, + negative_text_rope_pos, + guidance_scale, + sparse_params, + ) + + # Update latents using flow matching + latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0): + progress_bar.update() + + latents = latents[:, :, :, :, :16] + + # 9. Decode latents to video + if output_type != "latent": + latents = latents.to(self.vae.dtype) + # Reshape and normalize latents + video = latents.reshape( + batch_size, + num_videos_per_prompt, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // 8, + width // 8, + 16, + ) + video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] + video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8) + + # Normalize and decode + video = video / self.vae.config.scaling_factor + video = self.vae.decode(video).sample + video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) + + # Convert to output format + if output_type == "pil": + if num_frames == 1: + # Single image + video = [ToPILImage()(frame.squeeze(1)) for frame in video] + else: + # Video frames + video = [video[i] for i in range(video.shape[0])] + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return KandinskyPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_output.py b/src/diffusers/pipelines/kandinsky5/pipeline_output.py new file mode 100644 index 000000000000..ed77d42a9a83 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class KandinskyPipelineOutput(BaseOutput): + r""" + Output class for Wan pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor From 7db6093c539b84450bbc683193b75c91cfc599e3 Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 6 Oct 2025 12:43:04 +0000 Subject: [PATCH 02/77] updates --- .../transformers/transformer_kandinsky.py | 125 ++++++++----- .../kandinsky5/pipeline_kandinsky.py | 171 +++++++----------- 2 files changed, 144 insertions(+), 152 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index a057cc13cc0f..cca83988a762 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -35,6 +35,23 @@ logger = logging.get_logger(__name__) +if torch.cuda.get_device_capability()[0] >= 9: + try: + from flash_attn_interface import flash_attn_func as FA + except: + FA = None + + try: + from flash_attn import flash_attn_func as FA + except: + FA = None +else: + try: + from flash_attn import flash_attn_func as FA + except: + FA = None + + # @torch.compile() @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_scale_shift_norm(norm, x, scale, shift): @@ -99,7 +116,7 @@ def __init__(self, visual_dim, model_dim, patch_size): super().__init__() self.patch_size = patch_size self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) - + def forward(self, x): batch_size, duration, height, width, dim = x.shape x = ( @@ -107,7 +124,7 @@ def forward(self, x): batch_size, duration // self.patch_size[0], self.patch_size[0], - height // self.patch_size[1], + height // self.patch_size[1], self.patch_size[1], width // self.patch_size[2], self.patch_size[2], @@ -169,24 +186,23 @@ def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): @torch.autocast(device_type="cuda", enabled=False) def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): batch_size, duration, height, width = shape + args_t = self.args_0[pos[0]] / scale_factor[0] args_h = self.args_1[pos[1]] / scale_factor[1] args_w = self.args_2[pos[2]] / scale_factor[2] - # Replicate the original logic with batch dimension args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1) args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1) args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1) - # Concatenate along the last dimension - args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F] + args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) cosine = torch.cos(args) sine = torch.sin(args) - rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4] - rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2] - return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2] - + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) + rope = rope.view(*rope.shape[:-1], 2, 2) + return rope.unsqueeze(-4) + class Modulation(nn.Module): def __init__(self, time_dim, model_dim, num_params): @@ -230,11 +246,14 @@ def forward(self, x, rope): key = apply_rotary(key, rope).type_as(key) # Use torch's scaled_dot_product_attention - out = F.scaled_dot_product_attention( - query, - key, - value, - ).flatten(-2, -1) + # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionEnc SHAPE") + # out = F.scaled_dot_product_attention( + # query.permute(0, 2, 1, 3), + # key.permute(0, 2, 1, 3), + # value.permute(0, 2, 1, 3), + # ).permute(0, 2, 1, 3).flatten(-2, -1) + + out = FA(q=query, k=key, v=value).flatten(-2, -1) out = self.out_layer(out) return out @@ -270,11 +289,15 @@ def forward(self, x, rope, sparse_params=None): key = apply_rotary(key, rope).type_as(key) # Use standard attention (can be extended with sparse attention) - out = F.scaled_dot_product_attention( - query, - key, - value, - ).flatten(-2, -1) + # out = F.scaled_dot_product_attention( + # query.permute(0, 2, 1, 3), + # key.permute(0, 2, 1, 3), + # value.permute(0, 2, 1, 3), + # ).permute(0, 2, 1, 3).flatten(-2, -1) + + # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionDec SHAPE") + + out = FA(q=query, k=key, v=value).flatten(-2, -1) out = self.out_layer(out) return out @@ -306,11 +329,15 @@ def forward(self, x, cond): query = self.query_norm(query.float()).type_as(query) key = self.key_norm(key.float()).type_as(key) - out = F.scaled_dot_product_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - ).permute(0, 2, 1, 3).flatten(-2, -1) + # out = F.scaled_dot_product_attention( + # query.permute(0, 2, 1, 3), + # key.permute(0, 2, 1, 3), + # value.permute(0, 2, 1, 3), + # ).permute(0, 2, 1, 3).flatten(-2, -1) + + # print(query.shape, key.shape, value.shape, "QKV MultiheadCrossAttention SHAPE") + + out = FA(q=query, k=key, v=value).flatten(-2, -1) out = self.out_layer(out) return out @@ -339,19 +366,18 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.feed_forward = FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + self_attn_params, ff_params = torch.chunk( + self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 + ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - - out = self.self_attention_norm(x) - out = out * (scale + 1.0) + shift + out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) out = self.self_attention(out, rope) - x = x + gate * out + x = apply_gate_sum(x, out, gate) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - out = self.feed_forward_norm(x) - out = out * (scale + 1.0) + shift + out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) out = self.feed_forward(out) - x = x + gate * out + x = apply_gate_sum(x, out, gate) return x @@ -371,26 +397,22 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): self_attn_params, cross_attn_params, ff_params = torch.chunk( - self.visual_modulation(time_embed), 3, dim=-1 + self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - - visual_out = self.self_attention_norm(visual_embed) - visual_out = visual_out * (scale + 1.0) + shift + visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) visual_out = self.self_attention(visual_out, rope, sparse_params) - visual_embed = visual_embed + gate * visual_out + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = self.cross_attention_norm(visual_embed) - visual_out = visual_out * (scale + 1.0) + shift + visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) visual_out = self.cross_attention(visual_out, text_embed) - visual_embed = visual_embed + gate * visual_out + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = self.feed_forward_norm(visual_embed) - visual_out = visual_out * (scale + 1.0) + shift + visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) visual_out = self.feed_forward(visual_out) - visual_embed = visual_embed + gate * visual_out + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) return visual_embed @@ -575,7 +597,7 @@ def forward( # 1. Process text embeddings text_embed = self.text_embeddings(encoder_hidden_states) time_embed = self.time_embeddings(timestep) - + # Add pooled text embedding to time embedding pooled_embed = self.pooled_text_embeddings(pooled_text_embed) time_embed = time_embed + pooled_embed @@ -587,22 +609,29 @@ def forward( text_rope = self.text_rope_embeddings(text_rope_pos) # 4. Text transformer blocks + i = 0 for text_block in self.text_transformer_blocks: if self.gradient_checkpointing and self.training: text_embed = torch.utils.checkpoint.checkpoint( text_block, text_embed, time_embed, text_rope, use_reentrant=False ) + else: text_embed = text_block(text_embed, time_embed, text_rope) + i += 1 + # 5. Prepare visual rope visual_shape = visual_embed.shape[:-1] visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + + # visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) + # visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) + visual_embed = visual_embed.flatten(1, 3) + visual_rope = visual_rope.flatten(1, 3) - visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) - visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) - # 6. Visual transformer blocks + i = 0 for visual_block in self.visual_transformer_blocks: if self.gradient_checkpointing and self.training: visual_embed = torch.utils.checkpoint.checkpoint( @@ -619,6 +648,8 @@ def forward( visual_embed = visual_block( visual_embed, text_embed, time_embed, visual_rope, sparse_params ) + + i += 1 # 7. Output projection visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 02eae1363303..9dbf31fea960 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -220,19 +220,14 @@ def encode_prompt( ): device = device or self._execution_device - # Encode with Qwen2.5-VL - prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen( - prompt, device, num_videos_per_prompt - ) + prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt) pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt) if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen( - negative_prompt, device, num_videos_per_prompt - ) + negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt) negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt) else: negative_prompt_embeds = None @@ -264,23 +259,25 @@ def prepare_latents( latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: - return latents.to(device=device, dtype=dtype) - - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - num_channels_latents, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + num_latent_frames = latents.shape[1] + latents = latents.to(device=device, dtype=dtype) + + else: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) if visual_cond: # For visual conditioning, concatenate with zeros and mask @@ -294,50 +291,6 @@ def prepare_latents( return latents - def get_velocity( - self, - latents: torch.Tensor, - timestep: torch.Tensor, - text_embeds: Dict[str, torch.Tensor], - negative_text_embeds: Optional[Dict[str, torch.Tensor]], - visual_rope_pos: List[torch.Tensor], - text_rope_pos: torch.Tensor, - negative_text_rope_pos: torch.Tensor, - guidance_scale: float, - sparse_params: Optional[Dict] = None, - ): - # print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params) - - pred_velocity = self.transformer( - latents, - text_embeds["text_embeds"], - text_embeds["pooled_embed"], - timestep * 1000, # Scale to match training - visual_rope_pos, - text_rope_pos, - scale_factor=(1, 2, 2), # From Kandinsky config - sparse_params=sparse_params, - return_dict=False - )[0] - - if guidance_scale > 1.0 and negative_text_embeds is not None: - uncond_pred_velocity = self.transformer( - latents, - negative_text_embeds["text_embeds"], - negative_text_embeds["pooled_embed"], - timestep * 1000, - visual_rope_pos, - negative_text_rope_pos, - scale_factor=(1, 2, 2), - sparse_params=sparse_params, - return_dict=False - )[0] - - pred_velocity = uncond_pred_velocity + guidance_scale * ( - pred_velocity - uncond_pred_velocity - ) - - return pred_velocity @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -402,11 +355,9 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - # 1. Check inputs if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - # 2. Define call parameters if isinstance(prompt, str): batch_size = 1 else: @@ -415,16 +366,18 @@ def __call__( device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 - if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - - # 3. Encode input prompt text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, @@ -433,11 +386,6 @@ def __call__( device=device, ) - # 4. Prepare timesteps (Kandinsky uses custom flow matching) - timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device) - timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps) - - # 5. Prepare latent variables num_channels_latents = 16 latents = self.prepare_latents( batch_size=batch_size * num_videos_per_prompt, @@ -451,11 +399,12 @@ def __call__( generator=generator, latents=latents, ) + + visual_cond = latents[:, :, :, :, 16:] - # 6. Prepare rope positions visual_rope_pos = [ torch.arange(num_frames // 4 + 1, device=device), - torch.arange(height // 8 // 2, device=device), # patch size 2 + torch.arange(height // 8 // 2, device=device), torch.arange(width // 8 // 2, device=device), ] @@ -467,31 +416,43 @@ def __call__( else None ) - # 7. Prepare sparse attention params if needed - sparse_params = None # Can be extended based on Kandinsky attention config - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))): - # Expand timestep to match batch size - time = timestep.unsqueeze(0) - - pred_velocity = self.get_velocity( - latents, - time, - text_embeds, - negative_text_embeds, - visual_rope_pos, - text_rope_pos, - negative_text_rope_pos, - guidance_scale, - sparse_params, - ) - - # Update latents using flow matching - latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity + for i, t in enumerate(timesteps): + timestep = t.unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + # print(latents.shape) + pred_velocity = self.transformer( + latents, + text_embeds["text_embeds"], + text_embeds["pooled_embed"], + timestep, + visual_rope_pos, + text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=None, + return_dict=False + )[0] + + if guidance_scale > 1.0 and negative_text_embeds is not None: + uncond_pred_velocity = self.transformer( + latents, + negative_text_embeds["text_embeds"], + negative_text_embeds["pooled_embed"], + timestep, + visual_rope_pos, + negative_text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=None, + return_dict=False + )[0] + + pred_velocity = uncond_pred_velocity + guidance_scale * ( + pred_velocity - uncond_pred_velocity + ) + + latents = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] + latents = torch.cat([latents, visual_cond], dim=-1) if callback_on_step_end is not None: callback_kwargs = {} @@ -499,8 +460,8 @@ def __call__( callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) latents = callback_outputs.pop("latents", latents) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0): + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() latents = latents[:, :, :, :, :16] @@ -524,7 +485,6 @@ def __call__( video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) - # Convert to output format if output_type == "pil": if num_frames == 1: @@ -533,6 +493,7 @@ def __call__( else: # Video frames video = [video[i] for i in range(video.shape[0])] + else: video = latents From a0cf07f7e086b73a49b46e2e87d0ebb10056dcd4 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 9 Oct 2025 15:09:50 +0000 Subject: [PATCH 03/77] fix 5sec generation --- .../transformers/transformer_kandinsky.py | 660 +++++++++--------- .../kandinsky5/pipeline_kandinsky.py | 51 +- 2 files changed, 368 insertions(+), 343 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index cca83988a762..3bbb9421f7ce 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -13,21 +13,27 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple, Union, List +from typing import Any, Dict, List, Optional, Tuple, Union +from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F +from torch import BoolTensor, IntTensor, Tensor, nn +from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, + flex_attention) from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import (USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, + unscale_lora_layers) from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin -from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..embeddings import (PixArtAlphaTextProjection, TimestepEmbedding, + Timesteps, get_1d_rotary_pos_embed) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -35,34 +41,129 @@ logger = logging.get_logger(__name__) -if torch.cuda.get_device_capability()[0] >= 9: - try: - from flash_attn_interface import flash_attn_func as FA - except: - FA = None - - try: - from flash_attn import flash_attn_func as FA - except: - FA = None -else: - try: - from flash_attn import flash_attn_func as FA - except: - FA = None - - -# @torch.compile() +def exist(item): + return item is not None + + +def freeze(model): + for p in model.parameters(): + p.requires_grad = False + return model + + +@torch.autocast(device_type="cuda", enabled=False) +def get_freqs(dim, max_period=10000.0): + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=dim, dtype=torch.float32) + / dim + ) + return freqs + + +def fractal_flatten(x, rope, shape, block_mask=False): + if block_mask: + pixel_size = 8 + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) + x = x.flatten(1, 2) + rope = rope.flatten(1, 2) + else: + x = x.flatten(1, 3) + rope = rope.flatten(1, 3) + return x, rope + + +def fractal_unflatten(x, shape, block_mask=False): + if block_mask: + pixel_size = 8 + x = x.reshape(-1, pixel_size**2, *x.shape[1:]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + else: + x = x.reshape(*shape, *x.shape[2:]) + return x + + +def local_patching(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + g1, + height // g2, + g2, + width // g3, + g3, + *x.shape[dim + 3 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 2, + dim + 4, + dim + 1, + dim + 3, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) + return x + + +def local_merge(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + height // g2, + width // g3, + g1, + g2, + g3, + *x.shape[dim + 2 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 3, + dim + 1, + dim + 4, + dim + 2, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) + return x + + +def sdpa(q, k, v): + query = q.transpose(1, 2).contiguous() + key = k.transpose(1, 2).contiguous() + value = v.transpose(1, 2).contiguous() + out = ( + F.scaled_dot_product_attention( + query, + key, + value + ) + .transpose(1, 2) + .contiguous() + ) + return out + + @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_scale_shift_norm(norm, x, scale, shift): return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) -# @torch.compile() + @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_gate_sum(x, out, gate): return (x + gate * out).to(torch.bfloat16) -# @torch.compile() + @torch.autocast(device_type="cuda", enabled=False) def apply_rotary(x, rope): x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) @@ -70,16 +171,6 @@ def apply_rotary(x, rope): return x_out.reshape(*x.shape).to(torch.bfloat16) -@torch.autocast(device_type="cuda", enabled=False) -def get_freqs(dim, max_period=10000.0): - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=dim, dtype=torch.float32) - / dim - ) - return freqs - - class TimeEmbeddings(nn.Module): def __init__(self, model_dim, time_dim, max_period=10000.0): super().__init__() @@ -93,12 +184,16 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, time_dim, bias=True) + @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, time): args = torch.outer(time, self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed + def reset_dtype(self): + self.freqs = get_freqs(self.model_dim // 2, self.max_period) + class TextEmbeddings(nn.Module): def __init__(self, text_dim, model_dim): @@ -116,7 +211,7 @@ def __init__(self, visual_dim, model_dim, patch_size): super().__init__() self.patch_size = patch_size self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) - + def forward(self, x): batch_size, duration, height, width, dim = x.shape x = ( @@ -124,7 +219,7 @@ def forward(self, x): batch_size, duration // self.patch_size[0], self.patch_size[0], - height // self.patch_size[1], + height // self.patch_size[1], self.patch_size[1], width // self.patch_size[2], self.patch_size[2], @@ -137,15 +232,6 @@ def forward(self, x): class RoPE1D(nn.Module): - """ - 1D Rotary Positional Embeddings for text sequences. - - Args: - dim: Dimension of the rotary embeddings - max_pos: Maximum sequence length - max_period: Maximum period for sinusoidal embeddings - """ - def __init__(self, dim, max_pos=1024, max_period=10000.0): super().__init__() self.max_period = max_period @@ -153,22 +239,21 @@ def __init__(self, dim, max_pos=1024, max_period=10000.0): self.max_pos = max_pos freq = get_freqs(dim // 2, max_period) pos = torch.arange(max_pos, dtype=freq.dtype) - self.register_buffer("args", torch.outer(pos, freq), persistent=False) + self.register_buffer(f"args", torch.outer(pos, freq), persistent=False) + @torch.autocast(device_type="cuda", enabled=False) def forward(self, pos): - """ - Args: - pos: Position indices of shape [seq_len] or [batch_size, seq_len] - - Returns: - Rotary embeddings of shape [seq_len, 1, 2, 2] - """ args = self.args[pos] cosine = torch.cos(args) sine = torch.sin(args) rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) + + def reset_dtype(self): + freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device) + pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device) + self.args = torch.outer(pos, freq) class RoPE3D(nn.Module): @@ -186,22 +271,29 @@ def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): @torch.autocast(device_type="cuda", enabled=False) def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): batch_size, duration, height, width = shape - args_t = self.args_0[pos[0]] / scale_factor[0] args_h = self.args_1[pos[1]] / scale_factor[1] args_w = self.args_2[pos[2]] / scale_factor[2] - args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1) - args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1) - args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1) - - args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) - + args = torch.cat( + [ + args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1), + args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1), + args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1), + ], + dim=-1, + ) cosine = torch.cos(args) sine = torch.sin(args) rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) + + def reset_dtype(self): + for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)): + freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device) + pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device) + setattr(self, f'args_{i}', torch.outer(pos, freq)) class Modulation(nn.Module): @@ -212,10 +304,11 @@ def __init__(self, time_dim, model_dim, num_params): self.out_layer.weight.data.zero_() self.out_layer.bias.data.zero_() + @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): return self.out_layer(self.activation(x)) - + class MultiheadSelfAttentionEnc(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() @@ -227,9 +320,10 @@ def __init__(self, num_channels, head_dim): self.to_value = nn.Linear(num_channels, num_channels, bias=True) self.query_norm = nn.RMSNorm(head_dim) self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - def forward(self, x, rope): + def get_qkv(self, x): query = self.to_query(x) key = self.to_key(x) value = self.to_value(x) @@ -239,26 +333,31 @@ def forward(self, x, rope): key = key.reshape(*shape, self.num_heads, -1) value = value.reshape(*shape, self.num_heads, -1) - query = self.query_norm(query.float()).type_as(query) - key = self.key_norm(key.float()).type_as(key) + return query, key, value + + def norm_qk(self, q, k): + q = self.query_norm(q.float()).type_as(q) + k = self.key_norm(k.float()).type_as(k) + return q, k + + def scaled_dot_product_attention(self, query, key, value): + out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + return out + + def out_l(self, x): + return self.out_layer(x) + def forward(self, x, rope): + query, key, value = self.get_qkv(x) + query, key = self.norm_qk(query, key) query = apply_rotary(query, rope).type_as(query) key = apply_rotary(key, rope).type_as(key) - # Use torch's scaled_dot_product_attention - # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionEnc SHAPE") - # out = F.scaled_dot_product_attention( - # query.permute(0, 2, 1, 3), - # key.permute(0, 2, 1, 3), - # value.permute(0, 2, 1, 3), - # ).permute(0, 2, 1, 3).flatten(-2, -1) - - out = FA(q=query, k=key, v=value).flatten(-2, -1) + out = self.scaled_dot_product_attention(query, key, value) - out = self.out_layer(out) + out = self.out_l(out) return out - class MultiheadSelfAttentionDec(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() @@ -270,9 +369,10 @@ def __init__(self, num_channels, head_dim): self.to_value = nn.Linear(num_channels, num_channels, bias=True) self.query_norm = nn.RMSNorm(head_dim) self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - def forward(self, x, rope, sparse_params=None): + def get_qkv(self, x): query = self.to_query(x) key = self.to_key(x) value = self.to_value(x) @@ -282,24 +382,29 @@ def forward(self, x, rope, sparse_params=None): key = key.reshape(*shape, self.num_heads, -1) value = value.reshape(*shape, self.num_heads, -1) - query = self.query_norm(query.float()).type_as(query) - key = self.key_norm(key.float()).type_as(key) + return query, key, value + + def norm_qk(self, q, k): + q = self.query_norm(q.float()).type_as(q) + k = self.key_norm(k.float()).type_as(k) + return q, k + + def attention(self, query, key, value): + out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + return out + + def out_l(self, x): + return self.out_layer(x) + def forward(self, x, rope, sparse_params=None): + query, key, value = self.get_qkv(x) + query, key = self.norm_qk(query, key) query = apply_rotary(query, rope).type_as(query) key = apply_rotary(key, rope).type_as(key) - # Use standard attention (can be extended with sparse attention) - # out = F.scaled_dot_product_attention( - # query.permute(0, 2, 1, 3), - # key.permute(0, 2, 1, 3), - # value.permute(0, 2, 1, 3), - # ).permute(0, 2, 1, 3).flatten(-2, -1) - - # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionDec SHAPE") - - out = FA(q=query, k=key, v=value).flatten(-2, -1) + out = self.attention(query, key, value) - out = self.out_layer(out) + out = self.out_l(out) return out @@ -314,32 +419,39 @@ def __init__(self, num_channels, head_dim): self.to_value = nn.Linear(num_channels, num_channels, bias=True) self.query_norm = nn.RMSNorm(head_dim) self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - def forward(self, x, cond): + def get_qkv(self, x, cond): query = self.to_query(x) key = self.to_key(cond) value = self.to_value(cond) - + shape, cond_shape = query.shape[:-1], key.shape[:-1] query = query.reshape(*shape, self.num_heads, -1) key = key.reshape(*cond_shape, self.num_heads, -1) value = value.reshape(*cond_shape, self.num_heads, -1) - - query = self.query_norm(query.float()).type_as(query) - key = self.key_norm(key.float()).type_as(key) - - # out = F.scaled_dot_product_attention( - # query.permute(0, 2, 1, 3), - # key.permute(0, 2, 1, 3), - # value.permute(0, 2, 1, 3), - # ).permute(0, 2, 1, 3).flatten(-2, -1) - - # print(query.shape, key.shape, value.shape, "QKV MultiheadCrossAttention SHAPE") - out = FA(q=query, k=key, v=value).flatten(-2, -1) + return query, key, value + + def norm_qk(self, q, k): + q = self.query_norm(q.float()).type_as(q) + k = self.key_norm(k.float()).type_as(k) + return q, k - out = self.out_layer(out) + def attention(self, query, key, value): + out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + return out + + def out_l(self, x): + return self.out_layer(x) + + def forward(self, x, cond): + query, key, value = self.get_qkv(x, cond) + query, key = self.norm_qk(query, key) + + out = self.attention(query, key, value) + out = self.out_l(out) return out @@ -354,6 +466,48 @@ def forward(self, x): return self.out_layer(self.activation(self.in_layer(x))) +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.out_layer = nn.Linear( + model_dim, math.prod(patch_size) * visual_dim, bias=True + ) + + def forward(self, visual_embed, text_embed, time_embed): + shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + visual_embed = apply_scale_shift_norm( + self.norm, + visual_embed, + scale[:, None, None], + shift[:, None, None], + ).type_as(visual_embed) + x = self.out_layer(visual_embed) + + batch_size, duration, height, width, _ = x.shape + x = ( + x.view( + batch_size, + duration, + height, + width, + -1, + self.patch_size[0], + self.patch_size[1], + self.patch_size[2], + ) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(1, 2) + .flatten(2, 3) + .flatten(3, 4) + ) + return x + + + + class TransformerEncoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() @@ -366,9 +520,7 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.feed_forward = FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk( - self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 - ) + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) out = self.self_attention(out, rope) @@ -416,246 +568,116 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): return visual_embed -class OutLayer(nn.Module): - def __init__(self, model_dim, time_dim, visual_dim, patch_size): - super().__init__() - self.patch_size = patch_size - self.modulation = Modulation(time_dim, model_dim, 2) - self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.out_layer = nn.Linear( - model_dim, math.prod(patch_size) * visual_dim, bias=True - ) - - def forward(self, visual_embed, text_embed, time_embed): - # Handle the new batch dimension: [batch, duration, height, width, model_dim] - batch_size, duration, height, width, _ = visual_embed.shape - - shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) - - # Apply modulation with proper broadcasting for the new shape - visual_embed = apply_scale_shift_norm( - self.norm, - visual_embed, - scale[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] - shift[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] - ).type_as(visual_embed) - - x = self.out_layer(visual_embed) - - # Now x has shape [batch, duration, height, width, patch_prod * visual_dim] - x = ( - x.view( - batch_size, - duration, - height, - width, - -1, - self.patch_size[0], - self.patch_size[1], - self.patch_size[2], - ) - .permute(0, 5, 1, 6, 2, 7, 3, 4) # [batch, patch_t, duration, patch_h, height, patch_w, width, features] - .flatten(1, 2) # [batch, patch_t * duration, height, patch_w, width, features] - .flatten(2, 3) # [batch, patch_t * duration, patch_h * height, width, features] - .flatten(3, 4) # [batch, patch_t * duration, patch_h * height, patch_w * width] - ) - return x - - -@maybe_allow_in_graph class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): - r""" - A 3D Transformer model for video generation used in Kandinsky 5.0. - - This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods implemented for all models (such as downloading or saving). - - Args: - in_visual_dim (`int`, defaults to 16): - Number of channels in the input visual latent. - out_visual_dim (`int`, defaults to 16): - Number of channels in the output visual latent. - time_dim (`int`, defaults to 512): - Dimension of the time embeddings. - patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): - Patch size for the visual embeddings (temporal, height, width). - model_dim (`int`, defaults to 1792): - Hidden dimension of the transformer model. - ff_dim (`int`, defaults to 7168): - Intermediate dimension of the feed-forward networks. - num_text_blocks (`int`, defaults to 2): - Number of transformer blocks in the text encoder. - num_visual_blocks (`int`, defaults to 32): - Number of transformer blocks in the visual decoder. - axes_dims (`Tuple[int]`, defaults to `(16, 24, 24)`): - Dimensions for the rotary positional embeddings (temporal, height, width). - visual_cond (`bool`, defaults to `True`): - Whether to use visual conditioning (for image/video conditioning). - in_text_dim (`int`, defaults to 3584): - Dimension of the text embeddings from Qwen2.5-VL. - in_text_dim2 (`int`, defaults to 768): - Dimension of the pooled text embeddings from CLIP. """ - + A 3D Diffusion Transformer model for video-like data. + """ + @register_to_config def __init__( self, - in_visual_dim: int = 16, - out_visual_dim: int = 16, - time_dim: int = 512, - patch_size: Tuple[int, int, int] = (1, 2, 2), - model_dim: int = 1792, - ff_dim: int = 7168, - num_text_blocks: int = 2, - num_visual_blocks: int = 32, - axes_dims: Tuple[int, int, int] = (16, 24, 24), - visual_cond: bool = True, - in_text_dim: int = 3584, - in_text_dim2: int = 768, + in_visual_dim=4, + in_text_dim=3584, + in_text_dim2=768, + time_dim=512, + out_visual_dim=4, + patch_size=(1, 2, 2), + model_dim=2048, + ff_dim=5120, + num_text_blocks=2, + num_visual_blocks=32, + axes_dims=(16, 24, 24), + visual_cond=False, ): super().__init__() - + + head_dim = sum(axes_dims) self.in_visual_dim = in_visual_dim self.model_dim = model_dim self.patch_size = patch_size self.visual_cond = visual_cond - # Calculate head dimension for attention - head_dim = sum(axes_dims) - - # Determine visual embedding dimension based on conditioning visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim - - # 1. Embedding layers self.time_embeddings = TimeEmbeddings(model_dim, time_dim) self.text_embeddings = TextEmbeddings(in_text_dim, model_dim) self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim) self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size) - # 2. Rotary positional embeddings self.text_rope_embeddings = RoPE1D(head_dim) - self.visual_rope_embeddings = RoPE3D(axes_dims) - - # 3. Transformer blocks - self.text_transformer_blocks = nn.ModuleList([ - TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) - for _ in range(num_text_blocks) - ]) + self.text_transformer_blocks = nn.ModuleList( + [ + TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_text_blocks) + ] + ) - self.visual_transformer_blocks = nn.ModuleList([ - TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) - for _ in range(num_visual_blocks) - ]) + self.visual_rope_embeddings = RoPE3D(axes_dims) + self.visual_transformer_blocks = nn.ModuleList( + [ + TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_visual_blocks) + ] + ) - # 4. Output layer self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size) - self.gradient_checkpointing = False + def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x, + text_rope_pos): + text_embed = self.text_embeddings(text_embed) + time_embed = self.time_embeddings(time) + time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) + visual_embed = self.visual_embeddings(x) + text_rope = self.text_rope_embeddings(text_rope_pos) + text_rope = text_rope.unsqueeze(dim=0) + return text_embed, time_embed, text_rope, visual_embed + + def before_visual_transformer_blocks(self, visual_embed, visual_rope_pos, scale_factor, + sparse_params): + visual_shape = visual_embed.shape[:-1] + visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False + visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, + block_mask=to_fractal) + return visual_embed, visual_shape, to_fractal, visual_rope + + def after_blocks(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed): + visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) + x = self.out_layer(visual_embed, text_embed, time_embed) + return x def forward( self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - pooled_text_embed: torch.Tensor, - timestep: torch.Tensor, - visual_rope_pos: List[torch.Tensor], - text_rope_pos: torch.Tensor, - scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), - sparse_params: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[torch.Tensor, Transformer2DModelOutput]: - """ - Forward pass of the Kandinsky 5.0 3D Transformer. - - Args: - hidden_states (`torch.Tensor`): - Input visual latent tensor of shape `(batch_size, num_frames, height, width, channels)`. - encoder_hidden_states (`torch.Tensor`): - Text embeddings from Qwen2.5-VL of shape `(batch_size, sequence_length, text_dim)`. - pooled_text_embed (`torch.Tensor`): - Pooled text embeddings from CLIP of shape `(batch_size, pooled_text_dim)`. - timestep (`torch.Tensor`): - Timestep tensor of shape `(batch_size,)` or `(batch_size * num_frames,)`. - visual_rope_pos (`List[torch.Tensor]`): - List of tensors for visual rotary positional embeddings [temporal, height, width]. - text_rope_pos (`torch.Tensor`): - Tensor for text rotary positional embeddings. - scale_factor (`Tuple[float, float, float]`, defaults to `(1.0, 1.0, 1.0)`): - Scale factors for rotary positional embeddings. - sparse_params (`Dict[str, Any]`, *optional*): - Parameters for sparse attention. - return_dict (`bool`, defaults to `True`): - Whether to return a dictionary or a tensor. - - Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - If `return_dict` is `True`, a [`~models.transformer_2d.Transformer2DModelOutput`] is returned, - otherwise a `tuple` where the first element is the sample tensor. - """ - batch_size, num_frames, height, width, channels = hidden_states.shape - - # 1. Process text embeddings - text_embed = self.text_embeddings(encoder_hidden_states) - time_embed = self.time_embeddings(timestep) - - # Add pooled text embedding to time embedding - pooled_embed = self.pooled_text_embeddings(pooled_text_embed) - time_embed = time_embed + pooled_embed - - # visual_embed shape: [batch_size, seq_len, model_dim] - visual_embed = self.visual_embeddings(hidden_states) - - # 3. Text rotary embeddings - text_rope = self.text_rope_embeddings(text_rope_pos) + hidden_states, # x + encoder_hidden_states, #text_embed + timestep, # time + pooled_projections, #pooled_text_embed, + visual_rope_pos, + text_rope_pos, + scale_factor=(1.0, 1.0, 1.0), + sparse_params=None, + return_dict=True, + ): + x = hidden_states + text_embed = encoder_hidden_states + time = timestep + pooled_text_embed = pooled_projections + + text_embed, time_embed, text_rope, visual_embed = self.before_text_transformer_blocks( + text_embed, time, pooled_text_embed, x, text_rope_pos) - # 4. Text transformer blocks - i = 0 - for text_block in self.text_transformer_blocks: - if self.gradient_checkpointing and self.training: - text_embed = torch.utils.checkpoint.checkpoint( - text_block, text_embed, time_embed, text_rope, use_reentrant=False - ) - - else: - text_embed = text_block(text_embed, time_embed, text_rope) + for text_transformer_block in self.text_transformer_blocks: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) - i += 1 + visual_embed, visual_shape, to_fractal, visual_rope = self.before_visual_transformer_blocks( + visual_embed, visual_rope_pos, scale_factor, sparse_params) - # 5. Prepare visual rope - visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) - - # visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) - # visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) - visual_embed = visual_embed.flatten(1, 3) - visual_rope = visual_rope.flatten(1, 3) + for visual_transformer_block in self.visual_transformer_blocks: + visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, + visual_rope, sparse_params) + + x = self.after_blocks(visual_embed, visual_shape, to_fractal, text_embed, time_embed) - # 6. Visual transformer blocks - i = 0 - for visual_block in self.visual_transformer_blocks: - if self.gradient_checkpointing and self.training: - visual_embed = torch.utils.checkpoint.checkpoint( - visual_block, - visual_embed, - text_embed, - time_embed, - visual_rope, - # visual_rope_flat, - sparse_params, - use_reentrant=False, - ) - else: - visual_embed = visual_block( - visual_embed, text_embed, time_embed, visual_rope, sparse_params - ) - - i += 1 - - # 7. Output projection - visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1) - output = self.out_layer(visual_embed, text_embed, time_embed) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) + if return_dict: + return Transformer2DModelOutput(sample=x) + + return x diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 9dbf31fea960..214b2b953c1c 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -300,7 +300,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, - num_frames: int = 25, + num_frames: int = 121, num_inference_steps: int = 50, guidance_scale: float = 5.0, scheduler_scale: float = 10.0, @@ -354,6 +354,11 @@ def __call__( the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + self.transformer.time_embeddings.reset_dtype() + self.transformer.text_rope_embeddings.reset_dtype() + self.transformer.visual_rope_embeddings.reset_dtype() + + dtype = self.transformer.dtype if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -394,7 +399,7 @@ def __call__( width=width, num_frames=num_frames, visual_cond=self.transformer.visual_cond, - dtype=self.transformer.dtype, + dtype=dtype, device=device, generator=generator, latents=latents, @@ -418,41 +423,39 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - timestep = t.unsqueeze(0) + timestep = t.unsqueeze(0).flatten() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - # print(latents.shape) + with torch.autocast(device_type="cuda", dtype=dtype): pred_velocity = self.transformer( - latents, - text_embeds["text_embeds"], - text_embeds["pooled_embed"], - timestep, - visual_rope_pos, - text_rope_pos, + hidden_states=latents, + encoder_hidden_states=text_embeds["text_embeds"], + pooled_projections=text_embeds["pooled_embed"], + timestep=timestep, + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, scale_factor=(1, 2, 2), sparse_params=None, - return_dict=False - )[0] - + return_dict=True + ).sample + if guidance_scale > 1.0 and negative_text_embeds is not None: uncond_pred_velocity = self.transformer( - latents, - negative_text_embeds["text_embeds"], - negative_text_embeds["pooled_embed"], - timestep, - visual_rope_pos, - negative_text_rope_pos, + hidden_states=latents, + encoder_hidden_states=negative_text_embeds["text_embeds"], + pooled_projections=negative_text_embeds["pooled_embed"], + timestep=timestep, + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, scale_factor=(1, 2, 2), sparse_params=None, - return_dict=False - )[0] + return_dict=True + ).sample pred_velocity = uncond_pred_velocity + guidance_scale * ( pred_velocity - uncond_pred_velocity ) - latents = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] - latents = torch.cat([latents, visual_cond], dim=-1) + latents[:, :, :, :, :16] = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} From c8f3a36fba49799c21161858872f03ffde7bef57 Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 10 Oct 2025 14:39:59 +0000 Subject: [PATCH 04/77] rewrite Kandinsky5T2VPipeline to diffusers style --- .../kandinsky5/pipeline_kandinsky.py | 531 ++++++++++++++---- 1 file changed, 407 insertions(+), 124 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 214b2b953c1c..cea079251bc3 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -75,6 +75,101 @@ ``` """ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer +import torchvision +from torchvision.transforms import ToPILImage + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5T2VPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=512, + ... width=768, + ... num_frames=25, + ... num_inference_steps=50, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=6) + ``` +""" + + +def basic_clean(text): + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): r""" @@ -96,9 +191,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): Frozen CLIP text encoder. tokenizer_2 ([`CLIPTokenizer`]): Tokenizer for CLIP. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -125,6 +222,7 @@ def __init__( self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _encode_prompt_qwen( self, @@ -132,9 +230,12 @@ def _encode_prompt_qwen( device: Optional[torch.device] = None, num_videos_per_prompt: int = 1, max_sequence_length: int = 256, + dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(p) for p in prompt] # Kandinsky specific prompt template prompt_template = "\n".join([ @@ -180,16 +281,19 @@ def _encode_prompt_qwen( embeds = embeds.repeat(1, num_videos_per_prompt, 1) embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return embeds, cu_seqlens + return embeds.to(dtype), cu_seqlens def _encode_prompt_clip( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, num_videos_per_prompt: int = 1, + dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(p) for p in prompt] inputs = self.tokenizer_2( prompt, @@ -208,7 +312,7 @@ def _encode_prompt_clip( pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) - return pooled_embed + return pooled_embed.to(dtype) def encode_prompt( self, @@ -216,34 +320,151 @@ def encode_prompt( negative_prompt: Optional[Union[str, List[str]]] = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for text encoding. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ device = device or self._execution_device - prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt) - pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + dtype=dtype, + ) + else: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt) - negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen( + prompt=negative_prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + negative_prompt_embeds_clip = self._encode_prompt_clip( + prompt=negative_prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + dtype=dtype, + ) else: - negative_prompt_embeds = None - negative_pooled_embed = None + negative_prompt_embeds_qwen = None + negative_prompt_embeds_clip = None negative_cu_seqlens = None - text_embeds = { - "text_embeds": prompt_embeds, - "pooled_embed": pooled_embed, + prompt_embeds_dict = { + "text_embeds": prompt_embeds_qwen, + "pooled_embed": prompt_embeds_clip, } - negative_text_embeds = { - "text_embeds": negative_prompt_embeds, - "pooled_embed": negative_pooled_embed, + negative_prompt_embeds_dict = { + "text_embeds": negative_prompt_embeds_qwen, + "pooled_embed": negative_prompt_embeds_clip, } if do_classifier_free_guidance else None - return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens + return prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") def prepare_latents( self, @@ -252,34 +473,31 @@ def prepare_latents( height: int = 480, width: int = 832, num_frames: int = 81, - visual_cond: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: - num_latent_frames = latents.shape[1] - latents = latents.to(device=device, dtype=dtype) - - else: - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - num_channels_latents, + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - if visual_cond: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if self.transformer.visual_cond: # For visual conditioning, concatenate with zeros and mask visual_cond = torch.zeros_like(latents) visual_cond_mask = torch.zeros( @@ -291,26 +509,46 @@ def prepare_latents( return latents + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, - num_frames: int = 121, + num_frames: int = 25, num_inference_steps: int = 50, guidance_scale: float = 5.0, scheduler_scale: float = 10.0, - num_videos_per_prompt: int = 1, - generator: Optional[torch.Generator] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, **kwargs, ): r""" @@ -318,9 +556,10 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the video generation. + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to avoid during video generation. + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). height (`int`, defaults to `512`): The height in pixels of the generated video. width (`int`, defaults to `768`): @@ -335,82 +574,109 @@ def __call__( Scale factor for the custom flow matching scheduler. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. - generator (`torch.Generator`, *optional*): + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A torch generator to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`KandinskyPipelineOutput`]. - callback_on_step_end (`Callable`, *optional*): + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function that is called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length for text encoding. Examples: Returns: [`~KandinskyPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated images and the second element is a list of `bool`s - indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + the first element is a list with the generated images. """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Reset embeddings dtype self.transformer.time_embeddings.reset_dtype() self.transformer.text_rope_embeddings.reset_dtype() self.transformer.visual_rope_embeddings.reset_dtype() - - dtype = self.transformer.dtype - - if height % 16 != 0 or width % 16 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - if isinstance(prompt, str): - batch_size = 1 - else: - batch_size = len(prompt) + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) - device = self._execution_device - do_classifier_free_guidance = guidance_scale > 1.0 - if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) - - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + dtype = self.transformer.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, device=device, + dtype=dtype, ) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables num_channels_latents = 16 latents = self.prepare_latents( - batch_size=batch_size * num_videos_per_prompt, - num_channels_latents=16, - height=height, - width=width, - num_frames=num_frames, - visual_cond=self.transformer.visual_cond, - dtype=dtype, - device=device, - generator=generator, - latents=latents, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents, ) - - visual_cond = latents[:, :, :, :, 16:] + # 6. Prepare rope positions + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 visual_rope_pos = [ - torch.arange(num_frames // 4 + 1, device=device), - torch.arange(height // 8 // 2, device=device), - torch.arange(width // 8 // 2, device=device), + torch.arange(num_latent_frames, device=device), + torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), + torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), ] text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device) @@ -421,52 +687,72 @@ def __call__( else None ) + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + timestep = t.unsqueeze(0).flatten() - with torch.autocast(device_type="cuda", dtype=dtype): - pred_velocity = self.transformer( - hidden_states=latents, - encoder_hidden_states=text_embeds["text_embeds"], - pooled_projections=text_embeds["pooled_embed"], - timestep=timestep, + + + # Predict noise residual + # with torch.autocast(device_type="cuda", dtype=dtype): + pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype), + pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=None, + return_dict=True + ).sample + + if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None: + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_dict["text_embeds"].to(dtype), + pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype), + timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, - text_rope_pos=text_rope_pos, - scale_factor=(1, 2, 2), + text_rope_pos=negative_text_rope_pos, + scale_factor=(1, 2, 2), sparse_params=None, return_dict=True ).sample - if guidance_scale > 1.0 and negative_text_embeds is not None: - uncond_pred_velocity = self.transformer( - hidden_states=latents, - encoder_hidden_states=negative_text_embeds["text_embeds"], - pooled_projections=negative_text_embeds["pooled_embed"], - timestep=timestep, - visual_rope_pos=visual_rope_pos, - text_rope_pos=negative_text_rope_pos, - scale_factor=(1, 2, 2), - sparse_params=None, - return_dict=True - ).sample - - pred_velocity = uncond_pred_velocity + guidance_scale * ( - pred_velocity - uncond_pred_velocity - ) + pred_velocity = uncond_pred_velocity + guidance_scale * ( + pred_velocity - uncond_pred_velocity + ) - latents[:, :, :, :, :16] = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] + # Compute previous sample + latents[:, :, :, :, :16] = self.scheduler.step( + pred_velocity, t, latents[:, :, :, :, :16], return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) - + prompt_embeds_dict = callback_outputs.pop("prompt_embeds", prompt_embeds_dict) + negative_prompt_embeds_dict = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds_dict) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing latents = latents[:, :, :, :, :16] # 9. Decode latents to video @@ -477,26 +763,23 @@ def __call__( batch_size, num_videos_per_prompt, (num_frames - 1) // self.vae_scale_factor_temporal + 1, - height // 8, - width // 8, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, 16, ) video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] - video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8) + video = video.reshape( + batch_size * num_videos_per_prompt, + 16, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial + ) # Normalize and decode video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample - video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) - # Convert to output format - if output_type == "pil": - if num_frames == 1: - # Single image - video = [ToPILImage()(frame.squeeze(1)) for frame in video] - else: - # Video frames - video = [video[i] for i in range(video.shape[0])] - + video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From 723d149dc1dad0db009abcb210e671a775b23db6 Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 10 Oct 2025 17:00:23 +0000 Subject: [PATCH 05/77] add multiprompt support --- .../kandinsky5/pipeline_kandinsky.py | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index cea079251bc3..a417d9967548 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -269,18 +269,21 @@ def _encode_prompt_qwen( output_hidden_states=True, )["hidden_states"][-1][:, crop_start:] + batch_size = len(prompt) + attention_mask = inputs["attention_mask"][:, crop_start:] - embeds = embeds[attention_mask.bool()] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) - # duplicate for each generation per prompt - batch_size = len(prompt) - seq_len = embeds.shape[0] // batch_size - embeds = embeds.reshape(batch_size, seq_len, -1) - embeds = embeds.repeat(1, num_videos_per_prompt, 1) - embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) +# # duplicate for each generation per prompt +# seq_len = embeds.shape[0] // batch_size +# embeds = embeds.reshape(batch_size, seq_len, -1) +# embeds = embeds.repeat(1, num_videos_per_prompt, 1) +# embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) +# print(embeds.shape, cu_seqlens, "ENCODE PROMPT") + embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) + return embeds.to(dtype), cu_seqlens def _encode_prompt_clip( @@ -679,10 +682,10 @@ def __call__( torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), ] - text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device) + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) negative_text_rope_pos = ( - torch.arange(negative_cu_seqlens[-1].item(), device=device) + torch.arange(negative_cu_seqlens.diff().max().item(), device=device) if negative_cu_seqlens is not None else None ) @@ -696,12 +699,19 @@ def __call__( if self.interrupt: continue - timestep = t.unsqueeze(0).flatten() - - - - # Predict noise residual - # with torch.autocast(device_type="cuda", dtype=dtype): + timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) + + # Predict noise residual + # print( + # latents.shape, + # prompt_embeds_dict["text_embeds"].shape, + # prompt_embeds_dict["pooled_embed"].shape, + # timestep.shape, + # [el.shape for el in visual_rope_pos], + # text_rope_pos.shape, + # prompt_cu_seqlens, + # ) + pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype), From 22e14bdac82fd5c100c4b1f34f5726c9c4aa4705 Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 10 Oct 2025 17:03:09 +0000 Subject: [PATCH 06/77] remove prints in pipeline --- .../kandinsky5/pipeline_kandinsky.py | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index a417d9967548..5d1eb7d60507 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -274,14 +274,6 @@ def _encode_prompt_qwen( attention_mask = inputs["attention_mask"][:, crop_start:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) - -# # duplicate for each generation per prompt -# seq_len = embeds.shape[0] // batch_size -# embeds = embeds.reshape(batch_size, seq_len, -1) -# embeds = embeds.repeat(1, num_videos_per_prompt, 1) -# embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - -# print(embeds.shape, cu_seqlens, "ENCODE PROMPT") embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) return embeds.to(dtype), cu_seqlens @@ -642,7 +634,7 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] - + # 3. Encode input prompt prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( prompt=prompt, @@ -702,16 +694,6 @@ def __call__( timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) # Predict noise residual - # print( - # latents.shape, - # prompt_embeds_dict["text_embeds"].shape, - # prompt_embeds_dict["pooled_embed"].shape, - # timestep.shape, - # [el.shape for el in visual_rope_pos], - # text_rope_pos.shape, - # prompt_cu_seqlens, - # ) - pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype), From 70fa62baeaa019e7a47abb5e3a2662ba509d5bb8 Mon Sep 17 00:00:00 2001 From: leffff Date: Sun, 12 Oct 2025 21:59:23 +0000 Subject: [PATCH 07/77] add nabla attention --- .../transformers/transformer_kandinsky.py | 84 +++++++++++++++++-- .../kandinsky5/pipeline_kandinsky.py | 69 ++++++++++++++- 2 files changed, 142 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 3bbb9421f7ce..45d4ccdf9af3 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -64,8 +64,8 @@ def get_freqs(dim, max_period=10000.0): def fractal_flatten(x, rope, shape, block_mask=False): if block_mask: pixel_size = 8 - x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) - rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1) x = x.flatten(1, 2) rope = rope.flatten(1, 2) else: @@ -77,15 +77,15 @@ def fractal_flatten(x, rope, shape, block_mask=False): def fractal_unflatten(x, shape, block_mask=False): if block_mask: pixel_size = 8 - x = x.reshape(-1, pixel_size**2, *x.shape[1:]) - x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1) else: x = x.reshape(*shape, *x.shape[2:]) return x def local_patching(x, shape, group_size, dim=0): - duration, height, width = shape + batch_size, duration, height, width = shape g1, g2, g3 = group_size x = x.reshape( *x.shape[:dim], @@ -112,7 +112,7 @@ def local_patching(x, shape, group_size, dim=0): def local_merge(x, shape, group_size, dim=0): - duration, height, width = shape + batch_size, duration, height, width = shape g1, g2, g3 = group_size x = x.reshape( *x.shape[:dim], @@ -138,6 +138,36 @@ def local_merge(x, shape, group_size, dim=0): return x +def nablaT_v2( + q: Tensor, + k: Tensor, + sta: Tensor, + thr: float = 0.9, +) -> BlockMask: + # Map estimation + B, h, S, D = q.shape + s1 = S // 64 + qa = q.reshape(B, h, s1, 64, D).mean(-2) + ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1) + map = qa @ ka + + map = torch.softmax(map / math.sqrt(D), dim=-1) + # Map binarization + vals, inds = map.sort(-1) + cvals = vals.cumsum_(-1) + mask = (cvals >= 1 - thr).int() + mask = mask.gather(-1, inds.argsort(-1)) + + mask = torch.logical_or(mask, sta) + + # BlockMask creation + kv_nb = mask.sum(-1).to(torch.int32) + kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) + return BlockMask.from_kv_blocks( + torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None + ) + + def sdpa(q, k, v): query = q.transpose(1, 2).contiguous() key = k.transpose(1, 2).contiguous() @@ -392,6 +422,29 @@ def norm_qk(self, q, k): def attention(self, query, key, value): out = sdpa(q=query, k=key, v=value).flatten(-2, -1) return out + + def nabla(self, query, key, value, sparse_params=None): + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + block_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) + out = ( + flex_attention( + query, + key, + value, + block_mask=block_mask + ) + .transpose(1, 2) + .contiguous() + ) + out = out.flatten(-2, -1) + return out def out_l(self, x): return self.out_layer(x) @@ -402,7 +455,10 @@ def forward(self, x, rope, sparse_params=None): query = apply_rotary(query, rope).type_as(query) key = apply_rotary(key, rope).type_as(key) - out = self.attention(query, key, value) + if sparse_params is not None: + out = self.nabla(query, key, value, sparse_params=sparse_params) + else: + out = self.attention(query, key, value) out = self.out_l(out) return out @@ -587,7 +643,18 @@ def __init__( num_text_blocks=2, num_visual_blocks=32, axes_dims=(16, 24, 24), - visual_cond=False, + visual_cond=False, + attention_type: str = "regular", + attention_causal: bool = None, #Deffault for Nabla: false, + attention_local: bool = None, #Deffault for Nabla: false, + attention_glob:bool = None, #Deffault for Nabla: false, + attention_window: int = None, #Deffault for Nabla: 3 + attention_P: float = None, #Deffault for Nabla: 0.9 + attention_wT: int = None, #Deffault for Nabla: 11 + attention_wW:int = None, #Deffault for Nabla: 3 + attention_wH:int = None, #Deffault for Nabla: 3 + attention_add_sta: bool = None, #Deffault for Nabla: true + attention_method: str = None, #Deffault for Nabla: "topcdf" ): super().__init__() @@ -596,6 +663,7 @@ def __init__( self.model_dim = model_dim self.patch_size = patch_size self.visual_cond = visual_cond + self.attention_type = attention_type visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim self.time_embeddings = TimeEmbeddings(model_dim, time_dim) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 5d1eb7d60507..05230a604fa4 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -223,6 +223,66 @@ def __init__( self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + @staticmethod + def fast_sta_nabla( + T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda" + ) -> torch.Tensor: + l = torch.Tensor([T, H, W]).amax() + r = torch.arange(0, l, 1, dtype=torch.int16, device=device) + mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() + sta_t, sta_h, sta_w = ( + mat[:T, :T].flatten(), + mat[:H, :H].flatten(), + mat[:W, :W].flatten(), + ) + sta_t = sta_t <= wT // 2 + sta_h = sta_h <= wH // 2 + sta_w = sta_w <= wW // 2 + sta_hw = ( + (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)) + .reshape(H, H, W, W) + .transpose(1, 2) + .flatten() + ) + sta = ( + (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)) + .reshape(T, T, H * W, H * W) + .transpose(1, 2) + ) + return sta.reshape(T * H * W, T * H * W) + + def get_sparse_params(self, sample, device): + assert self.transformer.config.patch_size[0] == 1 + B, T, H, W, _ = sample.shape + T, H, W = ( + T // self.transformer.config.patch_size[0], + H // self.transformer.config.patch_size[1], + W // self.transformer.config.patch_size[2], + ) + if self.transformer.config.attention_type == "nabla": + sta_mask = self.fast_sta_nabla( + T, H // 8, W // 8, + self.transformer.config.attention_wT, self.transformer.config.attention_wH, self.transformer.config.attention_wW, + device=device + ) + + sparse_params = { + "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), + "attention_type": self.transformer.config.attention_type, + "to_fractal": True, + "P": self.transformer.config.attention_P, + "wT": self.transformer.config.attention_wT, + "wW": self.transformer.config.attention_wW, + "wH": self.transformer.config.attention_wH, + "add_sta": self.transformer.config.attention_add_sta, + "visual_shape": (T, H, W), + "method": self.transformer.config.attention_method, + } + else: + sparse_params = None + + return sparse_params def _encode_prompt_qwen( self, @@ -681,8 +741,11 @@ def __call__( if negative_cu_seqlens is not None else None ) + + # 7. Sparse Params + sparse_params = self.get_sparse_params(latents, device) - # 7. Denoising loop + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -702,7 +765,7 @@ def __call__( visual_rope_pos=visual_rope_pos, text_rope_pos=text_rope_pos, scale_factor=(1, 2, 2), - sparse_params=None, + sparse_params=sparse_params, return_dict=True ).sample @@ -715,7 +778,7 @@ def __call__( visual_rope_pos=visual_rope_pos, text_rope_pos=negative_text_rope_pos, scale_factor=(1, 2, 2), - sparse_params=None, + sparse_params=sparse_params, return_dict=True ).sample From 45240a7317d12228d16c3fad31920dbb939cc538 Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 13 Oct 2025 12:27:03 +0000 Subject: [PATCH 08/77] Wrap Transformer in Diffusers style --- .../transformers/transformer_kandinsky.py | 301 ++++++++++++------ .../kandinsky5/pipeline_kandinsky.py | 4 +- 2 files changed, 209 insertions(+), 96 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 45d4ccdf9af3..4ba7e144030f 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -201,7 +201,7 @@ def apply_rotary(x, rope): return x_out.reshape(*x.shape).to(torch.bfloat16) -class TimeEmbeddings(nn.Module): +class Kandinsky5TimeEmbeddings(nn.Module): def __init__(self, model_dim, time_dim, max_period=10000.0): super().__init__() assert model_dim % 2 == 0 @@ -225,7 +225,7 @@ def reset_dtype(self): self.freqs = get_freqs(self.model_dim // 2, self.max_period) -class TextEmbeddings(nn.Module): +class Kandinsky5TextEmbeddings(nn.Module): def __init__(self, text_dim, model_dim): super().__init__() self.in_layer = nn.Linear(text_dim, model_dim, bias=True) @@ -236,7 +236,7 @@ def forward(self, text_embed): return self.norm(text_embed).type_as(text_embed) -class VisualEmbeddings(nn.Module): +class Kandinsky5VisualEmbeddings(nn.Module): def __init__(self, visual_dim, model_dim, patch_size): super().__init__() self.patch_size = patch_size @@ -261,7 +261,7 @@ def forward(self, x): return self.in_layer(x) -class RoPE1D(nn.Module): +class Kandinsky5RoPE1D(nn.Module): def __init__(self, dim, max_pos=1024, max_period=10000.0): super().__init__() self.max_period = max_period @@ -286,7 +286,7 @@ def reset_dtype(self): self.args = torch.outer(pos, freq) -class RoPE3D(nn.Module): +class Kandinsky5RoPE3D(nn.Module): def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): super().__init__() self.axes_dims = axes_dims @@ -326,7 +326,7 @@ def reset_dtype(self): setattr(self, f'args_{i}', torch.outer(pos, freq)) -class Modulation(nn.Module): +class Kandinsky5Modulation(nn.Module): def __init__(self, time_dim, model_dim, num_params): super().__init__() self.activation = nn.SiLU() @@ -338,8 +338,63 @@ def __init__(self, time_dim, model_dim, num_params): def forward(self, x): return self.out_layer(self.activation(x)) + +class Kandinsky5SDPAAttentionProcessor(nn.Module): + """Custom attention processor for standard SDPA attention""" + + def __call__( + self, + attn, + query, + key, + value, + **kwargs, + ): + # Process attention with the given query, key, value tensors + out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + return out + + +class Kandinsky5NablaAttentionProcessor(nn.Module): + """Custom attention processor for Nabla attention""" -class MultiheadSelfAttentionEnc(nn.Module): + def __call__( + self, + attn, + query, + key, + value, + sparse_params=None, + **kwargs, + ): + if sparse_params is None: + raise ValueError("sparse_params is required for Nabla attention") + + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + block_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) + out = ( + flex_attention( + query, + key, + value, + block_mask=block_mask + ) + .transpose(1, 2) + .contiguous() + ) + out = out.flatten(-2, -1) + return out + + +class Kandinsky5MultiheadSelfAttentionEnc(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() assert num_channels % head_dim == 0 @@ -352,6 +407,9 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + # Initialize attention processor + self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() def get_qkv(self, x): query = self.to_query(x) @@ -371,8 +429,14 @@ def norm_qk(self, q, k): return q, k def scaled_dot_product_attention(self, query, key, value): - out = sdpa(q=query, k=key, v=value).flatten(-2, -1) - return out + # Use the processor + return self.sdpa_processor( + attn=self, + query=query, + key=key, + value=value, + **{} + ) def out_l(self, x): return self.out_layer(x) @@ -388,7 +452,8 @@ def forward(self, x, rope): out = self.out_l(out) return out -class MultiheadSelfAttentionDec(nn.Module): + +class Kandinsky5MultiheadSelfAttentionDec(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() assert num_channels % head_dim == 0 @@ -401,6 +466,10 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + # Initialize attention processors + self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() + self.nabla_processor = Kandinsky5NablaAttentionProcessor() def get_qkv(self, x): query = self.to_query(x) @@ -420,31 +489,25 @@ def norm_qk(self, q, k): return q, k def attention(self, query, key, value): - out = sdpa(q=query, k=key, v=value).flatten(-2, -1) - return out + # Use the processor + return self.sdpa_processor( + attn=self, + query=query, + key=key, + value=value, + **{} + ) def nabla(self, query, key, value, sparse_params=None): - query = query.transpose(1, 2).contiguous() - key = key.transpose(1, 2).contiguous() - value = value.transpose(1, 2).contiguous() - block_mask = nablaT_v2( - query, - key, - sparse_params["sta_mask"], - thr=sparse_params["P"], + # Use the processor + return self.nabla_processor( + attn=self, + query=query, + key=key, + value=value, + sparse_params=sparse_params, + **{} ) - out = ( - flex_attention( - query, - key, - value, - block_mask=block_mask - ) - .transpose(1, 2) - .contiguous() - ) - out = out.flatten(-2, -1) - return out def out_l(self, x): return self.out_layer(x) @@ -464,7 +527,7 @@ def forward(self, x, rope, sparse_params=None): return out -class MultiheadCrossAttention(nn.Module): +class Kandinsky5MultiheadCrossAttention(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() assert num_channels % head_dim == 0 @@ -477,6 +540,9 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + # Initialize attention processor + self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() def get_qkv(self, x, cond): query = self.to_query(x) @@ -496,8 +562,14 @@ def norm_qk(self, q, k): return q, k def attention(self, query, key, value): - out = sdpa(q=query, k=key, v=value).flatten(-2, -1) - return out + # Use the processor + return self.sdpa_processor( + attn=self, + query=query, + key=key, + value=value, + **{} + ) def out_l(self, x): return self.out_layer(x) @@ -511,7 +583,7 @@ def forward(self, x, cond): return out -class FeedForward(nn.Module): +class Kandinsky5FeedForward(nn.Module): def __init__(self, dim, ff_dim): super().__init__() self.in_layer = nn.Linear(dim, ff_dim, bias=False) @@ -522,11 +594,11 @@ def forward(self, x): return self.out_layer(self.activation(self.in_layer(x))) -class OutLayer(nn.Module): +class Kandinsky5OutLayer(nn.Module): def __init__(self, model_dim, time_dim, visual_dim, patch_size): super().__init__() self.patch_size = patch_size - self.modulation = Modulation(time_dim, model_dim, 2) + self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2) self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) self.out_layer = nn.Linear( model_dim, math.prod(patch_size) * visual_dim, bias=True @@ -561,19 +633,17 @@ def forward(self, visual_embed, text_embed, time_embed): ) return x - - -class TransformerEncoderBlock(nn.Module): +class Kandinsky5TransformerEncoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() - self.text_modulation = Modulation(time_dim, model_dim, 6) + self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = MultiheadSelfAttentionEnc(model_dim, head_dim) + self.self_attention = Kandinsky5MultiheadSelfAttentionEnc(model_dim, head_dim) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.feed_forward = FeedForward(model_dim, ff_dim) + self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) @@ -589,19 +659,19 @@ def forward(self, x, time_embed, rope): return x -class TransformerDecoderBlock(nn.Module): +class Kandinsky5TransformerDecoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() - self.visual_modulation = Modulation(time_dim, model_dim, 9) + self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = MultiheadSelfAttentionDec(model_dim, head_dim) + self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.cross_attention = MultiheadCrossAttention(model_dim, head_dim) + self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.feed_forward = FeedForward(model_dim, ff_dim) + self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): self_attn_params, cross_attn_params, ff_params = torch.chunk( @@ -645,16 +715,16 @@ def __init__( axes_dims=(16, 24, 24), visual_cond=False, attention_type: str = "regular", - attention_causal: bool = None, #Deffault for Nabla: false, - attention_local: bool = None, #Deffault for Nabla: false, - attention_glob:bool = None, #Deffault for Nabla: false, - attention_window: int = None, #Deffault for Nabla: 3 - attention_P: float = None, #Deffault for Nabla: 0.9 - attention_wT: int = None, #Deffault for Nabla: 11 - attention_wW:int = None, #Deffault for Nabla: 3 - attention_wH:int = None, #Deffault for Nabla: 3 - attention_add_sta: bool = None, #Deffault for Nabla: true - attention_method: str = None, #Deffault for Nabla: "topcdf" + attention_causal: bool = None, # Default for Nabla: false + attention_local: bool = None, # Default for Nabla: false + attention_glob: bool = None, # Default for Nabla: false + attention_window: int = None, # Default for Nabla: 3 + attention_P: float = None, # Default for Nabla: 0.9 + attention_wT: int = None, # Default for Nabla: 11 + attention_wW: int = None, # Default for Nabla: 3 + attention_wH: int = None, # Default for Nabla: 3 + attention_add_sta: bool = None, # Default for Nabla: true + attention_method: str = None, # Default for Nabla: "topcdf" ): super().__init__() @@ -666,31 +736,37 @@ def __init__( self.attention_type = attention_type visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim - self.time_embeddings = TimeEmbeddings(model_dim, time_dim) - self.text_embeddings = TextEmbeddings(in_text_dim, model_dim) - self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim) - self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size) + + # Initialize embeddings + self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim) + self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim) + self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim) + self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size) - self.text_rope_embeddings = RoPE1D(head_dim) + # Initialize positional embeddings + self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim) + self.visual_rope_embeddings = Kandinsky5RoPE3D(axes_dims) + + # Initialize transformer blocks self.text_transformer_blocks = nn.ModuleList( [ - TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) + Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_text_blocks) ] ) - self.visual_rope_embeddings = RoPE3D(axes_dims) self.visual_transformer_blocks = nn.ModuleList( [ - TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) + Kandinsky5TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_visual_blocks) ] ) - self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + # Initialize output layer + self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) - def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x, - text_rope_pos): + def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos): + """Prepare text embeddings and related components""" text_embed = self.text_embeddings(text_embed) time_embed = self.time_embeddings(time) time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) @@ -699,8 +775,8 @@ def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x, text_rope = text_rope.unsqueeze(dim=0) return text_embed, time_embed, text_rope, visual_embed - def before_visual_transformer_blocks(self, visual_embed, visual_rope_pos, scale_factor, - sparse_params): + def prepare_visual_embeddings(self, visual_embed, visual_rope_pos, scale_factor, sparse_params): + """Prepare visual embeddings and related components""" visual_shape = visual_embed.shape[:-1] visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False @@ -708,44 +784,79 @@ def before_visual_transformer_blocks(self, visual_embed, visual_rope_pos, scale_ block_mask=to_fractal) return visual_embed, visual_shape, to_fractal, visual_rope - def after_blocks(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed): + def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): + """Process text through transformer blocks""" + for text_transformer_block in self.text_transformer_blocks: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) + return text_embed + + def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params): + """Process visual through transformer blocks""" + for visual_transformer_block in self.visual_transformer_blocks: + visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, + visual_rope, sparse_params) + return visual_embed + + def prepare_output(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed): + """Prepare the final output""" visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) x = self.out_layer(visual_embed, text_embed, time_embed) return x def forward( self, - hidden_states, # x - encoder_hidden_states, #text_embed - timestep, # time - pooled_projections, #pooled_text_embed, - visual_rope_pos, - text_rope_pos, - scale_factor=(1.0, 1.0, 1.0), - sparse_params=None, - return_dict=True, - ): + hidden_states: torch.FloatTensor, # x + encoder_hidden_states: torch.FloatTensor, # text_embed + timestep: Union[torch.Tensor, float, int], # time + pooled_projections: torch.FloatTensor, # pooled_text_embed + visual_rope_pos: Tuple[int, int, int], + text_rope_pos: torch.LongTensor, + scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), + sparse_params: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Transformer2DModelOutput, torch.FloatTensor]: + """ + Forward pass of the Kandinsky5 3D Transformer. + + Args: + hidden_states (`torch.FloatTensor`): Input visual states + encoder_hidden_states (`torch.FloatTensor`): Text embeddings + timestep (`torch.Tensor` or `float` or `int`): Current timestep + pooled_projections (`torch.FloatTensor`): Pooled text embeddings + visual_rope_pos (`Tuple[int, int, int]`): Position for visual RoPE + text_rope_pos (`torch.LongTensor`): Position for text RoPE + scale_factor (`Tuple[float, float, float]`, optional): Scale factor for RoPE + sparse_params (`Dict[str, Any]`, optional): Parameters for sparse attention + return_dict (`bool`, optional): Whether to return a dictionary + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: + The output of the transformer + """ x = hidden_states text_embed = encoder_hidden_states time = timestep pooled_text_embed = pooled_projections - text_embed, time_embed, text_rope, visual_embed = self.before_text_transformer_blocks( + # Prepare text embeddings and related components + text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings( text_embed, time, pooled_text_embed, x, text_rope_pos) - for text_transformer_block in self.text_transformer_blocks: - text_embed = text_transformer_block(text_embed, time_embed, text_rope) + # Process text through transformer blocks + text_embed = self.process_text_transformer_blocks(text_embed, time_embed, text_rope) - visual_embed, visual_shape, to_fractal, visual_rope = self.before_visual_transformer_blocks( + # Prepare visual embeddings and related components + visual_embed, visual_shape, to_fractal, visual_rope = self.prepare_visual_embeddings( visual_embed, visual_rope_pos, scale_factor, sparse_params) - for visual_transformer_block in self.visual_transformer_blocks: - visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, - visual_rope, sparse_params) - - x = self.after_blocks(visual_embed, visual_shape, to_fractal, text_embed, time_embed) + # Process visual through transformer blocks + visual_embed = self.process_visual_transformer_blocks( + visual_embed, text_embed, time_embed, visual_rope, sparse_params) - if return_dict: - return Transformer2DModelOutput(sample=x) + # Prepare final output + x = self.prepare_output(visual_embed, visual_shape, to_fractal, text_embed, time_embed) - return x + if not return_dict: + return x + + return Transformer2DModelOutput(sample=x) \ No newline at end of file diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 05230a604fa4..12bc12cca205 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -263,7 +263,9 @@ def get_sparse_params(self, sample, device): if self.transformer.config.attention_type == "nabla": sta_mask = self.fast_sta_nabla( T, H // 8, W // 8, - self.transformer.config.attention_wT, self.transformer.config.attention_wH, self.transformer.config.attention_wW, + self.transformer.config.attention_wT, + self.transformer.config.attention_wH, + self.transformer.config.attention_wW, device=device ) From 43bd1e81d2b0aba750477af04f0c3927c84e0761 Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 13 Oct 2025 14:41:50 +0000 Subject: [PATCH 09/77] fix license --- src/diffusers/models/transformers/transformer_kandinsky.py | 4 ++-- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 4ba7e144030f..01c9b258b7c3 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -859,4 +859,4 @@ def forward( if not return_dict: return x - return Transformer2DModelOutput(sample=x) \ No newline at end of file + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 12bc12cca205..a30484c701b0 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 149fd53df84c42100062def55d25ca02dc023979 Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 13 Oct 2025 22:38:03 +0000 Subject: [PATCH 10/77] fix prompt type --- .../kandinsky5/pipeline_kandinsky.py | 227 ++++++++++-------- 1 file changed, 130 insertions(+), 97 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index a30484c701b0..407dc127fda8 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -33,83 +33,6 @@ from .pipeline_output import KandinskyPipelineOutput -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -if is_ftfy_available(): - import ftfy - - -logger = logging.get_logger(__name__) - -EXAMPLE_DOC_STRING = """ - Examples: - - ```python - >>> import torch - >>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel - >>> from diffusers.utils import export_to_video - - >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") - >>> pipe = pipe.to("cuda") - - >>> prompt = "A cat and a dog baking a cake together in a kitchen." - >>> negative_prompt = "Bright tones, overexposed, static, blurred details" - - >>> output = pipe( - ... prompt=prompt, - ... negative_prompt=negative_prompt, - ... height=512, - ... width=768, - ... num_frames=25, - ... num_inference_steps=50, - ... guidance_scale=5.0, - ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=6) - ``` -""" - -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import html -from typing import Any, Callable, Dict, List, Optional, Union - -import regex as re -import torch -from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer -import torchvision -from torchvision.transforms import ToPILImage - -from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import KandinskyLoraLoaderMixin -from ...models import AutoencoderKLHunyuanVideo -from ...models.transformers import Kandinsky5Transformer3DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor -from ...video_processor import VideoProcessor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import KandinskyPipelineOutput - - if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -137,23 +60,23 @@ >>> pipe = pipe.to("cuda") >>> prompt = "A cat and a dog baking a cake together in a kitchen." - >>> negative_prompt = "Bright tones, overexposed, static, blurred details" - + >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" >>> output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... height=512, ... width=768, - ... num_frames=25, + ... num_frames=121, ... num_inference_steps=50, ... guidance_scale=5.0, ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=6) + >>> export_to_video(output, "output.mp4", fps=24) ``` """ def basic_clean(text): + """Clean text using ftfy if available and unescape HTML entities.""" if is_ftfy_available(): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) @@ -161,12 +84,14 @@ def basic_clean(text): def whitespace_clean(text): + """Normalize whitespace in text by replacing multiple spaces with single space.""" text = re.sub(r"\s+", " ", text) text = text.strip() return text def prompt_clean(text): + """Apply both basic cleaning and whitespace normalization to prompts.""" text = whitespace_clean(basic_clean(text)) return text @@ -228,6 +153,24 @@ def __init__( def fast_sta_nabla( T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda" ) -> torch.Tensor: + """ + Create a sparse temporal attention (STA) mask for efficient video generation. + + This method generates a mask that limits attention to nearby frames and spatial positions, + reducing computational complexity for video generation. + + Args: + T (int): Number of temporal frames + H (int): Height in latent space + W (int): Width in latent space + wT (int): Temporal attention window size + wH (int): Height attention window size + wW (int): Width attention window size + device (str): Device to create tensor on + + Returns: + torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W) + """ l = torch.Tensor([T, H, W]).amax() r = torch.arange(0, l, 1, dtype=torch.int16, device=device) mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() @@ -253,6 +196,19 @@ def fast_sta_nabla( return sta.reshape(T * H * W, T * H * W) def get_sparse_params(self, sample, device): + """ + Generate sparse attention parameters for the transformer based on sample dimensions. + + This method computes the sparse attention configuration needed for efficient + video processing in the transformer model. + + Args: + sample (torch.Tensor): Input sample tensor + device (torch.device): Device to place tensors on + + Returns: + Dict: Dictionary containing sparse attention parameters + """ assert self.transformer.config.patch_size[0] == 1 B, T, H, W, _ = sample.shape T, H, W = ( @@ -294,12 +250,28 @@ def _encode_prompt_qwen( max_sequence_length: int = 256, dtype: Optional[torch.dtype] = None, ): + """ + Encode prompt using Qwen2.5-VL text encoder. + + This method processes the input prompt through the Qwen2.5-VL model to generate + text embeddings suitable for video generation. + + Args: + prompt (Union[str, List[str]]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + max_sequence_length (int): Maximum sequence length for tokenization + dtype (torch.dtype): Data type for embeddings + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths + """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt_clean(p) for p in prompt] - # Kandinsky specific prompt template + # Kandinsky specific prompt template for detailed video description prompt_template = "\n".join([ "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", @@ -310,7 +282,7 @@ def _encode_prompt_qwen( "Pay attention to the order of key actions shown in the scene.<|im_end|>", "<|im_start|>user\n{}<|im_end|>", ]) - crop_start = 129 + crop_start = 129 # Position to start cropping from (system prompt tokens) full_texts = [prompt_template.format(p) for p in prompt] @@ -347,6 +319,21 @@ def _encode_prompt_clip( num_videos_per_prompt: int = 1, dtype: Optional[torch.dtype] = None, ): + """ + Encode prompt using CLIP text encoder. + + This method processes the input prompt through the CLIP model to generate + pooled embeddings that capture semantic information. + + Args: + prompt (Union[str, List[str]]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + dtype (torch.dtype): Data type for embeddings + + Returns: + torch.Tensor: Pooled text embeddings from CLIP + """ device = device or self._execution_device dtype = dtype or self.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -386,6 +373,9 @@ def encode_prompt( r""" Encodes the prompt into text encoder hidden states. + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders + to create comprehensive text representations for video generation. + Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded @@ -410,11 +400,15 @@ def encode_prompt( torch device dtype: (`torch.dtype`, *optional*): torch dtype + + Returns: + Tuple: Contains prompt embeddings, negative prompt embeddings, and sequence length information """ device = device or self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 + prompt = [prompt] elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: @@ -438,7 +432,7 @@ def encode_prompt( prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" + negative_prompt = negative_prompt or "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): @@ -492,6 +486,21 @@ def check_inputs( negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): + """ + Validate input parameters for the pipeline. + + Args: + prompt: Input prompt + negative_prompt: Negative prompt for guidance + height: Video height + width: Video width + prompt_embeds: Pre-computed prompt embeddings + negative_prompt_embeds: Pre-computed negative prompt embeddings + callback_on_step_end_tensor_inputs: Callback tensor inputs + + Raises: + ValueError: If inputs are invalid + """ if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -535,6 +544,26 @@ def prepare_latents( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """ + Prepare initial latent variables for video generation. + + This method creates random noise latents or uses provided latents as starting point + for the denoising process. + + Args: + batch_size (int): Number of videos to generate + num_channels_latents (int): Number of channels in latent space + height (int): Height of generated video + width (int): Width of generated video + num_frames (int): Number of frames in video + dtype (torch.dtype): Data type for latents + device (torch.device): Device to create latents on + generator (torch.Generator): Random number generator + latents (torch.Tensor): Pre-existing latents to use + + Returns: + torch.Tensor: Prepared latent tensor + """ if latents is not None: return latents.to(device=device, dtype=dtype) @@ -568,18 +597,22 @@ def prepare_latents( @property def guidance_scale(self): + """Get the current guidance scale value.""" return self._guidance_scale @property def do_classifier_free_guidance(self): + """Check if classifier-free guidance is enabled.""" return self._guidance_scale > 1.0 @property def num_timesteps(self): + """Get the number of denoising timesteps.""" return self._num_timesteps @property def interrupt(self): + """Check if generation has been interrupted.""" return self._interrupt @torch.no_grad() @@ -590,10 +623,10 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, - num_frames: int = 25, + num_frames: int = 121, num_inference_steps: int = 50, guidance_scale: float = 5.0, - scheduler_scale: float = 10.0, + scheduler_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -715,7 +748,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = 16 + num_channels_latents = self.transformer.config.in_visual_dim latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, @@ -728,7 +761,7 @@ def __call__( latents, ) - # 6. Prepare rope positions + # 6. Prepare rope positions for positional encoding num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 visual_rope_pos = [ torch.arange(num_latent_frames, device=device), @@ -744,7 +777,7 @@ def __call__( else None ) - # 7. Sparse Params + # 7. Sparse Params for efficient attention sparse_params = self.get_sparse_params(latents, device) # 8. Denoising loop @@ -788,9 +821,9 @@ def __call__( pred_velocity - uncond_pred_velocity ) - # Compute previous sample - latents[:, :, :, :, :16] = self.scheduler.step( - pred_velocity, t, latents[:, :, :, :, :16], return_dict=False + # Compute previous sample using the scheduler + latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( + pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False )[0] if callback_on_step_end is not None: @@ -809,8 +842,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - # 8. Post-processing - latents = latents[:, :, :, :, :16] + # 8. Post-processing - extract main latents + latents = latents[:, :, :, :, :num_channels_latents] # 9. Decode latents to video if output_type != "latent": @@ -822,18 +855,18 @@ def __call__( (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, - 16, + num_channels_latents, ) video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] video = video.reshape( batch_size * num_videos_per_prompt, - 16, + num_channels_latents, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial ) - # Normalize and decode + # Normalize and decode through VAE video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample video = self.video_processor.postprocess_video(video, output_type=output_type) From 7af80e9ffcf4daef408d0f1c99b115c70ae73756 Mon Sep 17 00:00:00 2001 From: leffff Date: Tue, 14 Oct 2025 11:24:24 +0000 Subject: [PATCH 11/77] add gradient checkpointing and peft support --- .../transformers/transformer_kandinsky.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 01c9b258b7c3..6dec8d93ac9e 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -22,6 +22,7 @@ from torch import BoolTensor, IntTensor, Tensor, nn from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, flex_attention) +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -694,11 +695,12 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): return visual_embed -class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): +class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): """ A 3D Diffusion Transformer model for video-like data. """ - + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -764,6 +766,7 @@ def __init__( # Initialize output layer self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + self.gradient_checkpointing = False def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos): """Prepare text embeddings and related components""" @@ -787,13 +790,20 @@ def prepare_visual_embeddings(self, visual_embed, visual_rope_pos, scale_factor, def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): """Process text through transformer blocks""" for text_transformer_block in self.text_transformer_blocks: - text_embed = text_transformer_block(text_embed, time_embed, text_rope) + if torch.is_grad_enabled() and self.gradient_checkpointing: + text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope) + else: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) return text_embed def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params): """Process visual through transformer blocks""" for visual_transformer_block in self.visual_transformer_blocks: - visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, + if torch.is_grad_enabled() and self.gradient_checkpointing: + visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed, + visual_rope, sparse_params) + else: + visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, visual_rope, sparse_params) return visual_embed From 04efb19b1aeba3b41b7b1bd6d0353a1715c0f839 Mon Sep 17 00:00:00 2001 From: leffff Date: Tue, 14 Oct 2025 12:14:37 +0000 Subject: [PATCH 12/77] add usage example --- .../pipelines/kandinsky5/pipeline_kandinsky.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 407dc127fda8..38d94ded42ad 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -55,12 +55,20 @@ >>> import torch >>> from diffusers import Kandinsky5T2VPipeline >>> from diffusers.utils import export_to_video - - >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + + >>> # Available models: + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers + + >>> model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers" + >>> pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) >>> pipe = pipe.to("cuda") >>> prompt = "A cat and a dog baking a cake together in a kitchen." >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + >>> output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, @@ -70,7 +78,8 @@ ... num_inference_steps=50, ... guidance_scale=5.0, ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=24) + + >>> export_to_video(output, "output.mp4", fps=24, quality=9) ``` """ From 235f0d5df8a7d9842c63d458044ea823e921c8a8 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:53:32 +0300 Subject: [PATCH 13/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 38d94ded42ad..73868c972c32 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -13,7 +13,7 @@ # limitations under the License. import html -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import regex as re import torch From 88a8eea0962a3d209039e01c30d7601d14343ce0 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:53:47 +0300 Subject: [PATCH 14/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 73868c972c32..3840ad11dd5f 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -17,7 +17,7 @@ import regex as re import torch -from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer +from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer import torchvision from torchvision.transforms import ToPILImage From f52f3b45b75e461cbd9a28f280cdbad015059420 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:54:10 +0300 Subject: [PATCH 15/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3840ad11dd5f..39306cb9e812 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -18,7 +18,6 @@ import regex as re import torch from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer -import torchvision from torchvision.transforms import ToPILImage from ...callbacks import MultiPipelineCallbacks, PipelineCallback From 0190e55641e70ab65f656b2499ee325ce2149f83 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:54:21 +0300 Subject: [PATCH 16/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 39306cb9e812..3a8628a1b339 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -18,7 +18,6 @@ import regex as re import torch from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer -from torchvision.transforms import ToPILImage from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import KandinskyLoraLoaderMixin From d62dffcb212ea6f6281615f23230d77de3efc988 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 23:25:14 +0300 Subject: [PATCH 17/77] Update src/diffusers/models/transformers/transformer_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 6dec8d93ac9e..24b2c4ae99b6 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -15,7 +15,6 @@ import math from typing import Any, Dict, List, Optional, Tuple, Union -from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F From 7084106eaaa9b998efd520e72b4a69a6e2dd90cf Mon Sep 17 00:00:00 2001 From: leffff Date: Tue, 14 Oct 2025 20:38:40 +0000 Subject: [PATCH 18/77] remove unused imports --- .../transformers/transformer_kandinsky.py | 250 ++++++++++-------- 1 file changed, 142 insertions(+), 108 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 24b2c4ae99b6..ac2fe58d60b4 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -19,21 +19,27 @@ import torch.nn as nn import torch.nn.functional as F from torch import BoolTensor, IntTensor, Tensor, nn -from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, - flex_attention) -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from torch.nn.attention.flex_attention import ( + BlockMask, + flex_attention, +) from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import (USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, - unscale_lora_layers) +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import maybe_allow_in_graph -from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward -from ..attention_dispatch import dispatch_attention_fn +from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin -from ..embeddings import (PixArtAlphaTextProjection, TimestepEmbedding, - Timesteps, get_1d_rotary_pos_embed) +from ..embeddings import ( + TimestepEmbedding, + get_1d_rotary_pos_embed, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -95,7 +101,7 @@ def local_patching(x, shape, group_size, dim=0): g2, width // g3, g3, - *x.shape[dim + 3 :] + *x.shape[dim + 3 :], ) x = x.permute( *range(len(x.shape[:dim])), @@ -105,7 +111,7 @@ def local_patching(x, shape, group_size, dim=0): dim + 1, dim + 3, dim + 5, - *range(dim + 6, len(x.shape)) + *range(dim + 6, len(x.shape)), ) x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) return x @@ -122,7 +128,7 @@ def local_merge(x, shape, group_size, dim=0): g1, g2, g3, - *x.shape[dim + 2 :] + *x.shape[dim + 2 :], ) x = x.permute( *range(len(x.shape[:dim])), @@ -132,7 +138,7 @@ def local_merge(x, shape, group_size, dim=0): dim + 4, dim + 2, dim + 5, - *range(dim + 6, len(x.shape)) + *range(dim + 6, len(x.shape)), ) x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) return x @@ -172,15 +178,7 @@ def sdpa(q, k, v): query = q.transpose(1, 2).contiguous() key = k.transpose(1, 2).contiguous() value = v.transpose(1, 2).contiguous() - out = ( - F.scaled_dot_product_attention( - query, - key, - value - ) - .transpose(1, 2) - .contiguous() - ) + out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous() return out @@ -279,7 +277,7 @@ def forward(self, pos): rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - + def reset_dtype(self): freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device) pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device) @@ -307,9 +305,15 @@ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): args = torch.cat( [ - args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1), - args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1), - args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1), + args_t.view(1, duration, 1, 1, -1).repeat( + batch_size, 1, height, width, 1 + ), + args_h.view(1, 1, height, 1, -1).repeat( + batch_size, duration, 1, width, 1 + ), + args_w.view(1, 1, 1, width, -1).repeat( + batch_size, duration, height, 1, 1 + ), ], dim=-1, ) @@ -318,12 +322,12 @@ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - + def reset_dtype(self): for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)): freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device) pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device) - setattr(self, f'args_{i}', torch.outer(pos, freq)) + setattr(self, f"args_{i}", torch.outer(pos, freq)) class Kandinsky5Modulation(nn.Module): @@ -341,7 +345,7 @@ def forward(self, x): class Kandinsky5SDPAAttentionProcessor(nn.Module): """Custom attention processor for standard SDPA attention""" - + def __call__( self, attn, @@ -357,7 +361,7 @@ def __call__( class Kandinsky5NablaAttentionProcessor(nn.Module): """Custom attention processor for Nabla attention""" - + def __call__( self, attn, @@ -369,11 +373,11 @@ def __call__( ): if sparse_params is None: raise ValueError("sparse_params is required for Nabla attention") - + query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() - + block_mask = nablaT_v2( query, key, @@ -381,12 +385,7 @@ def __call__( thr=sparse_params["P"], ) out = ( - flex_attention( - query, - key, - value, - block_mask=block_mask - ) + flex_attention(query, key, value, block_mask=block_mask) .transpose(1, 2) .contiguous() ) @@ -407,7 +406,7 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - + # Initialize attention processor self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() @@ -430,13 +429,7 @@ def norm_qk(self, q, k): def scaled_dot_product_attention(self, query, key, value): # Use the processor - return self.sdpa_processor( - attn=self, - query=query, - key=key, - value=value, - **{} - ) + return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) def out_l(self, x): return self.out_layer(x) @@ -466,7 +459,7 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - + # Initialize attention processors self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() self.nabla_processor = Kandinsky5NablaAttentionProcessor() @@ -490,14 +483,8 @@ def norm_qk(self, q, k): def attention(self, query, key, value): # Use the processor - return self.sdpa_processor( - attn=self, - query=query, - key=key, - value=value, - **{} - ) - + return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) + def nabla(self, query, key, value, sparse_params=None): # Use the processor return self.nabla_processor( @@ -506,7 +493,7 @@ def nabla(self, query, key, value, sparse_params=None): key=key, value=value, sparse_params=sparse_params, - **{} + **{}, ) def out_l(self, x): @@ -540,7 +527,7 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - + # Initialize attention processor self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() @@ -563,13 +550,7 @@ def norm_qk(self, q, k): def attention(self, query, key, value): # Use the processor - return self.sdpa_processor( - attn=self, - query=query, - key=key, - value=value, - **{} - ) + return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) def out_l(self, x): return self.out_layer(x) @@ -605,7 +586,9 @@ def __init__(self, model_dim, time_dim, visual_dim, patch_size): ) def forward(self, visual_embed, text_embed, time_embed): - shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + shift, scale = torch.chunk( + self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 + ) visual_embed = apply_scale_shift_norm( self.norm, visual_embed, @@ -646,7 +629,9 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + self_attn_params, ff_params = torch.chunk( + self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 + ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) out = self.self_attention(out, rope) @@ -678,26 +663,40 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) + visual_out = apply_scale_shift_norm( + self.self_attention_norm, visual_embed, scale, shift + ) visual_out = self.self_attention(visual_out, rope, sparse_params) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) + visual_out = apply_scale_shift_norm( + self.cross_attention_norm, visual_embed, scale, shift + ) visual_out = self.cross_attention(visual_out, text_embed) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) + visual_out = apply_scale_shift_norm( + self.feed_forward_norm, visual_embed, scale, shift + ) visual_out = self.feed_forward(visual_out) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) return visual_embed -class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): +class Kandinsky5Transformer3DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + CacheMixin, + AttentionMixin, +): """ A 3D Diffusion Transformer model for video-like data. """ + _supports_gradient_checkpointing = True @register_to_config @@ -714,21 +713,21 @@ def __init__( num_text_blocks=2, num_visual_blocks=32, axes_dims=(16, 24, 24), - visual_cond=False, + visual_cond=False, attention_type: str = "regular", - attention_causal: bool = None, # Default for Nabla: false - attention_local: bool = None, # Default for Nabla: false - attention_glob: bool = None, # Default for Nabla: false - attention_window: int = None, # Default for Nabla: 3 - attention_P: float = None, # Default for Nabla: 0.9 - attention_wT: int = None, # Default for Nabla: 11 - attention_wW: int = None, # Default for Nabla: 3 - attention_wH: int = None, # Default for Nabla: 3 - attention_add_sta: bool = None, # Default for Nabla: true - attention_method: str = None, # Default for Nabla: "topcdf" + attention_causal: bool = None, # Default for Nabla: false + attention_local: bool = None, # Default for Nabla: false + attention_glob: bool = None, # Default for Nabla: false + attention_window: int = None, # Default for Nabla: 3 + attention_P: float = None, # Default for Nabla: 0.9 + attention_wT: int = None, # Default for Nabla: 11 + attention_wW: int = None, # Default for Nabla: 3 + attention_wH: int = None, # Default for Nabla: 3 + attention_add_sta: bool = None, # Default for Nabla: true + attention_method: str = None, # Default for Nabla: "topcdf" ): super().__init__() - + head_dim = sum(axes_dims) self.in_visual_dim = in_visual_dim self.model_dim = model_dim @@ -737,12 +736,14 @@ def __init__( self.attention_type = attention_type visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim - + # Initialize embeddings self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim) self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim) self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim) - self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size) + self.visual_embeddings = Kandinsky5VisualEmbeddings( + visual_embed_dim, model_dim, patch_size + ) # Initialize positional embeddings self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim) @@ -764,10 +765,14 @@ def __init__( ) # Initialize output layer - self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + self.out_layer = Kandinsky5OutLayer( + model_dim, time_dim, out_visual_dim, patch_size + ) self.gradient_checkpointing = False - def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos): + def prepare_text_embeddings( + self, text_embed, time, pooled_text_embed, x, text_rope_pos + ): """Prepare text embeddings and related components""" text_embed = self.text_embeddings(text_embed) time_embed = self.time_embeddings(time) @@ -777,38 +782,58 @@ def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_r text_rope = text_rope.unsqueeze(dim=0) return text_embed, time_embed, text_rope, visual_embed - def prepare_visual_embeddings(self, visual_embed, visual_rope_pos, scale_factor, sparse_params): + def prepare_visual_embeddings( + self, visual_embed, visual_rope_pos, scale_factor, sparse_params + ): """Prepare visual embeddings and related components""" visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + visual_rope = self.visual_rope_embeddings( + visual_shape, visual_rope_pos, scale_factor + ) to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False - visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, - block_mask=to_fractal) + visual_embed, visual_rope = fractal_flatten( + visual_embed, visual_rope, visual_shape, block_mask=to_fractal + ) return visual_embed, visual_shape, to_fractal, visual_rope def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): """Process text through transformer blocks""" for text_transformer_block in self.text_transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope) + text_embed = self._gradient_checkpointing_func( + text_transformer_block, text_embed, time_embed, text_rope + ) else: text_embed = text_transformer_block(text_embed, time_embed, text_rope) return text_embed - def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params): + def process_visual_transformer_blocks( + self, visual_embed, text_embed, time_embed, visual_rope, sparse_params + ): """Process visual through transformer blocks""" for visual_transformer_block in self.visual_transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed, - visual_rope, sparse_params) + visual_embed = self._gradient_checkpointing_func( + visual_transformer_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + sparse_params, + ) else: - visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, - visual_rope, sparse_params) + visual_embed = visual_transformer_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) return visual_embed - def prepare_output(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed): + def prepare_output( + self, visual_embed, visual_shape, to_fractal, text_embed, time_embed + ): """Prepare the final output""" - visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) + visual_embed = fractal_unflatten( + visual_embed, visual_shape, block_mask=to_fractal + ) x = self.out_layer(visual_embed, text_embed, time_embed) return x @@ -846,25 +871,34 @@ def forward( text_embed = encoder_hidden_states time = timestep pooled_text_embed = pooled_projections - + # Prepare text embeddings and related components text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings( - text_embed, time, pooled_text_embed, x, text_rope_pos) + text_embed, time, pooled_text_embed, x, text_rope_pos + ) # Process text through transformer blocks - text_embed = self.process_text_transformer_blocks(text_embed, time_embed, text_rope) + text_embed = self.process_text_transformer_blocks( + text_embed, time_embed, text_rope + ) # Prepare visual embeddings and related components - visual_embed, visual_shape, to_fractal, visual_rope = self.prepare_visual_embeddings( - visual_embed, visual_rope_pos, scale_factor, sparse_params) + visual_embed, visual_shape, to_fractal, visual_rope = ( + self.prepare_visual_embeddings( + visual_embed, visual_rope_pos, scale_factor, sparse_params + ) + ) # Process visual through transformer blocks visual_embed = self.process_visual_transformer_blocks( - visual_embed, text_embed, time_embed, visual_rope, sparse_params) - + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) + # Prepare final output - x = self.prepare_output(visual_embed, visual_shape, to_fractal, text_embed, time_embed) - + x = self.prepare_output( + visual_embed, visual_shape, to_fractal, text_embed, time_embed + ) + if not return_dict: return x From b615d5cb131243e20cd40453fd6ceb874a092b25 Mon Sep 17 00:00:00 2001 From: leffff Date: Wed, 15 Oct 2025 18:09:23 +0000 Subject: [PATCH 19/77] add 10 second models support --- src/diffusers/models/transformers/transformer_kandinsky.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index ac2fe58d60b4..8d2bae11cbfa 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -361,7 +361,8 @@ def __call__( class Kandinsky5NablaAttentionProcessor(nn.Module): """Custom attention processor for Nabla attention""" - + + @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) def __call__( self, attn, From 588c12ab98d67be2c4dd8234877b3c4b16cac965 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:38:02 +0300 Subject: [PATCH 20/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3a8628a1b339..3d0d68cbe93b 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -303,7 +303,6 @@ def _encode_prompt_qwen( padding=True, ).to(device) - with torch.no_grad(): embeds = self.text_encoder( input_ids=inputs["input_ids"], return_dict=True, From 327ab84d1923518ecc5314831254cfd70faf99e1 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 06:50:57 +0000 Subject: [PATCH 21/77] remove no_grad and simplified prompt paddings --- .../kandinsky5/pipeline_kandinsky.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3d0d68cbe93b..d4470a43d578 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -17,6 +17,7 @@ import regex as re import torch +from torch.nn import functional as F from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -303,17 +304,19 @@ def _encode_prompt_qwen( padding=True, ).to(device) - embeds = self.text_encoder( - input_ids=inputs["input_ids"], - return_dict=True, - output_hidden_states=True, - )["hidden_states"][-1][:, crop_start:] - + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, crop_start:] + batch_size = len(prompt) attention_mask = inputs["attention_mask"][:, crop_start:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) - cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + # cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) + embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) return embeds.to(dtype), cu_seqlens @@ -354,8 +357,7 @@ def _encode_prompt_clip( return_tensors="pt", ).to(device) - with torch.no_grad(): - pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] # duplicate for each generation per prompt batch_size = len(prompt) From 9b06afba6b446352b9249a7f632af388174dd6ba Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:54:00 +0300 Subject: [PATCH 22/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3d0d68cbe93b..58ba3270a5f3 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -314,7 +314,7 @@ def _encode_prompt_qwen( attention_mask = inputs["attention_mask"][:, crop_start:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) - embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) + embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0) return embeds.to(dtype), cu_seqlens From 28458d0caf929b90bc36df7f7004dd00fa607517 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:57:56 +0300 Subject: [PATCH 23/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 58ba3270a5f3..850795ada162 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -313,7 +313,7 @@ def _encode_prompt_qwen( attention_mask = inputs["attention_mask"][:, crop_start:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) - cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + cu_seqlens =F.pad(cu_seqlens, (1, 0), value=0)).to(dtype=torch.int32) embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0) return embeds.to(dtype), cu_seqlens From cd3cc6156ea949e0a620b893660ad96933691f77 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 07:14:47 +0000 Subject: [PATCH 24/77] moved template to __init__ --- .../kandinsky5/pipeline_kandinsky.py | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 6ebedd04e830..bdf7e41df919 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -152,6 +152,16 @@ def __init__( tokenizer_2=tokenizer_2, scheduler=scheduler, ) + + self.prompt_template = "\n".join(["<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>"]) + self.prompt_template_encode_start_idx = 129 self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio @@ -276,29 +286,14 @@ def _encode_prompt_qwen( """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(p) for p in prompt] - # Kandinsky specific prompt template for detailed video description - prompt_template = "\n".join([ - "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", - "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", - "Describe the location of the video, main characters or objects and their action.", - "Describe the dynamism of the video and presented actions.", - "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", - "Describe the visual effects, postprocessing and transitions if they are presented in the video.", - "Pay attention to the order of key actions shown in the scene.<|im_end|>", - "<|im_start|>user\n{}<|im_end|>", - ]) - crop_start = 129 # Position to start cropping from (system prompt tokens) - - full_texts = [prompt_template.format(p) for p in prompt] + full_texts = [self.prompt_template.format(p) for p in prompt] inputs = self.tokenizer( text=full_texts, images=None, videos=None, - max_length=max_sequence_length + crop_start, + max_length=max_sequence_length + self.prompt_template_encode_start_idx, truncation=True, return_tensors="pt", padding=True, @@ -308,11 +303,11 @@ def _encode_prompt_qwen( input_ids=inputs["input_ids"], return_dict=True, output_hidden_states=True, - )["hidden_states"][-1][:, crop_start:] + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx:] batch_size = len(prompt) - attention_mask = inputs["attention_mask"][:, crop_start:] + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0) @@ -343,8 +338,6 @@ def _encode_prompt_clip( """ device = device or self._execution_device dtype = dtype or self.text_encoder_2.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(p) for p in prompt] inputs = self.tokenizer_2( prompt, @@ -357,7 +350,6 @@ def _encode_prompt_clip( pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] - # duplicate for each generation per prompt batch_size = len(prompt) pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) @@ -421,6 +413,8 @@ def encode_prompt( batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] if prompt_embeds is None: + prompt = [prompt_clean(p) for p in prompt] + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( prompt=prompt, device=device, @@ -452,6 +446,8 @@ def encode_prompt( f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) + + negative_prompt = [prompt_clean(p) for p in negative_prompt] negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen( prompt=negative_prompt, From 4450265bf76ee29ae2cbd7371d1237b1b4db24cf Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:19:26 +0300 Subject: [PATCH 25/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index bdf7e41df919..ff674b10ec1b 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -260,7 +260,7 @@ def get_sparse_params(self, sample, device): return sparse_params - def _encode_prompt_qwen( + def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, From b9a3be2a152e0135ef0f0739e9aa62938a7d8dec Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:19:45 +0300 Subject: [PATCH 26/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index ff674b10ec1b..3e61ae0bf2c6 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -314,7 +314,7 @@ def _get_qwen_prompt_embeds( return embeds.to(dtype), cu_seqlens - def _encode_prompt_clip( + def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, From 78a23b9ddefa4199c1218b0ee0330785b6d5f43e Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:34:59 +0300 Subject: [PATCH 27/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 8d2bae11cbfa..b8723bfe86ea 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -335,8 +335,6 @@ def __init__(self, time_dim, model_dim, num_params): super().__init__() self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, num_params * model_dim) - self.out_layer.weight.data.zero_() - self.out_layer.bias.data.zero_() @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): From 56b90b10ef1fe17d7aae3cdbb65025084177fc27 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 07:35:17 +0000 Subject: [PATCH 28/77] moved sdps inside processor --- .../models/transformers/transformer_kandinsky.py | 15 ++++++--------- .../pipelines/kandinsky5/pipeline_kandinsky.py | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 8d2bae11cbfa..680b456df3f7 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -174,14 +174,6 @@ def nablaT_v2( ) -def sdpa(q, k, v): - query = q.transpose(1, 2).contiguous() - key = k.transpose(1, 2).contiguous() - value = v.transpose(1, 2).contiguous() - out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous() - return out - - @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_scale_shift_norm(norm, x, scale, shift): return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) @@ -355,7 +347,12 @@ def __call__( **kwargs, ): # Process attention with the given query, key, value tensors - out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous().flatten(-2, -1) + return out diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3e61ae0bf2c6..bdf7e41df919 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -260,7 +260,7 @@ def get_sparse_params(self, sample, device): return sparse_params - def _get_qwen_prompt_embeds( + def _encode_prompt_qwen( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, @@ -314,7 +314,7 @@ def _get_qwen_prompt_embeds( return embeds.to(dtype), cu_seqlens - def _get_clip_prompt_embeds( + def _encode_prompt_clip( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, From 31a1474378a0ae3fe22bc626f7fe274c99ed30fd Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 08:46:34 +0000 Subject: [PATCH 29/77] remove oneline function --- .../transformers/transformer_kandinsky.py | 91 ++++++++++++------- 1 file changed, 59 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index febe6cff7ae7..bed1938ae34d 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -174,16 +174,6 @@ def nablaT_v2( ) -@torch.autocast(device_type="cuda", dtype=torch.float32) -def apply_scale_shift_norm(norm, x, scale, shift): - return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) - - -@torch.autocast(device_type="cuda", dtype=torch.float32) -def apply_gate_sum(x, out, gate): - return (x + gate * out).to(torch.bfloat16) - - @torch.autocast(device_type="cuda", enabled=False) def apply_rotary(x, rope): x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) @@ -327,6 +317,8 @@ def __init__(self, time_dim, model_dim, num_params): super().__init__() self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, num_params * model_dim) + self.out_layer.weight.data.zero_() + self.out_layer.bias.data.zero_() @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): @@ -585,12 +577,9 @@ def forward(self, visual_embed, text_embed, time_embed): shift, scale = torch.chunk( self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 ) - visual_embed = apply_scale_shift_norm( - self.norm, - visual_embed, - scale[:, None, None], - shift[:, None, None], - ).type_as(visual_embed) + + visual_embed = (self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]).type_as(visual_embed) + x = self.out_layer(visual_embed) batch_size, duration, height, width, _ = x.shape @@ -629,17 +618,59 @@ def forward(self, x, time_embed, rope): self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) + out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) out = self.self_attention(out, rope) - x = apply_gate_sum(x, out, gate) + x = (x.float() + gate.float() * out.float()).type_as(x) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) + out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) out = self.feed_forward(out) - x = apply_gate_sum(x, out, gate) + x = (x.float() + gate.float() * out.float()).type_as(x) + return x +# class Kandinsky5TransformerDecoderBlock(nn.Module): +# def __init__(self, model_dim, time_dim, ff_dim, head_dim): +# super().__init__() +# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) + +# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) +# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) + +# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) +# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) + +# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) +# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) + +# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): +# self_attn_params, cross_attn_params, ff_params = torch.chunk( +# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 +# ) +# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) +# visual_out = apply_scale_shift_norm( +# self.self_attention_norm, visual_embed, scale, shift +# ) +# visual_out = self.self_attention(visual_out, rope, sparse_params) +# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + +# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) +# visual_out = apply_scale_shift_norm( +# self.cross_attention_norm, visual_embed, scale, shift +# ) +# visual_out = self.cross_attention(visual_out, text_embed) +# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + +# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) +# visual_out = apply_scale_shift_norm( +# self.feed_forward_norm, visual_embed, scale, shift +# ) +# visual_out = self.feed_forward(visual_out) +# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) +# return visual_embed + + class Kandinsky5TransformerDecoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() @@ -658,26 +689,22 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): self_attn_params, cross_attn_params, ff_params = torch.chunk( self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 ) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm( - self.self_attention_norm, visual_embed, scale, shift - ) + visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) visual_out = self.self_attention(visual_out, rope, sparse_params) - visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm( - self.cross_attention_norm, visual_embed, scale, shift - ) + visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) visual_out = self.cross_attention(visual_out, text_embed) - visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = apply_scale_shift_norm( - self.feed_forward_norm, visual_embed, scale, shift - ) + visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) visual_out = self.feed_forward(visual_out) - visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) + return visual_embed From 894aa98a2753dfc448f4398cf9a4fd256f763a61 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 09:17:39 +0000 Subject: [PATCH 30/77] remove reset_dtype methods --- .../transformers/transformer_kandinsky.py | 20 +++---------------- .../kandinsky5/pipeline_kandinsky.py | 5 ----- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index bed1938ae34d..8d3b4fac513e 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -189,7 +189,8 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.max_period = max_period self.register_buffer( "freqs", get_freqs(model_dim // 2, max_period), persistent=False - ) + ) + self.freqs = get_freqs(self.model_dim // 2, self.max_period) self.in_layer = nn.Linear(model_dim, time_dim, bias=True) self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, time_dim, bias=True) @@ -199,10 +200,7 @@ def forward(self, time): args = torch.outer(time, self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) - return time_embed - - def reset_dtype(self): - self.freqs = get_freqs(self.model_dim // 2, self.max_period) + return time_embed class Kandinsky5TextEmbeddings(nn.Module): @@ -260,11 +258,6 @@ def forward(self, pos): rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - def reset_dtype(self): - freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device) - pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device) - self.args = torch.outer(pos, freq) - class Kandinsky5RoPE3D(nn.Module): def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): @@ -305,12 +298,6 @@ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - def reset_dtype(self): - for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)): - freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device) - pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device) - setattr(self, f"args_{i}", torch.outer(pos, freq)) - class Kandinsky5Modulation(nn.Module): def __init__(self, time_dim, model_dim, num_params): @@ -337,7 +324,6 @@ def __call__( **kwargs, ): # Process attention with the given query, key, value tensors - query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index bdf7e41df919..b1f7924e9b9f 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -695,11 +695,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - # 0. Reset embeddings dtype - self.transformer.time_embeddings.reset_dtype() - self.transformer.text_rope_embeddings.reset_dtype() - self.transformer.visual_rope_embeddings.reset_dtype() - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, From c8be08149e80ae22e7a7d3b4a1f2413a9f149690 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 09:31:12 +0000 Subject: [PATCH 31/77] Transformer: move all methods to forward --- .../transformers/transformer_kandinsky.py | 185 +++++------------- 1 file changed, 47 insertions(+), 138 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 8d3b4fac513e..45e4238cfb51 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -616,47 +616,6 @@ def forward(self, x, time_embed, rope): return x -# class Kandinsky5TransformerDecoderBlock(nn.Module): -# def __init__(self, model_dim, time_dim, ff_dim, head_dim): -# super().__init__() -# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) - -# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) -# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) - -# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) -# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) - -# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) -# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) - -# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): -# self_attn_params, cross_attn_params, ff_params = torch.chunk( -# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 -# ) -# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) -# visual_out = apply_scale_shift_norm( -# self.self_attention_norm, visual_embed, scale, shift -# ) -# visual_out = self.self_attention(visual_out, rope, sparse_params) -# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) - -# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) -# visual_out = apply_scale_shift_norm( -# self.cross_attention_norm, visual_embed, scale, shift -# ) -# visual_out = self.cross_attention(visual_out, text_embed) -# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) - -# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) -# visual_out = apply_scale_shift_norm( -# self.feed_forward_norm, visual_embed, scale, shift -# ) -# visual_out = self.feed_forward(visual_out) -# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) -# return visual_embed - - class Kandinsky5TransformerDecoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() @@ -724,16 +683,16 @@ def __init__( axes_dims=(16, 24, 24), visual_cond=False, attention_type: str = "regular", - attention_causal: bool = None, # Default for Nabla: false - attention_local: bool = None, # Default for Nabla: false - attention_glob: bool = None, # Default for Nabla: false - attention_window: int = None, # Default for Nabla: 3 - attention_P: float = None, # Default for Nabla: 0.9 - attention_wT: int = None, # Default for Nabla: 11 - attention_wW: int = None, # Default for Nabla: 3 - attention_wH: int = None, # Default for Nabla: 3 - attention_add_sta: bool = None, # Default for Nabla: true - attention_method: str = None, # Default for Nabla: "topcdf" + attention_causal: bool = None, + attention_local: bool = None, + attention_glob: bool = None, + attention_window: int = None, + attention_P: float = None, + attention_wT: int = None, + attention_wW: int = None, + attention_wH: int = None, + attention_add_sta: bool = None, + attention_method: str = None, ): super().__init__() @@ -779,73 +738,6 @@ def __init__( ) self.gradient_checkpointing = False - def prepare_text_embeddings( - self, text_embed, time, pooled_text_embed, x, text_rope_pos - ): - """Prepare text embeddings and related components""" - text_embed = self.text_embeddings(text_embed) - time_embed = self.time_embeddings(time) - time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) - visual_embed = self.visual_embeddings(x) - text_rope = self.text_rope_embeddings(text_rope_pos) - text_rope = text_rope.unsqueeze(dim=0) - return text_embed, time_embed, text_rope, visual_embed - - def prepare_visual_embeddings( - self, visual_embed, visual_rope_pos, scale_factor, sparse_params - ): - """Prepare visual embeddings and related components""" - visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings( - visual_shape, visual_rope_pos, scale_factor - ) - to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False - visual_embed, visual_rope = fractal_flatten( - visual_embed, visual_rope, visual_shape, block_mask=to_fractal - ) - return visual_embed, visual_shape, to_fractal, visual_rope - - def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): - """Process text through transformer blocks""" - for text_transformer_block in self.text_transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - text_embed = self._gradient_checkpointing_func( - text_transformer_block, text_embed, time_embed, text_rope - ) - else: - text_embed = text_transformer_block(text_embed, time_embed, text_rope) - return text_embed - - def process_visual_transformer_blocks( - self, visual_embed, text_embed, time_embed, visual_rope, sparse_params - ): - """Process visual through transformer blocks""" - for visual_transformer_block in self.visual_transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - visual_embed = self._gradient_checkpointing_func( - visual_transformer_block, - visual_embed, - text_embed, - time_embed, - visual_rope, - sparse_params, - ) - else: - visual_embed = visual_transformer_block( - visual_embed, text_embed, time_embed, visual_rope, sparse_params - ) - return visual_embed - - def prepare_output( - self, visual_embed, visual_shape, to_fractal, text_embed, time_embed - ): - """Prepare the final output""" - visual_embed = fractal_unflatten( - visual_embed, visual_shape, block_mask=to_fractal - ) - x = self.out_layer(visual_embed, text_embed, time_embed) - return x - def forward( self, hidden_states: torch.FloatTensor, # x @@ -881,32 +773,49 @@ def forward( time = timestep pooled_text_embed = pooled_projections - # Prepare text embeddings and related components - text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings( - text_embed, time, pooled_text_embed, x, text_rope_pos - ) + text_embed = self.text_embeddings(text_embed) + time_embed = self.time_embeddings(time) + time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) + visual_embed = self.visual_embeddings(x) + text_rope = self.text_rope_embeddings(text_rope_pos) + text_rope = text_rope.unsqueeze(dim=0) - # Process text through transformer blocks - text_embed = self.process_text_transformer_blocks( - text_embed, time_embed, text_rope - ) + for text_transformer_block in self.text_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + text_embed = self._gradient_checkpointing_func( + text_transformer_block, text_embed, time_embed, text_rope + ) + else: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) - # Prepare visual embeddings and related components - visual_embed, visual_shape, to_fractal, visual_rope = ( - self.prepare_visual_embeddings( - visual_embed, visual_rope_pos, scale_factor, sparse_params - ) + visual_shape = visual_embed.shape[:-1] + visual_rope = self.visual_rope_embeddings( + visual_shape, visual_rope_pos, scale_factor ) - - # Process visual through transformer blocks - visual_embed = self.process_visual_transformer_blocks( - visual_embed, text_embed, time_embed, visual_rope, sparse_params + to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False + visual_embed, visual_rope = fractal_flatten( + visual_embed, visual_rope, visual_shape, block_mask=to_fractal ) - # Prepare final output - x = self.prepare_output( - visual_embed, visual_shape, to_fractal, text_embed, time_embed + for visual_transformer_block in self.visual_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + visual_embed = self._gradient_checkpointing_func( + visual_transformer_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + sparse_params, + ) + else: + visual_embed = visual_transformer_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) + + visual_embed = fractal_unflatten( + visual_embed, visual_shape, block_mask=to_fractal ) + x = self.out_layer(visual_embed, text_embed, time_embed) if not return_dict: return x From 3ffdf7f113e442c68d65da5033e31a195f7a1be7 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 10:32:47 +0000 Subject: [PATCH 32/77] separated prompt encoding --- .../kandinsky5/pipeline_kandinsky.py | 153 +++++++----------- 1 file changed, 56 insertions(+), 97 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index b1f7924e9b9f..2ff0c1d45d81 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -359,124 +359,64 @@ def _encode_prompt_clip( def encode_prompt( self, prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): r""" - Encodes the prompt into text encoder hidden states. + Encodes a single prompt (positive or negative) into text encoder hidden states. This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text representations for video generation. - + Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. + prompt (`str` or `List[str]`): + Prompt to be encoded. num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + Number of videos to generate per prompt. max_sequence_length (`int`, *optional*, defaults to 512): Maximum sequence length for text encoding. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - + device (`torch.device`, *optional*): + Torch device. + dtype (`torch.dtype`, *optional*): + Torch dtype. + Returns: - Tuple: Contains prompt embeddings, negative prompt embeddings, and sequence length information + Tuple[Dict[str, torch.Tensor], torch.Tensor]: + - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings """ device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt = [prompt_clean(p) for p in prompt] - - prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( - prompt=prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - dtype=dtype, - ) - prompt_embeds_clip = self._encode_prompt_clip( - prompt=prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - dtype=dtype, - ) - else: - prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds + batch_size = len(prompt) - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" - negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + prompt = [prompt_clean(p) for p in prompt] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt = [prompt_clean(p) for p in negative_prompt] + # Encode with Qwen2.5-VL + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) - negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen( - prompt=negative_prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - dtype=dtype, - ) - negative_prompt_embeds_clip = self._encode_prompt_clip( - prompt=negative_prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - dtype=dtype, - ) - else: - negative_prompt_embeds_qwen = None - negative_prompt_embeds_clip = None - negative_cu_seqlens = None + # Encode with CLIP + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + dtype=dtype, + ) prompt_embeds_dict = { "text_embeds": prompt_embeds_qwen, "pooled_embed": prompt_embeds_clip, } - negative_prompt_embeds_dict = { - "text_embeds": negative_prompt_embeds_qwen, - "pooled_embed": negative_prompt_embeds_clip, - } if do_classifier_free_guidance else None - return prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens + return prompt_embeds_dict, prompt_cu_seqlens def check_inputs( self, @@ -722,24 +662,43 @@ def __call__( # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 + prompt = [prompt] elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] # 3. Encode input prompt - prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( + prompt_embeds_dict, prompt_cu_seqlens = self.encode_prompt( prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) + negative_prompt_embeds_dict = None + negative_cu_seqlens = None + + if self.do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt] + elif len(negative_prompt) != len(prompt): + raise ValueError( + f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." + ) + + negative_prompt_embeds_dict, negative_cu_seqlens = self.encode_prompt( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps From 9f52335290e0e2076166dcc35180557527a7d5eb Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 01:47:38 +0300 Subject: [PATCH 33/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 45e4238cfb51..38cc5156bc49 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -57,7 +57,6 @@ def freeze(model): return model -@torch.autocast(device_type="cuda", enabled=False) def get_freqs(dim, max_period=10000.0): freqs = torch.exp( -math.log(max_period) From cc46e2d2defbb922b7e0ef8e1f014e9361850b5c Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 22:48:09 +0000 Subject: [PATCH 34/77] refactoring --- src/diffusers/models/transformers/transformer_kandinsky.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 45e4238cfb51..d08f2a968e15 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -186,10 +186,7 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): super().__init__() assert model_dim % 2 == 0 self.model_dim = model_dim - self.max_period = max_period - self.register_buffer( - "freqs", get_freqs(model_dim // 2, max_period), persistent=False - ) + self.max_period = max_period self.freqs = get_freqs(self.model_dim // 2, self.max_period) self.in_layer = nn.Linear(model_dim, time_dim, bias=True) self.activation = nn.SiLU() From 9672c6bd6f70a28cca896025fc57e89b72117838 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 01:49:19 +0300 Subject: [PATCH 35/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 38cc5156bc49..488c44189202 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -173,7 +173,6 @@ def nablaT_v2( ) -@torch.autocast(device_type="cuda", enabled=False) def apply_rotary(x, rope): x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) x_out = (rope * x_).sum(dim=-1) From 900feba4fe196b911344c779cc9c951dfbc067ca Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 17 Oct 2025 14:38:42 +0000 Subject: [PATCH 36/77] refactoring acording to https://github.com/huggingface/diffusers/commit/acabbc0033d4b4933fc651766a4aa026db2e6dc1 --- .../transformers/transformer_kandinsky.py | 318 ++++++------------ 1 file changed, 104 insertions(+), 214 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index f88429fa1714..7a4f85c744ec 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -19,10 +19,6 @@ import torch.nn as nn import torch.nn.functional as F from torch import BoolTensor, IntTensor, Tensor, nn -from torch.nn.attention.flex_attention import ( - BlockMask, - flex_attention, -) from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -34,7 +30,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, FeedForward +from ..attention import AttentionMixin, FeedForward, AttentionModuleMixin from ..cache_utils import CacheMixin from ..embeddings import ( TimestepEmbedding, @@ -43,6 +39,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm +from ..attention_dispatch import dispatch_attention_fn, _CAN_USE_FLEX_ATTN logger = logging.get_logger(__name__) @@ -148,7 +145,15 @@ def nablaT_v2( k: Tensor, sta: Tensor, thr: float = 0.9, -) -> BlockMask: +): + if _CAN_USE_FLEX_ATTN: + from torch.nn.attention.flex_attention import BlockMask + else: + raise ValueError("Nabla attention is not supported with this version of PyTorch") + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + # Map estimation B, h, S, D = q.shape s1 = S // 64 @@ -173,18 +178,15 @@ def nablaT_v2( ) -def apply_rotary(x, rope): - x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) - x_out = (rope * x_).sum(dim=-1) - return x_out.reshape(*x.shape).to(torch.bfloat16) - - class Kandinsky5TimeEmbeddings(nn.Module): def __init__(self, model_dim, time_dim, max_period=10000.0): super().__init__() assert model_dim % 2 == 0 self.model_dim = model_dim - self.max_period = max_period + self.max_period = max_period + self.register_buffer( + "freqs", get_freqs(model_dim // 2, max_period), persistent=False + ) self.freqs = get_freqs(self.model_dim // 2, self.max_period) self.in_layer = nn.Linear(model_dim, time_dim, bias=True) self.activation = nn.SiLU() @@ -307,184 +309,82 @@ def forward(self, x): return self.out_layer(self.activation(x)) -class Kandinsky5SDPAAttentionProcessor(nn.Module): - """Custom attention processor for standard SDPA attention""" - - def __call__( - self, - attn, - query, - key, - value, - **kwargs, - ): - # Process attention with the given query, key, value tensors - query = query.transpose(1, 2).contiguous() - key = key.transpose(1, 2).contiguous() - value = value.transpose(1, 2).contiguous() - out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous().flatten(-2, -1) - - return out - - -class Kandinsky5NablaAttentionProcessor(nn.Module): - """Custom attention processor for Nabla attention""" - - @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) - def __call__( - self, - attn, - query, - key, - value, - sparse_params=None, - **kwargs, - ): - if sparse_params is None: - raise ValueError("sparse_params is required for Nabla attention") - - query = query.transpose(1, 2).contiguous() - key = key.transpose(1, 2).contiguous() - value = value.transpose(1, 2).contiguous() - - block_mask = nablaT_v2( - query, - key, - sparse_params["sta_mask"], - thr=sparse_params["P"], - ) - out = ( - flex_attention(query, key, value, block_mask=block_mask) - .transpose(1, 2) - .contiguous() - ) - out = out.flatten(-2, -1) - return out - - -class Kandinsky5MultiheadSelfAttentionEnc(nn.Module): - def __init__(self, num_channels, head_dim): - super().__init__() - assert num_channels % head_dim == 0 - self.num_heads = num_channels // head_dim - - self.to_query = nn.Linear(num_channels, num_channels, bias=True) - self.to_key = nn.Linear(num_channels, num_channels, bias=True) - self.to_value = nn.Linear(num_channels, num_channels, bias=True) - self.query_norm = nn.RMSNorm(head_dim) - self.key_norm = nn.RMSNorm(head_dim) - - self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - - # Initialize attention processor - self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() - - def get_qkv(self, x): - query = self.to_query(x) - key = self.to_key(x) - value = self.to_value(x) +class Kandinsky5AttnProcessor: - shape = query.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*shape, self.num_heads, -1) - value = value.reshape(*shape, self.num_heads, -1) + _attention_backend = None + _parallel_config = None - return query, key, value + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") - def norm_qk(self, q, k): - q = self.query_norm(q.float()).type_as(q) - k = self.key_norm(k.float()).type_as(k) - return q, k - def scaled_dot_product_attention(self, query, key, value): - # Use the processor - return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) + def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None): + # query, key, value = self.get_qkv(x) + query = attn.to_query(hidden_states) - def out_l(self, x): - return self.out_layer(x) + if encoder_hidden_states is not None: + key = attn.to_key(encoder_hidden_states) + value = attn.to_value(encoder_hidden_states) - def forward(self, x, rope): - query, key, value = self.get_qkv(x) - query, key = self.norm_qk(query, key) - query = apply_rotary(query, rope).type_as(query) - key = apply_rotary(key, rope).type_as(key) + shape, cond_shape = query.shape[:-1], key.shape[:-1] + query = query.reshape(*shape, attn.num_heads, -1) + key = key.reshape(*cond_shape, attn.num_heads, -1) + value = value.reshape(*cond_shape, attn.num_heads, -1) + + else: + key = attn.to_key(hidden_states) + value = attn.to_value(hidden_states) - out = self.scaled_dot_product_attention(query, key, value) + shape = query.shape[:-1] + query = query.reshape(*shape, attn.num_heads, -1) + key = key.reshape(*shape, attn.num_heads, -1) + value = value.reshape(*shape, attn.num_heads, -1) - out = self.out_l(out) - return out + # query, key = self.norm_qk(query, key) + query = attn.query_norm(query.float()).type_as(query) + key = attn.key_norm(key.float()).type_as(key) + def apply_rotary(x, rope): + x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) + x_out = (rope * x_).sum(dim=-1) + return x_out.reshape(*x.shape).to(torch.bfloat16) -class Kandinsky5MultiheadSelfAttentionDec(nn.Module): - def __init__(self, num_channels, head_dim): - super().__init__() - assert num_channels % head_dim == 0 - self.num_heads = num_channels // head_dim - - self.to_query = nn.Linear(num_channels, num_channels, bias=True) - self.to_key = nn.Linear(num_channels, num_channels, bias=True) - self.to_value = nn.Linear(num_channels, num_channels, bias=True) - self.query_norm = nn.RMSNorm(head_dim) - self.key_norm = nn.RMSNorm(head_dim) - - self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - - # Initialize attention processors - self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() - self.nabla_processor = Kandinsky5NablaAttentionProcessor() - - def get_qkv(self, x): - query = self.to_query(x) - key = self.to_key(x) - value = self.to_value(x) - - shape = query.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*shape, self.num_heads, -1) - value = value.reshape(*shape, self.num_heads, -1) - - return query, key, value - - def norm_qk(self, q, k): - q = self.query_norm(q.float()).type_as(q) - k = self.key_norm(k.float()).type_as(k) - return q, k - - def attention(self, query, key, value): - # Use the processor - return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) - - def nabla(self, query, key, value, sparse_params=None): - # Use the processor - return self.nabla_processor( - attn=self, - query=query, - key=key, - value=value, - sparse_params=sparse_params, - **{}, - ) - - def out_l(self, x): - return self.out_layer(x) - - def forward(self, x, rope, sparse_params=None): - query, key, value = self.get_qkv(x) - query, key = self.norm_qk(query, key) - query = apply_rotary(query, rope).type_as(query) - key = apply_rotary(key, rope).type_as(key) + if rotary_emb is not None: + query = apply_rotary(query, rotary_emb).type_as(query) + key = apply_rotary(key, rotary_emb).type_as(key) if sparse_params is not None: - out = self.nabla(query, key, value, sparse_params=sparse_params) + attn_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) else: - out = self.attention(query, key, value) + attn_mask = None + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(-2, -1) - out = self.out_l(out) - return out + attn_out = attn.out_layer(hidden_states) + return attn_out -class Kandinsky5MultiheadCrossAttention(nn.Module): - def __init__(self, num_channels, head_dim): +class Kandinsky5Attention(nn.Module, AttentionModuleMixin): + + _default_processor_cls = Kandinsky5AttnProcessor + _available_processors = [ + Kandinsky5AttnProcessor, + ] + def __init__(self, num_channels, head_dim, processor=None): super().__init__() assert num_channels % head_dim == 0 self.num_heads = num_channels // head_dim @@ -496,43 +396,33 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) - # Initialize attention processor - self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() - - def get_qkv(self, x, cond): - query = self.to_query(x) - key = self.to_key(cond) - value = self.to_value(cond) - - shape, cond_shape = query.shape[:-1], key.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*cond_shape, self.num_heads, -1) - value = value.reshape(*cond_shape, self.num_heads, -1) - - return query, key, value - - def norm_qk(self, q, k): - q = self.query_norm(q.float()).type_as(q) - k = self.key_norm(k.float()).type_as(k) - return q, k - - def attention(self, query, key, value): - # Use the processor - return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) - - def out_l(self, x): - return self.out_layer(x) + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + sparse_params: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: - def forward(self, x, cond): - query, key, value = self.get_qkv(x, cond) - query, key = self.norm_qk(query, key) + import inspect - out = self.attention(query, key, value) - out = self.out_l(out) - return out + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states=encoder_hidden_states, sparse_params=sparse_params, rotary_emb=rotary_emb, **kwargs) + class Kandinsky5FeedForward(nn.Module): def __init__(self, dim, ff_dim): super().__init__() @@ -589,7 +479,7 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = Kandinsky5MultiheadSelfAttentionEnc(model_dim, head_dim) + self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) @@ -600,7 +490,7 @@ def forward(self, x, time_embed, rope): ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) - out = self.self_attention(out, rope) + out = self.self_attention(out, rotary_emb=rope) x = (x.float() + gate.float() * out.float()).type_as(x) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) @@ -617,10 +507,10 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) + self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) + self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) @@ -632,12 +522,12 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) - visual_out = self.self_attention(visual_out, rope, sparse_params) + visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) - visual_out = self.cross_attention(visual_out, text_embed) + visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) @@ -815,4 +705,4 @@ def forward( if not return_dict: return x - return Transformer2DModelOutput(sample=x) + return Transformer2DModelOutput(sample=x) \ No newline at end of file From 226bbf8ee1c3c1ddc408aaa6664519c36c995176 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:36:09 +0300 Subject: [PATCH 37/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 7a4f85c744ec..7569b8cd8006 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -44,8 +44,6 @@ logger = logging.get_logger(__name__) -def exist(item): - return item is not None def freeze(model): From 9504fb0d63f9ddd59c01e290c9d71304981bf7f5 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:36:32 +0300 Subject: [PATCH 38/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 7569b8cd8006..d85b411caf07 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -46,10 +46,6 @@ -def freeze(model): - for p in model.parameters(): - p.requires_grad = False - return model def get_freqs(dim, max_period=10000.0): From f0eca0849b68d61b7cf98b54e4a95ec9e92157a4 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:37:35 +0300 Subject: [PATCH 39/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index d85b411caf07..03b40e78de55 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -178,9 +178,6 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): assert model_dim % 2 == 0 self.model_dim = model_dim self.max_period = max_period - self.register_buffer( - "freqs", get_freqs(model_dim // 2, max_period), persistent=False - ) self.freqs = get_freqs(self.model_dim // 2, self.max_period) self.in_layer = nn.Linear(model_dim, time_dim, bias=True) self.activation = nn.SiLU() From cc74c1e46e47d2dbd518c40d636e21e20d3bfbc1 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:41:21 +0300 Subject: [PATCH 40/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 03b40e78de55..45bc4849749a 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -237,7 +237,6 @@ def __init__(self, dim, max_pos=1024, max_period=10000.0): pos = torch.arange(max_pos, dtype=freq.dtype) self.register_buffer(f"args", torch.outer(pos, freq), persistent=False) - @torch.autocast(device_type="cuda", enabled=False) def forward(self, pos): args = self.args[pos] cosine = torch.cos(args) From cb915d71adb2bcfef1a30b91774ce19542923c0a Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:41:33 +0300 Subject: [PATCH 41/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 45bc4849749a..6b9f60432503 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -258,7 +258,6 @@ def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): pos = torch.arange(ax_max_pos, dtype=freq.dtype) self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False) - @torch.autocast(device_type="cuda", enabled=False) def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): batch_size, duration, height, width = shape args_t = self.args_0[pos[0]] / scale_factor[0] From 9aa3c2eb20d4e16b3c2db2caef458acaaac32fbf Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:41:56 +0300 Subject: [PATCH 42/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 6b9f60432503..490b64ffdfd1 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -614,7 +614,7 @@ def __init__( def forward( self, - hidden_states: torch.FloatTensor, # x + hidden_states: torch.Tensor, # x encoder_hidden_states: torch.FloatTensor, # text_embed timestep: Union[torch.Tensor, float, int], # time pooled_projections: torch.FloatTensor, # pooled_text_embed From feac8f095ff285bbe9bfd23989567ab27166b2ad Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:45:30 +0300 Subject: [PATCH 43/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 490b64ffdfd1..2c12b0e90b65 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -615,7 +615,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, # x - encoder_hidden_states: torch.FloatTensor, # text_embed + encoder_hidden_states: torch.Tensor, # text_embed timestep: Union[torch.Tensor, float, int], # time pooled_projections: torch.FloatTensor, # pooled_text_embed visual_rope_pos: Tuple[int, int, int], From d3b959750bc3e39e44bcd6910504a9e1b23260bd Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:46:34 +0300 Subject: [PATCH 44/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 2c12b0e90b65..e674a8ba1f2a 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -616,7 +616,7 @@ def forward( self, hidden_states: torch.Tensor, # x encoder_hidden_states: torch.Tensor, # text_embed - timestep: Union[torch.Tensor, float, int], # time + timestep: torch.Tensor, # time pooled_projections: torch.FloatTensor, # pooled_text_embed visual_rope_pos: Tuple[int, int, int], text_rope_pos: torch.LongTensor, From 693b9aa9c2880d9d570d44996bcfcafd9be9cf01 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:47:03 +0300 Subject: [PATCH 45/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index e674a8ba1f2a..ad39a9bed63f 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -617,7 +617,7 @@ def forward( hidden_states: torch.Tensor, # x encoder_hidden_states: torch.Tensor, # text_embed timestep: torch.Tensor, # time - pooled_projections: torch.FloatTensor, # pooled_text_embed + pooled_projections: torch.Tensor, # pooled_text_embed visual_rope_pos: Tuple[int, int, int], text_rope_pos: torch.LongTensor, scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), From e2ed6ec961d8d2a251d71de5345a5012fd302a17 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:47:57 +0300 Subject: [PATCH 46/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 2ff0c1d45d81..5369bc579b67 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -416,7 +416,7 @@ def encode_prompt( "pooled_embed": prompt_embeds_clip, } - return prompt_embeds_dict, prompt_cu_seqlens + return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens def check_inputs( self, From 2925447e3339ca3477144f3814106e87952a0c4a Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:48:35 +0300 Subject: [PATCH 47/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 5369bc579b67..988cce6b5e79 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -411,10 +411,6 @@ def encode_prompt( dtype=dtype, ) - prompt_embeds_dict = { - "text_embeds": prompt_embeds_qwen, - "pooled_embed": prompt_embeds_clip, - } return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens From b02ad82513971dfe14c57b9782d0218e9364df97 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:48:55 +0300 Subject: [PATCH 48/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 988cce6b5e79..c1c510dc12c6 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -398,7 +398,6 @@ def encode_prompt( prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( prompt=prompt, device=device, - num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, dtype=dtype, ) From dc67c2bb4bb1367c7dc3fd4a9cdc93b452e531e5 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:49:19 +0300 Subject: [PATCH 49/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index c1c510dc12c6..420748873cf3 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -406,7 +406,6 @@ def encode_prompt( prompt_embeds_clip = self._encode_prompt_clip( prompt=prompt, device=device, - num_videos_per_prompt=num_videos_per_prompt, dtype=dtype, ) From d0fc426a744172595f194d01687ca1bc54300bd1 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:49:48 +0300 Subject: [PATCH 50/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 420748873cf3..f879f9dc5d09 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -305,7 +305,6 @@ def _encode_prompt_qwen( output_hidden_states=True, )["hidden_states"][-1][:, self.prompt_template_encode_start_idx:] - batch_size = len(prompt) attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) From 222ba4ca4dd2093696937252e21f11c6b04410a6 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:50:06 +0300 Subject: [PATCH 51/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index f879f9dc5d09..1e5a5ac58fa3 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -264,7 +264,6 @@ def _encode_prompt_qwen( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, - num_videos_per_prompt: int = 1, max_sequence_length: int = 256, dtype: Optional[torch.dtype] = None, ): From 3a495058b05dacc7bc2f4eb8982430e4864e8628 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:50:48 +0300 Subject: [PATCH 52/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 1e5a5ac58fa3..6adc611bdc11 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -308,7 +308,6 @@ def _encode_prompt_qwen( attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) - embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0) return embeds.to(dtype), cu_seqlens From 1e12017008ea693823d08fd9b54a1d54b7f1db56 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:51:08 +0300 Subject: [PATCH 53/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 6adc611bdc11..b700df0e485e 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -315,7 +315,6 @@ def _encode_prompt_clip( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, - num_videos_per_prompt: int = 1, dtype: Optional[torch.dtype] = None, ): """ From 5a300798efeee38600c9101882144e3d8ff53f16 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:51:40 +0300 Subject: [PATCH 54/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index b700df0e485e..4b5c19a9e3cf 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -346,9 +346,6 @@ def _encode_prompt_clip( pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] - batch_size = len(prompt) - pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) - pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) return pooled_embed.to(dtype) From 0d96ecfdd53f209bedd29b1df6e661eb03cd8dea Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:51:57 +0300 Subject: [PATCH 55/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 4b5c19a9e3cf..4c880e079a55 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -401,7 +401,11 @@ def encode_prompt( device=device, dtype=dtype, ) - + prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1) + + prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens From aadafc14d20117db514fd70ddadc9d4fb5c5bf05 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:52:15 +0300 Subject: [PATCH 56/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 4c880e079a55..67a49ecaa5e6 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -668,8 +668,6 @@ def __call__( dtype=dtype, ) - negative_prompt_embeds_dict = None - negative_cu_seqlens = None if self.do_classifier_free_guidance: if negative_prompt is None: From 54cf03c7139c26670edd15a781c5e98f6c56ad88 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:52:29 +0300 Subject: [PATCH 57/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 67a49ecaa5e6..a7b8bd117c1a 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -563,7 +563,7 @@ def __call__( num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_qwen: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, From 22c503fb84b60b2c6eed777c3b4f23ee82ea5936 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:52:55 +0300 Subject: [PATCH 58/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index a7b8bd117c1a..0ba0bed9e102 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -564,7 +564,11 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds_qwen: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_clip: Optional[torch.Tensor] = None, + negative_prompt_embeds_qwen: Optional[torch.Tensor] = None, + negative_prompt_embeds_clip: Optional[torch.Tensor] = None, + prompt_cu_seqlens: Optional[torch.Tensor] = None, + negative_prompt_cu_seqlens: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ From 211d3dd3407a413ce414646b0154781a817d9fba Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:53:10 +0300 Subject: [PATCH 59/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- .../pipelines/kandinsky5/pipeline_kandinsky.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 0ba0bed9e102..fcd6bc301ea9 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -664,13 +664,13 @@ def __call__( batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] # 3. Encode input prompt - prompt_embeds_dict, prompt_cu_seqlens = self.encode_prompt( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) + if prompt_embeds_qwen is None: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) if self.do_classifier_free_guidance: From 70cfb9e984344f72f63834670f05a5a328bfb565 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:54:16 +0300 Subject: [PATCH 60/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- .../pipelines/kandinsky5/pipeline_kandinsky.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index fcd6bc301ea9..5ab69420962d 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -684,13 +684,13 @@ def __call__( f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." ) - negative_prompt_embeds_dict, negative_cu_seqlens = self.encode_prompt( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) + if negative_prompt_embeds_qwen is None: + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt( + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) From 6e83133e699855c62824f34cac0dbd8ff86e6f0b Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:54:47 +0300 Subject: [PATCH 61/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 5ab69420962d..1cbf5f84fb94 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -743,7 +743,7 @@ def __call__( # Predict noise residual pred_velocity = self.transformer( hidden_states=latents.to(dtype), - encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, From 7ad87f3554e1d64d0fcf510698552a7408b810bb Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:55:06 +0300 Subject: [PATCH 62/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 1cbf5f84fb94..a863b49a8f71 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -744,7 +744,7 @@ def __call__( pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_qwen.to(dtype), - pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, text_rope_pos=text_rope_pos, From bf229afa110338bfbd9dd58460605c6670152c02 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:56:04 +0300 Subject: [PATCH 63/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index a863b49a8f71..c12cee5b8027 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -756,7 +756,7 @@ def __call__( if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None: uncond_pred_velocity = self.transformer( hidden_states=latents.to(dtype), - encoder_hidden_states=negative_prompt_embeds_dict["text_embeds"].to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, From 06afd9ba19ab5de8a2bfbfb1ff33f6fb1c845c02 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:57:04 +0300 Subject: [PATCH 64/77] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index c12cee5b8027..fe5c59cc247b 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -757,7 +757,7 @@ def __call__( uncond_pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), - pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, text_rope_pos=negative_text_rope_pos, From e1a635ec7fb0e2b7e29fb9c7e1629ae0fd2ffdea Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 17 Oct 2025 20:28:06 +0000 Subject: [PATCH 65/77] fixed --- .../kandinsky5/pipeline_kandinsky.py | 175 ++++++++++++++---- 1 file changed, 137 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index fe5c59cc247b..ff6b00d5fb26 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -349,6 +349,66 @@ def _encode_prompt_clip( return pooled_embed.to(dtype) +# def encode_prompt( +# self, +# prompt: Union[str, List[str]], +# num_videos_per_prompt: int = 1, +# max_sequence_length: int = 512, +# device: Optional[torch.device] = None, +# dtype: Optional[torch.dtype] = None, +# ): +# r""" +# Encodes a single prompt (positive or negative) into text encoder hidden states. + +# This method combines embeddings from both Qwen2.5-VL and CLIP text encoders +# to create comprehensive text representations for video generation. + +# Args: +# prompt (`str` or `List[str]`): +# Prompt to be encoded. +# num_videos_per_prompt (`int`, *optional*, defaults to 1): +# Number of videos to generate per prompt. +# max_sequence_length (`int`, *optional*, defaults to 512): +# Maximum sequence length for text encoding. +# device (`torch.device`, *optional*): +# Torch device. +# dtype (`torch.dtype`, *optional*): +# Torch dtype. + +# Returns: +# Tuple[Dict[str, torch.Tensor], torch.Tensor]: +# - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP) +# - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings +# """ +# device = device or self._execution_device +# dtype = dtype or self.text_encoder.dtype + +# batch_size = len(prompt) + +# prompt = [prompt_clean(p) for p in prompt] + +# # Encode with Qwen2.5-VL +# prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( +# prompt=prompt, +# device=device, +# max_sequence_length=max_sequence_length, +# dtype=dtype, +# ) + +# # Encode with CLIP +# prompt_embeds_clip = self._encode_prompt_clip( +# prompt=prompt, +# device=device, +# dtype=dtype, +# ) +# prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) +# prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1) + +# prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) +# prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) + +# return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens + def encode_prompt( self, prompt: Union[str, List[str]], @@ -376,9 +436,10 @@ def encode_prompt( Torch dtype. Returns: - Tuple[Dict[str, torch.Tensor], torch.Tensor]: - - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP) - - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim) + - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * num_videos_per_prompt + 1,) """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -394,6 +455,7 @@ def encode_prompt( max_sequence_length=max_sequence_length, dtype=dtype, ) + # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim] # Encode with CLIP prompt_embeds_clip = self._encode_prompt_clip( @@ -401,13 +463,30 @@ def encode_prompt( device=device, dtype=dtype, ) - prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) - prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1) - - prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) + # prompt_embeds_clip shape: [batch_size, clip_embed_dim] + + # Repeat embeddings for num_videos_per_prompt + # Qwen embeddings: repeat sequence for each video, then reshape + prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) # [batch_size, seq_len * num_videos_per_prompt, embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim] + prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1]) + + # CLIP embeddings: repeat for each video + prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) # [batch_size, num_videos_per_prompt, clip_embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim] prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) - return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens + # Repeat cumulative sequence lengths for num_videos_per_prompt + # Original cu_seqlens: [0, len1, len1+len2, ...] + # Need to repeat the differences and reconstruct for repeated prompts + # Original differences (lengths) for each prompt in the batch + original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] + # Repeat the lengths for num_videos_per_prompt + repeated_lengths = original_lengths.repeat_interleave(num_videos_per_prompt) # [len1, len1, ..., len2, len2, ...] + # Reconstruct the cumulative lengths + repeated_cu_seqlens = torch.cat([torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]) + + return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens def check_inputs( self, @@ -415,22 +494,30 @@ def check_inputs( negative_prompt, height, width, - prompt_embeds=None, - negative_prompt_embeds=None, + prompt_embeds_qwen=None, + prompt_embeds_clip=None, + negative_prompt_embeds_qwen=None, + negative_prompt_embeds_clip=None, + prompt_cu_seqlens=None, + negative_prompt_cu_seqlens=None, callback_on_step_end_tensor_inputs=None, ): """ Validate input parameters for the pipeline. - + Args: prompt: Input prompt negative_prompt: Negative prompt for guidance height: Video height width: Video width - prompt_embeds: Pre-computed prompt embeddings - negative_prompt_embeds: Pre-computed negative prompt embeddings + prompt_embeds_qwen: Pre-computed Qwen prompt embeddings + prompt_embeds_clip: Pre-computed CLIP prompt embeddings + negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings + negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings + prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt + negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt callback_on_step_end_tensor_inputs: Callback tensor inputs - + Raises: ValueError: If inputs are invalid """ @@ -444,23 +531,32 @@ def check_inputs( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: + # Check for consistency within positive prompt embeddings and sequence lengths + if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: + if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: + raise ValueError( + f"If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " + f"all three must be provided." + ) + + # Check for consistency within negative prompt embeddings and sequence lengths + if negative_prompt_embeds_qwen is not None or negative_prompt_embeds_clip is not None or negative_prompt_cu_seqlens is not None: + if negative_prompt_embeds_qwen is None or negative_prompt_embeds_clip is None or negative_prompt_cu_seqlens is None: + raise ValueError( + f"If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " + f"all three must be provided." + ) + + # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) + if prompt is None and prompt_embeds_qwen is None: raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + + # Validate types for prompt and negative_prompt if provided + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif negative_prompt is not None and ( + if negative_prompt is not None and ( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") @@ -632,13 +728,17 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, - negative_prompt, - height, - width, - prompt_embeds, - negative_prompt_embeds, - callback_on_step_end_tensor_inputs, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + prompt_embeds_qwen=prompt_embeds_qwen, + prompt_embeds_clip=prompt_embeds_clip, + negative_prompt_embeds_qwen=negative_prompt_embeds_qwen, + negative_prompt_embeds_clip=negative_prompt_embeds_clip, + prompt_cu_seqlens=prompt_cu_seqlens, + negative_prompt_cu_seqlens=negative_prompt_cu_seqlens, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -739,7 +839,7 @@ def __call__( continue timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) - + # Predict noise residual pred_velocity = self.transformer( hidden_states=latents.to(dtype), @@ -753,7 +853,7 @@ def __call__( return_dict=True ).sample - if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None: + if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None: uncond_pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), @@ -769,7 +869,6 @@ def __call__( pred_velocity = uncond_pred_velocity + guidance_scale * ( pred_velocity - uncond_pred_velocity ) - # Compute previous sample using the scheduler latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False From 1bf19f0904d9faa6849c75f0a4a6f9441643be66 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 18 Oct 2025 05:20:06 +0200 Subject: [PATCH 66/77] style +copies --- src/diffusers/__init__.py | 8 +- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 19 +- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformers/transformer_kandinsky.py | 138 +++---- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/kandinsky5/__init__.py | 2 +- .../kandinsky5/pipeline_kandinsky.py | 348 ++++++++---------- src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 15 + 11 files changed, 258 insertions(+), 297 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 54e33d69514f..aa500b149441 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -220,6 +220,7 @@ "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Kandinsky3UNet", + "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", @@ -260,7 +261,6 @@ "VQModel", "WanTransformer3DModel", "WanVACETransformer3DModel", - "Kandinsky5Transformer3DModel", "attention_backend", ] ) @@ -475,6 +475,7 @@ "ImageTextPipelineOutput", "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", + "Kandinsky5T2VPipeline", "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgPipeline", @@ -623,7 +624,6 @@ "WanPipeline", "WanVACEPipeline", "WanVideoToVideoPipeline", - "Kandinsky5T2VPipeline", "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", @@ -914,6 +914,7 @@ HunyuanVideoTransformer3DModel, I2VGenXLUNet, Kandinsky3UNet, + Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, @@ -953,7 +954,6 @@ VQModel, WanTransformer3DModel, WanVACETransformer3DModel, - Kandinsky5Transformer3DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1139,6 +1139,7 @@ ImageTextPipelineOutput, Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, + Kandinsky5T2VPipeline, KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyImg2ImgPipeline, @@ -1286,7 +1287,6 @@ WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline, - Kandinsky5T2VPipeline, WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 6a48ac1b0deb..48507aae038c 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -116,6 +116,7 @@ def text_encoder_attn_modules(text_encoder): FluxLoraLoaderMixin, HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, + KandinskyLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Lumina2LoraLoaderMixin, @@ -127,7 +128,6 @@ def text_encoder_attn_modules(text_encoder): StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, - KandinskyLoraLoaderMixin ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ea1b92c68b59..2bb6c0ea026e 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3638,7 +3638,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components, **kwargs) - + class KandinskyLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`Kandinsky5Transformer3DModel`], @@ -3662,7 +3662,8 @@ def lora_state_dict( Can be either: - A string, the *model id* of a pretrained model hosted on the Hub. - A path to a *directory* containing the model weights. - - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached. @@ -3737,7 +3738,7 @@ def load_lora_weights( ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` - + Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. @@ -3746,7 +3747,8 @@ def load_lora_weights( hotswap (`bool`, *optional*): Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. kwargs (`dict`, *optional*): See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. """ @@ -3827,7 +3829,6 @@ def load_lora_into_transformer( hotswap=hotswap, ) - @classmethod def save_lora_weights( cls, @@ -3864,9 +3865,7 @@ def save_lora_weights( lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata if not lora_layers: - raise ValueError( - "You must pass at least one of `transformer_lora_layers`" - ) + raise ValueError("You must pass at least one of `transformer_lora_layers`") cls._save_lora_weights( save_directory=save_directory, @@ -3923,7 +3922,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ super().unfuse_lora(components=components, **kwargs) - + class WanLoraLoaderMixin(LoraBaseMixin): r""" @@ -5088,4 +5087,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) \ No newline at end of file + super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 89ca9d39774b..8d029bf5d31c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -91,6 +91,7 @@ _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] + _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] @@ -101,7 +102,6 @@ _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] - _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -183,6 +183,7 @@ HunyuanDiT2DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, + Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, @@ -201,7 +202,6 @@ TransformerTemporalModel, WanTransformer3DModel, WanVACETransformer3DModel, - Kandinsky5Transformer3DModel, ) from .unets import ( I2VGenXLUNet, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 4b9911f9cb5d..6b80ea6c82a5 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -27,6 +27,7 @@ from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel + from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel @@ -37,4 +38,3 @@ from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel - from .transformer_kandinsky import Kandinsky5Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index ad39a9bed63f..a338922583ca 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -12,48 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F -from torch import BoolTensor, IntTensor, Tensor, nn +from torch import Tensor from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import ( - USE_PEFT_BACKEND, - deprecate, logging, - scale_lora_layers, - unscale_lora_layers, ) -from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, FeedForward, AttentionModuleMixin +from ..attention import AttentionMixin, AttentionModuleMixin +from ..attention_dispatch import _CAN_USE_FLEX_ATTN, dispatch_attention_fn from ..cache_utils import CacheMixin -from ..embeddings import ( - TimestepEmbedding, - get_1d_rotary_pos_embed, -) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm -from ..attention_dispatch import dispatch_attention_fn, _CAN_USE_FLEX_ATTN - -logger = logging.get_logger(__name__) - - +logger = logging.get_logger(__name__) def get_freqs(dim, max_period=10000.0): - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=dim, dtype=torch.float32) - / dim - ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) return freqs @@ -147,7 +131,7 @@ def nablaT_v2( q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() - + # Map estimation B, h, S, D = q.shape s1 = S // 64 @@ -167,9 +151,7 @@ def nablaT_v2( # BlockMask creation kv_nb = mask.sum(-1).to(torch.int32) kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) - return BlockMask.from_kv_blocks( - torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None - ) + return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None) class Kandinsky5TimeEmbeddings(nn.Module): @@ -188,7 +170,7 @@ def forward(self, time): args = torch.outer(time, self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) - return time_embed + return time_embed class Kandinsky5TextEmbeddings(nn.Module): @@ -235,7 +217,7 @@ def __init__(self, dim, max_pos=1024, max_period=10000.0): self.max_pos = max_pos freq = get_freqs(dim // 2, max_period) pos = torch.arange(max_pos, dtype=freq.dtype) - self.register_buffer(f"args", torch.outer(pos, freq), persistent=False) + self.register_buffer("args", torch.outer(pos, freq), persistent=False) def forward(self, pos): args = self.args[pos] @@ -266,15 +248,9 @@ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): args = torch.cat( [ - args_t.view(1, duration, 1, 1, -1).repeat( - batch_size, 1, height, width, 1 - ), - args_h.view(1, 1, height, 1, -1).repeat( - batch_size, duration, 1, width, 1 - ), - args_w.view(1, 1, 1, width, -1).repeat( - batch_size, duration, height, 1, 1 - ), + args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1), + args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1), + args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1), ], dim=-1, ) @@ -299,7 +275,6 @@ def forward(self, x): class Kandinsky5AttnProcessor: - _attention_backend = None _parallel_config = None @@ -307,7 +282,6 @@ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") - def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None): # query, key, value = self.get_qkv(x) query = attn.to_query(hidden_states) @@ -320,7 +294,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=N query = query.reshape(*shape, attn.num_heads, -1) key = key.reshape(*cond_shape, attn.num_heads, -1) value = value.reshape(*cond_shape, attn.num_heads, -1) - + else: key = attn.to_key(hidden_states) value = attn.to_value(hidden_states) @@ -352,10 +326,10 @@ def apply_rotary(x, rope): ) else: attn_mask = None - + hidden_states = dispatch_attention_fn( - query, - key, + query, + key, value, attn_mask=attn_mask, backend=self._attention_backend, @@ -368,11 +342,11 @@ def apply_rotary(x, rope): class Kandinsky5Attention(nn.Module, AttentionModuleMixin): - _default_processor_cls = Kandinsky5AttnProcessor _available_processors = [ Kandinsky5AttnProcessor, ] + def __init__(self, num_channels, head_dim, processor=None): super().__init__() assert num_channels % head_dim == 0 @@ -397,9 +371,6 @@ def forward( rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: - - import inspect - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) quiet_attn_parameters = {} unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] @@ -409,9 +380,16 @@ def forward( ) kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} - return self.processor(self, hidden_states, encoder_hidden_states=encoder_hidden_states, sparse_params=sparse_params, rotary_emb=rotary_emb, **kwargs) + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + sparse_params=sparse_params, + rotary_emb=rotary_emb, + **kwargs, + ) + - class Kandinsky5FeedForward(nn.Module): def __init__(self, dim, ff_dim): super().__init__() @@ -429,16 +407,14 @@ def __init__(self, model_dim, time_dim, visual_dim, patch_size): self.patch_size = patch_size self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2) self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.out_layer = nn.Linear( - model_dim, math.prod(patch_size) * visual_dim, bias=True - ) + self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True) def forward(self, visual_embed, text_embed, time_embed): - shift, scale = torch.chunk( - self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 - ) - - visual_embed = (self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]).type_as(visual_embed) + shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + + visual_embed = ( + self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None] + ).type_as(visual_embed) x = self.out_layer(visual_embed) @@ -474,9 +450,7 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk( - self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 - ) + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) out = self.self_attention(out, rotary_emb=rope) @@ -510,17 +484,23 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) + visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( + visual_embed + ) visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) + visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( + visual_embed + ) visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) + visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( + visual_embed + ) visual_out = self.feed_forward(visual_out) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) @@ -583,9 +563,7 @@ def __init__( self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim) self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim) self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim) - self.visual_embeddings = Kandinsky5VisualEmbeddings( - visual_embed_dim, model_dim, patch_size - ) + self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size) # Initialize positional embeddings self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim) @@ -593,10 +571,7 @@ def __init__( # Initialize transformer blocks self.text_transformer_blocks = nn.ModuleList( - [ - Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) - for _ in range(num_text_blocks) - ] + [Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_text_blocks)] ) self.visual_transformer_blocks = nn.ModuleList( @@ -607,9 +582,7 @@ def __init__( ) # Initialize output layer - self.out_layer = Kandinsky5OutLayer( - model_dim, time_dim, out_visual_dim, patch_size - ) + self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) self.gradient_checkpointing = False def forward( @@ -639,8 +612,7 @@ def forward( return_dict (`bool`, optional): Whether to return a dictionary Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: - The output of the transformer + [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: The output of the transformer """ x = hidden_states text_embed = encoder_hidden_states @@ -663,13 +635,9 @@ def forward( text_embed = text_transformer_block(text_embed, time_embed, text_rope) visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings( - visual_shape, visual_rope_pos, scale_factor - ) + visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False - visual_embed, visual_rope = fractal_flatten( - visual_embed, visual_rope, visual_shape, block_mask=to_fractal - ) + visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, block_mask=to_fractal) for visual_transformer_block in self.visual_transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -686,12 +654,10 @@ def forward( visual_embed, text_embed, time_embed, visual_rope, sparse_params ) - visual_embed = fractal_unflatten( - visual_embed, visual_shape, block_mask=to_fractal - ) + visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) x = self.out_layer(visual_embed, text_embed, time_embed) if not return_dict: return x - return Transformer2DModelOutput(sample=x) \ No newline at end of file + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 201d92afb07c..c438caed571f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -672,6 +672,7 @@ Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, ) + from .kandinsky5 import Kandinsky5T2VPipeline from .latent_consistency_models import ( LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, @@ -788,7 +789,6 @@ ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline - from .kandinsky5 import Kandinsky5T2VPipeline from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, diff --git a/src/diffusers/pipelines/kandinsky5/__init__.py b/src/diffusers/pipelines/kandinsky5/__init__.py index af8e12421740..a7975bdce926 100644 --- a/src/diffusers/pipelines/kandinsky5/__init__.py +++ b/src/diffusers/pipelines/kandinsky5/__init__.py @@ -23,7 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"] - + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index ff6b00d5fb26..3eb706f238ad 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -18,7 +18,7 @@ import regex as re import torch from torch.nn import functional as F -from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import KandinskyLoraLoaderMixin @@ -49,13 +49,13 @@ EXAMPLE_DOC_STRING = """ Examples: - + ```python >>> import torch >>> from diffusers import Kandinsky5T2VPipeline >>> from diffusers.utils import export_to_video - - >>> # Available models: + + >>> # Available models: >>> # ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers >>> # ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers >>> # ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers @@ -67,7 +67,7 @@ >>> prompt = "A cat and a dog baking a cake together in a kitchen." >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" - + >>> output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, @@ -77,7 +77,7 @@ ... num_inference_steps=50, ... guidance_scale=5.0, ... ).frames[0] - + >>> export_to_video(output, "output.mp4", fps=24, quality=9) ``` """ @@ -129,7 +129,13 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_qwen", + "prompt_embeds_clip", + "negative_prompt_embeds_qwen", + "negative_prompt_embeds_clip", + ] def __init__( self, @@ -152,40 +158,42 @@ def __init__( tokenizer_2=tokenizer_2, scheduler=scheduler, ) - - self.prompt_template = "\n".join(["<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", - "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", - "Describe the location of the video, main characters or objects and their action.", - "Describe the dynamism of the video and presented actions.", - "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", - "Describe the visual effects, postprocessing and transitions if they are presented in the video.", - "Pay attention to the order of key actions shown in the scene.<|im_end|>", - "<|im_start|>user\n{}<|im_end|>"]) + + self.prompt_template = "\n".join( + [ + "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>", + ] + ) self.prompt_template_encode_start_idx = 129 self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - + @staticmethod - def fast_sta_nabla( - T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda" - ) -> torch.Tensor: + def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> torch.Tensor: """ Create a sparse temporal attention (STA) mask for efficient video generation. - - This method generates a mask that limits attention to nearby frames and spatial positions, - reducing computational complexity for video generation. - + + This method generates a mask that limits attention to nearby frames and spatial positions, reducing + computational complexity for video generation. + Args: T (int): Number of temporal frames H (int): Height in latent space - W (int): Width in latent space + W (int): Width in latent space wT (int): Temporal attention window size wH (int): Height attention window size wW (int): Width attention window size device (str): Device to create tensor on - + Returns: torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W) """ @@ -200,30 +208,21 @@ def fast_sta_nabla( sta_t = sta_t <= wT // 2 sta_h = sta_h <= wH // 2 sta_w = sta_w <= wW // 2 - sta_hw = ( - (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)) - .reshape(H, H, W, W) - .transpose(1, 2) - .flatten() - ) - sta = ( - (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)) - .reshape(T, T, H * W, H * W) - .transpose(1, 2) - ) + sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten() + sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2) return sta.reshape(T * H * W, T * H * W) - + def get_sparse_params(self, sample, device): """ Generate sparse attention parameters for the transformer based on sample dimensions. - - This method computes the sparse attention configuration needed for efficient - video processing in the transformer model. - + + This method computes the sparse attention configuration needed for efficient video processing in the + transformer model. + Args: sample (torch.Tensor): Input sample tensor device (torch.device): Device to place tensors on - + Returns: Dict: Dictionary containing sparse attention parameters """ @@ -236,13 +235,15 @@ def get_sparse_params(self, sample, device): ) if self.transformer.config.attention_type == "nabla": sta_mask = self.fast_sta_nabla( - T, H // 8, W // 8, - self.transformer.config.attention_wT, - self.transformer.config.attention_wH, - self.transformer.config.attention_wW, - device=device + T, + H // 8, + W // 8, + self.transformer.config.attention_wT, + self.transformer.config.attention_wH, + self.transformer.config.attention_wW, + device=device, ) - + sparse_params = { "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), "attention_type": self.transformer.config.attention_type, @@ -269,17 +270,17 @@ def _encode_prompt_qwen( ): """ Encode prompt using Qwen2.5-VL text encoder. - - This method processes the input prompt through the Qwen2.5-VL model to generate - text embeddings suitable for video generation. - + + This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for + video generation. + Args: prompt (Union[str, List[str]]): Input prompt or list of prompts device (torch.device): Device to run encoding on num_videos_per_prompt (int): Number of videos to generate per prompt max_sequence_length (int): Maximum sequence length for tokenization dtype (torch.dtype): Data type for embeddings - + Returns: Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths """ @@ -287,7 +288,7 @@ def _encode_prompt_qwen( dtype = dtype or self.text_encoder.dtype full_texts = [self.prompt_template.format(p) for p in prompt] - + inputs = self.tokenizer( text=full_texts, images=None, @@ -302,13 +303,12 @@ def _encode_prompt_qwen( input_ids=inputs["input_ids"], return_dict=True, output_hidden_states=True, - )["hidden_states"][-1][:, self.prompt_template_encode_start_idx:] + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :] - - attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:] + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) - + return embeds.to(dtype), cu_seqlens def _encode_prompt_clip( @@ -319,16 +319,16 @@ def _encode_prompt_clip( ): """ Encode prompt using CLIP text encoder. - - This method processes the input prompt through the CLIP model to generate - pooled embeddings that capture semantic information. - + + This method processes the input prompt through the CLIP model to generate pooled embeddings that capture + semantic information. + Args: prompt (Union[str, List[str]]): Input prompt or list of prompts device (torch.device): Device to run encoding on num_videos_per_prompt (int): Number of videos to generate per prompt dtype (torch.dtype): Data type for embeddings - + Returns: torch.Tensor: Pooled text embeddings from CLIP """ @@ -346,69 +346,8 @@ def _encode_prompt_clip( pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] - return pooled_embed.to(dtype) -# def encode_prompt( -# self, -# prompt: Union[str, List[str]], -# num_videos_per_prompt: int = 1, -# max_sequence_length: int = 512, -# device: Optional[torch.device] = None, -# dtype: Optional[torch.dtype] = None, -# ): -# r""" -# Encodes a single prompt (positive or negative) into text encoder hidden states. - -# This method combines embeddings from both Qwen2.5-VL and CLIP text encoders -# to create comprehensive text representations for video generation. - -# Args: -# prompt (`str` or `List[str]`): -# Prompt to be encoded. -# num_videos_per_prompt (`int`, *optional*, defaults to 1): -# Number of videos to generate per prompt. -# max_sequence_length (`int`, *optional*, defaults to 512): -# Maximum sequence length for text encoding. -# device (`torch.device`, *optional*): -# Torch device. -# dtype (`torch.dtype`, *optional*): -# Torch dtype. - -# Returns: -# Tuple[Dict[str, torch.Tensor], torch.Tensor]: -# - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP) -# - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings -# """ -# device = device or self._execution_device -# dtype = dtype or self.text_encoder.dtype - -# batch_size = len(prompt) - -# prompt = [prompt_clean(p) for p in prompt] - -# # Encode with Qwen2.5-VL -# prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( -# prompt=prompt, -# device=device, -# max_sequence_length=max_sequence_length, -# dtype=dtype, -# ) - -# # Encode with CLIP -# prompt_embeds_clip = self._encode_prompt_clip( -# prompt=prompt, -# device=device, -# dtype=dtype, -# ) -# prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) -# prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1) - -# prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) -# prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) - -# return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens - def encode_prompt( self, prompt: Union[str, List[str]], @@ -420,8 +359,8 @@ def encode_prompt( r""" Encodes a single prompt (positive or negative) into text encoder hidden states. - This method combines embeddings from both Qwen2.5-VL and CLIP text encoders - to create comprehensive text representations for video generation. + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text + representations for video generation. Args: prompt (`str` or `List[str]`): @@ -439,7 +378,8 @@ def encode_prompt( Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim) - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim) - - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * num_videos_per_prompt + 1,) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * + num_videos_per_prompt + 1,) """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -467,12 +407,18 @@ def encode_prompt( # Repeat embeddings for num_videos_per_prompt # Qwen embeddings: repeat sequence for each video, then reshape - prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) # [batch_size, seq_len * num_videos_per_prompt, embed_dim] + prompt_embeds_qwen = prompt_embeds_qwen.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, seq_len * num_videos_per_prompt, embed_dim] # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim] - prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1]) + prompt_embeds_qwen = prompt_embeds_qwen.view( + batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1] + ) # CLIP embeddings: repeat for each video - prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) # [batch_size, num_videos_per_prompt, clip_embed_dim] + prompt_embeds_clip = prompt_embeds_clip.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, num_videos_per_prompt, clip_embed_dim] # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim] prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) @@ -480,11 +426,15 @@ def encode_prompt( # Original cu_seqlens: [0, len1, len1+len2, ...] # Need to repeat the differences and reconstruct for repeated prompts # Original differences (lengths) for each prompt in the batch - original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] + original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] # Repeat the lengths for num_videos_per_prompt - repeated_lengths = original_lengths.repeat_interleave(num_videos_per_prompt) # [len1, len1, ..., len2, len2, ...] + repeated_lengths = original_lengths.repeat_interleave( + num_videos_per_prompt + ) # [len1, len1, ..., len2, len2, ...] # Reconstruct the cumulative lengths - repeated_cu_seqlens = torch.cat([torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]) + repeated_cu_seqlens = torch.cat( + [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)] + ) return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens @@ -509,7 +459,7 @@ def check_inputs( prompt: Input prompt negative_prompt: Negative prompt for guidance height: Video height - width: Video width + width: Video width prompt_embeds_qwen: Pre-computed Qwen prompt embeddings prompt_embeds_clip: Pre-computed CLIP prompt embeddings negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings @@ -535,16 +485,24 @@ def check_inputs( if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: raise ValueError( - f"If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " - f"all three must be provided." + "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " + "all three must be provided." ) # Check for consistency within negative prompt embeddings and sequence lengths - if negative_prompt_embeds_qwen is not None or negative_prompt_embeds_clip is not None or negative_prompt_cu_seqlens is not None: - if negative_prompt_embeds_qwen is None or negative_prompt_embeds_clip is None or negative_prompt_cu_seqlens is None: + if ( + negative_prompt_embeds_qwen is not None + or negative_prompt_embeds_clip is not None + or negative_prompt_cu_seqlens is not None + ): + if ( + negative_prompt_embeds_qwen is None + or negative_prompt_embeds_clip is None + or negative_prompt_cu_seqlens is None + ): raise ValueError( - f"If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " - f"all three must be provided." + "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " + "all three must be provided." ) # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) @@ -575,21 +533,20 @@ def prepare_latents( ) -> torch.Tensor: """ Prepare initial latent variables for video generation. - - This method creates random noise latents or uses provided latents as starting point - for the denoising process. - + + This method creates random noise latents or uses provided latents as starting point for the denoising process. + Args: batch_size (int): Number of videos to generate num_channels_latents (int): Number of channels in latent space height (int): Height of generated video - width (int): Width of generated video + width (int): Width of generated video num_frames (int): Number of frames in video dtype (torch.dtype): Data type for latents device (torch.device): Device to create latents on generator (torch.Generator): Random number generator latents (torch.Tensor): Pre-existing latents to use - + Returns: torch.Tensor: Prepared latent tensor """ @@ -611,14 +568,20 @@ def prepare_latents( ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - + if self.transformer.visual_cond: # For visual conditioning, concatenate with zeros and mask visual_cond = torch.zeros_like(latents) visual_cond_mask = torch.zeros( - [batch_size, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, 1], - dtype=latents.dtype, - device=latents.device + [ + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + 1, + ], + dtype=latents.dtype, + device=latents.device, ) latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1) @@ -715,13 +678,13 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. max_sequence_length (`int`, defaults to `512`): The maximum sequence length for text encoding. - + Examples: - + Returns: [`~KandinskyPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated images. + If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -761,17 +724,16 @@ def __call__( elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] - - # 3. Encode input prompt - if prompt_embeds_qwen is None: - prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( - prompt=prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) + batch_size = prompt_embeds_qwen.shape[0] + # 3. Encode input prompt + if prompt_embeds_qwen is None: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) if self.do_classifier_free_guidance: if negative_prompt is None: @@ -785,12 +747,12 @@ def __call__( ) if negative_prompt_embeds_qwen is None: - negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt( - prompt=negative_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt( + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -817,15 +779,15 @@ def __call__( torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), ] - + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) - + negative_text_rope_pos = ( torch.arange(negative_cu_seqlens.diff().max().item(), device=device) if negative_cu_seqlens is not None else None ) - + # 7. Sparse Params for efficient attention sparse_params = self.get_sparse_params(latents, device) @@ -839,8 +801,8 @@ def __call__( continue timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) - - # Predict noise residual + + # Predict noise residual pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_qwen.to(dtype), @@ -848,12 +810,12 @@ def __call__( timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, text_rope_pos=text_rope_pos, - scale_factor=(1, 2, 2), + scale_factor=(1, 2, 2), sparse_params=sparse_params, - return_dict=True + return_dict=True, ).sample - if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None: + if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None: uncond_pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), @@ -863,12 +825,10 @@ def __call__( text_rope_pos=negative_text_rope_pos, scale_factor=(1, 2, 2), sparse_params=sparse_params, - return_dict=True + return_dict=True, ).sample - pred_velocity = uncond_pred_velocity + guidance_scale * ( - pred_velocity - uncond_pred_velocity - ) + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) # Compute previous sample using the scheduler latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False @@ -881,8 +841,14 @@ def __call__( callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds_dict = callback_outputs.pop("prompt_embeds", prompt_embeds_dict) - negative_prompt_embeds_dict = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds_dict) + prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen) + prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip) + negative_prompt_embeds_qwen = callback_outputs.pop( + "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen + ) + negative_prompt_embeds_clip = callback_outputs.pop( + "negative_prompt_embeds_clip", negative_prompt_embeds_clip + ) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -907,13 +873,13 @@ def __call__( ) video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] video = video.reshape( - batch_size * num_videos_per_prompt, - num_channels_latents, - (num_frames - 1) // self.vae_scale_factor_temporal + 1, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial + batch_size * num_videos_per_prompt, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, ) - + # Normalize and decode through VAE video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6e7d22797902..5d62709c28fd 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -918,6 +918,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Kandinsky5Transformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LatteTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9ed625045261..3244ef12ef87 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1247,6 +1247,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Kandinsky5T2VPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class KandinskyCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 1746f6d426dd37541dec98a9c338e0465ced3ead Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 17 Oct 2025 17:22:58 -1000 Subject: [PATCH 67/77] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: Charles --- src/diffusers/models/transformers/transformer_kandinsky.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index a338922583ca..86032f5462d1 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -518,7 +518,10 @@ class Kandinsky5Transformer3DModel( """ A 3D Diffusion Transformer model for video-like data. """ - +_repeated_blocks = [ + "Kandinsky5TransformerEncoderBlock", + "Kandinsky5TransformerDecoderBlock", +] _supports_gradient_checkpointing = True @register_to_config From 5bb1657f9efb11d50d3c19cbe367e8086e15623a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 18 Oct 2025 05:25:17 +0200 Subject: [PATCH 68/77] more --- .../models/transformers/transformer_kandinsky.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 86032f5462d1..d4ba92acaf6e 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -518,10 +518,11 @@ class Kandinsky5Transformer3DModel( """ A 3D Diffusion Transformer model for video-like data. """ -_repeated_blocks = [ - "Kandinsky5TransformerEncoderBlock", - "Kandinsky5TransformerDecoderBlock", -] + + _repeated_blocks = [ + "Kandinsky5TransformerEncoderBlock", + "Kandinsky5TransformerDecoderBlock", + ] _supports_gradient_checkpointing = True @register_to_config From a26300f7335613ae8eaf1ee082038de63dbddfa7 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 17 Oct 2025 17:32:19 -1000 Subject: [PATCH 69/77] Apply suggestions from code review --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3eb706f238ad..a1122a82565e 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -618,7 +618,6 @@ def __call__( num_frames: int = 121, num_inference_steps: int = 50, guidance_scale: float = 5.0, - scheduler_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -656,8 +655,6 @@ def __call__( The number of denoising steps. guidance_scale (`float`, defaults to `5.0`): Guidance scale as defined in classifier-free guidance. - scheduler_scale (`float`, defaults to `10.0`): - Scale factor for the custom flow matching scheduler. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): From ecbe522399e61b61b2ff26658bd5090d849bb190 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 18 Oct 2025 05:37:42 +0200 Subject: [PATCH 70/77] add lora loader doc --- docs/source/en/api/loaders/lora.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index b1d1ffb63423..8e0326e0c334 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -107,6 +107,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin +## KandinskyLoraLoaderMixin +[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin + ## LoraBaseMixin [[autodoc]] loaders.lora_base.LoraBaseMixin \ No newline at end of file From b35445c65ab61f3d0e63b18967ca730757b28ca5 Mon Sep 17 00:00:00 2001 From: leffff Date: Tue, 21 Oct 2025 10:39:17 +0000 Subject: [PATCH 71/77] add compiled Nabla Attention --- .../transformers/transformer_kandinsky.py | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index d4ba92acaf6e..409238cb4ab1 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -281,6 +281,19 @@ class Kandinsky5AttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) + def compiled_flex_attn(self, query, key, value, attn_mask, backend, parallel_config): + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask, + backend=backend, + parallel_config=parallel_config, + ) + + return hidden_states def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None): # query, key, value = self.get_qkv(x) @@ -324,17 +337,28 @@ def apply_rotary(x, rope): sparse_params["sta_mask"], thr=sparse_params["P"], ) + + hidden_states = self.compiled_flex_attn( + query, + key, + value, + attn_mask=attn_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config + ) + else: attn_mask = None + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attn_mask, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) hidden_states = hidden_states.flatten(-2, -1) attn_out = attn.out_layer(hidden_states) From 54e77574f95739df4df75fdb5c61d121d1784be5 Mon Sep 17 00:00:00 2001 From: leffff Date: Wed, 22 Oct 2025 11:25:34 +0000 Subject: [PATCH 72/77] all needed changes for 10 sec models are added! --- .../transformers/transformer_kandinsky.py | 38 ++++--------------- 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 409238cb4ab1..cca211e5ed70 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -281,19 +281,6 @@ class Kandinsky5AttnProcessor: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") - - @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) - def compiled_flex_attn(self, query, key, value, attn_mask, backend, parallel_config): - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attn_mask, - backend=backend, - parallel_config=parallel_config, - ) - - return hidden_states def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None): # query, key, value = self.get_qkv(x) @@ -338,26 +325,17 @@ def apply_rotary(x, rope): thr=sparse_params["P"], ) - hidden_states = self.compiled_flex_attn( - query, - key, - value, - attn_mask=attn_mask, - backend=self._attention_backend, - parallel_config=self._parallel_config - ) - else: attn_mask = None - hidden_states = dispatch_attention_fn( - query, - key, - value, - attn_mask=attn_mask, - backend=self._attention_backend, - parallel_config=self._parallel_config, - ) + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) hidden_states = hidden_states.flatten(-2, -1) From 25f2e9cc03a7b5678fe739678d83f8552dc42464 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 23 Oct 2025 15:09:33 +0000 Subject: [PATCH 73/77] add docs --- docs/source/en/api/pipelines/kandinsky_v5.md | 109 +++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 docs/source/en/api/pipelines/kandinsky_v5.md diff --git a/docs/source/en/api/pipelines/kandinsky_v5.md b/docs/source/en/api/pipelines/kandinsky_v5.md new file mode 100644 index 000000000000..c3816a7520d2 --- /dev/null +++ b/docs/source/en/api/pipelines/kandinsky_v5.md @@ -0,0 +1,109 @@ + + +# Kandinsky 5.0 + +Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov + + +Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem. + +The model introduces several key innovations: +- **Latent diffusion pipeline** with **Flow Matching** for improved training stability +- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings +- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding +- **HunyuanVideo 3D VAE** for efficient video encoding and decoding +- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing + +The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5). + +> [!TIP] +> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants. + +## Available Models + +Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases: + +| Model Type | Description | Use Cases | +|------------|-------------|-----------| +| **SFT** | Supervised Fine-Tuned model | Highest generation quality | +| **no-CFG** | Classifier-Free Guidance distilled | 2× faster inference | +| **Distilled** | Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss | +| **Pretrain** | Base pretrained model | Research and fine-tuning | + +All models are available in 5-second and 10-second video generation versions. + +## Kandinsky5T2VPipeline + +[[autodoc]] Kandinsky5T2VPipeline + - all + - __call__ + +## Usage Examples + +### Basic Text-to-Video Generation + +```python +import torch +from diffusers import Kandinsky5T2VPipeline +from diffusers.utils import export_to_video + +# Load the pipeline +model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers" +pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +# Generate video +prompt = "A cat and a dog baking a cake together in a kitchen." +negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + +output = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=512, + width=768, + num_frames=121, # ~5 seconds at 24fps + num_inference_steps=50, + guidance_scale=5.0, +).frames[0] + +export_to_video(output, "output.mp4", fps=24, quality=9) +``` + + +### Using Different Model Variants +```python +# For faster generation with distilled model +model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers" +pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) +pipe = pipe.to("cuda") + +# Generate with fewer steps +output = pipe( + prompt="A beautiful sunset over mountains", + num_inference_steps=16, # Only 16 steps needed for distilled model + guidance_scale=1.0, +).frames[0] +``` + +## Citation +```bibtex +@misc{kandinsky2025, + author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and + Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and + Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and + Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and + Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and + Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and + Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov}, + title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation}, + howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}}, + year = 2025 +} +``` \ No newline at end of file From 3bbc2329b9c5b79589fc6619dabd89625ff63f68 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 23 Oct 2025 17:44:03 +0000 Subject: [PATCH 74/77] Apply style fixes --- src/diffusers/models/transformers/transformer_kandinsky.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index cca211e5ed70..316e79da4fd6 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -324,10 +324,10 @@ def apply_rotary(x, rope): sparse_params["sta_mask"], thr=sparse_params["P"], ) - + else: attn_mask = None - + hidden_states = dispatch_attention_fn( query, key, From dd6bf3982aa8991a2c74c4d44250e341a9b20c55 Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 24 Oct 2025 12:09:00 +0000 Subject: [PATCH 75/77] update docs --- docs/source/en/api/pipelines/kandinsky_v5.md | 60 ++++++++++++++++---- 1 file changed, 50 insertions(+), 10 deletions(-) diff --git a/docs/source/en/api/pipelines/kandinsky_v5.md b/docs/source/en/api/pipelines/kandinsky_v5.md index c3816a7520d2..cb1c119f8099 100644 --- a/docs/source/en/api/pipelines/kandinsky_v5.md +++ b/docs/source/en/api/pipelines/kandinsky_v5.md @@ -30,12 +30,16 @@ The original codebase can be found at [ai-forever/Kandinsky-5](https://github.co Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases: -| Model Type | Description | Use Cases | +| model_id | Description | Use Cases | |------------|-------------|-----------| -| **SFT** | Supervised Fine-Tuned model | Highest generation quality | -| **no-CFG** | Classifier-Free Guidance distilled | 2× faster inference | -| **Distilled** | Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss | -| **Pretrain** | Base pretrained model | Research and fine-tuning | +| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality | +| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality | +| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference | +| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference | +| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss | +| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss | +| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning | +| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning | All models are available in 5-second and 10-second video generation versions. @@ -76,22 +80,58 @@ output = pipe( export_to_video(output, "output.mp4", fps=24, quality=9) ``` +### 10 second Models +**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation: + +```python +pipe = Kandinsky5T2VPipeline.from_pretrained( + "ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers", + torch_dtype=torch.bfloat16 +) +pipe = pipe.to("cuda") + +pipe.transformer.set_attention_backend( + "flex" +) # <--- Sett attention bakend to Flex +pipe.transformer.compile( + mode="max-autotune-no-cudagraphs", + dynamic=True +) # <--- Compile with max-autotune-no-cudagraphs + +prompt = "A cat and a dog baking a cake together in a kitchen." +negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + +output = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + height=512, + width=768, + num_frames=241, + num_inference_steps=50, + guidance_scale=5.0, +).frames[0] + +export_to_video(output, "output.mp4", fps=24, quality=9) +``` + +### Diffusion Distilled model +**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```): -### Using Different Model Variants ```python -# For faster generation with distilled model model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers" pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) pipe = pipe.to("cuda") -# Generate with fewer steps output = pipe( prompt="A beautiful sunset over mountains", - num_inference_steps=16, # Only 16 steps needed for distilled model - guidance_scale=1.0, + num_inference_steps=16, # <--- Model is distilled in 16 steps + guidance_scale=1.0, # <--- no CFG ).frames[0] + +export_to_video(output, "output.mp4", fps=24, quality=9) ``` + ## Citation ```bibtex @misc{kandinsky2025, From 5fb528bfc1372c7bb8b597d4a9a919990c6aaacc Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 24 Oct 2025 21:43:55 +0000 Subject: [PATCH 76/77] add kandinsky5 to toctree --- docs/source/en/_toctree.yml | 2 ++ docs/source/en/api/pipelines/{kandinsky_v5.md => kandinsky5.md} | 0 2 files changed, 2 insertions(+) rename docs/source/en/api/pipelines/{kandinsky_v5.md => kandinsky5.md} (100%) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 540e99a2c609..44870f680eac 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -525,6 +525,8 @@ title: Kandinsky 2.2 - local: api/pipelines/kandinsky3 title: Kandinsky 3 + - local: api/pipelines/kandinsky5 + title: Kandinsky 5 - local: api/pipelines/kolors title: Kolors - local: api/pipelines/latent_consistency_models diff --git a/docs/source/en/api/pipelines/kandinsky_v5.md b/docs/source/en/api/pipelines/kandinsky5.md similarity index 100% rename from docs/source/en/api/pipelines/kandinsky_v5.md rename to docs/source/en/api/pipelines/kandinsky5.md From d2a206ea16000f913ad16d2ca9063d7ba906655e Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 27 Oct 2025 12:56:42 +0000 Subject: [PATCH 77/77] add tests --- tests/pipelines/kandinsky5/test_kandinsky5.py | 361 ++++++++++++++++++ 1 file changed, 361 insertions(+) create mode 100644 tests/pipelines/kandinsky5/test_kandinsky5.py diff --git a/tests/pipelines/kandinsky5/test_kandinsky5.py b/tests/pipelines/kandinsky5/test_kandinsky5.py new file mode 100644 index 000000000000..68aac6a659a2 --- /dev/null +++ b/tests/pipelines/kandinsky5/test_kandinsky5.py @@ -0,0 +1,361 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import tempfile +import unittest + +import numpy as np +import torch +from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor + +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + Kandinsky5T2VPipeline, + Kandinsky5Transformer3DModel, +) + +from ...testing_utils import ( + backend_empty_cache, + enable_full_determinism, + require_torch_accelerator, + slow, + torch_device, +) +from ..pipeline_params import TEXT_TO_VIDEO_BATCH_PARAMS, TEXT_TO_VIDEO_VIDEO_PARAMS, TEXT_TO_VIDEO_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class Kandinsky5T2VPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Kandinsky5T2VPipeline + params = TEXT_TO_VIDEO_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_VIDEO_BATCH_PARAMS + image_params = TEXT_TO_VIDEO_VIDEO_PARAMS + image_latents_params = TEXT_TO_VIDEO_VIDEO_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + "max_sequence_length", + ] + ) + test_xformers_attention = False + supports_dduf = False + + def get_dummy_components(self): + torch.manual_seed(0) + vae = AutoencoderKLHunyuanVideo( + in_channels=16, + out_channels=16, + spatial_compression_ratio=8, + temporal_compression_ratio=4, + base_channels=32, + channel_multipliers=[1, 2, 4], + num_res_blocks=2, + ) + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0) + + # Dummy Qwen2.5-VL model + text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-Qwen2.5-VL") + tokenizer = Qwen2VLProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2.5-VL") + + # Dummy CLIP model + text_encoder_2 = CLIPTextModel.from_pretrained("hf-internal-testing/tiny-random-clip") + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + torch.manual_seed(0) + transformer = Kandinsky5Transformer3DModel( + in_visual_dim=16, + in_text_dim=32, # Match tiny Qwen2.5-VL hidden size + in_text_dim2=32, # Match tiny CLIP hidden size + time_dim=32, + out_visual_dim=16, + patch_size=(1, 2, 2), + model_dim=64, + ff_dim=128, + num_text_blocks=1, + num_visual_blocks=1, + axes_dims=(8, 8, 8), + visual_cond=False, + ) + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder.eval(), + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2.eval(), + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A cat dancing", + "negative_prompt": "blurry, low quality", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "num_frames": 5, + "max_sequence_length": 16, + "output_type": "pt", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + # Check video shape: (batch, channels, frames, height, width) + expected_shape = (1, 3, 5, 32, 32) + self.assertEqual(generated_video.shape, expected_shape) + + # Check specific values + expected_slice = torch.tensor([ + 0.5015, 0.4929, 0.4990, 0.4985, 0.4980, 0.5044, 0.5044, 0.5005, + 0.4995, 0.4961, 0.4961, 0.4966, 0.4980, 0.4985, 0.4985, 0.4990 + ]) + + generated_slice = generated_video.flatten() + # Take first 8 and last 8 values for comparison + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + + def test_inference_batch_consistent(self): + # Override to test batch consistency with video + super().test_inference_batch_consistent(batch_sizes=[1, 2]) + + def test_inference_batch_single_identical(self): + # Override to test batch single identical with video + super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-3) + + @unittest.skip("Kandinsky5T2VPipeline does not support attention slicing") + def test_attention_slicing_forward_pass(self): + pass + + @unittest.skip("Kandinsky5T2VPipeline does not support xformers") + def test_xformers_attention_forwardGenerator_pass(self): + pass + + def test_save_load_optional_components(self): + # Kandinsky5T2VPipeline doesn't have optional components like transformer_2 + # but we can test saving/loading with the current components + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs).frames + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs).frames + + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() + self.assertLess(max_diff, 1e-4) + + def test_prompt_embeds(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + pipe.to(torch_device) + + # Test without prompt (should raise error) + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("prompt") + with self.assertRaises(ValueError): + pipe(**inputs) + + # Test with prompt embeddings + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + negative_prompt = inputs.pop("negative_prompt") + + # Encode prompts to get embeddings + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe.encode_prompt( + prompt, device=torch_device, max_sequence_length=inputs["max_sequence_length"] + ) + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = pipe.encode_prompt( + negative_prompt, device=torch_device, max_sequence_length=inputs["max_sequence_length"] + ) + + inputs.update({ + "prompt_embeds_qwen": prompt_embeds_qwen, + "prompt_embeds_clip": prompt_embeds_clip, + "prompt_cu_seqlens": prompt_cu_seqlens, + "negative_prompt_embeds_qwen": negative_prompt_embeds_qwen, + "negative_prompt_embeds_clip": negative_prompt_embeds_clip, + "negative_prompt_cu_seqlens": negative_prompt_cu_seqlens, + }) + + output_with_embeds = pipe(**inputs).frames + + # Compare with output from prompt strings + inputs_with_prompt = self.get_dummy_inputs(torch_device) + output_with_prompt = pipe(**inputs_with_prompt).frames + + # Should be similar but not exactly the same due to different encoding + self.assertEqual(output_with_embeds.shape, output_with_prompt.shape) + + +@slow +@require_torch_accelerator +class Kandinsky5T2VPipelineIntegrationTests(unittest.TestCase): + prompt = "A cat dancing in a kitchen with colorful lights" + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def test_kandinsky_5_t2v(self): + # This is a slow integration test that would use actual pretrained models + # For now, we'll skip it since we don't have tiny models for integration testing + pass + + def test_kandinsky_5_t2v_different_sizes(self): + # Test different video sizes + pipe = Kandinsky5T2VPipeline.from_pretrained( + "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers", torch_dtype=torch.float16 + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Test different resolutions + test_cases = [ + (256, 256, 17), # height, width, frames + (320, 512, 25), + (512, 320, 33), + ] + + for height, width, num_frames in test_cases: + with self.subTest(height=height, width=width, num_frames=num_frames): + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=self.prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=2, # Few steps for quick test + generator=generator, + output_type="np", + ).frames + + self.assertEqual(output.shape, (1, 3, num_frames, height, width)) + + def test_kandinsky_5_t2v_negative_prompt(self): + pipe = Kandinsky5T2VPipeline.from_pretrained( + "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers", torch_dtype=torch.float16 + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Test with negative prompt + generator = torch.Generator(device=torch_device).manual_seed(0) + output_with_negative = pipe( + prompt=self.prompt, + negative_prompt="blurry, low quality, distorted", + height=256, + width=256, + num_frames=17, + num_inference_steps=2, + generator=generator, + output_type="np", + ).frames + + # Test without negative prompt + generator = torch.Generator(device=torch_device).manual_seed(0) + output_without_negative = pipe( + prompt=self.prompt, + height=256, + width=256, + num_frames=17, + num_inference_steps=2, + generator=generator, + output_type="np", + ).frames + + # Outputs should be different + max_diff = np.abs(output_with_negative - output_without_negative).max() + self.assertGreater(max_diff, 1e-3) # Should be noticeably different + + def test_kandinsky_5_t2v_guidance_scale(self): + pipe = Kandinsky5T2VPipeline.from_pretrained( + "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers", torch_dtype=torch.float16 + ) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Test different guidance scales + guidance_scales = [1.0, 3.0, 7.0] + + outputs = [] + for guidance_scale in guidance_scales: + generator = torch.Generator(device=torch_device).manual_seed(0) + output = pipe( + prompt=self.prompt, + height=256, + width=256, + num_frames=17, + num_inference_steps=2, + guidance_scale=guidance_scale, + generator=generator, + output_type="np", + ).frames + outputs.append(output) + + # All outputs should have same shape but different content + for i, output in enumerate(outputs): + self.assertEqual(output.shape, (1, 3, 17, 256, 256)) + + # Check they are different + for i in range(len(outputs) - 1): + max_diff = np.abs(outputs[i] - outputs[i + 1]).max() + self.assertGreater(max_diff, 1e-3) \ No newline at end of file