-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Check correct model type is passed to from_pretrained
#10189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): | ||
if key not in passed_class_obj: | ||
continue | ||
class_name = passed_class_obj[key].__class__.__name__ | ||
if class_name != expected_class_name: | ||
raise ValueError(f"Expected {expected_class_name} for {key}, got {class_name}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a test for this too?
Thanks @hlky! This looks leaner than I had thought. thank you! |
Thanks @sayakpaul. I've added a test, trimmed |
if key not in passed_class_obj or key == "scheduler": | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we pass scheduler=text_encoder
that should be errored out as well, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some special handling for scheduler
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
tests/pipelines/test_pipelines.py
Outdated
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") | ||
with self.assertRaises(ValueError) as error_context: | ||
_ = StableDiffusionPipeline.from_pretrained( | ||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer | ||
) | ||
|
||
assert "Expected" in str(error_context.exception) | ||
assert "text_encoder" in str(error_context.exception) | ||
assert f"{tokenizer.__class__.__name}" in str(error_context.exception) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also a check for the scheduler
as that is handled slightly differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add it. For context this is what we're handling:
diffusers/src/diffusers/schedulers/scheduling_utils.py
Lines 33 to 48 in d041dd5
class KarrasDiffusionSchedulers(Enum): | |
DDIMScheduler = 1 | |
DDPMScheduler = 2 | |
PNDMScheduler = 3 | |
LMSDiscreteScheduler = 4 | |
EulerDiscreteScheduler = 5 | |
HeunDiscreteScheduler = 6 | |
EulerAncestralDiscreteScheduler = 7 | |
DPMSolverMultistepScheduler = 8 | |
DPMSolverSinglestepScheduler = 9 | |
KDPM2DiscreteScheduler = 10 | |
KDPM2AncestralDiscreteScheduler = 11 | |
DEISMultistepScheduler = 12 | |
UniPCMultistepScheduler = 13 | |
DPMSolverSDEScheduler = 14 | |
EDMEulerScheduler = 15 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's cool. But I don't see the flow matching schedulers here. So, if I do assign a text encoder to scheduler
in an RF pipeline (FluxPipeline
, for example), would it still work as expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that also works, for pipelines like Flux we're getting the type
<class 'diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler'>
For SD etc we get the enum
<enum 'KarrasDiffusionSchedulers'>
[<KarrasDiffusionSchedulers.DDIMScheduler: 1>, <KarrasDiffusionSchedulers.DDPMScheduler: 2>, <KarrasDiffusionSchedulers.PNDMScheduler: 3>, <KarrasDiffusionSchedulers.LMSDiscreteScheduler: 4>, <KarrasDiffusionSchedulers.EulerDiscreteScheduler: 5>, <KarrasDiffusionSchedulers.HeunDiscreteScheduler: 6>, <KarrasDiffusionSchedulers.EulerAncestralDiscreteScheduler: 7>, <KarrasDiffusionSchedulers.DPMSolverMultistepScheduler: 8>, <KarrasDiffusionSchedulers.DPMSolverSinglestepScheduler: 9>, <KarrasDiffusionSchedulers.KDPM2DiscreteScheduler: 10>, <KarrasDiffusionSchedulers.KDPM2AncestralDiscreteScheduler: 11>, <KarrasDiffusionSchedulers.DEISMultistepScheduler: 12>, <KarrasDiffusionSchedulers.UniPCMultistepScheduler: 13>, <KarrasDiffusionSchedulers.DPMSolverSDEScheduler: 14>, <KarrasDiffusionSchedulers.EDMEulerScheduler: 15>]
So we apply the same processing (str
, split
, strip
applies for type
case) to get a list of scheduler
['FlowMatchEulerDiscreteScheduler']
['DDIMScheduler', 'DDPMScheduler', 'PNDMScheduler', 'LMSDiscreteScheduler', 'EulerDiscreteScheduler', 'HeunDiscreteScheduler', 'EulerAncestralDiscreteScheduler', 'DPMSolverMultistepScheduler', 'DPMSolverSinglestepScheduler', 'KDPM2DiscreteScheduler', 'KDPM2AncestralDiscreteScheduler', 'DEISMultistepScheduler', 'UniPCMultistepScheduler', 'DPMSolverSDEScheduler', 'EDMEulerScheduler']
If it's not a scheduler it will raise or if it's the wrong type of scheduler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining! Works for me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now also support Union
, context is failed test test_load_connected_checkpoint_with_passed_obj
for KandinskyV22CombinedPipeline
, we also change scheduler type to Union[DDPMScheduler, UnCLIPScheduler]
, the test is actually for passing obj to submodels, but changing the scheduler is how that test works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests for wrong scheduler are added.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Go to go from my side once the scheduler related tests are added. Thanks!
Note that we add |
scheduler_types.extend([str(scheduler_type)]) | ||
scheduler_types = [str(scheduler).split(".")[-1].strip("'>") for scheduler in scheduler_types] | ||
|
||
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we already have the types extracted in expected_types
can't we fetch them using the key and then check if the passed object is an instance of the type? If the expected type is an enum then we can check if the passed obj class name exists in the keys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might be better to make this check more agnostic to the component names.
We have a few pipelines with Union types on non-scheduler components (mostly AnimateDiff). So this snippet would fail even though it's valid, because init_dict is based on the model_index.json which doesn't support multiple types.
from diffusers import (
AnimateDiffPipeline,
UNetMotionModel,
)
unet = UNetMotionModel()
pipe = AnimateDiffPipeline.from_pretrained(
"hf-internal-testing/tiny-sd-pipe", unet=unet
)
Enforcing scheduler types might be a breaking change cc: @yiyixuxu . e.g. Using DDIM with Kandinsky is currently valid, but with this change any downstream code doing this it would break. It would be good to enforce on the pipelines with Flow based schedulers though? (perhaps via a new Enum)
I would try something like:
for key, (_, expected_class_name) in zip(init_dict.keys(), init_dict.values()):
if key not in passed_class_obj:
continue
class_obj = passed_class_obj[key]
_expected_class_types = []
for expected_type in expected_types[key]:
if isinstance(expected_type, enum.EnumMeta):
_expected_class_types.extend(expected_type.__members__.keys())
else:
_expected_class_types.append(expected_type.__name__)
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if isinstance(class_obj, SchedulerMixin) and not _is_valid_type:
# Handle case where scheduler is still valid
# raise if scheduler is meant to be a Flow based scheduler?
elif not _is_valid_type:
raise ValueError(f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this for scheduler
_requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types)
_is_flow_match = "FlowMatch" in class_obj.__class__.__name__
if _requires_flow_match and not _is_flow_match:
raise ValueError(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.")
elif not _requires_flow_match and _is_flow_match:
raise ValueError(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need a value error here, a warning is enough, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A warning should be sufficient, it's mainly for the situation here #10093 (comment) where the wrong text encoder is given because the resulting error is uninformative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do a warning then:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just chiming here a bit to share a perspective as a user (not a strong opinion). Related to #10189 (comment).
Here
if there's an unexpected module passed we raise a value error. I think the check is almost along similar lines -- users are passing assigning components that are unexpected / incompatible. We probably cannot predict the consequences of allowing the loading without raising any errors but if we raise an error, users would know what to do to fix the in correct behaviour.
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") | ||
with self.assertRaises(ValueError) as error_context: | ||
_ = StableDiffusionPipeline.from_pretrained( | ||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer | ||
) | ||
|
||
assert "is of type" in str(error_context.exception) | ||
assert "but should be" in str(error_context.exception) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're now using warning, but for this case, CLIPTokenizer
in text_encoder
we still get a ValueError
later on from here
diffusers/src/diffusers/pipelines/pipeline_utils.py
Lines 893 to 897 in 6324340
# if the model is in a pipeline module, then we load it from the pipeline | |
# check that passed_class_obj has correct parent class | |
maybe_raise_or_warn( | |
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module | |
) |
So it's a little inconsistent and needs further testing to determine which other cases this already applies to.
|
||
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types | ||
if isinstance(class_obj, SchedulerMixin) and not _is_valid_type: | ||
_requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think checking against a FlowMatchSchedulers
enum would be better in case we end up not using "FlowMatch" in the class name.
_requires_flow_match = any(class_type in FlowMatchSchedulers.__members__ for class_type in _expected_class_types)
_is_flow_match = class_obj.__class__.__name__ in FlowMatchSchedulers
cc: @yiyixuxu
_requires_flow_match = any("FlowMatch" in class_type for class_type in _expected_class_types) | ||
_is_flow_match = "FlowMatch" in class_obj.__class__.__name__ | ||
if _requires_flow_match and not _is_flow_match: | ||
logger.warning(f"Expected FlowMatch scheduler, got {class_obj.__class__.__name__}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably okay to raise an error here because scheduler.scale_noise
would raise an error in the flow matching pipelines if a non-FlowMatch scheduler is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really don't want to raise any error here because type hint was not something enforced in this library and it is hard even for us to tell which schedulers can be used/cannot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e.g kandinsky if memory serves I think ddim may also works with some of the pipelines, and the compatibility may change
elif not _requires_flow_match and _is_flow_match: | ||
logger.warning(f"Expected non-FlowMatch scheduler, got {class_obj.__class__.__name__}.") | ||
elif not _is_valid_type: | ||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if it's not a scheduler and the types don't match it's okay to raise an error. I think it would break in the model loading step anyway in this case. wdyt @yiyixuxu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer a warning because:
- I think there is very little /no benefits in raising an error vs a warning here
- in case we make a mistake in type hint, we will throw an error by mistake
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we just added use_flow_sigma
to a few non-flow match schedulers with the SANA pr, and also we plan to refactor them but don't have a design finalized yet
given that, I think maybe we can skip checking for scheduler altogether for now, and revisit later. let me know what you guys think!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed scheduler related changes for now, I think we can revisit that later, as @yiyixuxu mentioned above type hints haven't been strictly enforced there are probably some missing/wrong, especially for schedulers. Warning is better because of that too, if there is some wrong type hint that makes its way into a release we'd have to issue a hotfix release to fix it, that just creates headaches and issue reports.
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @hlky !
Thanks for taking it up, @hlky! To me it's a really nice QoL improvement from a DX perspective. |
…10189) * Check correct model type is passed to `from_pretrained` * Flax, skip scheduler * test_wrong_model * Fix for scheduler * Update tests/pipelines/test_pipelines.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * EnumMeta * Flax * scheduler in expected types * make * type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name' * support union * fix typing in kandinsky * make * add LCMScheduler * 'LCMScheduler' object has no attribute 'sigmas' * tests for wrong scheduler * make * update * warning * tests * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * import FlaxSchedulerMixin * skip scheduler --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
* Check correct model type is passed to `from_pretrained` * Flax, skip scheduler * test_wrong_model * Fix for scheduler * Update tests/pipelines/test_pipelines.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * EnumMeta * Flax * scheduler in expected types * make * type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name' * support union * fix typing in kandinsky * make * add LCMScheduler * 'LCMScheduler' object has no attribute 'sigmas' * tests for wrong scheduler * make * update * warning * tests * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * import FlaxSchedulerMixin * skip scheduler --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
What does this PR do?
Example
Fixes #10093
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc @sayakpaul