diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 06e05e0206f1..52ab289effec 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -274,6 +274,8 @@ title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d title: LuminaNextDiT2DModel + - local: api/models/ltx_video_transformer3d + title: LTXVideoTransformer3DModel - local: api/models/mochi_transformer3d title: MochiTransformer3DModel - local: api/models/pixart_transformer2d @@ -312,6 +314,8 @@ title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX + - local: api/models/autoencoderkl_ltx_video + title: AutoencoderKLLTXVideo - local: api/models/autoencoderkl_mochi title: AutoencoderKLMochi - local: api/models/asymmetricautoencoderkl @@ -408,6 +412,8 @@ title: Latte - local: api/pipelines/ledits_pp title: LEDITS++ + - local: api/pipelines/ltx_video + title: LTX - local: api/pipelines/lumina title: Lumina-T2X - local: api/pipelines/marigold diff --git a/docs/source/en/api/models/autoencoderkl_ltx_video.md b/docs/source/en/api/models/autoencoderkl_ltx_video.md new file mode 100644 index 000000000000..694b5ace6fdf --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_ltx_video.md @@ -0,0 +1,37 @@ + + +# AutoencoderKLLTXVideo + +The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLLTXVideo + +vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLLTXVideo + +[[autodoc]] AutoencoderKLLTXVideo + - decode + - encode + - all + +## AutoencoderKLOutput + +[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/models/ltx_video_transformer3d.md b/docs/source/en/api/models/ltx_video_transformer3d.md new file mode 100644 index 000000000000..8a60bc0432c6 --- /dev/null +++ b/docs/source/en/api/models/ltx_video_transformer3d.md @@ -0,0 +1,30 @@ + + +# LTXVideoTransformer3DModel + +A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks. + +The model can be loaded with the following code snippet. + +```python +from diffusers import LTXVideoTransformer3DModel + +transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## LTXVideoTransformer3DModel + +[[autodoc]] LTXVideoTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/ltx_video.md b/docs/source/en/api/pipelines/ltx_video.md new file mode 100644 index 000000000000..162e1334ce9a --- /dev/null +++ b/docs/source/en/api/pipelines/ltx_video.md @@ -0,0 +1,68 @@ + + +# LTX + +[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases. + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## Loading Single Files + +Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. + +```python +import torch +from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel + +single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" +transformer = LTXVideoTransformer3DModel.from_single_file(single_file_url, torch_dtype=torch.bfloat16) +vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16) +pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16) + +# ... inference code ... +``` + +Alternatively, the pipeline can be used to load the weights with [~FromSingleFileMixin.from_single_file`]. + +```python +import torch +from diffusers import LTXImageToVideoPipeline +from transformers import T5EncoderModel, T5Tokenizer + +single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors" +text_encoder = T5EncoderModel.from_pretrained("Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16) +tokenizer = T5Tokenizer.from_pretrained("Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16) +pipe = LTXImageToVideoPipeline.from_single_file(single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16) +``` + +## LTXPipeline + +[[autodoc]] LTXPipeline + - all + - __call__ + +## LTXImageToVideoPipeline + +[[autodoc]] LTXImageToVideoPipeline + - all + - __call__ + +## LTXPipelineOutput + +[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py new file mode 100644 index 000000000000..f4398a2e687c --- /dev/null +++ b/scripts/convert_ltx_to_diffusers.py @@ -0,0 +1,209 @@ +import argparse +from typing import Any, Dict + +import torch +from safetensors.torch import load_file +from transformers import T5EncoderModel, T5Tokenizer + +from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel + + +def remove_keys_(key: str, state_dict: Dict[str, Any]): + state_dict.pop(key) + + +TOKENIZER_MAX_LENGTH = 128 + +TRANSFORMER_KEYS_RENAME_DICT = { + "patchify_proj": "proj_in", + "adaln_single": "time_embed", + "q_norm": "norm_q", + "k_norm": "norm_k", +} + +TRANSFORMER_SPECIAL_KEYS_REMAP = {} + +VAE_KEYS_RENAME_DICT = { + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0", + "up_blocks.2": "up_blocks.1.upsamplers.0", + "up_blocks.3": "up_blocks.1", + "up_blocks.4": "up_blocks.2.conv_in", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.conv_in", + "up_blocks.8": "up_blocks.3.upsamplers.0", + "up_blocks.9": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.0.conv_out", + "down_blocks.3": "down_blocks.1", + "down_blocks.4": "down_blocks.1.downsamplers.0", + "down_blocks.5": "down_blocks.1.conv_out", + "down_blocks.6": "down_blocks.2", + "down_blocks.7": "down_blocks.2.downsamplers.0", + "down_blocks.8": "down_blocks.3", + "down_blocks.9": "mid_block", + # common + "conv_shortcut": "conv_shortcut.conv", + "res_blocks": "resnets", + "norm3.norm": "norm3", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", +} + +VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_, + "per_channel_statistics.mean-of-means": remove_keys_, + "per_channel_statistics.mean-of-stds": remove_keys_, +} + + +def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]: + state_dict = saved_dict + if "model" in saved_dict.keys(): + state_dict = state_dict["model"] + if "module" in saved_dict.keys(): + state_dict = state_dict["module"] + if "state_dict" in saved_dict.keys(): + state_dict = state_dict["state_dict"] + return state_dict + + +def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: + state_dict[new_key] = state_dict.pop(old_key) + + +def convert_transformer( + ckpt_path: str, + dtype: torch.dtype, +): + PREFIX_KEY = "" + + original_state_dict = get_state_dict(load_file(ckpt_path)) + transformer = LTXVideoTransformer3DModel().to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[len(PREFIX_KEY) :] + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + transformer.load_state_dict(original_state_dict, strict=True) + return transformer + + +def convert_vae(ckpt_path: str, dtype: torch.dtype): + original_state_dict = get_state_dict(load_file(ckpt_path)) + vae = AutoencoderKLLTXVideo().to(dtype=dtype) + + for key in list(original_state_dict.keys()): + new_key = key[:] + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + update_state_dict_inplace(original_state_dict, key, new_key) + + for key in list(original_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, original_state_dict) + + vae.load_state_dict(original_state_dict, strict=True) + return vae + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint" + ) + parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint") + parser.add_argument( + "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" + ) + parser.add_argument( + "--typecast_text_encoder", + action="store_true", + default=False, + help="Whether or not to apply fp16/bf16 precision to text_encoder", + ) + parser.add_argument("--save_pipeline", action="store_true") + parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved") + parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.") + return parser.parse_args() + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + +VARIANT_MAPPING = { + "fp32": None, + "fp16": "fp16", + "bf16": "bf16", +} + + +if __name__ == "__main__": + args = get_args() + + transformer = None + dtype = DTYPE_MAPPING[args.dtype] + variant = VARIANT_MAPPING[args.dtype] + + if args.save_pipeline: + assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None + + if args.transformer_ckpt_path is not None: + transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype) + if not args.save_pipeline: + transformer.save_pretrained( + args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant + ) + + if args.vae_ckpt_path is not None: + vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype) + if not args.save_pipeline: + vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant) + + if args.save_pipeline: + text_encoder_id = "google/t5-v1_1-xxl" + tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) + text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) + + if args.typecast_text_encoder: + text_encoder = text_encoder.to(dtype=dtype) + + # Apparently, the conversion does not work anymore without this :shrug: + for param in text_encoder.parameters(): + param.data = param.data.contiguous() + + scheduler = FlowMatchEulerDiscreteScheduler( + use_dynamic_shifting=True, + base_shift=0.95, + max_shift=2.05, + base_image_seq_len=1024, + max_image_seq_len=4096, + shift_terminal=0.1, + ) + + pipe = LTXPipeline( + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + ) + + pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB") diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2605f02fab04..ae4ef299abb3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -84,6 +84,7 @@ "AutoencoderKL", "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", + "AutoencoderKLLTXVideo", "AutoencoderKLMochi", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", @@ -104,6 +105,7 @@ "I2VGenXLUNet", "Kandinsky3UNet", "LatteTransformer3DModel", + "LTXVideoTransformer3DModel", "LuminaNextDiT2DModel", "MochiTransformer3DModel", "ModelMixin", @@ -317,6 +319,8 @@ "LDMTextToImagePipeline", "LEditsPPPipelineStableDiffusion", "LEditsPPPipelineStableDiffusionXL", + "LTXImageToVideoPipeline", + "LTXPipeline", "LuminaText2ImgPipeline", "MarigoldDepthPipeline", "MarigoldNormalsPipeline", @@ -582,6 +586,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -602,6 +607,7 @@ I2VGenXLUNet, Kandinsky3UNet, LatteTransformer3DModel, + LTXVideoTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, ModelMixin, @@ -794,6 +800,8 @@ LDMTextToImagePipeline, LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, + LTXImageToVideoPipeline, + LTXPipeline, LuminaText2ImgPipeline, MarigoldDepthPipeline, MarigoldNormalsPipeline, diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index d1fe55840aff..78ce47273d8f 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -28,6 +28,8 @@ convert_flux_transformer_checkpoint_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, + convert_ltx_transformer_checkpoint_to_diffusers, + convert_ltx_vae_checkpoint_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, create_controlnet_diffusers_config_from_ldm, @@ -83,6 +85,14 @@ "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "LTXVideoTransformer3DModel": { + "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, + "AutoencoderKLLTXVideo": { + "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, + "default_subfolder": "vae", + }, "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers}, } diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 8256e38054fe..21ff2841700d 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -92,6 +92,12 @@ "double_blocks.0.img_attn.norm.key_norm.scale", "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", ], + "ltx-video": [ + ( + "model.diffusion_model.patchify_proj.weight", + "model.diffusion_model.transformer_blocks.27.scale_shift_table", + ), + ], "autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias", "autoencoder-dc-sana": "encoder.project_in.conv.bias", } @@ -140,6 +146,7 @@ "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, + "ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"}, "autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"}, "autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"}, "autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"}, @@ -571,6 +578,9 @@ def infer_diffusers_model_type(checkpoint): else: model_type = "flux-schnell" + elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]): + model_type = "ltx-video" + elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint: encoder_key = "encoder.project_in.conv.conv.bias" decoder_key = "decoder.project_in.main.conv.weight" @@ -2223,6 +2233,96 @@ def swap_scale_shift(weight): return converted_state_dict +def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = { + key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key + } + + TRANSFORMER_KEYS_RENAME_DICT = { + "model.diffusion_model.": "", + "patchify_proj": "proj_in", + "adaln_single": "time_embed", + "q_norm": "norm_q", + "k_norm": "norm_k", + } + + TRANSFORMER_SPECIAL_KEYS_REMAP = {} + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + +def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key} + + def remove_keys_(key: str, state_dict): + state_dict.pop(key) + + VAE_KEYS_RENAME_DICT = { + # common + "vae.": "", + # decoder + "up_blocks.0": "mid_block", + "up_blocks.1": "up_blocks.0", + "up_blocks.2": "up_blocks.1.upsamplers.0", + "up_blocks.3": "up_blocks.1", + "up_blocks.4": "up_blocks.2.conv_in", + "up_blocks.5": "up_blocks.2.upsamplers.0", + "up_blocks.6": "up_blocks.2", + "up_blocks.7": "up_blocks.3.conv_in", + "up_blocks.8": "up_blocks.3.upsamplers.0", + "up_blocks.9": "up_blocks.3", + # encoder + "down_blocks.0": "down_blocks.0", + "down_blocks.1": "down_blocks.0.downsamplers.0", + "down_blocks.2": "down_blocks.0.conv_out", + "down_blocks.3": "down_blocks.1", + "down_blocks.4": "down_blocks.1.downsamplers.0", + "down_blocks.5": "down_blocks.1.conv_out", + "down_blocks.6": "down_blocks.2", + "down_blocks.7": "down_blocks.2.downsamplers.0", + "down_blocks.8": "down_blocks.3", + "down_blocks.9": "mid_block", + # common + "conv_shortcut": "conv_shortcut.conv", + "res_blocks": "resnets", + "norm3.norm": "norm3", + "per_channel_statistics.mean-of-means": "latents_mean", + "per_channel_statistics.std-of-means": "latents_std", + } + + VAE_SPECIAL_KEYS_REMAP = { + "per_channel_statistics.channel": remove_keys_, + "per_channel_statistics.mean-of-means": remove_keys_, + "per_channel_statistics.mean-of-stds": remove_keys_, + } + + for key in list(converted_state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + converted_state_dict[new_key] = converted_state_dict.pop(key) + + for key in list(converted_state_dict.keys()): + for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items(): + if special_key not in key: + continue + handler_fn_inplace(key, converted_state_dict) + + return converted_state_dict + + def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 65707e63219d..c8ef85b75229 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -31,6 +31,7 @@ _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] + _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] @@ -65,6 +66,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] + _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] @@ -94,6 +96,7 @@ AutoencoderKL, AutoencoderKLAllegro, AutoencoderKLCogVideoX, + AutoencoderKLLTXVideo, AutoencoderKLMochi, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -127,6 +130,7 @@ FluxTransformer2DModel, HunyuanDiT2DModel, LatteTransformer3DModel, + LTXVideoTransformer3DModel, LuminaNextDiT2DModel, MochiTransformer3DModel, PixArtTransformer2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 945ceb57e769..6e892ec29637 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -199,12 +199,16 @@ def __init__( self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps) elif qk_norm == "layer_norm_across_heads": - # Lumina applys qk norm across all heads + # Lumina applies qk norm across all heads self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps) self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps) elif qk_norm == "rms_norm": self.norm_q = RMSNorm(dim_head, eps=eps) self.norm_k = RMSNorm(dim_head, eps=eps) + elif qk_norm == "rms_norm_across_heads": + # LTX applies qk norm across all heads + self.norm_q = RMSNorm(dim_head * heads, eps=eps) + self.norm_k = RMSNorm(dim_head * kv_heads, eps=eps) elif qk_norm == "l2": self.norm_q = LpNorm(p=2, dim=-1, eps=eps) self.norm_k = LpNorm(p=2, dim=-1, eps=eps) diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 7a36e88f1a36..d08e67c40975 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -3,6 +3,7 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX +from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py new file mode 100644 index 000000000000..ff202b980b95 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -0,0 +1,1145 @@ +# Copyright 2024 The Lightricks team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils.accelerate_utils import apply_forward_hook +from ..activations import get_activation +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..normalization import RMSNorm +from .vae import DecoderOutput, DiagonalGaussianDistribution + + +class LTXCausalConv3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]] = 3, + stride: Union[int, Tuple[int, int, int]] = 1, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups: int = 1, + padding_mode: str = "zeros", + is_causal: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.is_causal = is_causal + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + + dilation = dilation if isinstance(dilation, tuple) else (dilation, 1, 1) + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + height_pad = self.kernel_size[1] // 2 + width_pad = self.kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + self.kernel_size, + stride=stride, + dilation=dilation, + groups=groups, + padding=padding, + padding_mode=padding_mode, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + time_kernel_size = self.kernel_size[0] + + if self.is_causal: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, time_kernel_size - 1, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states], dim=2) + else: + pad_left = hidden_states[:, :, :1, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + pad_right = hidden_states[:, :, -1:, :, :].repeat((1, 1, (time_kernel_size - 1) // 2, 1, 1)) + hidden_states = torch.concatenate([pad_left, hidden_states, pad_right], dim=2) + + hidden_states = self.conv(hidden_states) + return hidden_states + + +class LTXResnetBlock3d(nn.Module): + r""" + A 3D ResNet block used in the LTX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + dropout (`float`, defaults to `0.0`): + Dropout rate. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + elementwise_affine (`bool`, defaults to `False`): + Whether to enable elementwise affinity in the normalization layers. + non_linearity (`str`, defaults to `"swish"`): + Activation function to use. + conv_shortcut (bool, defaults to `False`): + Whether or not to use a convolution shortcut. + """ + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + eps: float = 1e-6, + elementwise_affine: bool = False, + non_linearity: str = "swish", + is_causal: bool = True, + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = RMSNorm(in_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.conv1 = LTXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm2 = RMSNorm(out_channels, eps=1e-8, elementwise_affine=elementwise_affine) + self.dropout = nn.Dropout(dropout) + self.conv2 = LTXCausalConv3d( + in_channels=out_channels, out_channels=out_channels, kernel_size=3, is_causal=is_causal + ) + + self.norm3 = None + self.conv_shortcut = None + if in_channels != out_channels: + self.norm3 = nn.LayerNorm(in_channels, eps=eps, elementwise_affine=True, bias=True) + self.conv_shortcut = LTXCausalConv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, is_causal=is_causal + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + hidden_states = inputs + + hidden_states = self.norm1(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.norm3 is not None: + inputs = self.norm3(inputs.movedim(1, -1)).movedim(-1, 1) + + if self.conv_shortcut is not None: + inputs = self.conv_shortcut(inputs) + + hidden_states = hidden_states + inputs + return hidden_states + + +class LTXUpsampler3d(nn.Module): + def __init__( + self, + in_channels: int, + stride: Union[int, Tuple[int, int, int]] = 1, + is_causal: bool = True, + ) -> None: + super().__init__() + + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + + out_channels = in_channels * stride[0] * stride[1] * stride[2] + + self.conv = LTXCausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + is_causal=is_causal, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + hidden_states = self.conv(hidden_states) + hidden_states = hidden_states.reshape( + batch_size, -1, self.stride[0], self.stride[1], self.stride[2], num_frames, height, width + ) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4).flatten(6, 7).flatten(4, 5).flatten(2, 3) + hidden_states = hidden_states[:, :, self.stride[0] - 1 :] + + return hidden_states + + +class LTXDownBlock3D(nn.Module): + r""" + Down block used in the LTX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + ): + super().__init__() + + out_channels = out_channels or in_channels + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if spatio_temporal_scale: + self.downsamplers = nn.ModuleList( + [ + LTXCausalConv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=(2, 2, 2), + is_causal=is_causal, + ) + ] + ) + + self.conv_out = None + if in_channels != out_channels: + self.conv_out = LTXResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""Forward method of the `LTXDownBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + else: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + if self.conv_out is not None: + hidden_states = self.conv_out(hidden_states) + + return hidden_states + + +# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d +class LTXMidBlock3d(nn.Module): + r""" + A middle block used in the LTX model. + + Args: + in_channels (`int`): + Number of input channels. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + is_causal: bool = True, + ) -> None: + super().__init__() + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXResnetBlock3d( + in_channels=in_channels, + out_channels=in_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""Forward method of the `LTXMidBlock3D` class.""" + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + else: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class LTXUpBlock3d(nn.Module): + r""" + Up block used in the LTX model. + + Args: + in_channels (`int`): + Number of input channels. + out_channels (`int`, *optional*): + Number of output channels. If None, defaults to `in_channels`. + num_layers (`int`, defaults to `1`): + Number of resnet layers. + dropout (`float`, defaults to `0.0`): + Dropout rate. + resnet_eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + resnet_act_fn (`str`, defaults to `"swish"`): + Activation function to use. + spatio_temporal_scale (`bool`, defaults to `True`): + Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension. + Whether or not to downsample across temporal dimension. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + spatio_temporal_scale: bool = True, + is_causal: bool = True, + ): + super().__init__() + + out_channels = out_channels or in_channels + + self.conv_in = None + if in_channels != out_channels: + self.conv_in = LTXResnetBlock3d( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + + self.upsamplers = None + if spatio_temporal_scale: + self.upsamplers = nn.ModuleList([LTXUpsampler3d(out_channels, stride=(2, 2, 2), is_causal=is_causal)]) + + resnets = [] + for _ in range(num_layers): + resnets.append( + LTXResnetBlock3d( + in_channels=out_channels, + out_channels=out_channels, + dropout=dropout, + eps=resnet_eps, + non_linearity=resnet_act_fn, + is_causal=is_causal, + ) + ) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.conv_in is not None: + hidden_states = self.conv_in(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + for i, resnet in enumerate(self.resnets): + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states) + else: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class LTXEncoder3d(nn.Module): + r""" + The `LTXEncoder3D` layer of a variational autoencoder that encodes input video samples to its latent + representation. + + Args: + in_channels (`int`, defaults to 3): + Number of input channels. + out_channels (`int`, defaults to 128): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `True`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 128, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = True, + ): + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.in_channels = in_channels * patch_size**2 + + output_channel = block_out_channels[0] + + self.conv_in = LTXCausalConv3d( + in_channels=self.in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + is_causal=is_causal, + ) + + # down blocks + num_block_out_channels = len(block_out_channels) + self.down_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i] + + down_block = LTXDownBlock3D( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + ) + + self.down_blocks.append(down_block) + + # mid block + self.mid_block = LTXMidBlock3d( + in_channels=output_channel, + num_layers=layers_per_block[-1], + resnet_eps=resnet_norm_eps, + is_causal=is_causal, + ) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXCausalConv3d( + in_channels=output_channel, out_channels=out_channels + 1, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + r"""The forward method of the `LTXEncoder3D` class.""" + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + hidden_states = hidden_states.reshape( + batch_size, num_channels, post_patch_num_frames, p_t, post_patch_height, p, post_patch_width, p + ) + # Thanks for driving me insane with the weird patching order :( + hidden_states = hidden_states.permute(0, 1, 3, 7, 5, 2, 4, 6).flatten(1, 4) + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + for down_block in self.down_blocks: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), hidden_states) + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) + else: + for down_block in self.down_blocks: + hidden_states = down_block(hidden_states) + + hidden_states = self.mid_block(hidden_states) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + last_channel = hidden_states[:, -1:] + last_channel = last_channel.repeat(1, hidden_states.size(1) - 2, 1, 1, 1) + hidden_states = torch.cat([hidden_states, last_channel], dim=1) + + return hidden_states + + +class LTXDecoder3d(nn.Module): + r""" + The `LTXDecoder3d` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, defaults to 128): + Number of latent channels. + out_channels (`int`, defaults to 3): + Number of output channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal upscaling layers or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + is_causal (`bool`, defaults to `False`): + Whether this layer behaves causally (future frames depend only on past frames) or not. + """ + + def __init__( + self, + in_channels: int = 128, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + is_causal: bool = False, + ) -> None: + super().__init__() + + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.out_channels = out_channels * patch_size**2 + + block_out_channels = tuple(reversed(block_out_channels)) + spatio_temporal_scaling = tuple(reversed(spatio_temporal_scaling)) + layers_per_block = tuple(reversed(layers_per_block)) + output_channel = block_out_channels[0] + + self.conv_in = LTXCausalConv3d( + in_channels=in_channels, out_channels=output_channel, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.mid_block = LTXMidBlock3d( + in_channels=output_channel, num_layers=layers_per_block[0], resnet_eps=resnet_norm_eps, is_causal=is_causal + ) + + # up blocks + num_block_out_channels = len(block_out_channels) + self.up_blocks = nn.ModuleList([]) + for i in range(num_block_out_channels): + input_channel = output_channel + output_channel = block_out_channels[i] + + up_block = LTXUpBlock3d( + in_channels=input_channel, + out_channels=output_channel, + num_layers=layers_per_block[i + 1], + resnet_eps=resnet_norm_eps, + spatio_temporal_scale=spatio_temporal_scaling[i], + is_causal=is_causal, + ) + + self.up_blocks.append(up_block) + + # out + self.norm_out = RMSNorm(out_channels, eps=1e-8, elementwise_affine=False) + self.conv_act = nn.SiLU() + self.conv_out = LTXCausalConv3d( + in_channels=output_channel, out_channels=self.out_channels, kernel_size=3, stride=1, is_causal=is_causal + ) + + self.gradient_checkpointing = False + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.conv_in(hidden_states) + + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module): + def create_forward(*inputs): + return module(*inputs) + + return create_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), hidden_states) + + for up_block in self.up_blocks: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states) + else: + hidden_states = self.mid_block(hidden_states) + + for up_block in self.up_blocks: + hidden_states = up_block(hidden_states) + + hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1) + hidden_states = self.conv_act(hidden_states) + hidden_states = self.conv_out(hidden_states) + + p = self.patch_size + p_t = self.patch_size_t + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + hidden_states = hidden_states.reshape(batch_size, -1, p_t, p, p, num_frames, height, width) + hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 4, 7, 3).flatten(6, 7).flatten(4, 5).flatten(2, 3) + + return hidden_states + + +class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in + [LTX](https://huggingface.co/Lightricks/LTX-Video). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Args: + in_channels (`int`, defaults to `3`): + Number of input channels. + out_channels (`int`, defaults to `3`): + Number of output channels. + latent_channels (`int`, defaults to `128`): + Number of latent channels. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + The number of output channels for each block. + spatio_temporal_scaling (`Tuple[bool, ...], defaults to `(True, True, True, False)`: + Whether a block should contain spatio-temporal downscaling or not. + layers_per_block (`Tuple[int, ...]`, defaults to `(4, 3, 3, 3, 4)`): + The number of layers per block. + patch_size (`int`, defaults to `4`): + The size of spatial patches. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches. + resnet_norm_eps (`float`, defaults to `1e-6`): + Epsilon value for ResNet normalization layers. + scaling_factor (`float`, *optional*, defaults to `1.0`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + encoder_causal (`bool`, defaults to `True`): + Whether the encoder should behave causally (future frames depend only on past frames) or not. + decoder_causal (`bool`, defaults to `False`): + Whether the decoder should behave causally (future frames depend only on past frames) or not. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + latent_channels: int = 128, + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False), + layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4), + patch_size: int = 4, + patch_size_t: int = 1, + resnet_norm_eps: float = 1e-6, + scaling_factor: float = 1.0, + encoder_causal: bool = True, + decoder_causal: bool = False, + ) -> None: + super().__init__() + + self.encoder = LTXEncoder3d( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=encoder_causal, + ) + self.decoder = LTXDecoder3d( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + spatio_temporal_scaling=spatio_temporal_scaling, + layers_per_block=layers_per_block, + patch_size=patch_size, + patch_size_t=patch_size_t, + resnet_norm_eps=resnet_norm_eps, + is_causal=decoder_causal, + ) + + latents_mean = torch.zeros((latent_channels,), requires_grad=False) + latents_std = torch.ones((latent_channels,), requires_grad=False) + self.register_buffer("latents_mean", latents_mean, persistent=True) + self.register_buffer("latents_std", latents_std, persistent=True) + + self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling) + self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling) + + # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension + # to perform decoding of a single video latent at a time. + self.use_slicing = False + + # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent + # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the + # intermediate tiles together, the memory requirement can be lowered. + self.use_tiling = False + + # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames + # at a fixed frame batch size (based on `self.num_latent_frames_batch_sizes`), the memory requirement can be lowered. + self.use_framewise_encoding = False + self.use_framewise_decoding = False + + # This can be configured based on the amount of GPU memory available. + # `16` for sample frames and `2` for latent frames are sensible defaults for consumer GPUs. + # Setting it to higher values results in higher memory usage. + self.num_sample_frames_batch_size = 16 + self.num_latent_frames_batch_size = 2 + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 512 + self.tile_sample_min_width = 512 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 448 + self.tile_sample_stride_width = 448 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LTXEncoder3d, LTXDecoder3d)): + module.gradient_checkpointing = value + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[float] = None, + tile_sample_stride_width: Optional[float] = None, + ) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + + Args: + tile_sample_min_height (`int`, *optional*): + The minimum height required for a sample to be separated into tiles across the height dimension. + tile_sample_min_width (`int`, *optional*): + The minimum width required for a sample to be separated into tiles across the width dimension. + tile_sample_stride_height (`int`, *optional*): + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are + no tiling artifacts produced across the height dimension. + tile_sample_stride_width (`int`, *optional*): + The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling + artifacts produced across the width dimension. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + batch_size, num_channels, num_frames, height, width = x.shape + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + if self.use_framewise_encoding: + # TODO(aryan): requires investigation + raise NotImplementedError( + "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " + "quality issues caused by splitting inference across frame dimension. If you believe this " + "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." + ) + else: + enc = self.encoder(x) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + if self.use_framewise_decoding: + # TODO(aryan): requires investigation + raise NotImplementedError( + "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " + "quality issues caused by splitting inference across frame dimension. If you believe this " + "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." + ) + else: + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + r"""Encode a batch of images using a tiled encoder. + + Args: + x (`torch.Tensor`): Input batch of videos. + + Returns: + `torch.Tensor`: + The latent representation of the encoded videos. + """ + batch_size, num_channels, num_frames, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + if self.use_framewise_encoding: + # TODO(aryan): requires investigation + raise NotImplementedError( + "Frame-wise encoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " + "quality issues caused by splitting inference across frame dimension. If you believe this " + "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." + ) + else: + time = self.encoder( + x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + ) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + + batch_size, num_channels, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + if self.use_framewise_decoding: + # TODO(aryan): requires investigation + raise NotImplementedError( + "Frame-wise decoding has not been implemented for AutoencoderKLLTXVideo, at the moment, due to " + "quality issues caused by splitting inference across frame dimension. If you believe this " + "should be possible, please submit a PR to https://github.com/huggingface/diffusers/pulls." + ) + else: + time = self.decoder(z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]) + + row.append(time) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[torch.Tensor, torch.Tensor]: + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z) + if not return_dict: + return (dec,) + return dec diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index a2c087d708a4..fed64d45fbd0 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -17,6 +17,7 @@ from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel + from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py new file mode 100644 index 000000000000..8aa3a1590fb9 --- /dev/null +++ b/src/diffusers/models/transformers/transformer_ltx.py @@ -0,0 +1,449 @@ +# Copyright 2024 The Genmo team and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import Attention +from ..embeddings import PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LTXAttentionProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.norm_q(query) + key = attn.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.to(query.dtype) + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +class LTXRotaryPosEmbed(nn.Module): + def __init__( + self, + dim: int, + base_num_frames: int = 20, + base_height: int = 2048, + base_width: int = 2048, + patch_size: int = 1, + patch_size_t: int = 1, + theta: float = 10000.0, + ) -> None: + super().__init__() + + self.dim = dim + self.base_num_frames = base_num_frames + self.base_height = base_height + self.base_width = base_width + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.theta = theta + + def forward( + self, + hidden_states: torch.Tensor, + num_frames: int, + height: int, + width: int, + rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = hidden_states.size(0) + + # Always compute rope in fp32 + grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device) + grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device) + grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device) + grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij") + grid = torch.stack(grid, dim=0) + grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + + if rope_interpolation_scale is not None: + grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames + grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height + grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width + + grid = grid.flatten(2, 4).transpose(1, 2) + + start = 1.0 + end = self.theta + freqs = self.theta ** torch.linspace( + math.log(start, self.theta), + math.log(end, self.theta), + self.dim // 6, + device=hidden_states.device, + dtype=torch.float32, + ) + freqs = freqs * math.pi / 2.0 + freqs = freqs * (grid.unsqueeze(-1) * 2 - 1) + freqs = freqs.transpose(-1, -2).flatten(2) + + cos_freqs = freqs.cos().repeat_interleave(2, dim=-1) + sin_freqs = freqs.sin().repeat_interleave(2, dim=-1) + + if self.dim % 6 != 0: + cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6]) + sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6]) + cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1) + sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1) + + return cos_freqs, sin_freqs + + +@maybe_allow_in_graph +class LTXTransformerBlock(nn.Module): + r""" + Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization layer to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: int, + qk_norm: str = "rms_norm_across_heads", + activation_fn: str = "gelu-approximate", + attention_bias: bool = True, + attention_out_bias: bool = True, + eps: float = 1e-6, + elementwise_affine: bool = False, + ): + super().__init__() + + self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + cross_attention_dim=None, + out_bias=attention_out_bias, + qk_norm=qk_norm, + processor=LTXAttentionProcessor2_0(), + ) + + self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + kv_heads=num_attention_heads, + dim_head=attention_head_dim, + bias=attention_bias, + out_bias=attention_out_bias, + qk_norm=qk_norm, + processor=LTXAttentionProcessor2_0(), + ) + + self.ff = FeedForward(dim, activation_fn=activation_fn) + + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + batch_size = hidden_states.size(0) + norm_hidden_states = self.norm1(hidden_states) + + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + + attn_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=None, + image_rotary_emb=image_rotary_emb, + ) + hidden_states = hidden_states + attn_hidden_states * gate_msa + + attn_hidden_states = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_rotary_emb=None, + attention_mask=encoder_attention_mask, + ) + hidden_states = hidden_states + attn_hidden_states + norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + hidden_states = hidden_states + ff_output * gate_mlp + + return hidden_states + + +@maybe_allow_in_graph +class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): + r""" + A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video). + + Args: + in_channels (`int`, defaults to `128`): + The number of channels in the input. + out_channels (`int`, defaults to `128`): + The number of channels in the output. + patch_size (`int`, defaults to `1`): + The size of the spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of the tmeporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `32`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `64`): + The number of channels in each head. + cross_attention_dim (`int`, defaults to `2048 `): + The number of channels for cross attention heads. + num_layers (`int`, defaults to `28`): + The number of layers of Transformer blocks to use. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + qk_norm (`str`, defaults to `"rms_norm_across_heads"`): + The normalization layer to use. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 128, + out_channels: int = 128, + patch_size: int = 1, + patch_size_t: int = 1, + num_attention_heads: int = 32, + attention_head_dim: int = 64, + cross_attention_dim: int = 2048, + num_layers: int = 28, + activation_fn: str = "gelu-approximate", + qk_norm: str = "rms_norm_across_heads", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 4096, + attention_bias: bool = True, + attention_out_bias: bool = True, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + inner_dim = num_attention_heads * attention_head_dim + + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False) + + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.rope = LTXRotaryPosEmbed( + dim=inner_dim, + base_num_frames=20, + base_height=2048, + base_width=2048, + patch_size=patch_size, + patch_size_t=patch_size_t, + theta=10000.0, + ) + + self.transformer_blocks = nn.ModuleList( + [ + LTXTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + qk_norm=qk_norm, + activation_fn=activation_fn, + attention_bias=attention_bias, + attention_out_bias=attention_out_bias, + eps=norm_eps, + elementwise_affine=norm_elementwise_affine, + ) + for _ in range(num_layers) + ] + ) + + self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_attention_mask: torch.Tensor, + num_frames: int, + height: int, + width: int, + rope_interpolation_scale: Optional[Tuple[float, float, float]] = None, + return_dict: bool = True, + ) -> torch.Tensor: + image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + batch_size = hidden_states.size(0) + hidden_states = self.proj_in(hidden_states) + + temb, embedded_timestep = self.time_embed( + timestep.flatten(), + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + temb = temb.view(batch_size, -1, temb.size(-1)) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1)) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1)) + + for block in self.transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + encoder_attention_mask, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + encoder_attention_mask=encoder_attention_mask, + ) + + scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + + hidden_states = self.norm_out(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + output = self.proj_out(hidden_states) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def apply_rotary_emb(x, freqs): + cos, sin = freqs + x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + return out diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3409aea3cfde..7f85ad19e30d 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -250,6 +250,7 @@ ] ) _import_structure["latte"] = ["LattePipeline"] + _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] _import_structure["lumina"] = ["LuminaText2ImgPipeline"] _import_structure["marigold"].extend( [ @@ -585,6 +586,7 @@ LEditsPPPipelineStableDiffusion, LEditsPPPipelineStableDiffusionXL, ) + from .ltx import LTXImageToVideoPipeline, LTXPipeline from .lumina import LuminaText2ImgPipeline from .marigold import ( MarigoldDepthPipeline, diff --git a/src/diffusers/pipelines/ltx/__init__.py b/src/diffusers/pipelines/ltx/__init__.py new file mode 100644 index 000000000000..20cc1c216522 --- /dev/null +++ b/src/diffusers/pipelines/ltx/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_ltx"] = ["LTXPipeline"] + _import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_ltx import LTXPipeline + from .pipeline_ltx_image2video import LTXImageToVideoPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py new file mode 100644 index 000000000000..72b95fea1ce1 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -0,0 +1,755 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTXPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage" + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=704, + ... height=480, + ... num_frames=161, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LTXPipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + Pipeline for text-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTXVideo, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXVideoTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32 + self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8 + self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1 + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128 + ) + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128 + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def prepare_latents( + self, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + + shape = (batch_size, num_channels_latents, num_frames, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 704, + num_frames: int = 161, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 3, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, defaults to `704`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `161`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `3 `): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `128 `): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio + rope_interpolation_scale = ( + 1 / latent_frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py new file mode 100644 index 000000000000..25ed635a3d17 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -0,0 +1,851 @@ +# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import T5EncoderModel, T5TokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput +from ...loaders import FromSingleFileMixin +from ...models.autoencoders import AutoencoderKLLTXVideo +from ...models.transformers import LTXVideoTransformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import LTXPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import LTXImageToVideoPipeline + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> image = load_image( + ... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" + ... ) + >>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene." + >>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + + >>> video = pipe( + ... image=image, + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... width=704, + ... height=480, + ... num_frames=161, + ... num_inference_steps=50, + ... ).frames[0] + >>> export_to_video(video, "output.mp4", fps=24) + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin): + r""" + Pipeline for image-to-video generation. + + Reference: https://github.com/Lightricks/LTX-Video + + Args: + transformer ([`LTXVideoTransformer3DModel`]): + Conditional Transformer architecture to denoise the encoded video latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLLTXVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLLTXVideo, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: LTXVideoTransformer3DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + + self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32 + self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8 + self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1 + self.transformer_temporal_patch_size = ( + self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128 + ) + + self.default_height = 512 + self.default_width = 704 + self.default_frames = 121 + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128 + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + prompt_attention_mask = prompt_attention_mask.view(batch_size, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128 + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 128, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents + def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: + # Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. + # The patch dimensions are then permuted and collapsed into the channel dimension of shape: + # [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). + # dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features + batch_size, num_channels, num_frames, height, width = latents.shape + post_patch_num_frames = num_frames // patch_size_t + post_patch_height = height // patch_size + post_patch_width = width // patch_size + latents = latents.reshape( + batch_size, + -1, + post_patch_num_frames, + patch_size_t, + post_patch_height, + patch_size, + post_patch_width, + patch_size, + ) + latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents + def _unpack_latents( + latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1 + ) -> torch.Tensor: + # Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions) + # are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of + # what happens in the `_pack_latents` method. + batch_size = latents.size(0) + latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size) + latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3) + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents + def _normalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Normalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = (latents - latents_mean) * scaling_factor / latents_std + return latents + + @staticmethod + # Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents + def _denormalize_latents( + latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 + ) -> torch.Tensor: + # Denormalize latents across the channel dimension [B, C, F, H, W] + latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype) + latents = latents * latents_std / scaling_factor + latents_mean + return latents + + def prepare_latents( + self, + image: Optional[torch.Tensor] = None, + batch_size: int = 1, + num_channels_latents: int = 128, + height: int = 512, + width: int = 704, + num_frames: int = 161, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + height = height // self.vae_spatial_compression_ratio + width = width // self.vae_spatial_compression_ratio + num_frames = ( + (num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2) + ) + + shape = (batch_size, num_channels_latents, num_frames, height, width) + mask_shape = (batch_size, 1, num_frames, height, width) + + if latents is not None: + conditioning_mask = latents.new_zeros(shape) + conditioning_mask[:, :, 0] = 1.0 + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + return latents.to(device=device, dtype=dtype), conditioning_mask + + if isinstance(generator, list): + if len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + init_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i]) + for i in range(batch_size) + ] + else: + init_latents = [ + retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image + ] + + init_latents = torch.cat(init_latents, dim=0).to(dtype) + init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std) + init_latents = init_latents.repeat(1, 1, num_frames, 1, 1) + conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype) + conditioning_mask[:, :, 0] = 1.0 + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask) + + conditioning_mask = self._pack_latents( + conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ).squeeze(-1) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + return latents, conditioning_mask + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: PipelineImageInput = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 704, + num_frames: int = 161, + frame_rate: int = 25, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 3, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 128, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `512`): + The height in pixels of the generated image. This is set to 480 by default for the best results. + width (`int`, defaults to `704`): + The width in pixels of the generated image. This is set to 848 by default for the best results. + num_frames (`int`, defaults to `161`): + The number of video frames to generate + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, defaults to `3 `): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `128 `): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.ltx.LTXPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Prepare text embeddings + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + device=device, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + + # 4. Prepare latent variables + if latents is None: + image = self.video_processor.preprocess(image, height=height, width=width) + image = image.to(device=device, dtype=prompt_embeds.dtype) + + num_channels_latents = self.transformer.config.in_channels + latents, conditioning_mask = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + if self.do_classifier_free_guidance: + conditioning_mask = torch.cat([conditioning_mask, conditioning_mask]) + + # 5. Prepare timesteps + latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1 + latent_height = height // self.vae_spatial_compression_ratio + latent_width = width // self.vae_spatial_compression_ratio + video_sequence_length = latent_num_frames * latent_height * latent_width + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + mu = calculate_shift( + video_sequence_length, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Prepare micro-conditions + latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio + rope_interpolation_scale = ( + 1 / latent_frame_rate, + self.vae_spatial_compression_ratio, + self.vae_spatial_compression_ratio, + ) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = latent_model_input.to(prompt_embeds.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + timestep, _ = timestep.chunk(2) + + # compute the previous noisy sample x_t -> x_t-1 + noise_pred = self._unpack_latents( + noise_pred, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + + noise_pred = noise_pred[:, :, 1:] + noise_latents = latents[:, :, 1:] + pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] + + latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) + latents = self._pack_latents( + latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size + ) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + video = latents + else: + latents = self._unpack_latents( + latents, + latent_num_frames, + latent_height, + latent_width, + self.transformer_spatial_patch_size, + self.transformer_temporal_patch_size, + ) + latents = self._denormalize_latents( + latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor + ) + latents = latents.to(prompt_embeds.dtype) + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return LTXPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/ltx/pipeline_output.py b/src/diffusers/pipelines/ltx/pipeline_output.py new file mode 100644 index 000000000000..36ec3ea884a2 --- /dev/null +++ b/src/diffusers/pipelines/ltx/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class LTXPipelineOutput(BaseOutput): + r""" + Output class for LTX pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 91264e805a0f..6ddd9ac23009 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -75,6 +75,7 @@ def __init__( base_image_seq_len: Optional[int] = 256, max_image_seq_len: Optional[int] = 4096, invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, @@ -181,6 +182,27 @@ def _sigma_to_t(self, sigma): def time_shift(self, mu: float, sigma: float, t: torch.Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + def set_timesteps( self, num_inference_steps: int = None, @@ -216,6 +238,9 @@ def set_timesteps( else: sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3f09b90f6b69..1c3a6123a469 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -107,6 +107,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLLTXVideo(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLMochi(metaclass=DummyObject): _backends = ["torch"] @@ -407,6 +422,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class LTXVideoTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LuminaNextDiT2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 4fbdc2a83573..55a2a3df7572 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1067,6 +1067,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class LTXImageToVideoPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class LTXPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LuminaText2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py new file mode 100644 index 000000000000..128bf04155e7 --- /dev/null +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import LTXVideoTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class LTXTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = LTXVideoTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 2 + height = 16 + width = 16 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "encoder_attention_mask": encoder_attention_mask, + "num_frames": num_frames, + "height": height, + "width": width, + } + + @property + def input_shape(self): + return (512, 4) + + @property + def output_shape(self): + return (512, 4) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 4, + "out_channels": 4, + "num_attention_heads": 2, + "attention_head_dim": 8, + "cross_attention_dim": 16, + "num_layers": 1, + "qk_norm": "rms_norm_across_heads", + "caption_channels": 16, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"LTXVideoTransformer3DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/ltx/__init__.py b/tests/pipelines/ltx/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py new file mode 100644 index 000000000000..0f9819bfd6d8 --- /dev/null +++ b/tests/pipelines/ltx/test_ltx.py @@ -0,0 +1,256 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LTXVideoTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + spatio_temporal_scaling=(True, True, False, False), + layers_per_block=(1, 1, 1, 1, 1), + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + # 8 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + expected_video = torch.randn(9, 3, 32, 32) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) diff --git a/tests/pipelines/ltx/test_ltx_image2video.py b/tests/pipelines/ltx/test_ltx_image2video.py new file mode 100644 index 000000000000..40397e4c3619 --- /dev/null +++ b/tests/pipelines/ltx/test_ltx_image2video.py @@ -0,0 +1,264 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5EncoderModel + +from diffusers import ( + AutoencoderKLLTXVideo, + FlowMatchEulerDiscreteScheduler, + LTXImageToVideoPipeline, + LTXVideoTransformer3DModel, +) +from diffusers.utils.testing_utils import enable_full_determinism, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = LTXImageToVideoPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = LTXVideoTransformer3DModel( + in_channels=8, + out_channels=8, + patch_size=1, + patch_size_t=1, + num_attention_heads=4, + attention_head_dim=8, + cross_attention_dim=32, + num_layers=1, + caption_channels=32, + ) + + torch.manual_seed(0) + vae = AutoencoderKLLTXVideo( + latent_channels=8, + block_out_channels=(8, 8, 8, 8), + spatio_temporal_scaling=(True, True, False, False), + layers_per_block=(1, 1, 1, 1, 1), + patch_size=1, + patch_size_t=1, + encoder_causal=True, + decoder_causal=False, + ) + vae.use_framewise_encoding = False + vae.use_framewise_decoding = False + + torch.manual_seed(0) + scheduler = FlowMatchEulerDiscreteScheduler() + text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + image = torch.randn((1, 3, 32, 32), generator=generator, device=device) + + inputs = { + "image": image, + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 3.0, + "height": 32, + "width": 32, + # 8 * k + 1 is the recommendation + "num_frames": 9, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (9, 3, 32, 32)) + expected_video = torch.randn(9, 3, 32, 32) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + )