Skip to content
10 changes: 7 additions & 3 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,13 @@ def add_noise(
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
dtype = torch.float32 if original_samples.device.type == "mps" else timesteps.dtype
self.timesteps = self.timesteps.to(original_samples.device, dtype=dtype)
timesteps = timesteps.to(original_samples.device)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
self.timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

schedule_timesteps = self.timesteps

Expand Down
3 changes: 2 additions & 1 deletion tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,14 @@ def test_add_noise_device(self):
continue
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(100)

sample = self.dummy_sample.to(torch_device)
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)

noise = torch.randn_like(scaled_sample).to(torch_device)
t = torch.tensor([10]).to(torch_device)
t = scheduler.timesteps[5][None]
noised = scheduler.add_noise(scaled_sample, noise, t)
self.assertEqual(noised.shape, scaled_sample.shape)

Expand Down