-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates
Description
Describe the bug
I was adjusting the scheduler configs and I've come across a few potential issues.
Edit: the 1. is a duplicate of #5628
1. Different denoising_start and denoising_end calculation for different schedulers.
Repro:
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
base = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=base.text_encoder_2,
vae=base.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
).to("cuda")
# Define how many steps and what % of steps to be run on each experts (80/20) here
n_steps = 40
high_noise_frac = 0.8
prompt = "A majestic lion jumping from a big stone at night"
# run both experts
image = base(
prompt=prompt,
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
).images
image = refiner(
prompt=prompt,
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
).images[0]
base.scheduler = DPMSolverMultistepScheduler.from_config(
base.scheduler.config,
)
refiner.scheduler = DPMSolverMultistepScheduler.from_config(
refiner.scheduler.config,
)
# run both experts
image = base(
prompt=prompt,
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
).images
image = refiner(
prompt=prompt,
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
).images[0]
base.scheduler = DPMSolverMultistepScheduler.from_config(
base.scheduler.config,
use_karras_sigmas=True,
)
refiner.scheduler = DPMSolverMultistepScheduler.from_config(
refiner.scheduler.config,
use_karras_sigmas=True,
)
# run both experts
image = base(
prompt=prompt,
num_inference_steps=n_steps,
denoising_end=high_noise_frac,
output_type="latent",
).images
image = refiner(
prompt=prompt,
num_inference_steps=n_steps,
denoising_start=high_noise_frac,
image=image,
).images[0]This does:
- 32/8 split with the default
EulerDiscreteScheduler(which is according to the expectedhigh_noise_frac=0.8) - 32/8 split with
DPMSolverMultistepScheduler(still good) - 25/15 with
DPMSolverMultistepScheduler+use_karras_sigmas=True, which is not expected.
I'm not sure whether this is intended or a bug. It might be due to the calculation around here?
2. The denoising_start valid check in StableDiffusionXLImg2ImgPipeline.__call__
See here.
The call if denoising_value_valid else None is always true.
It should be something like if denoising_value_valid(self.denoising_start) else None, no?
3. The discrete_timestep_cutoff calculation in StableDiffusionXLPipeline.__call__
See here.
The cutoff is set as
self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps)but I would have assumed that it's simply
self.denoising_end * self.scheduler.config.num_train_timestepsThe cutoff should be higher with higher denoising_end, with no cutoff at all if denoising_end=1.0, no?
Reproduction
See above.
Logs
See above.System Info
diffusersversion: 0.22.1- Platform: Linux-5.4.0-164-generic-x86_64-with-glibc2.31
- Python version: 3.10.7
- PyTorch version (GPU?): 2.1.0+cu121 (True)
- Huggingface_hub version: 0.18.0
- Transformers version: 4.35.0
- Accelerate version: 0.24.1
- xFormers version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
vesmanojlovic
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstaleIssues that haven't received updatesIssues that haven't received updates