Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion diffsynth_engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.cuda import set_device

from diffsynth_engine.configs import PipelineConfig
from diffsynth_engine.pipelines.utils import (
from diffsynth_engine.pipelines.registry import (
get_pipeline_class,
get_pipeline_class_name,
)
Expand Down
60 changes: 60 additions & 0 deletions diffsynth_engine/pipelines/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import json
import os

from diffsynth_engine.pipelines.base import Pipeline
from diffsynth_engine.plugins import load_general_plugins
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.constants import MODEL_INDEX_NAME
from diffsynth_engine.utils.import_utils import LazyImport

logger = logging.get_logger(__name__)

_DIFFSYNTH_PIPELINES: dict[str, str] = {
"QwenImagePipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage:QwenImagePipeline",
"QwenImageEditPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_edit:QwenImageEditPipeline",
"QwenImageEditPlusPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_edit_plus:QwenImageEditPlusPipeline",
"QwenImageLayeredPipeline": "diffsynth_engine.pipelines.qwen_image.pipeline_qwenimage_layered:QwenImageLayeredPipeline",
}

PIPELINE_REGISTRY: dict[str, LazyImport] = {}
_registry_initialized = False


def register_pipeline(name: str, target: str) -> None:
"""Register a pipeline for lazy import.

`target` must be a "module_name:class_name" string. The class is not
imported until `get_pipeline_class(name)` is called.
"""
module_name, class_name = target.split(":", 1)
PIPELINE_REGISTRY[name] = LazyImport(module_name, class_name)


def _register_builtin_pipelines() -> None:
for name, pipeline_cls in _DIFFSYNTH_PIPELINES.items():
register_pipeline(name, pipeline_cls)


def get_pipeline_class_name(model_path: str) -> str:
model_index_path = os.path.join(model_path, MODEL_INDEX_NAME)
if not os.path.exists(model_index_path):
raise FileNotFoundError(f"Model index file not found: {model_index_path}")

with open(model_index_path, "r", encoding="utf-8") as f:
model_index = json.load(f)

if "_class_name" not in model_index:
raise KeyError(f"_class_name field not found in {model_index_path}")

return model_index["_class_name"]


def get_pipeline_class(name: str) -> type[Pipeline]:
global _registry_initialized
if not _registry_initialized:
_registry_initialized = True
_register_builtin_pipelines()
load_general_plugins()
if name not in PIPELINE_REGISTRY:
raise ValueError(f"Pipeline class {name!r} not found. Available pipelines: {sorted(PIPELINE_REGISTRY)}")
return PIPELINE_REGISTRY[name].load()
Comment thread
molepi40 marked this conversation as resolved.
69 changes: 0 additions & 69 deletions diffsynth_engine/pipelines/utils.py

This file was deleted.

45 changes: 45 additions & 0 deletions diffsynth_engine/plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# References:
# - https://github.com/vllm-project/vllm-omni/blob/v0.20.0/vllm_omni/plugins/__init__.py

from collections.abc import Callable
from importlib.metadata import entry_points

from diffsynth_engine.utils import logging

logger = logging.get_logger(__name__)

# entry point group name
DIFFSYNTH_DEFAULT_PLUGINS_GROUP = "diffsynth_engine.general_plugins"

_plugins_loaded = False


def load_plugins_by_group(group: str) -> dict[str, Callable]:
"""Discover external plugins via entry points."""
plugins: dict[str, Callable] = {}
for ep in entry_points(group=group):
try:
func = ep.load()
plugins[ep.name] = func
except Exception:
logger.exception("Failed to load plugin %r (%s)", ep.name, ep.value)
continue
logger.info("Loaded plugin %r from %s", ep.name, ep.value)

return plugins


def load_general_plugins() -> None:
global _plugins_loaded
if _plugins_loaded:
return
_plugins_loaded = True

plugins = load_plugins_by_group(DIFFSYNTH_DEFAULT_PLUGINS_GROUP)
# execute the loaded functions of general plugins
for name, func in plugins.items():
try:
func()
logger.info("Executed general plugin %r", name)
except Exception:
logger.warning("Failed to execute general plugin %r", name, exc_info=True)
Comment on lines +32 to +45
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a potential race condition here in multi-threaded environments (e.g., when running under a multi-threaded web server).

Setting _plugins_loaded = True at the very beginning of the function means that any concurrent thread calling load_general_plugins() will return immediately, even though the first thread is still in the middle of loading and executing the plugins. This can cause the second thread to attempt to use pipelines or plugins before they are fully registered.

We should use a threading.Lock to ensure thread-safe initialization, and only set _plugins_loaded = True after the plugins have been successfully loaded and executed.

def load_general_plugins() -> None:
    global _plugins_loaded
    if _plugins_loaded:
        return
    import threading
    lock = getattr(load_general_plugins, "_lock", None)
    if lock is None:
        lock = threading.Lock()
        setattr(load_general_plugins, "_lock", lock)
    with lock:
        if _plugins_loaded:
            return
        plugins = load_plugins_by_group(DIFFSYNTH_DEFAULT_PLUGINS_GROUP)
        # execute the loaded functions of general plugins
        for name, func in plugins.items():
            try:
                func()
                logger.info("Executed general plugin %r", name)
            except Exception:
                logger.warning("Failed to execute general plugin %r", name, exc_info=True)
        _plugins_loaded = True

6 changes: 3 additions & 3 deletions diffsynth_engine/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ def __init__(self, module_name: str, class_name: str):
self.class_name = class_name
self._module = None

def _load(self):
def load(self):
if self._module is None:
module = importlib.import_module(self.module_name)
self._module = getattr(module, self.class_name)
return self._module

def __getattr__(self, name: str):
module = self._load()
module = self.load()
return getattr(module, name)

def __call__(self, *args, **kwargs):
module = self._load()
module = self.load()
return module(*args, **kwargs)
4 changes: 2 additions & 2 deletions diffsynth_engine/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
init_distributed_environment,
initialize_model_parallel,
)
from diffsynth_engine.pipelines.utils import get_pipeline_class
from diffsynth_engine.pipelines.registry import get_pipeline_class
from diffsynth_engine.utils import logging
from diffsynth_engine.utils.torch_profiler import TorchProfiler

Expand Down Expand Up @@ -69,7 +69,7 @@ def stop_profile(self, **kwargs):
result = TorchProfiler.stop()
get_world_group().barrier()
return result

def __getattr__(self, name):
pipeline = self.__dict__.get("pipeline")
if pipeline is None:
Expand Down