Skip to content

Potential scheduler issues #5685

@vakker

Description

@vakker

Describe the bug

I was adjusting the scheduler configs and I've come across a few potential issues.

Edit: the 1. is a duplicate of #5628

1. Different denoising_start and denoising_end calculation for different schedulers.

Repro:

import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler

base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True,
).to("cuda")
refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
).to("cuda")

# Define how many steps and what % of steps to be run on each experts (80/20) here
n_steps = 40
high_noise_frac = 0.8

prompt = "A majestic lion jumping from a big stone at night"

# run both experts
image = base(
    prompt=prompt,
    num_inference_steps=n_steps,
    denoising_end=high_noise_frac,
    output_type="latent",
).images
image = refiner(
    prompt=prompt,
    num_inference_steps=n_steps,
    denoising_start=high_noise_frac,
    image=image,
).images[0]


base.scheduler = DPMSolverMultistepScheduler.from_config(
    base.scheduler.config,
)

refiner.scheduler = DPMSolverMultistepScheduler.from_config(
    refiner.scheduler.config,
)


# run both experts
image = base(
    prompt=prompt,
    num_inference_steps=n_steps,
    denoising_end=high_noise_frac,
    output_type="latent",
).images
image = refiner(
    prompt=prompt,
    num_inference_steps=n_steps,
    denoising_start=high_noise_frac,
    image=image,
).images[0]


base.scheduler = DPMSolverMultistepScheduler.from_config(
    base.scheduler.config,
    use_karras_sigmas=True,
)

refiner.scheduler = DPMSolverMultistepScheduler.from_config(
    refiner.scheduler.config,
    use_karras_sigmas=True,
)


# run both experts
image = base(
    prompt=prompt,
    num_inference_steps=n_steps,
    denoising_end=high_noise_frac,
    output_type="latent",
).images
image = refiner(
    prompt=prompt,
    num_inference_steps=n_steps,
    denoising_start=high_noise_frac,
    image=image,
).images[0]

This does:

  1. 32/8 split with the default EulerDiscreteScheduler (which is according to the expected high_noise_frac=0.8)
  2. 32/8 split with DPMSolverMultistepScheduler (still good)
  3. 25/15 with DPMSolverMultistepScheduler + use_karras_sigmas=True, which is not expected.

I'm not sure whether this is intended or a bug. It might be due to the calculation around here?

2. The denoising_start valid check in StableDiffusionXLImg2ImgPipeline.__call__

See here.

The call if denoising_value_valid else None is always true.
It should be something like if denoising_value_valid(self.denoising_start) else None, no?

3. The discrete_timestep_cutoff calculation in StableDiffusionXLPipeline.__call__

See here.

The cutoff is set as

self.scheduler.config.num_train_timesteps - (self.denoising_end * self.scheduler.config.num_train_timesteps)

but I would have assumed that it's simply

self.denoising_end * self.scheduler.config.num_train_timesteps

The cutoff should be higher with higher denoising_end, with no cutoff at all if denoising_end=1.0, no?

Reproduction

See above.

Logs

See above.

System Info

  • diffusers version: 0.22.1
  • Platform: Linux-5.4.0-164-generic-x86_64-with-glibc2.31
  • Python version: 3.10.7
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Huggingface_hub version: 0.18.0
  • Transformers version: 4.35.0
  • Accelerate version: 0.24.1
  • xFormers version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@yiyixuxu
@patrickvonplaten
@sayakpaul
@DN6

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingstaleIssues that haven't received updates

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions