Describe the bug
Just checking if this is the expected behavior. Calling WanTransformer3DModel.from_pretrained with argument torch_dtype=torch.bfloat16 keeps some parameters as float32.
Reproduction
repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
transformer = WanTransformer3DModel.from_pretrained(repo_id, subfolder="transformer", torch_dtype=torch.bfloat16)
print(transformer.blocks[0].norm2.bias.dtype) # torch.float32
print(transformer.blocks[0].scale_shift_table.dtype) # torch.float32
print(transformer.blocks[0].attn1.norm_k.weight.dtype) # torch.bfloat16
print(transformer.blocks[0].attn1.to_k.weight.dtype) # torch.bfloat16
Logs
System Info
Diffusers #812b4e1eaa20fa8d88aa48b645b9d34ca45ecfde (2025-03-06), Linux, Python 3.10
Who can help?
Calling @DN6 @a-r-r-o-w