Skip to content

Conversation

hlky
Copy link
Contributor

@hlky hlky commented Dec 11, 2024

What does this PR do?

Example

from diffusers import DiffusionPipeline 
from transformers import T5EncoderModel 
import torch 

repo_id = "black-forest-labs/FLUX.1-dev"
text_encoder_2 = T5EncoderModel.from_pretrained(repo_id, subfolder="text_encoder_2")
pipe = DiffusionPipeline.from_pretrained(
    repo_id, text_encoder=text_encoder_2, torch_dtype=torch.bfloat16
)
ValueError: Expected CLIPTextModel for text_encoder, got T5EncoderModel.

Fixes #10093

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc @sayakpaul

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 836 to 841
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
if key not in passed_class_obj:
continue
class_name = passed_class_obj[key].__class__.__name__
if class_name != expected_class_name:
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test for this too?

@sayakpaul
Copy link
Member

Thanks @hlky! This looks leaner than I had thought. thank you!

@hlky
Copy link
Contributor Author

hlky commented Dec 11, 2024

Thanks @sayakpaul. I've added a test, trimmed Flax from class_name and we skip checking scheduler, this should fix the failed tests.

Comment on lines 837 to 838
if key not in passed_class_obj or key == "scheduler":
continue
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we pass scheduler=text_encoder that should be errored out as well, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some special handling for scheduler

Comment on lines 1806 to 1814
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
with self.assertRaises(ValueError) as error_context:
_ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
)

assert "Expected" in str(error_context.exception)
assert "text_encoder" in str(error_context.exception)
assert f"{tokenizer.__class__.__name}" in str(error_context.exception)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also a check for the scheduler as that is handled slightly differently?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add it. For context this is what we're handling:

class KarrasDiffusionSchedulers(Enum):
DDIMScheduler = 1
DDPMScheduler = 2
PNDMScheduler = 3
LMSDiscreteScheduler = 4
EulerDiscreteScheduler = 5
HeunDiscreteScheduler = 6
EulerAncestralDiscreteScheduler = 7
DPMSolverMultistepScheduler = 8
DPMSolverSinglestepScheduler = 9
KDPM2DiscreteScheduler = 10
KDPM2AncestralDiscreteScheduler = 11
DEISMultistepScheduler = 12
UniPCMultistepScheduler = 13
DPMSolverSDEScheduler = 14
EDMEulerScheduler = 15

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's cool. But I don't see the flow matching schedulers here. So, if I do assign a text encoder to scheduler in an RF pipeline (FluxPipeline, for example), would it still work as expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that also works, for pipelines like Flux we're getting the type

<class 'diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler'>

For SD etc we get the enum

<enum 'KarrasDiffusionSchedulers'>
[<KarrasDiffusionSchedulers.DDIMScheduler: 1>, <KarrasDiffusionSchedulers.DDPMScheduler: 2>, <KarrasDiffusionSchedulers.PNDMScheduler: 3>, <KarrasDiffusionSchedulers.LMSDiscreteScheduler: 4>, <KarrasDiffusionSchedulers.EulerDiscreteScheduler: 5>, <KarrasDiffusionSchedulers.HeunDiscreteScheduler: 6>, <KarrasDiffusionSchedulers.EulerAncestralDiscreteScheduler: 7>, <KarrasDiffusionSchedulers.DPMSolverMultistepScheduler: 8>, <KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler: 9>, <KarrasDiffusionSchedulers.KDPM2DiscreteScheduler: 10>, <KarrasDiffusionSchedulers.KDPM2AncestralDiscreteScheduler: 11>, <KarrasDiffusionSchedulers.DEISMultistepScheduler: 12>, <KarrasDiffusionSchedulers.UniPCMultistepScheduler: 13>, <KarrasDiffusionSchedulers.DPMSolverSDEScheduler: 14>, <KarrasDiffusionSchedulers.EDMEulerScheduler: 15>]

So we apply the same processing (str, split, strip applies for type case) to get a list of scheduler

['FlowMatchEulerDiscreteScheduler']
['DDIMScheduler', 'DDPMScheduler', 'PNDMScheduler', 'LMSDiscreteScheduler', 'EulerDiscreteScheduler', 'HeunDiscreteScheduler', 'EulerAncestralDiscreteScheduler', 'DPMSolverMultistepScheduler', 'DPMSolverSinglestepScheduler', 'KDPM2DiscreteScheduler', 'KDPM2AncestralDiscreteScheduler', 'DEISMultistepScheduler', 'UniPCMultistepScheduler', 'DPMSolverSDEScheduler', 'EDMEulerScheduler']

If it's not a scheduler it will raise or if it's the wrong type of scheduler.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! Works for me.

Copy link
Contributor Author

@hlky hlky Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now also support Union, context is failed test test_load_connected_checkpoint_with_passed_obj for KandinskyV22CombinedPipeline, we also change scheduler type to Union[DDPMScheduler, UnCLIPScheduler], the test is actually for passing obj to submodels, but changing the scheduler is how that test works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests for wrong scheduler are added.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go to go from my side once the scheduler related tests are added. Thanks!

@sayakpaul sayakpaul requested review from DN6 and yiyixuxu December 11, 2024 11:30
@hlky
Copy link
Contributor Author

hlky commented Dec 11, 2024

Note that we add LCMScheduler to KarrasDiffusionSchedulers, it's more of a "compatibles" because some don't actually support Karras and LCMScheduler is generally compatible with pipelines using KarrasDiffusionSchedulers, otherwise we'd need to change most (all?) instances of KarrasDiffusionSchedulers to Union[LCMScheduler, KarrasDiffusionSchedulers]

scheduler_types.extend([str(scheduler_type)])
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types]

for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
Copy link
Collaborator

@DN6 DN6 Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have the types extracted in expected_types can't we fetch them using the key and then check if the passed object is an instance of the type? If the expected type is an enum then we can check if the passed obj class name exists in the keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@DN6 DN6 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to make this check more agnostic to the component names.

We have a few pipelines with Union types on non-scheduler components (mostly AnimateDiff). So this snippet would fail even though it's valid, because init_dict is based on the model_index.json which doesn't support multiple types.

from diffusers import (
	AnimateDiffPipeline,
	UNetMotionModel,
)

unet = UNetMotionModel()
pipe = AnimateDiffPipeline.from_pretrained(
	"hf-internal-testing/tiny-sd-pipe", unet=unet
)

Enforcing scheduler types might be a breaking change cc: @yiyixuxu . e.g. Using DDIM with Kandinsky is currently valid, but with this change any downstream code doing this it would break. It would be good to enforce on the pipelines with Flow based schedulers though? (perhaps via a new Enum)

I would try something like:

        for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
            if key not in passed_class_obj:
                continue

            class_obj = passed_class_obj[key]
            _expected_class_types = []
            for expected_type in expected_types[key]:
                if isinstance(expected_type, enum.EnumMeta):
                    _expected_class_types.extend(expected_type.__members__.keys())
                else:
                    _expected_class_types.append(expected_type.__name__)

            _is_valid_type = class_obj.__class__.__name__ in _expected_class_types
            if isinstance(class_obj, SchedulerMixin) and not _is_valid_type:
                # Handle case where scheduler is still valid 
                # raise if scheduler is meant to be a Flow based scheduler?
            elif not _is_valid_type:
                raise ValueError(f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}.")

Copy link
Contributor Author

@hlky hlky Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this for scheduler

                _requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
                _is_flow_match = "FlowMatch" in class_obj.__class__.__name__
                if _requires_flow_match and not _is_flow_match:
                    raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.")
                elif not _requires_flow_match and _is_flow_match:
                    raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need a value error here, a warning is enough, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A warning should be sufficient, it's mainly for the situation here #10093 (comment) where the wrong text encoder is given because the resulting error is uninformative.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do a warning then:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just chiming here a bit to share a perspective as a user (not a strong opinion). Related to #10189 (comment).

Here

https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_loading_utils.py#L287-L290

if there's an unexpected module passed we raise a value error. I think the check is almost along similar lines -- users are passing assigning components that are unexpected / incompatible. We probably cannot predict the consequences of allowing the loading without raising any errors but if we raise an error, users would know what to do to fix the in correct behaviour.

Comment on lines +1808 to +1815
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
with self.assertRaises(ValueError) as error_context:
_ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
)

assert "is of type" in str(error_context.exception)
assert "but should be" in str(error_context.exception)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're now using warning, but for this case, CLIPTokenizer in text_encoder we still get a ValueError later on from here

# if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class
maybe_raise_or_warn(
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
)

So it's a little inconsistent and needs further testing to determine which other cases this already applies to.


_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if isinstance(class_obj, SchedulerMixin) and not _is_valid_type:
_requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think checking against a FlowMatchSchedulers enum would be better in case we end up not using "FlowMatch" in the class name.

_requires_flow_match = any(class_type in FlowMatchSchedulers.__members__ for class_type in _expected_class_types)
_is_flow_match =  class_obj.__class__.__name__ in FlowMatchSchedulers

cc: @yiyixuxu

_requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
_is_flow_match = "FlowMatch" in class_obj.__class__.__name__
if _requires_flow_match and not _is_flow_match:
logger.warning(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably okay to raise an error here because scheduler.scale_noise would raise an error in the flow matching pipelines if a non-FlowMatch scheduler is used.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't want to raise any error here because type hint was not something enforced in this library and it is hard even for us to tell which schedulers can be used/cannot.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g kandinsky if memory serves I think ddim may also works with some of the pipelines, and the compatibility may change

elif not _requires_flow_match and _is_flow_match:
logger.warning(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.")
elif not _is_valid_type:
logger.warning(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if it's not a scheduler and the types don't match it's okay to raise an error. I think it would break in the model loading step anyway in this case. wdyt @yiyixuxu?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer a warning because:

  1. I think there is very little /no benefits in raising an error vs a warning here
  2. in case we make a mistake in type hint, we will throw an error by mistake

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we just added use_flow_sigma to a few non-flow match schedulers with the SANA pr, and also we plan to refactor them but don't have a design finalized yet
given that, I think maybe we can skip checking for scheduler altogether for now, and revisit later. let me know what you guys think!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed scheduler related changes for now, I think we can revisit that later, as @yiyixuxu mentioned above type hints haven't been strictly enforced there are probably some missing/wrong, especially for schedulers. Warning is better because of that too, if there is some wrong type hint that makes its way into a release we'd have to issue a hotfix release to fix it, that just creates headaches and issue reports.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @hlky !

@sayakpaul
Copy link
Member

Thanks for taking it up, @hlky! To me it's a really nice QoL improvement from a DX perspective.

@hlky hlky merged commit 0ed09a1 into huggingface:main Dec 19, 2024
12 checks passed
Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
…10189)

* Check correct model type is passed to `from_pretrained`

* Flax, skip scheduler

* test_wrong_model

* Fix for scheduler

* Update tests/pipelines/test_pipelines.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* EnumMeta

* Flax

* scheduler in expected types

* make

* type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name'

* support union

* fix typing in kandinsky

* make

* add LCMScheduler

* 'LCMScheduler' object has no attribute 'sigmas'

* tests for wrong scheduler

* make

* update

* warning

* tests

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* import FlaxSchedulerMixin

* skip scheduler

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* Check correct model type is passed to `from_pretrained`

* Flax, skip scheduler

* test_wrong_model

* Fix for scheduler

* Update tests/pipelines/test_pipelines.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* EnumMeta

* Flax

* scheduler in expected types

* make

* type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name'

* support union

* fix typing in kandinsky

* make

* add LCMScheduler

* 'LCMScheduler' object has no attribute 'sigmas'

* tests for wrong scheduler

* make

* update

* warning

* tests

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* import FlaxSchedulerMixin

* skip scheduler

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[pipelines] add better checking when a wrong model is passed when initializing a pipeline
5 participants