From c3c8e06aeb9c75c7e14a5d40224ae6d88fc39e8d Mon Sep 17 00:00:00 2001 From: DavidBert Date: Fri, 14 Nov 2025 11:47:18 +0000 Subject: [PATCH] rope in float32 --- src/diffusers/models/transformers/transformer_prx.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py index 9b2664b9cb26..ccbc83ffca03 100644 --- a/src/diffusers/models/transformers/transformer_prx.py +++ b/src/diffusers/models/transformers/transformer_prx.py @@ -275,7 +275,12 @@ def __init__(self, dim: int, theta: int, axes_dim: List[int]): def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + + is_mps = pos.device.type == "mps" + is_npu = pos.device.type == "npu" + dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + + scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) out = pos.unsqueeze(-1) * omega.unsqueeze(0) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)