Skip to content

Conversation

@sywangyi
Copy link
Contributor

fix the corrupted output when enable ulysses with native attention. since native attention is widely use in no-cuda plaform. and ulysses does not rely on lse. so ulysses attention still could be used in native attention path.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
@sywangyi
Copy link
Contributor Author

from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video


model_id="Wan-AI/Wan2.2-T2V-A14B-Diffusers"

def setup_distributed():
    if not dist.is_initialized():
        dist.init_process_group(backend="nccl")
    device = torch.device(f"cuda:{dist.get_rank()}")
    torch.cuda.set_device(device)
    return device


def set_seed_for_all_ranks(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    generator = torch.Generator(device="cuda")
    generator.manual_seed(seed)
    return generator


device = setup_distributed()
generator = set_seed_for_all_ranks(42)
onload_device = device
offload_device = torch.device("cpu")

vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
# group-offloading
pipe = WanPipeline.from_pretrained(
    model_id,
    vae=vae,
    torch_dtype=torch.bfloat16,
)
ulysses_degree = torch.distributed.get_world_size()
pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=ulysses_degree))
pipe.transformer_2.enable_parallelism(config=ContextParallelConfig(ulysses_degree=ulysses_degree))
pipe.to("cuda")

pipe.vae.enable_tiling(tile_sample_min_height=480,tile_sample_min_width=960,tile_sample_stride_height=352,tile_sample_stride_width=640)
height = 704
width = 1280
num_frames = 24
num_inference_steps = 50
guidance_scale = 5.0


prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩>残留,丑陋的,残缺的>,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手>指融合,静止不动的画面,杂>乱的背景,三条腿,背>景人很多,倒着走"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=num_frames,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    generator=generator,
).frames[0]
if torch.distributed.get_rank() == 0:
    export_to_video(output, "5bit2v_output_bad.mp4", fps=24)
if dist.is_initialized():
    torch.distributed.destroy_process_group()

@sywangyi
Copy link
Contributor Author

torchrun --nproc-per-node 4 test.py
corrupted video output:
image
video output after fix:
image

@sayakpaul sayakpaul requested a review from DN6 October 30, 2025 12:14
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
@DefTruth
Copy link
Contributor

Also worked for me, thanks

enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
if _parallel_config is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, 'supports_context_parallel=True' should be also added to register @sywangyi @sayakpaul

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, added

Copy link
Member

@sayakpaul sayakpaul Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DefTruth that should also fix #12446 (comment) right? Could you give this a check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DefTruth that should also fix #12446 (comment) right? Could you give this a check?

confirm fixed

sywangyi and others added 2 commits November 2, 2025 21:36
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@yiyixuxu yiyixuxu merged commit 1ec28a2 into huggingface:main Nov 3, 2025
8 of 11 checks passed
@sayakpaul
Copy link
Member

@sywangyi can we enable it for Ring x Native, too? I don't see ring's reliance on lse, either.

@sywangyi
Copy link
Contributor Author

sywangyi commented Nov 4, 2025

@sywangyi can we enable it for Ring x Native, too? I don't see ring's reliance on lse, either.

no, ring need lse to guarantee the precision see https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_dispatch.py#L1062-L1065

@sayakpaul
Copy link
Member

Yeah but it's always None no?

prev_out = prev_lse = None

@sywangyi
Copy link
Contributor Author

sywangyi commented Nov 4, 2025

no, see

prev_out = out
prev_lse = lse

@sayakpaul
Copy link
Member

Ah sorry for the oversight. Thanks for clarifying.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants