Skip to content

Commit

Permalink
Fix infer task for stable diffusion (#1793)
Browse files Browse the repository at this point in the history
* fix

* apply suggestions
  • Loading branch information
JingyaHuang committed Apr 5, 2024
1 parent 253c6c2 commit 5584eb8
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,7 +1560,14 @@ def _infer_task_from_model_name_or_path(
library_name = TasksManager.infer_library_from_model(model_name_or_path, subfolder, revision)

if library_name == "diffusers":
class_name = model_info.config["diffusers"]["class_name"]
if model_info.config["diffusers"].get("class_name", None):
class_name = model_info.config["diffusers"]["class_name"]
elif model_info.config["diffusers"].get("_class_name", None):
class_name = model_info.config["diffusers"]["_class_name"]
else:
raise ValueError(
f"Could not automatically infer the class name for {model_name_or_path}. Please open an issue at https://github.com/huggingface/optimum/issues."
)
inferred_task_name = "stable-diffusion-xl" if "StableDiffusionXL" in class_name else "stable-diffusion"
elif library_name == "timm":
inferred_task_name = "image-classification"
Expand Down

0 comments on commit 5584eb8

Please sign in to comment.