-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Description
Is your feature request related to a problem? Please describe.
By default the stable diffusion pipeline uses the PNDM scheduler, but one could easily use other schedulers (we only need to overwrite the self.scheduler
) attribute.
This can be done with the following code-snippet:
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", use_auth_token=True) # make sure you're logged in with `huggingface-cli login`
# use DDIM scheduler
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)
pipe.scheduler = scheduler
Now, that's a bit hacky and not the way we want users to do it ideally!
Describe the solution you'd like
Instead, the following code snippet should work or less for all pipelines:
from diffusers import StableDiffusionPipeline, DDIMScheduler
# Use DDIM scheduler here instead
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", scheduler=scheduler, use_auth_token=True) # make sure you're logged in with `huggingface-cli login`
This is a cleaner & more intuitive API. The idea should be that every class variable that can be passed to
def __init__( |
from_pretrained(...)
When currently running this command it fails:
TypeError: cannot unpack non-iterable DDIMScheduler object
Now we can allow such behavior by adding some logic to the general DiffusionPipeline from_pretrained
method here:
diffusers/src/diffusers/pipeline_utils.py
Line 115 in 051b346
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): |
Also we want this approach to work not just for one pipeline and only the scheduler class, but for all pipelines and all schedulers classes.
We can achieve this by doing more or less the following in
diffusers/src/diffusers/pipeline_utils.py
Line 115 in 051b346
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): |
Pseudo code:
- Retrieve all variables that can be passed to the class init here:
diffusers/src/diffusers/pipeline_utils.py
Line 149 in 051b346
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
-> you should get a list of keys such as[vae, text_encoder, tokenizer, unet, scheduler]
- Check if any of those parameters are passed in
kwargs
-> if yes -> store them in a dictpassed_class_obj
- In the loop that loads the class variables:
diffusers/src/diffusers/pipeline_utils.py
Line 162 in 051b346
if is_pipeline_module: name
is inpassed_class_obj
dict -> if yes -> simple use this instead and skip the loading part (set the passed class toloaded_sub_model
)
=> after the PR this should work:
from diffusers import StableDiffusionPipeline, DDIMScheduler
# Use DDIM scheduler here instead
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", scheduler=scheduler, use_auth_token=True) # make sure you're logged in with `huggingface-cli login`
where as this should give a nice error message (note how scheduler is incorrectly passed to vae):
from diffusers import StableDiffusionPipeline, DDIMScheduler
# Use DDIM scheduler here instead
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, clip_alpha_at_one=False)
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-3-diffusers", vae=scheduler, use_auth_token=True) # make sure you're logged in with `huggingface-cli login`
The error message can be based on the passed class not having a matching parent class with what was expected (this could be checked using this dict:
diffusers/src/diffusers/pipeline_utils.py
Line 34 in 051b346
LOADABLE_CLASSES = { |
Additional context
As suggusted by @apolinario - it's very important to allow one to easily swap out schedulers. At the same time we don't want to create too much costum code. IMO the solution above handles the problem nicely.