Skip to content

When using LMS scheduler with Karras sigmas, denoising_end and denoising_start seem to get offset for some reason. #5628

@nhnt11

Description

@nhnt11

Describe the bug

I've been playing around with the LMS scheduler trying to understand a different bug (which I'll file a separate issue for) and noticed that denoising_start and denoising_end don't seem to work correctly when using LMS with Karras sigmas.

With denoising_end=0.9 and num_inference_steps=50, I see that 35 denoising steps are performed instead of 45. And with denoising_start=0.9, 15 denoising steps are performed instead of 5.

With denoising_end=1.0, 50 steps are performed as expected.

Reproduction

Here's sample code that performs a generation with the refiner set up in an ensemble-of-experts configuration. LMS is used with use_karras_sigmas=True, and denoising_end and denoising_start are set to 0.9.

import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
from typing import cast
from diffusers import LMSDiscreteScheduler

sdxl_model = cast(StableDiffusionXLPipeline, StableDiffusionXLPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-base-1.0',
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
    revision="76d28af79639c28a79fa5c6c6468febd3490a37e",
)).to('cuda')
refiner_model = cast(StableDiffusionXLImg2ImgPipeline, StableDiffusionXLImg2ImgPipeline.from_pretrained(
    'stabilityai/stable-diffusion-xl-refiner-1.0',
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
    revision="5d4cfe854c9a9a87939ff3653551c2b3c99a4356",
)).to('cuda')

common_config = {'beta_start': 0.00085, 'beta_end': 0.012, 'beta_schedule': 'scaled_linear'}
scheduler = LMSDiscreteScheduler(**common_config, use_karras_sigmas=True)
sdxl_model.scheduler = scheduler
refiner_model.scheduler = scheduler

sdxl_model.watermark = None
generator = torch.Generator(device='cuda')
generator.manual_seed(12345)

hnf = 0.9
params = {
    'prompt': ['evening sunset scenery blue sky nature, glass bottle with a galaxy in it'],
    'negative_prompt': ['text, watermark'],
    "negative_prompt": [''],
    "num_inference_steps": 50,
    "height": 1024,
    "width": 1024,
    "guidance_scale": 7,
}

sdxl_res = sdxl_model(**params, denoising_end=hnf, generator=generator, output_type='latent')
sdxl_latents = sdxl_res.images

refiner_res = refiner_model(
    output_type="pil",
    image=sdxl_latents,
    prompt=params['prompt'],
    num_inference_steps=params['num_inference_steps'],
    negative_prompt=params['negative_prompt'],
    generator=generator,
    denoising_start=hnf,
)
refiner_imgs = refiner_res.images

display(refiner_imgs[0])

The output shows
image

which matches a denoising_start/end value of 0.7 rather than 0.9. If I set use_karras_sigmas=False it does the expected number of steps (45 on base and 5 on refiner)

Logs

No response

System Info

  • diffusers version: 0.21.4
  • Platform: Linux-5.4.0-163-generic-x86_64-with-glibc2.31
  • Python version: 3.11.5
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Huggingface_hub version: 0.17.1
  • Transformers version: 4.34.0
  • Accelerate version: 0.22.0
  • xFormers version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@yiyixuxu @patrickvonplaten

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions