Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def preprocess(image):
return 2.0 * image - 1.0


def posterior_sample(scheduler, latents, timestep, clean_latents, eta):
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
# 1. get previous step value (=t-1)
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps

Expand All @@ -62,7 +62,9 @@ def posterior_sample(scheduler, latents, timestep, clean_latents, eta):
# direction pointing to x_t
e_t = (latents - alpha_prod_t ** (0.5) * clean_latents) / (1 - alpha_prod_t) ** (0.5)
dir_xt = (1.0 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * e_t
noise = std_dev_t * torch.randn(clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device)
noise = std_dev_t * torch.randn(
clean_latents.shape, dtype=clean_latents.dtype, device=clean_latents.device, generator=generator
)
prev_latents = alpha_prod_t_prev ** (0.5) * clean_latents + dir_xt + noise

return prev_latents
Expand Down Expand Up @@ -499,7 +501,7 @@ def __call__(

# Sample source_latents from the posterior distribution.
prev_source_latents = posterior_sample(
self.scheduler, source_latents, t, clean_latents, **extra_step_kwargs
self.scheduler, source_latents, t, clean_latents, generator=generator, **extra_step_kwargs
)
# Compute noise.
noise = compute_noise(
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def step(

if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
device = model_output.device
if variance_noise is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def step(

prev_sample = sample + derivative * dt

device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def step(

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0

device = model_output.device if torch.is_tensor(model_output) else torch.device("cpu")
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to(
Expand Down
8 changes: 5 additions & 3 deletions tests/pipelines/stable_diffusion/test_cycle_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_cycle_diffusion_pipeline_fp16(self):
source_prompt = "A black colored car"
prompt = "A blue colored car"

torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
source_prompt=source_prompt,
Expand All @@ -303,12 +303,13 @@ def test_cycle_diffusion_pipeline_fp16(self):
strength=0.85,
guidance_scale=3,
source_guidance_scale=1,
generator=generator,
output_type="np",
)
image = output.images

# the values aren't exactly equal, but the images look the same visually
assert np.abs(image - expected_image).max() < 1e-2
assert np.abs(image - expected_image).max() < 5e-1

def test_cycle_diffusion_pipeline(self):
init_image = load_image(
Expand All @@ -331,7 +332,7 @@ def test_cycle_diffusion_pipeline(self):
source_prompt = "A black colored car"
prompt = "A blue colored car"

torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipe(
prompt=prompt,
source_prompt=source_prompt,
Expand All @@ -341,6 +342,7 @@ def test_cycle_diffusion_pipeline(self):
strength=0.85,
guidance_scale=3,
source_guidance_scale=1,
generator=generator,
output_type="np",
)
image = output.images
Expand Down
4 changes: 2 additions & 2 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ def test_stable_diffusion_text2img_pipeline_fp16(self):

def test_stable_diffusion_text2img_pipeline_default(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/lewington/expected-images/resolve/main/astronaut_riding_a_horse.npy"
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text2img/astronaut_riding_a_horse.npy"
)

model_id = "CompVis/stable-diffusion-v1-4"
Expand All @@ -771,7 +771,7 @@ def test_stable_diffusion_text2img_pipeline_default(self):
image = output.images[0]

assert image.shape == (512, 512, 3)
assert np.abs(expected_image - image).max() < 1e-3
assert np.abs(expected_image - image).max() < 5e-3

def test_stable_diffusion_text2img_intermediate_state(self):
number_of_steps = 0
Expand Down
7 changes: 4 additions & 3 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,8 @@ def test_from_pretrained_hub_pass_model(self):
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"

pipe = DDIMPipeline.from_pretrained(model_path)
scheduler = DDIMScheduler.from_config(model_path)
pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

Expand All @@ -451,13 +452,13 @@ def test_output_format(self):
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)

images = pipe(generator=generator, output_type="pil").images
images = pipe(generator=generator, output_type="pil", num_inference_steps=4).images
assert isinstance(images, list)
assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image)

# use PIL by default
images = pipe(generator=generator).images
images = pipe(generator=generator, num_inference_steps=4).images
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)

Expand Down
33 changes: 21 additions & 12 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,10 +1281,11 @@ def test_full_loop_no_noise(self):

scheduler.set_timesteps(self.num_inference_steps)

generator = torch.Generator().manual_seed(0)
generator = torch.Generator(torch_device).manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)

for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
Expand All @@ -1296,7 +1297,6 @@ def test_full_loop_no_noise(self):

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)

assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3
Expand All @@ -1308,7 +1308,7 @@ def test_full_loop_device(self):

scheduler.set_timesteps(self.num_inference_steps, device=torch_device)

generator = torch.Generator().manual_seed(0)
generator = torch.Generator(torch_device).manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
Expand All @@ -1324,7 +1324,6 @@ def test_full_loop_device(self):

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)

assert abs(result_sum.item() - 10.0807) < 1e-2
assert abs(result_mean.item() - 0.0131) < 1e-3
Expand Down Expand Up @@ -1365,10 +1364,11 @@ def test_full_loop_no_noise(self):

scheduler.set_timesteps(self.num_inference_steps)

generator = torch.Generator().manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)

for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
Expand All @@ -1380,9 +1380,14 @@ def test_full_loop_no_noise(self):

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3

if str(torch_device).startswith("cpu"):
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
else:
# CUDA
assert abs(result_sum.item() - 144.8084) < 1e-2
assert abs(result_mean.item() - 0.18855) < 1e-3

def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
Expand All @@ -1391,7 +1396,7 @@ def test_full_loop_device(self):

scheduler.set_timesteps(self.num_inference_steps, device=torch_device)

generator = torch.Generator().manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
Expand All @@ -1407,14 +1412,18 @@ def test_full_loop_device(self):

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
print(result_sum, result_mean)
if not str(torch_device).startswith("mps"):

if str(torch_device).startswith("cpu"):
# The following sum varies between 148 and 156 on mps. Why?
assert abs(result_sum.item() - 152.3192) < 1e-2
assert abs(result_mean.item() - 0.1983) < 1e-3
else:
elif str(torch_device).startswith("mps"):
# Larger tolerance on mps
assert abs(result_mean.item() - 0.1983) < 1e-2
else:
# CUDA
assert abs(result_sum.item() - 144.8084) < 1e-2
assert abs(result_mean.item() - 0.18855) < 1e-3


class IPNDMSchedulerTest(SchedulerCommonTest):
Expand Down