Skip to content

Conversation

@lawrence-cj
Copy link
Contributor

@lawrence-cj lawrence-cj commented Nov 26, 2025

This PR supports LongSANA: a minute-length real-time video generation model

Related links:

project: https://nvlabs.github.io/Sana/Video
code: https://github.com/NVlabs/Sana
paper: https://arxiv.org/pdf/2509.24695

PR feature:

LongSANA uses Causal Linear Attention KV Cache during inference, which is crucial for long video generation(FlashAttention may need other PR). This PR adds Causal computation logi for both Linear Attention and Mix-FFN (Conv in MLP)

Added classes and functions

  1. add SanaVideoCausalTransformerBlock and SanaVideoCausalTransformer3DModel;
  2. add LongSanaVideoPipeline for Linear Attention KV-Cache;
  3. support LongSANA converting from pth to diffusers safetensor;

Cc: @sayakpaul @dg845
Co-author: @HeliosZhao

Code snap:

from diffusers import LongSanaVideoPipeline
from diffusers.utils import export_to_video

pipe = LongSanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_LongLive_diffusers", torch_dtype=torch.bfloat16)

pipe.scheduler = FlowMatchEulerDiscreteScheduler()
pipe.vae.to(torch.float32)
pipe.text_encoder.to(torch.bfloat16)
pipe.to("cuda")

prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

video = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=480,
    width=832,
    frames=161,
    guidance_scale=1.0,
    timesteps=[1000, 960, 889, 727, 0],  # Multi-step denoising per chunk
    generator=torch.Generator(device="cuda").manual_seed(42),
).frames[0]
export_to_video(video, "longsana.mp4", fps=16)

lawrence-cj and others added 4 commits November 26, 2025 07:32
1. add `SanaVideoCausalTransformerBlock` and `SanaVideoCausalTransformer3DModel`;
2. add `LongSanaVideoPipeline` for Linear Attention KV-Cache;
3. support LongSANA converting from pth to diffusers safetensor;
Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>
@sayakpaul sayakpaul requested review from DN6, dg845 and yiyixuxu November 26, 2025 15:48
@sayakpaul
Copy link
Member

FlashAttention may need other PR

We can actually leverage our attention backends:
https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends

@lawrence-cj
Copy link
Contributor Author

FlashAttention may need other PR

We can actually leverage our attention backends: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends

Is KV cache is supported in any backends?
Actually, in my PR, the kv-cache part is not well organized. So we do need your kind help to do it better to match diffusers style.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants