Skip to content

Loading WanTransformer3DModel using torch_dtype=torch.bfloat16 keeps some parameters as float32 #10992

@spezialspezial

Description

@spezialspezial

Describe the bug

Just checking if this is the expected behavior. Calling WanTransformer3DModel.from_pretrained with argument torch_dtype=torch.bfloat16 keeps some parameters as float32.

Reproduction

repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
transformer = WanTransformer3DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=torch.bfloat16)
print(transformer.blocks[0].norm2.bias.dtype)  # torch.float32
print(transformer.blocks[0].scale_shift_table.dtype)  # torch.float32
print(transformer.blocks[0].attn1.norm_k.weight.dtype)  # torch.bfloat16
print(transformer.blocks[0].attn1.to_k.weight.dtype)  # torch.bfloat16

Logs

System Info

Diffusers #812b4e1eaa20fa8d88aa48b645b9d34ca45ecfde (2025-03-06), Linux, Python 3.10

Who can help?

Calling @DN6 @a-r-r-o-w

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions