diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index 09682a218d91..4bf00f749f25 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -44,7 +44,7 @@ class ErnieImageTransformer2DModelOutput(BaseOutput): def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 - scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim + scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim omega = 1.0 / (theta**scale) out = torch.einsum("...n,d->...nd", pos, omega) return out.float() @@ -400,8 +400,8 @@ def forward( ] # AdaLN - sample = self.time_proj(timestep.to(dtype)) - sample = sample.to(self.time_embedding.linear_1.weight.dtype) + sample = self.time_proj(timestep) + sample = sample.to(dtype=dtype) c = self.time_embedding(sample) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)