From 747aa5974c5ce105fe1cb88841664997ab30259c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 4 Feb 2023 14:44:45 +1100 Subject: [PATCH] Fix k_dpm_2 & k_dpm_2_a on MPS Needed to convert `timesteps` to `float32` a bit sooner. Fixes #1537 --- .../scheduling_k_dpm_2_ancestral_discrete.py | 14 +++++++------- .../schedulers/scheduling_k_dpm_2_discrete.py | 13 ++++++------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 175f338b929e..711bdf2d5ef0 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -161,16 +161,16 @@ def set_timesteps( # standard deviation of the initial noise distribution self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps).to(device) - timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) - interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() - timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) - if str(device).startswith("mps"): # mps does not support float64 - self.timesteps = timesteps.to(device, dtype=torch.float32) + timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) else: - self.timesteps = timesteps + timesteps = torch.from_numpy(timesteps).to(device) + + timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) + interleaved_timesteps = torch.stack((timesteps_interpol[:-2, None], timesteps[1:, None]), dim=-1).flatten() + + self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 18dd97671636..a46cc060522c 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -149,18 +149,17 @@ def set_timesteps( # standard deviation of the initial noise distribution self.init_noise_sigma = self.sigmas.max() - timesteps = torch.from_numpy(timesteps).to(device) + if str(device).startswith("mps"): + # mps does not support float64 + timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + timesteps = torch.from_numpy(timesteps).to(device) # interpolate timesteps timesteps_interpol = self.sigma_to_t(sigmas_interpol).to(device) interleaved_timesteps = torch.stack((timesteps_interpol[1:-1, None], timesteps[1:, None]), dim=-1).flatten() - timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) - if str(device).startswith("mps"): - # mps does not support float64 - self.timesteps = timesteps.to(torch.float32) - else: - self.timesteps = timesteps + self.timesteps = torch.cat([timesteps[:1], interleaved_timesteps]) self.sample = None