-
Notifications
You must be signed in to change notification settings - Fork 45
feat: refactor pipeline registry and support external plugin registra… #254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| import json | ||
| import os | ||
|
|
||
| from diffsynth_engine.pipelines.base import Pipeline | ||
| from diffsynth_engine.plugins import load_general_plugins | ||
| from diffsynth_engine.utils import logging | ||
| from diffsynth_engine.utils.constants import MODEL_INDEX_NAME | ||
| from diffsynth_engine.utils.import_utils import LazyImport | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| _DIFFSYNTH_PIPELINES: dict[str, str] = { | ||
| "QwenImagePipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage:QwenImagePipeline", | ||
| "QwenImageEditPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_edit:QwenImageEditPipeline", | ||
| "QwenImageEditPlusPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_edit_plus:QwenImageEditPlusPipeline", | ||
| "QwenImageLayeredPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_layered:QwenImageLayeredPipeline", | ||
| } | ||
|
|
||
| PIPELINE_REGISTRY: dict[str, LazyImport] = {} | ||
| _registry_initialized = False | ||
|
|
||
|
|
||
| def register_pipeline(name: str, target: str) -> None: | ||
| """Register a pipeline for lazy import. | ||
|
|
||
| `target` must be a "module_name:class_name" string. The class is not | ||
| imported until `get_pipeline_class(name)` is called. | ||
| """ | ||
| module_name, class_name = target.split(":", 1) | ||
| PIPELINE_REGISTRY[name] = LazyImport(module_name, class_name) | ||
|
|
||
|
|
||
| def _register_builtin_pipelines() -> None: | ||
| for name, pipeline_cls in _DIFFSYNTH_PIPELINES.items(): | ||
| register_pipeline(name, pipeline_cls) | ||
|
|
||
|
|
||
| def get_pipeline_class_name(model_path: str) -> str: | ||
| model_index_path = os.path.join(model_path, MODEL_INDEX_NAME) | ||
| if not os.path.exists(model_index_path): | ||
| raise FileNotFoundError(f"Model index file not found: {model_index_path}") | ||
|
|
||
| with open(model_index_path, "r", encoding="utf-8") as f: | ||
| model_index = json.load(f) | ||
|
|
||
| if "_class_name" not in model_index: | ||
| raise KeyError(f"_class_name field not found in {model_index_path}") | ||
|
|
||
| return model_index["_class_name"] | ||
|
|
||
|
|
||
| def get_pipeline_class(name: str) -> type[Pipeline]: | ||
| global _registry_initialized | ||
| if not _registry_initialized: | ||
| _registry_initialized = True | ||
| _register_builtin_pipelines() | ||
| load_general_plugins() | ||
| if name not in PIPELINE_REGISTRY: | ||
| raise ValueError(f"Pipeline class {name!r} not found. Available pipelines: {sorted(PIPELINE_REGISTRY)}") | ||
| return PIPELINE_REGISTRY[name].load() | ||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| # References: | ||
| # - https://github.com/vllm-project/vllm-omni/blob/v0.20.0/vllm_omni/plugins/__init__.py | ||
|
|
||
| from collections.abc import Callable | ||
| from importlib.metadata import entry_points | ||
|
|
||
| from diffsynth_engine.utils import logging | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| # entry point group name | ||
| DIFFSYNTH_DEFAULT_PLUGINS_GROUP = "diffsynth_engine.general_plugins" | ||
|
|
||
| _plugins_loaded = False | ||
|
|
||
|
|
||
| def load_plugins_by_group(group: str) -> dict[str, Callable]: | ||
| """Discover external plugins via entry points.""" | ||
| plugins: dict[str, Callable] = {} | ||
| for ep in entry_points(group=group): | ||
| try: | ||
| func = ep.load() | ||
| plugins[ep.name] = func | ||
| except Exception: | ||
| logger.exception("Failed to load plugin %r (%s)", ep.name, ep.value) | ||
| continue | ||
| logger.info("Loaded plugin %r from %s", ep.name, ep.value) | ||
|
|
||
| return plugins | ||
|
|
||
|
|
||
| def load_general_plugins() -> None: | ||
| global _plugins_loaded | ||
| if _plugins_loaded: | ||
| return | ||
| _plugins_loaded = True | ||
|
|
||
| plugins = load_plugins_by_group(DIFFSYNTH_DEFAULT_PLUGINS_GROUP) | ||
| # execute the loaded functions of general plugins | ||
| for name, func in plugins.items(): | ||
| try: | ||
| func() | ||
| logger.info("Executed general plugin %r", name) | ||
| except Exception: | ||
| logger.warning("Failed to execute general plugin %r", name, exc_info=True) | ||
|
Comment on lines
+32
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a potential race condition here in multi-threaded environments (e.g., when running under a multi-threaded web server). Setting We should use a def load_general_plugins() -> None:
global _plugins_loaded
if _plugins_loaded:
return
import threading
lock = getattr(load_general_plugins, "_lock", None)
if lock is None:
lock = threading.Lock()
setattr(load_general_plugins, "_lock", lock)
with lock:
if _plugins_loaded:
return
plugins = load_plugins_by_group(DIFFSYNTH_DEFAULT_PLUGINS_GROUP)
# execute the loaded functions of general plugins
for name, func in plugins.items():
try:
func()
logger.info("Executed general plugin %r", name)
except Exception:
logger.warning("Failed to execute general plugin %r", name, exc_info=True)
_plugins_loaded = True |
||
Uh oh!
There was an error while loading. Please reload this page.