From bb055ae21428a5cd1e1c9b0b38ec7aad2068b86d Mon Sep 17 00:00:00 2001 From: molepi40 Date: Fri, 5 Jun 2026 17:59:59 +0800 Subject: [PATCH 1/2] feat: refactor pipeline registry and support external plugin registration - Replace utils.py with dedicated registry module for pipeline management - Add plugins.py for external pipeline registration via exposed function - Lazy import pipeline to improve startup performance - Update engine and worker to use new registry interface --- diffsynth_engine/engine.py | 2 +- diffsynth_engine/pipelines/registry.py | 57 +++++++++++++++++++++ diffsynth_engine/pipelines/utils.py | 69 -------------------------- diffsynth_engine/plugins.py | 45 +++++++++++++++++ diffsynth_engine/utils/import_utils.py | 6 +-- diffsynth_engine/worker.py | 4 +- 6 files changed, 108 insertions(+), 75 deletions(-) create mode 100644 diffsynth_engine/pipelines/registry.py delete mode 100644 diffsynth_engine/pipelines/utils.py create mode 100644 diffsynth_engine/plugins.py diff --git a/diffsynth_engine/engine.py b/diffsynth_engine/engine.py index 767d3ac..03a9a95 100644 --- a/diffsynth_engine/engine.py +++ b/diffsynth_engine/engine.py @@ -4,7 +4,7 @@ from torch.cuda import set_device from diffsynth_engine.configs import PipelineConfig -from diffsynth_engine.pipelines.utils import ( +from diffsynth_engine.pipelines.registry import ( get_pipeline_class, get_pipeline_class_name, ) diff --git a/diffsynth_engine/pipelines/registry.py b/diffsynth_engine/pipelines/registry.py new file mode 100644 index 0000000..9867384 --- /dev/null +++ b/diffsynth_engine/pipelines/registry.py @@ -0,0 +1,57 @@ +import json +import os + +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.plugins import load_general_plugins +from diffsynth_engine.utils import logging +from diffsynth_engine.utils.constants import MODEL_INDEX_NAME +from diffsynth_engine.utils.import_utils import LazyImport + +logger = logging.get_logger(__name__) + +_DIFFSYNTH_PIPELINES: dict[str, str] = { + "QwenImagePipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage:QwenImagePipeline", + "QwenImageEditPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_edit:QwenImageEditPipeline", + "QwenImageEditPlusPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_edit_plus:QwenImageEditPlusPipeline", + "QwenImageLayeredPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_layered:QwenImageLayeredPipeline", +} + +PIPELINE_REGISTRY: dict[str, LazyImport] = {} + + +def register_pipeline(name: str, target: str) -> None: + """Register a pipeline for lazy import. + + `target` must be a "module_name:class_name" string. The class is not + imported until `get_pipeline_class(name)` is called. + """ + module_name, class_name = target.split(":", 1) + PIPELINE_REGISTRY[name] = LazyImport(module_name, class_name) + + +def _register_builtin_pipelines() -> None: + for name, pipeline_cls in _DIFFSYNTH_PIPELINES.items(): + register_pipeline(name, pipeline_cls) + + +def get_pipeline_class_name(model_path: str) -> str: + model_index_path = os.path.join(model_path, MODEL_INDEX_NAME) + if not os.path.exists(model_index_path): + raise FileNotFoundError(f"Model index file not found: {model_index_path}") + + with open(model_index_path, "r", encoding="utf-8") as f: + model_index = json.load(f) + + if "_class_name" not in model_index: + raise KeyError(f"_class_name field not found in {model_index_path}") + + return model_index["_class_name"] + + +def get_pipeline_class(name: str) -> type[Pipeline]: + if not PIPELINE_REGISTRY: + _register_builtin_pipelines() + load_general_plugins() + if name not in PIPELINE_REGISTRY: + raise ValueError(f"Pipeline class {name!r} not found. Available pipelines: {sorted(PIPELINE_REGISTRY)}") + return PIPELINE_REGISTRY[name].load() diff --git a/diffsynth_engine/pipelines/utils.py b/diffsynth_engine/pipelines/utils.py deleted file mode 100644 index 4b0f013..0000000 --- a/diffsynth_engine/pipelines/utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import importlib -import json -import os -import pkgutil -from typing import Dict, Type - -from diffsynth_engine.pipelines.base import Pipeline -from diffsynth_engine.utils import logging -from diffsynth_engine.utils.constants import MODEL_INDEX_NAME - -logger = logging.get_logger(__name__) - - -def _build_pipeline_class_map() -> Dict[str, str]: - pipeline_class_map = {} - module = importlib.import_module("diffsynth_engine.pipelines") - - for _, name, ispkg in pkgutil.iter_modules(module.__path__, "diffsynth_engine.pipelines."): - if not ispkg: - continue - - try: - submodule = importlib.import_module(name) - if not hasattr(submodule, "__all__"): - continue - - for class_name in submodule.__all__: - if not hasattr(submodule, class_name): - continue - - cls = getattr(submodule, class_name) - if isinstance(cls, type) and issubclass(cls, Pipeline): - pipeline_class_map[class_name] = name - except (ImportError, AttributeError, TypeError) as e: - logger.warning(f"Failed to import {name}: {e}", exc_info=True) - continue - - return pipeline_class_map - - -_PIPELINE_CLASS_MAP = _build_pipeline_class_map() - - -def get_pipeline_class_name(model_path: str) -> str: - model_index_path = os.path.join(model_path, MODEL_INDEX_NAME) - if not os.path.exists(model_index_path): - raise FileNotFoundError(f"Model index file not found: {model_index_path}") - - with open(model_index_path, "r", encoding="utf-8") as f: - model_index = json.load(f) - - if "_class_name" not in model_index: - raise KeyError(f"_class_name field not found in {model_index_path}") - - return model_index["_class_name"] - - -def get_pipeline_class(pipeline_class_name: str) -> Type[Pipeline]: - if pipeline_class_name in _PIPELINE_CLASS_MAP: - module_path = _PIPELINE_CLASS_MAP[pipeline_class_name] - module = importlib.import_module(module_path) - if hasattr(module, pipeline_class_name): - pipeline_class = getattr(module, pipeline_class_name) - if not issubclass(pipeline_class, Pipeline): - raise ValueError(f"Class {pipeline_class_name} from {module_path} is not a subclass of Pipeline") - return pipeline_class - raise ValueError( - f"Pipeline class '{pipeline_class_name}' not found. Available pipelines: {list(_PIPELINE_CLASS_MAP.keys())}" - ) diff --git a/diffsynth_engine/plugins.py b/diffsynth_engine/plugins.py new file mode 100644 index 0000000..9a1f576 --- /dev/null +++ b/diffsynth_engine/plugins.py @@ -0,0 +1,45 @@ +# References: +# - https://github.com/vllm-project/vllm-omni/blob/v0.20.0/vllm_omni/plugins/__init__.py + +from collections.abc import Callable +from importlib.metadata import entry_points + +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + +# entry point group name +DIFFSYNTH_DEFAULT_PLUGINS_GROUP = "diffsynth_engine.general_plugins" + +_plugins_loaded = False + + +def load_plugins_by_group(group: str) -> dict[str, Callable]: + """Discover external plugins via entry points.""" + plugins: dict[str, Callable] = {} + for ep in entry_points(group=group): + try: + func = ep.load() + plugins[ep.name] = func + except Exception: + logger.exception("Failed to load plugin %r (%s)", ep.name, ep.value) + continue + logger.info("Loaded plugin %r from %s", ep.name, ep.value) + + return plugins + + +def load_general_plugins() -> None: + global _plugins_loaded + if _plugins_loaded: + return + _plugins_loaded = True + + plugins = load_plugins_by_group(DIFFSYNTH_DEFAULT_PLUGINS_GROUP) + # execute the loaded functions of general plugins + for name, func in plugins.items(): + try: + func() + logger.info("Executed general plugin %r", name) + except Exception: + logger.warning("Failed to execute general plugin %r", name, exc_info=True) diff --git a/diffsynth_engine/utils/import_utils.py b/diffsynth_engine/utils/import_utils.py index 8c7ce7e..82fd271 100644 --- a/diffsynth_engine/utils/import_utils.py +++ b/diffsynth_engine/utils/import_utils.py @@ -9,16 +9,16 @@ def __init__(self, module_name: str, class_name: str): self.class_name = class_name self._module = None - def _load(self): + def load(self): if self._module is None: module = importlib.import_module(self.module_name) self._module = getattr(module, self.class_name) return self._module def __getattr__(self, name: str): - module = self._load() + module = self.load() return getattr(module, name) def __call__(self, *args, **kwargs): - module = self._load() + module = self.load() return module(*args, **kwargs) diff --git a/diffsynth_engine/worker.py b/diffsynth_engine/worker.py index 06e0d3a..f6b4471 100644 --- a/diffsynth_engine/worker.py +++ b/diffsynth_engine/worker.py @@ -9,7 +9,7 @@ init_distributed_environment, initialize_model_parallel, ) -from diffsynth_engine.pipelines.utils import get_pipeline_class +from diffsynth_engine.pipelines.registry import get_pipeline_class from diffsynth_engine.utils import logging from diffsynth_engine.utils.torch_profiler import TorchProfiler @@ -69,7 +69,7 @@ def stop_profile(self, **kwargs): result = TorchProfiler.stop() get_world_group().barrier() return result - + def __getattr__(self, name): pipeline = self.__dict__.get("pipeline") if pipeline is None: From c555d37ee18507cb4327734f4dced77dd58986ef Mon Sep 17 00:00:00 2001 From: molepi40 Date: Fri, 5 Jun 2026 18:14:38 +0800 Subject: [PATCH 2/2] use explicit registry initialization --- diffsynth_engine/pipelines/registry.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/diffsynth_engine/pipelines/registry.py b/diffsynth_engine/pipelines/registry.py index 9867384..d3f7ac5 100644 --- a/diffsynth_engine/pipelines/registry.py +++ b/diffsynth_engine/pipelines/registry.py @@ -17,6 +17,7 @@ } PIPELINE_REGISTRY: dict[str, LazyImport] = {} +_registry_initialized = False def register_pipeline(name: str, target: str) -> None: @@ -49,7 +50,9 @@ def get_pipeline_class_name(model_path: str) -> str: def get_pipeline_class(name: str) -> type[Pipeline]: - if not PIPELINE_REGISTRY: + global _registry_initialized + if not _registry_initialized: + _registry_initialized = True _register_builtin_pipelines() load_general_plugins() if name not in PIPELINE_REGISTRY: