diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 316e79da4fd6..b2b5baff7d95 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -165,10 +165,12 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, time_dim, bias=True) - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, time): - args = torch.outer(time, self.freqs.to(device=time.device)) + time = time.to(dtype=torch.float32) + freqs = self.freqs.to(device=time.device, dtype=torch.float32) + args = torch.outer(time, freqs) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + time_embed = time_embed.to(dtype=self.in_layer.weight.dtype) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed @@ -269,8 +271,8 @@ def __init__(self, time_dim, model_dim, num_params): self.out_layer.weight.data.zero_() self.out_layer.bias.data.zero_() - @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): + x = x.to(dtype=self.out_layer.weight.dtype) return self.out_layer(self.activation(x))