Skip to content

Allow non-default schedulers to be easily swapped into DiffusionPipeline classes #183

@patrickvonplaten

Description

@patrickvonplaten

Is your feature request related to a problem? Please describe.

By default the stable diffusion pipeline uses the PNDM scheduler, but one could easily use other schedulers (we only need to overwrite the self.scheduler) attribute.

This can be done with the following code-snippet:

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True)  # make sure you're logged in with `huggingface-cli login`

# use DDIM scheduler
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)
pipe.scheduler = scheduler

Now, that's a bit hacky and not the way we want users to do it ideally!

Describe the solution you'd like

Instead, the following code snippet should work or less for all pipelines:

from diffusers import StableDiffusionPipeline, DDIMScheduler

# Use DDIM scheduler here instead
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", scheduler=scheduler, use_auth_token=True)  # make sure you're logged in with `huggingface-cli login`

This is a cleaner & more intuitive API. The idea should be that every class variable that can be passed to

should also be overwrite-able when using from_pretrained(...)

When currently running this command it fails:

TypeError: cannot unpack non-iterable DDIMScheduler object 

Now we can allow such behavior by adding some logic to the general DiffusionPipeline from_pretrained method here:

def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):

Also we want this approach to work not just for one pipeline and only the scheduler class, but for all pipelines and all schedulers classes.
We can achieve this by doing more or less the following in

def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):

Pseudo code:

  1. Retrieve all variables that can be passed to the class init here:
    pipeline_class = getattr(diffusers_module, config_dict["_class_name"])

    -> you should get a list of keys such as [vae, text_encoder, tokenizer, unet, scheduler]
  2. Check if any of those parameters are passed in kwargs -> if yes -> store them in a dict passed_class_obj
  3. In the loop that loads the class variables:
    if is_pipeline_module:
    add a new if statements that checkes whether the name is in passed_class_obj dict -> if yes -> simple use this instead and skip the loading part (set the passed class to loaded_sub_model )

=> after the PR this should work:

from diffusers import StableDiffusionPipeline, DDIMScheduler

# Use DDIM scheduler here instead
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", scheduler=scheduler, use_auth_token=True)  # make sure you're logged in with `huggingface-cli login`

where as this should give a nice error message (note how scheduler is incorrectly passed to vae):

from diffusers import StableDiffusionPipeline, DDIMScheduler

# Use DDIM scheduler here instead
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", vae=scheduler, use_auth_token=True)  # make sure you're logged in with `huggingface-cli login`

The error message can be based on the passed class not having a matching parent class with what was expected (this could be checked using this dict:

LOADABLE_CLASSES = {
)

Additional context
As suggusted by @apolinario - it's very important to allow one to easily swap out schedulers. At the same time we don't want to create too much costum code. IMO the solution above handles the problem nicely.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions