-
Notifications
You must be signed in to change notification settings - Fork 403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix infer task for stable diffusion #1793
Fix infer task for stable diffusion #1793
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
optimum/exporters/tasks.py
Outdated
class_name = model_info.config["diffusers"].get("class_name", None) or model_info.config[ | ||
"diffusers" | ||
].get("_class_name", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class_name = model_info.config["diffusers"].get("class_name", None) or model_info.config[ | |
"diffusers" | |
].get("_class_name", None) | |
if hasattr(model_info.config["diffusers"], "class_name"): | |
class_name = model_info.config["diffusers"]["class_name"] | |
elif hasattr(model_info.config["diffusers"], "_class_name"): | |
class_name = model_info.config["diffusers"]["_class_name"] | |
else: | |
raise ValueError(f"Could not automatically infer the class name for {model_name_or_path}. Please open an issue at https://github.com/huggingface/optimum/issues.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fxmarty model_info.config["diffusers"]
is a dictionary and class_name
or _class_name
are keys instead of attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel free to edit accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JingyaHuang Is the suggestion fine as is in the end?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you
* fix * apply suggestions
What does this PR do?
The current task inference is not working for stable diffusion checkpoints, eg.
runwayml/stable-diffusion-v1-5
p.s. I did not see checkpoints using
class_name
instead of_class_name
, if there is not we shall remove it (keep it for now).Before submitting