Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhancement]:Add queue interface #3444

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ venv/
!/web/extensions/logging.js.example
!/web/extensions/core/
/tests-ui/data/object_info.json
/user/
/user/
/queues/*
!/queues/example_queue.py.example
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __call__(self, parser, namespace, values, option_string=None):


parser = argparse.ArgumentParser()
parser.add_argument("--queue_name", type=str, help="Set the Prompt Queue.")

parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
Expand Down
4 changes: 3 additions & 1 deletion execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import nodes

import comfy.model_management
from prompt_queue_interface import PromptQueueInterface


def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}):
valid_inputs = class_def.INPUT_TYPES()
Expand Down Expand Up @@ -710,7 +712,7 @@ def validate_prompt(prompt):

MAXIMUM_HISTORY_SIZE = 10000

class PromptQueue:
class PromptQueue(PromptQueueInterface):
def __init__(self, server):
self.server = server
self.mutex = threading.RLock()
Expand Down
2 changes: 2 additions & 0 deletions folder_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])

folder_names_and_paths["queues"] = ([os.path.join(base_path, "queues")], [])

folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)

folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
Expand Down
5 changes: 4 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import importlib.util
import folder_paths
import time
from prompt_queue_load import load_queues, create_prompt_queue


def execute_prestartup_script():
def execute_script(script_path):
Expand Down Expand Up @@ -204,7 +206,8 @@ def load_extra_path_config(yaml_path):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
server = server.PromptServer(loop)
q = execution.PromptQueue(server)
load_queues()
q = create_prompt_queue(server, args.queue_name)

extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
if os.path.isfile(extra_model_paths_config_path):
Expand Down
48 changes: 48 additions & 0 deletions prompt_queue_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from abc import ABC, abstractmethod


class PromptQueueInterface(ABC):
@abstractmethod
def __init__(self, server):
pass

@abstractmethod
def put(self, item):
pass

@abstractmethod
def get(self, timeout=None):
pass

@abstractmethod
def get_history(self, prompt_id=None, max_items=None, offset=-1):
pass

@abstractmethod
def delete_queue_item(self, function):
pass

@abstractmethod
def wipe_history(self):
pass

@abstractmethod
def delete_history_item(self, id_to_delete):
pass

@abstractmethod
def get_tasks_remaining(self):
pass

@abstractmethod
def get_current_queue(self):
pass

@abstractmethod
def wipe_queue(self):
pass

@abstractmethod
def set_flag(self, name, data):
pass

92 changes: 92 additions & 0 deletions prompt_queue_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import execution
import folder_paths
import os
import time
import sys
import logging
import traceback
import importlib

from prompt_queue_interface import PromptQueueInterface

QUEUE_CLASS_MAPPINGS = {
}


def check_queue_class(cls) -> bool:
if cls is None or not isinstance(cls, type) or not issubclass(cls, PromptQueueInterface):
return False
return True


def create_prompt_queue(server, queuq_name: str) -> PromptQueueInterface:
PromptQueueClass = QUEUE_CLASS_MAPPINGS.get(queuq_name, execution.PromptQueue)
return PromptQueueClass(server)


def load_queues():
base_names = set(QUEUE_CLASS_MAPPINGS.keys())
queue_import_times = []

paths = folder_paths.get_folder_paths("queues")
for queue_path in paths:
possible_modules = os.listdir(os.path.realpath(queue_path))

if "__pycache__" in possible_modules:
possible_modules.remove("__pycache__")

for possible_module in possible_modules:
module_path = os.path.join(queue_path, possible_module)

if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py":
continue

if module_path.endswith(".disabled"):
continue

time_before = time.perf_counter()
success = load_queue(module_path, base_names)
queue_import_times.append((time.perf_counter() - time_before, module_path, success))

if len(queue_import_times) > 0:
logging.info("\nImport times for queues:")
for n in sorted(queue_import_times):
if n[2]:
queue_import_times = ""
else:
queue_import_times = " (IMPORT FAILED)"
logging.info("{:6.1f} seconds{}: {}".format(n[0], queue_import_times, n[1]))
logging.info("")


def load_queue(module_path, ignore: set):
module_name = os.path.basename(module_path)

if os.path.isfile(module_path):
sp = os.path.splitext(module_path)
module_name = sp[0]

try:
logging.debug("Trying to load queue {}".format(module_path))

if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
else:
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))

module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)

if hasattr(module, "QUEUE_CLASS_MAPPINGS") and getattr(module, "QUEUE_CLASS_MAPPINGS") is not None:
for name in module.QUEUE_CLASS_MAPPINGS:
if name not in ignore and check_queue_class(module.QUEUE_CLASS_MAPPINGS[name]):
QUEUE_CLASS_MAPPINGS[name] = module.QUEUE_CLASS_MAPPINGS[name]
return True
else:
logging.warning(f"Skip {module_path} module for queue due to the lack of QUEUE_CLASS_MAPPINGS.")
return False
except Exception as e:
logging.warning(traceback.format_exc())
logging.warning(f"Cannot import {module_path} module for queue: {e}")
return False
30 changes: 30 additions & 0 deletions queues/example_queue.py.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json
import copy
from execution import PromptQueue, MAXIMUM_HISTORY_SIZE
from typing import Optional


class ExamplePromptQueue(PromptQueue): # Child类继承自Parent类
def task_done(self, item_id, outputs,
status: Optional['PromptQueue.ExecutionStatus']):
with self.mutex:
prompt = self.currently_running.pop(item_id)
if len(self.history) > MAXIMUM_HISTORY_SIZE:
self.history.pop(next(iter(self.history)))

status_dict: Optional[dict] = None
if status is not None:
status_dict = copy.deepcopy(status._asdict())
record = {
"prompt": prompt,
"outputs": copy.deepcopy(outputs),
'status': status_dict,
}
print(record)
self.history[prompt[1]] = record
self.server.queue_updated()


QUEUE_CLASS_MAPPINGS = {
"ExamplePromptQueue": ExamplePromptQueue
}
3 changes: 3 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import comfy.model_management

from app.user_manager import UserManager
from prompt_queue_interface import PromptQueueInterface


class BinaryEventTypes:
PREVIEW_IMAGE = 1
Expand Down Expand Up @@ -62,6 +64,7 @@ async def cors_middleware(request: web.Request, handler):
return cors_middleware

class PromptServer():
prompt_queue: PromptQueueInterface
def __init__(self, loop):
PromptServer.instance = self

Expand Down