From b9b7e260143d5688793dc3cd617a03842f78cfdf Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Mon, 8 Jan 2024 16:38:44 +0100 Subject: [PATCH 1/2] Replace global execution options with node context parameter --- backend/src/api/__init__.py | 1 + backend/src/api/api.py | 105 ++---------------- backend/src/api/node_check.py | 22 +++- backend/src/api/node_context.py | 45 ++++++++ backend/src/api/settings.py | 98 ++++++++++++++-- backend/src/api/types.py | 1 + .../ncnn/processing/upscale_image.py | 15 ++- .../src/packages/chaiNNer_ncnn/settings.py | 6 +- .../onnx/processing/remove_background.py | 5 +- .../onnx/processing/upscale_image.py | 5 +- .../src/packages/chaiNNer_onnx/settings.py | 6 +- .../chaiNNer_pytorch/pytorch/io/load_model.py | 8 +- .../pytorch/processing/guided_upscale.py | 5 +- .../pytorch/processing/inpaint.py | 5 +- .../pytorch/processing/upscale_image.py | 5 +- .../pytorch/restoration/upscale_face.py | 5 +- .../pytorch/utility/convert_to_ncnn.py | 6 +- .../pytorch/utility/convert_to_onnx.py | 6 +- .../pytorch/utility/interpolate_models.py | 2 - .../src/packages/chaiNNer_pytorch/settings.py | 6 +- backend/src/process.py | 66 ++++++++++- backend/src/server.py | 8 +- backend/src/settings.py | 17 --- 23 files changed, 287 insertions(+), 161 deletions(-) create mode 100644 backend/src/api/node_context.py delete mode 100644 backend/src/settings.py diff --git a/backend/src/api/__init__.py b/backend/src/api/__init__.py index 062f95096..b9975e4e2 100644 --- a/backend/src/api/__init__.py +++ b/backend/src/api/__init__.py @@ -1,6 +1,7 @@ from .api import * from .group import * from .input import * +from .node_context import * from .output import * from .settings import * from .types import * diff --git a/backend/src/api/api.py b/backend/src/api/api.py index 7401b4e4a..f3432e1fe 100644 --- a/backend/src/api/api.py +++ b/backend/src/api/api.py @@ -8,10 +8,8 @@ Callable, Generic, Iterable, - NewType, TypedDict, TypeVar, - Union, ) from sanic.log import logger @@ -29,8 +27,8 @@ check_schema_types, ) from .output import BaseOutput -from .settings import SettingsJson, get_execution_options -from .types import InputId, NodeId, NodeType, OutputId, RunFn +from .settings import Setting +from .types import FeatureId, InputId, NodeId, NodeType, OutputId, RunFn KB = 1024**1 MB = 1024**2 @@ -124,6 +122,7 @@ class NodeData: side_effects: bool deprecated: bool + node_context: bool features: list[FeatureId] run: RunFn @@ -180,6 +179,7 @@ def register( limited_to_8bpc: bool | str = False, iterator_inputs: list[IteratorInputInfo] | IteratorInputInfo | None = None, iterator_outputs: list[IteratorOutputInfo] | IteratorOutputInfo | None = None, + node_context: bool = False, ): if not isinstance(description, str): description = "\n\n".join(description) @@ -233,7 +233,9 @@ def inner_wrapper(wrapped_func: T) -> T: if node_type == "regularNode": run_check( TYPE_CHECK_LEVEL, - lambda _: check_schema_types(wrapped_func, p_inputs, p_outputs), + lambda _: check_schema_types( + wrapped_func, p_inputs, p_outputs, node_context + ), ) run_check( NAME_CHECK_LEVEL, @@ -258,6 +260,7 @@ def inner_wrapper(wrapped_func: T) -> T: iterator_outputs=iterator_outputs, side_effects=side_effects, deprecated=deprecated, + node_context=node_context, features=features, run=wrapped_func, ) @@ -322,9 +325,6 @@ def to_dict(self): } -FeatureId = NewType("FeatureId", str) - - @dataclass class Feature: id: str @@ -366,89 +366,6 @@ def disabled(details: str | None = None) -> FeatureState: return FeatureState(is_enabled=False, details=details) -@dataclass -class ToggleSetting: - label: str - key: str - description: str - default: bool = False - disabled: bool = False - type: str = "toggle" - - -class DropdownOption(TypedDict): - label: str - value: str - - -@dataclass -class DropdownSetting: - label: str - key: str - description: str - options: list[DropdownOption] - default: str - disabled: bool = False - type: str = "dropdown" - - -@dataclass -class NumberSetting: - label: str - key: str - description: str - min: float - max: float - default: float = 0 - disabled: bool = False - type: str = "number" - - -@dataclass -class CacheSetting: - label: str - key: str - description: str - directory: str - default: str = "" - disabled: bool = False - type: str = "cache" - - -Setting = Union[ToggleSetting, DropdownSetting, NumberSetting, CacheSetting] - - -class SettingsParser: - def __init__(self, raw: SettingsJson) -> None: - self.__settings = raw - - def get_bool(self, key: str, default: bool) -> bool: - value = self.__settings.get(key, default) - if isinstance(value, bool): - return value - raise ValueError(f"Invalid bool value for {key}: {value}") - - def get_int(self, key: str, default: int, parse_str: bool = False) -> int: - value = self.__settings.get(key, default) - if parse_str and isinstance(value, str): - return int(value) - if isinstance(value, int) and not isinstance(value, bool): - return value - raise ValueError(f"Invalid str value for {key}: {value}") - - def get_str(self, key: str, default: str) -> str: - value = self.__settings.get(key, default) - if isinstance(value, str): - return value - raise ValueError(f"Invalid str value for {key}: {value}") - - def get_cache_location(self, key: str) -> str | None: - value = self.__settings.get(key) - if isinstance(value, str) or value is None: - return value or None - raise ValueError(f"Invalid cache location value for {key}: {value}") - - @dataclass class Package: where: str @@ -501,9 +418,6 @@ def add_feature( self.features.append(feature) return feature - def get_settings(self) -> SettingsParser: - return SettingsParser(get_execution_options().get_package_settings(self.id)) - def _iter_py_files(directory: str): for root, _, files in os.walk(directory): @@ -528,6 +442,9 @@ def __init__(self) -> None: def get_node(self, schema_id: str) -> NodeData: return self.nodes[schema_id][0] + def get_package(self, schema_id: str) -> Package: + return self.nodes[schema_id][1].category.package + def add(self, package: Package) -> Package: # assert package.where not in self.packages self.packages[package.where] = package diff --git a/backend/src/api/node_check.py b/backend/src/api/node_check.py index c040a372f..2fb8a2955 100644 --- a/backend/src/api/node_check.py +++ b/backend/src/api/node_check.py @@ -4,10 +4,12 @@ import inspect import os import pathlib +from collections import OrderedDict from enum import Enum from typing import Any, Callable, NewType, Tuple, Union, cast, get_args from .input import BaseInput +from .node_context import NodeContext from .output import BaseOutput _Ty = NewType("_Ty", object) @@ -190,24 +192,38 @@ def check_schema_types( wrapped_func: Callable, inputs: list[BaseInput], outputs: list[BaseOutput], + node_context: bool, ): """ Runtime validation for the number of inputs/outputs compared to the type args """ - ann = get_type_annotations(wrapped_func) + ann = OrderedDict(get_type_annotations(wrapped_func)) # check return type if "return" in ann: validate_return_type(ann.pop("return"), outputs) - # check inputs - + # check arguments arg_spec = inspect.getfullargspec(wrapped_func) for arg in arg_spec.args: if arg not in ann: raise CheckFailedError(f"Missing type annotation for '{arg}'") + if node_context: + first = arg_spec.args[0] + if first != "context": + raise CheckFailedError( + f"Expected the first parameter to be 'context: NodeContext' but found '{first}'." + ) + context_type = ann.pop(first) + if context_type != NodeContext: # type: ignore + raise CheckFailedError( + f"Expected type of 'context' to be 'api.NodeContext' but found '{context_type}'" + ) + + # check inputs + if arg_spec.varargs is not None: if arg_spec.varargs not in ann: raise CheckFailedError(f"Missing type annotation for '{arg_spec.varargs}'") diff --git a/backend/src/api/node_context.py b/backend/src/api/node_context.py new file mode 100644 index 000000000..ddbb7cfd1 --- /dev/null +++ b/backend/src/api/node_context.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod + +from .settings import SettingsParser + + +class Aborted(Exception): + pass + + +class NodeProgress(ABC): + @property + @abstractmethod + def aborted(self) -> bool: + """ + Returns whether the current operation was aborted. + """ + + def check_aborted(self) -> None: + """ + Raises an `Aborted` exception if the current operation was aborted. Does nothing otherwise. + """ + + if self.aborted: + raise Aborted() + + @abstractmethod + def set_progress(self, progress: float) -> None: + """ + Sets the progress of the current node execution. `progress` must be a value between 0 and 1. + + Raises an `Aborted` exception if the current operation was aborted. + """ + + +class NodeContext(NodeProgress, ABC): + """ + The execution context of the current node. + """ + + @property + @abstractmethod + def settings(self) -> SettingsParser: + """ + Returns the settings of the current node execution. + """ diff --git a/backend/src/api/settings.py b/backend/src/api/settings.py index 0534aa304..1f05b3881 100644 --- a/backend/src/api/settings.py +++ b/backend/src/api/settings.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Dict +from dataclasses import dataclass +from typing import Dict, TypedDict, Union from sanic.log import logger @@ -14,6 +15,7 @@ def __init__( backend_settings: JsonExecutionOptions, ) -> None: self.__settings = backend_settings + self.__parsers: dict[str, SettingsParser] = {} logger.info(f"Execution options: {self.__settings}") @@ -21,19 +23,95 @@ def __init__( def parse(json: JsonExecutionOptions) -> ExecutionOptions: return ExecutionOptions(backend_settings=json) - def get_package_settings(self, package_id: str) -> SettingsJson: + def get_package_settings_json(self, package_id: str) -> SettingsJson: return self.__settings.get(package_id, {}) + def get_package_settings(self, package_id: str) -> SettingsParser: + parser = self.__parsers.get(package_id) + if parser is None: + parser = SettingsParser(self.get_package_settings_json(package_id)) + self.__parsers[package_id] = parser + return parser -__global_exec_options = ExecutionOptions({}) +class SettingsParser: + def __init__(self, raw: SettingsJson) -> None: + self.__settings = raw -def get_execution_options() -> ExecutionOptions: - return __global_exec_options + def get_bool(self, key: str, default: bool) -> bool: + value = self.__settings.get(key, default) + if isinstance(value, bool): + return value + raise ValueError(f"Invalid bool value for {key}: {value}") + def get_int(self, key: str, default: int, parse_str: bool = False) -> int: + value = self.__settings.get(key, default) + if parse_str and isinstance(value, str): + return int(value) + if isinstance(value, int) and not isinstance(value, bool): + return value + raise ValueError(f"Invalid str value for {key}: {value}") -def set_execution_options(value: ExecutionOptions): - # TODO: Make the mutable global state unnecessary - # pylint: disable=global-statement - global __global_exec_options - __global_exec_options = value + def get_str(self, key: str, default: str) -> str: + value = self.__settings.get(key, default) + if isinstance(value, str): + return value + raise ValueError(f"Invalid str value for {key}: {value}") + + def get_cache_location(self, key: str) -> str | None: + value = self.__settings.get(key) + if isinstance(value, str) or value is None: + return value or None + raise ValueError(f"Invalid cache location value for {key}: {value}") + + +@dataclass +class ToggleSetting: + label: str + key: str + description: str + default: bool = False + disabled: bool = False + type: str = "toggle" + + +class DropdownOption(TypedDict): + label: str + value: str + + +@dataclass +class DropdownSetting: + label: str + key: str + description: str + options: list[DropdownOption] + default: str + disabled: bool = False + type: str = "dropdown" + + +@dataclass +class NumberSetting: + label: str + key: str + description: str + min: float + max: float + default: float = 0 + disabled: bool = False + type: str = "number" + + +@dataclass +class CacheSetting: + label: str + key: str + description: str + directory: str + default: str = "" + disabled: bool = False + type: str = "cache" + + +Setting = Union[ToggleSetting, DropdownSetting, NumberSetting, CacheSetting] diff --git a/backend/src/api/types.py b/backend/src/api/types.py index 206560340..183536e48 100644 --- a/backend/src/api/types.py +++ b/backend/src/api/types.py @@ -5,6 +5,7 @@ NodeId = NewType("NodeId", str) InputId = NewType("InputId", int) OutputId = NewType("OutputId", int) +FeatureId = NewType("FeatureId", str) RunFn = Callable[..., Any] diff --git a/backend/src/packages/chaiNNer_ncnn/ncnn/processing/upscale_image.py b/backend/src/packages/chaiNNer_ncnn/ncnn/processing/upscale_image.py index 80005264f..0ccc053a4 100644 --- a/backend/src/packages/chaiNNer_ncnn/ncnn/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_ncnn/ncnn/processing/upscale_image.py @@ -15,6 +15,7 @@ use_gpu = False from sanic.log import logger +from api import NodeContext from nodes.groups import Condition, if_group from nodes.impl.ncnn.auto_split import ncnn_auto_split from nodes.impl.ncnn.model import NcnnModelWrapper @@ -36,7 +37,7 @@ from nodes.utils.utils import get_h_w_c from system import is_mac -from ...settings import get_settings +from ...settings import NcnnSettings, get_settings from .. import processing_group @@ -66,13 +67,13 @@ def ncnn_allocators(vkdev: ncnn.VulkanDevice): def upscale_impl( + settings: NcnnSettings, img: np.ndarray, model: NcnnModelWrapper, input_name: str, output_name: str, tile_size: TileSize, ): - settings = get_settings() net = get_ncnn_net(model, settings=settings) # Try/except block to catch errors try: @@ -181,10 +182,17 @@ def estimate(): ImageOutput(image_type="""convenientUpscale(Input0, Input1)"""), ], limited_to_8bpc=True, + node_context=True, ) def upscale_image_node( - img: np.ndarray, model: NcnnModelWrapper, tile_size: TileSize, separate_alpha: bool + context: NodeContext, + img: np.ndarray, + model: NcnnModelWrapper, + tile_size: TileSize, + separate_alpha: bool, ) -> np.ndarray: + settings = get_settings(context) + def upscale(i: np.ndarray) -> np.ndarray: ic = get_h_w_c(i)[2] if ic == 3: @@ -192,6 +200,7 @@ def upscale(i: np.ndarray) -> np.ndarray: elif ic == 4: i = cv2.cvtColor(i, cv2.COLOR_BGRA2RGBA) i = upscale_impl( + settings, i, model, model.model.layers[0].outputs[0], diff --git a/backend/src/packages/chaiNNer_ncnn/settings.py b/backend/src/packages/chaiNNer_ncnn/settings.py index ee2ce35f8..ee6f1c343 100644 --- a/backend/src/packages/chaiNNer_ncnn/settings.py +++ b/backend/src/packages/chaiNNer_ncnn/settings.py @@ -9,7 +9,7 @@ use_gpu = False -from api import DropdownSetting, NumberSetting, ToggleSetting +from api import DropdownSetting, NodeContext, NumberSetting, ToggleSetting from system import is_arm_mac from . import package @@ -99,8 +99,8 @@ class NcnnSettings: budget_limit: int -def get_settings() -> NcnnSettings: - settings = package.get_settings() +def get_settings(context: NodeContext) -> NcnnSettings: + settings = context.settings return NcnnSettings( gpu_index=settings.get_int("gpu_index", 0, parse_str=True), diff --git a/backend/src/packages/chaiNNer_onnx/onnx/processing/remove_background.py b/backend/src/packages/chaiNNer_onnx/onnx/processing/remove_background.py index b45d11a0f..764dd3061 100644 --- a/backend/src/packages/chaiNNer_onnx/onnx/processing/remove_background.py +++ b/backend/src/packages/chaiNNer_onnx/onnx/processing/remove_background.py @@ -3,6 +3,7 @@ import numpy as np import navi +from api import NodeContext from nodes.groups import Condition, if_group from nodes.impl.onnx.model import OnnxRemBg from nodes.impl.onnx.session import get_onnx_session @@ -53,8 +54,10 @@ ImageOutput("Mask", image_type=navi.Image(size_as="Input0"), channels=1), ], limited_to_8bpc=True, + node_context=True, ) def remove_background_node( + context: NodeContext, img: np.ndarray, model: OnnxRemBg, post_process_mask: bool, @@ -65,7 +68,7 @@ def remove_background_node( ) -> tuple[np.ndarray, np.ndarray]: """Removes background from image""" - settings = get_settings() + settings = get_settings(context) session = get_onnx_session( model, settings.gpu_index, diff --git a/backend/src/packages/chaiNNer_onnx/onnx/processing/upscale_image.py b/backend/src/packages/chaiNNer_onnx/onnx/processing/upscale_image.py index 8bfeb43a3..987a5dd69 100644 --- a/backend/src/packages/chaiNNer_onnx/onnx/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_onnx/onnx/processing/upscale_image.py @@ -4,6 +4,7 @@ import onnxruntime as ort from sanic.log import logger +from api import NodeContext from nodes.groups import Condition, if_group from nodes.impl.onnx.auto_split import onnx_auto_split from nodes.impl.onnx.model import OnnxModel @@ -89,15 +90,17 @@ def estimate(): outputs=[ImageOutput("Image")], name="Upscale Image", icon="ONNX", + node_context=True, ) def upscale_image_node( + context: NodeContext, img: np.ndarray, model: OnnxModel, tile_size: TileSize, separate_alpha: bool, ) -> np.ndarray: """Upscales an image with a pretrained model""" - settings = get_settings() + settings = get_settings(context) session = get_onnx_session( model, settings.gpu_index, diff --git a/backend/src/packages/chaiNNer_onnx/settings.py b/backend/src/packages/chaiNNer_onnx/settings.py index 5e09ad7e7..9e0ca0481 100644 --- a/backend/src/packages/chaiNNer_onnx/settings.py +++ b/backend/src/packages/chaiNNer_onnx/settings.py @@ -7,7 +7,7 @@ import onnxruntime as ort from sanic.log import logger -from api import CacheSetting, DropdownSetting, ToggleSetting +from api import CacheSetting, DropdownSetting, NodeContext, ToggleSetting from gpu import get_nvidia_helper from system import is_arm_mac @@ -90,8 +90,8 @@ class OnnxSettings: tensorrt_fp16_mode: bool -def get_settings() -> OnnxSettings: - settings = package.get_settings() +def get_settings(context: NodeContext) -> OnnxSettings: + settings = context.settings tensorrt_cache_path = settings.get_cache_location("onnx_tensorrt_cache") logger.info(f"TensorRT cache location: {tensorrt_cache_path}") diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/io/load_model.py b/backend/src/packages/chaiNNer_pytorch/pytorch/io/load_model.py index 5b6beb956..99dd56345 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/io/load_model.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/io/load_model.py @@ -5,6 +5,7 @@ from sanic.log import logger from spandrel import ModelDescriptor, ModelLoader +from api import NodeContext from nodes.properties.inputs import PthFileInput from nodes.properties.outputs import DirectoryOutput, FileNameOutput, ModelOutput from nodes.utils.utils import split_file_path @@ -56,11 +57,14 @@ def parse_ckpt_state_dict(checkpoint: dict): DirectoryOutput("Directory", of_input=0).with_id(2), FileNameOutput("Name", of_input=0).with_id(1), ], + node_context=True, see_also=[ "chainner:pytorch:load_models", ], ) -def load_model_node(path: str) -> tuple[ModelDescriptor, str, str]: +def load_model_node( + context: NodeContext, path: str +) -> tuple[ModelDescriptor, str, str]: """Read a pth file from the specified path and return it as a state dict and loaded model after finding arch config""" @@ -68,7 +72,7 @@ def load_model_node(path: str) -> tuple[ModelDescriptor, str, str]: assert os.path.isfile(path), f"Path {path} is not a file" - exec_options = get_settings() + exec_options = get_settings(context) pytorch_device = exec_options.device try: diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py index d9f16efb5..0951b38da 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/guided_upscale.py @@ -2,6 +2,7 @@ import numpy as np +from api import NodeContext from nodes.impl.pytorch.pix_transform.auto_split import pix_transform_auto_split from nodes.impl.pytorch.pix_transform.pix_transform import Params from nodes.impl.upscale.grayscale import SplitMode @@ -66,8 +67,10 @@ "The guide image must be larger than the source image, and the size of the guide image must be an integer multiple of the size of the source image (e.g. 2x, 3x, 4x, ...)." ), ], + node_context=True, ) def guided_upscale_node( + context: NodeContext, source: np.ndarray, guide: np.ndarray, iterations: float, @@ -76,7 +79,7 @@ def guided_upscale_node( return pix_transform_auto_split( source=source, guide=guide, - device=get_settings().device, + device=get_settings(context).device, params=Params(iteration=int(iterations * 1000)), split_mode=split_mode, ) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py index 472df92c9..7f11ada19 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/inpaint.py @@ -7,6 +7,7 @@ from spandrel import MaskedImageModelDescriptor import navi +from api import NodeContext from nodes.impl.image_utils import as_3d from nodes.impl.pytorch.utils import np2tensor, safe_cuda_cache_empty, tensor2np from nodes.properties.inputs import ImageInput @@ -143,8 +144,10 @@ def inpaint( channels=3, ).with_never_reason("The given image and mask must have the same resolution.") ], + node_context=True, ) def inpaint_node( + context: NodeContext, img: np.ndarray, mask: np.ndarray, model: MaskedImageModelDescriptor, @@ -153,6 +156,6 @@ def inpaint_node( img.shape[:2] == mask.shape[:2] ), "Input image and mask must have the same resolution" - exec_options = get_settings() + exec_options = get_settings(context) return inpaint(img, mask, model, exec_options) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py index 16023e5ef..e8a10ff8e 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/processing/upscale_image.py @@ -5,6 +5,7 @@ from sanic.log import logger from spandrel import ImageModelDescriptor, ModelTiling +from api import NodeContext from nodes.groups import Condition, if_group from nodes.impl.pytorch.auto_split import pytorch_auto_split from nodes.impl.upscale.auto_split_tiles import ( @@ -141,8 +142,10 @@ def estimate(): assume_normalized=True, # pytorch_auto_split already does clipping internally ) ], + node_context=True, ) def upscale_image_node( + context: NodeContext, img: np.ndarray, model: ImageModelDescriptor, tile_size: TileSize, @@ -150,7 +153,7 @@ def upscale_image_node( ) -> np.ndarray: """Upscales an image with a pretrained model""" - exec_options = get_settings() + exec_options = get_settings(context) logger.debug("Upscaling image...") diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/restoration/upscale_face.py b/backend/src/packages/chaiNNer_pytorch/pytorch/restoration/upscale_face.py index 38cad9e3a..615559861 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/restoration/upscale_face.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/restoration/upscale_face.py @@ -11,6 +11,7 @@ from spandrel import ImageModelDescriptor from torchvision.transforms.functional import normalize as tv_normalize +from api import NodeContext from nodes.groups import Condition, if_group from nodes.impl.image_utils import to_uint8 from nodes.impl.pytorch.utils import np2tensor, safe_cuda_cache_empty, tensor2np @@ -146,8 +147,10 @@ def upscale( ) ], limited_to_8bpc=True, + node_context=True, ) def upscale_face_node( + context: NodeContext, img: np.ndarray, face_model: ImageModelDescriptor, background_img: np.ndarray | None, @@ -160,7 +163,7 @@ def upscale_face_node( try: img = denormalize(img) - exec_options = get_settings() + exec_options = get_settings(context) device = exec_options.device with torch.no_grad(): diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py index 58bec21e5..ef904b205 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_ncnn.py @@ -9,6 +9,7 @@ from spandrel.architectures.Swin2SR import Swin2SR from spandrel.architectures.SwinIR import SwinIR +from api import NodeContext from nodes.impl.ncnn.model import NcnnModelWrapper from nodes.impl.onnx.model import OnnxGeneric from nodes.impl.pytorch.convert_to_onnx_impl import convert_to_onnx_impl @@ -43,9 +44,10 @@ NcnnModelOutput(label="NCNN Model"), TextOutput("FP Mode", "FpMode::toString(Input1)"), ], + node_context=True, ) def convert_to_ncnn_node( - model: ImageModelDescriptor, is_fp16: int + context: NodeContext, model: ImageModelDescriptor, is_fp16: int ) -> tuple[NcnnModelWrapper, str]: if onnx_convert_to_ncnn_node is None: raise ModuleNotFoundError( @@ -58,7 +60,7 @@ def convert_to_ncnn_node( model.model, (HAT, DAT, OmniSR, SwinIR, Swin2SR, SCUNet, SRFormer) ), f"{model.architecture} is not supported for NCNN conversions at this time." - exec_options = get_settings() + exec_options = get_settings(context) device = exec_options.device # Intermediate conversion to ONNX is always fp32 diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py index 2b4a4a37c..084d470df 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/convert_to_onnx.py @@ -5,6 +5,7 @@ from spandrel import ImageModelDescriptor from spandrel.architectures.SCUNet import SCUNet +from api import NodeContext from nodes.impl.onnx.model import OnnxGeneric from nodes.impl.pytorch.convert_to_onnx_impl import convert_to_onnx_impl from nodes.properties.inputs import EnumInput, OnnxFpDropdown, SrModelInput @@ -64,16 +65,17 @@ class Opset(Enum): """, ), ], + node_context=True, ) def convert_to_onnx_node( - model: ImageModelDescriptor, is_fp16: int, opset: Opset + context: NodeContext, model: ImageModelDescriptor, is_fp16: int, opset: Opset ) -> tuple[OnnxGeneric, str, str]: assert not isinstance( model.model, SCUNet ), "SCUNet is not supported for ONNX conversion at this time." fp16 = bool(is_fp16) - exec_options = get_settings() + exec_options = get_settings(context) device = exec_options.device if fp16: assert exec_options.use_fp16, "PyTorch fp16 mode must be supported and turned on in settings to convert model as fp16." diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py index 885128c59..4c6e7afcb 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/utility/interpolate_models.py @@ -16,7 +16,6 @@ from nodes.properties.inputs import ModelInput, SliderInput from nodes.properties.outputs import ModelOutput, NumberOutput -from ...settings import get_settings from .. import utility_group @@ -119,7 +118,6 @@ def interpolate_models_node( state_dict = perform_interp(state_a, state_b, amount) - get_settings() model = ModelLoader().load_from_state_dict(state_dict) return model, 100 - amount, amount diff --git a/backend/src/packages/chaiNNer_pytorch/settings.py b/backend/src/packages/chaiNNer_pytorch/settings.py index d371f0608..bc1b130a1 100644 --- a/backend/src/packages/chaiNNer_pytorch/settings.py +++ b/backend/src/packages/chaiNNer_pytorch/settings.py @@ -3,7 +3,7 @@ import torch from sanic.log import logger -from api import DropdownSetting, ToggleSetting +from api import DropdownSetting, NodeContext, ToggleSetting from gpu import get_nvidia_helper from system import is_arm_mac @@ -104,8 +104,8 @@ def device(self) -> torch.device: return torch.device(device) -def get_settings() -> PyTorchSettings: - settings = package.get_settings() +def get_settings(context: NodeContext) -> PyTorchSettings: + settings = context.settings return PyTorchSettings( use_cpu=settings.get_bool("use_cpu", False), diff --git a/backend/src/process.py b/backend/src/process.py index 97420a0e3..4a1ff541d 100644 --- a/backend/src/process.py +++ b/backend/src/process.py @@ -11,12 +11,24 @@ from sanic.log import logger -from api import BaseOutput, Collector, InputId, Iterator, NodeData, NodeId, OutputId +from api import ( + BaseOutput, + Collector, + ExecutionOptions, + InputId, + Iterator, + NodeContext, + NodeData, + NodeId, + OutputId, + SettingsParser, + registry, +) from chain.cache import CacheStrategy, OutputCache, StaticCaching, get_cache_strategies from chain.chain import Chain, CollectorNode, FunctionNode, NewIteratorNode, Node from chain.input import EdgeInput, Input, InputMap from events import EventConsumer, InputsDict -from progress_controller import Aborted, ProgressController +from progress_controller import Aborted, ProgressController, ProgressToken from util import timed_supplier Output = List[object] @@ -129,7 +141,7 @@ def enforce_iterator_output(raw_output: object, node: NodeData) -> IteratorOutpu def run_node( - node: NodeData, inputs: list[object], node_id: NodeId + node: NodeData, context: NodeContext, inputs: list[object], node_id: NodeId ) -> NodeOutput | CollectorOutput: if node.type == "collector": ignored_inputs = node.single_iterator_input.inputs @@ -139,7 +151,10 @@ def run_node( enforced_inputs = enforce_inputs(inputs, node, node_id, ignored_inputs) try: - raw_output = node.run(*enforced_inputs) + if node.node_context: + raw_output = node.run(context, *enforced_inputs) + else: + raw_output = node.run(*enforced_inputs) if node.type == "collector": assert isinstance(raw_output, Collector) @@ -265,6 +280,30 @@ class CollectorOutput: ExecutionId = NewType("ExecutionId", str) +class _ExecutorNodeContext(NodeContext): + def __init__(self, progress: ProgressToken, settings: SettingsParser) -> None: + super().__init__() + + self.progress = progress + self.__settings = settings + + @property + def aborted(self) -> bool: + return self.progress.aborted + + def set_progress(self, progress: float) -> None: + self.check_aborted() + + # TODO: send progress event + + @property + def settings(self) -> SettingsParser: + """ + Returns the settings of the current node execution. + """ + return self.__settings + + class Executor: """ Class for executing chaiNNer's processing logic @@ -276,6 +315,7 @@ def __init__( chain: Chain, inputs: InputMap, send_broadcast_data: bool, + options: ExecutionOptions, loop: asyncio.AbstractEventLoop, queue: EventConsumer, pool: ThreadPoolExecutor, @@ -285,8 +325,10 @@ def __init__( self.chain = chain self.inputs = inputs self.send_broadcast_data: bool = send_broadcast_data + self.options: ExecutionOptions = options self.cache: OutputCache[NodeOutput] = OutputCache(parent=parent_cache) self.__broadcast_tasks: list[asyncio.Task[None]] = [] + self.__context_cache: dict[str, _ExecutorNodeContext] = {} self.progress = ProgressController() @@ -418,6 +460,17 @@ async def __gather_collector_inputs(self, node: CollectorNode) -> list[object]: return inputs + def __get_node_context(self, node: Node) -> _ExecutorNodeContext: + context = self.__context_cache.get(node.data.schema_id, None) + if context is None: + package_id = registry.get_package(node.data.schema_id).id + settings = self.options.get_package_settings(package_id) + + context = _ExecutorNodeContext(self.progress, settings) + self.__context_cache[node.data.schema_id] = context + + return context + async def __process(self, node: Node) -> NodeOutput | CollectorOutput: """ Process a single node. @@ -430,6 +483,7 @@ async def __process(self, node: Node) -> NodeOutput | CollectorOutput: logger.debug(f"Running node {node.id}") inputs = await self.__gather_inputs(node) + context = self.__get_node_context(node) await self.progress.suspend() await self.__send_node_start(node) @@ -437,7 +491,9 @@ async def __process(self, node: Node) -> NodeOutput | CollectorOutput: output, execution_time = await self.loop.run_in_executor( self.pool, - timed_supplier(functools.partial(run_node, node.data, inputs, node.id)), + timed_supplier( + functools.partial(run_node, node.data, context, inputs, node.id) + ), ) await self.progress.suspend() diff --git a/backend/src/server.py b/backend/src/server.py index ae6a79fd5..199bc10dc 100644 --- a/backend/src/server.py +++ b/backend/src/server.py @@ -26,7 +26,6 @@ Group, JsonExecutionOptions, NodeId, - set_execution_options, ) from chain.cache import OutputCache from chain.chain import Chain, FunctionNode @@ -173,13 +172,12 @@ async def run(request: Request): optimize(chain) logger.info("Running new executor...") - exec_opts = ExecutionOptions.parse(full_data["options"]) - set_execution_options(exec_opts) executor = Executor( id=ExecutionId("main-executor " + uuid.uuid4().hex), chain=chain, inputs=inputs, send_broadcast_data=full_data["sendBroadcastData"], + options=ExecutionOptions.parse(full_data["options"]), loop=app.loop, queue=ctx.queue, pool=ctx.pool, @@ -234,9 +232,6 @@ async def run_individual(request: Request): node_id = full_data["id"] ctx.cache.pop(node_id, None) - exec_opts = ExecutionOptions.parse(full_data["options"]) - set_execution_options(exec_opts) - node = FunctionNode(node_id, full_data["schemaId"]) chain = Chain() chain.add_node(node) @@ -255,6 +250,7 @@ async def run_individual(request: Request): chain=chain, inputs=input_map, send_broadcast_data=True, + options=ExecutionOptions.parse(full_data["options"]), loop=app.loop, queue=queue, pool=ctx.pool, diff --git a/backend/src/settings.py b/backend/src/settings.py deleted file mode 100644 index 532be0541..000000000 --- a/backend/src/settings.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass -from typing import Any - -from api import SettingsParser, get_execution_options - - -@dataclass(frozen=True) -class GeneralSettings: - example: bool - - -def get_global_settings() -> Any: - settings = SettingsParser(get_execution_options().get_package_settings("general")) - - return GeneralSettings( - example=settings.get_bool("example", default=False), - ) From 02ea3d98b91b46bb9c34692c7384d8a6b30d1980 Mon Sep 17 00:00:00 2001 From: RunDevelopment Date: Mon, 8 Jan 2024 16:55:22 +0100 Subject: [PATCH 2/2] Fixed type errors --- .../chaiNNer_ncnn/ncnn/utility/interpolate_models.py | 12 ++++++++---- .../chaiNNer_onnx/onnx/utility/interpolate_models.py | 9 ++++++--- .../pytorch/iteration/load_models.py | 6 ++++-- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/backend/src/packages/chaiNNer_ncnn/ncnn/utility/interpolate_models.py b/backend/src/packages/chaiNNer_ncnn/ncnn/utility/interpolate_models.py index 5763540a7..05f154cbf 100644 --- a/backend/src/packages/chaiNNer_ncnn/ncnn/utility/interpolate_models.py +++ b/backend/src/packages/chaiNNer_ncnn/ncnn/utility/interpolate_models.py @@ -2,6 +2,7 @@ import numpy as np +from api import NodeContext from nodes.impl.ncnn.model import NcnnModelWrapper from nodes.impl.upscale.auto_split_tiles import NO_TILING from nodes.properties.inputs import NcnnModelInput, SliderInput @@ -11,9 +12,9 @@ from ..processing.upscale_image import upscale_image_node -def check_will_upscale(interp: NcnnModelWrapper): +def check_will_upscale(context: NodeContext, interp: NcnnModelWrapper): fake_img = np.ones((3, 3, 3), dtype=np.float32, order="F") - result = upscale_image_node(fake_img, interp, NO_TILING, False) + result = upscale_image_node(context, fake_img, interp, NO_TILING, False) mean_color = np.mean(result) del result @@ -47,7 +48,10 @@ def check_will_upscale(interp: NcnnModelWrapper): ], ) def interpolate_models_node( - model_a: NcnnModelWrapper, model_b: NcnnModelWrapper, amount: int + context: NodeContext, + model_a: NcnnModelWrapper, + model_b: NcnnModelWrapper, + amount: int, ) -> tuple[NcnnModelWrapper, int, int]: if amount == 0: return model_a, 100, 0 @@ -57,7 +61,7 @@ def interpolate_models_node( f_amount = 1 - amount / 100 interp_model = NcnnModelWrapper(model_a.model.interpolate(model_b.model, f_amount)) - if not check_will_upscale(interp_model): + if not check_will_upscale(context, interp_model): raise ValueError( "These NCNN models are not compatible and not able to be interpolated together" ) diff --git a/backend/src/packages/chaiNNer_onnx/onnx/utility/interpolate_models.py b/backend/src/packages/chaiNNer_onnx/onnx/utility/interpolate_models.py index f1212a91e..0747b1bae 100644 --- a/backend/src/packages/chaiNNer_onnx/onnx/utility/interpolate_models.py +++ b/backend/src/packages/chaiNNer_onnx/onnx/utility/interpolate_models.py @@ -9,6 +9,7 @@ from onnx.onnx_pb import TensorProto from sanic.log import logger +from api import NodeContext from nodes.impl.onnx.model import OnnxModel, load_onnx_model from nodes.impl.onnx.utils import safely_optimize_onnx_model from nodes.impl.upscale.auto_split_tiles import NO_TILING @@ -46,9 +47,9 @@ def perform_interp( return interp_weights_list -def check_will_upscale(model: OnnxModel): +def check_will_upscale(context: NodeContext, model: OnnxModel): fake_img = np.ones((3, 3, 3), dtype=np.float32, order="F") - result = upscale_image_node(fake_img, model, NO_TILING, False) + result = upscale_image_node(context, fake_img, model, NO_TILING, False) mean_color = np.mean(result) del result @@ -81,8 +82,10 @@ def check_will_upscale(model: OnnxModel): NumberOutput("Amount A", output_type="100 - Input2"), NumberOutput("Amount B", output_type="Input2"), ], + node_context=True, ) def interpolate_models_node( + context: NodeContext, a: OnnxModel, b: OnnxModel, amount: int, @@ -116,7 +119,7 @@ def interpolate_models_node( model_interp = model_proto_interp.SerializeToString() # type: ignore model = load_onnx_model(model_interp) - if not check_will_upscale(model): + if not check_will_upscale(context, model): raise ValueError( "These models are not compatible and not able to be interpolated together" ) diff --git a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py index 4d33b472a..42fe7b52c 100644 --- a/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py +++ b/backend/src/packages/chaiNNer_pytorch/pytorch/iteration/load_models.py @@ -5,7 +5,7 @@ from sanic.log import logger from spandrel import ModelDescriptor -from api import Iterator, IteratorOutputInfo +from api import Iterator, IteratorOutputInfo, NodeContext from nodes.properties.inputs import DirectoryInput from nodes.properties.inputs.generic_inputs import BoolInput from nodes.properties.outputs import DirectoryOutput, NumberOutput, TextOutput @@ -43,15 +43,17 @@ ], iterator_outputs=IteratorOutputInfo(outputs=[0, 2, 3, 4]), node_type="newIterator", + node_context=True, ) def load_models_node( + context: NodeContext, directory: str, defer_errors: bool, ) -> tuple[Iterator[tuple[ModelDescriptor, str, str, int]], str]: logger.debug(f"Iterating over models in directory: {directory}") def load_model(path: str, index: int): - model, dirname, basename = load_model_node(path) + model, dirname, basename = load_model_node(context, path) # Get relative path from root directory passed by Iterator directory input rel_path = os.path.relpath(dirname, directory) return model, rel_path, basename, index