Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace global execution options with node context parameter #2444

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/src/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .api import *
from .group import *
from .input import *
from .node_context import *
from .output import *
from .settings import *
from .types import *
105 changes: 11 additions & 94 deletions backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
Callable,
Generic,
Iterable,
NewType,
TypedDict,
TypeVar,
Union,
)

from sanic.log import logger
Expand All @@ -29,8 +27,8 @@
check_schema_types,
)
from .output import BaseOutput
from .settings import SettingsJson, get_execution_options
from .types import InputId, NodeId, NodeType, OutputId, RunFn
from .settings import Setting
from .types import FeatureId, InputId, NodeId, NodeType, OutputId, RunFn

KB = 1024**1
MB = 1024**2
Expand Down Expand Up @@ -124,6 +122,7 @@ class NodeData:

side_effects: bool
deprecated: bool
node_context: bool
features: list[FeatureId]

run: RunFn
Expand Down Expand Up @@ -180,6 +179,7 @@ def register(
limited_to_8bpc: bool | str = False,
iterator_inputs: list[IteratorInputInfo] | IteratorInputInfo | None = None,
iterator_outputs: list[IteratorOutputInfo] | IteratorOutputInfo | None = None,
node_context: bool = False,
):
if not isinstance(description, str):
description = "\n\n".join(description)
Expand Down Expand Up @@ -233,7 +233,9 @@ def inner_wrapper(wrapped_func: T) -> T:
if node_type == "regularNode":
run_check(
TYPE_CHECK_LEVEL,
lambda _: check_schema_types(wrapped_func, p_inputs, p_outputs),
lambda _: check_schema_types(
wrapped_func, p_inputs, p_outputs, node_context
),
)
run_check(
NAME_CHECK_LEVEL,
Expand All @@ -258,6 +260,7 @@ def inner_wrapper(wrapped_func: T) -> T:
iterator_outputs=iterator_outputs,
side_effects=side_effects,
deprecated=deprecated,
node_context=node_context,
features=features,
run=wrapped_func,
)
Expand Down Expand Up @@ -322,9 +325,6 @@ def to_dict(self):
}


FeatureId = NewType("FeatureId", str)


@dataclass
class Feature:
id: str
Expand Down Expand Up @@ -366,89 +366,6 @@ def disabled(details: str | None = None) -> FeatureState:
return FeatureState(is_enabled=False, details=details)


@dataclass
class ToggleSetting:
label: str
key: str
description: str
default: bool = False
disabled: bool = False
type: str = "toggle"


class DropdownOption(TypedDict):
label: str
value: str


@dataclass
class DropdownSetting:
label: str
key: str
description: str
options: list[DropdownOption]
default: str
disabled: bool = False
type: str = "dropdown"


@dataclass
class NumberSetting:
label: str
key: str
description: str
min: float
max: float
default: float = 0
disabled: bool = False
type: str = "number"


@dataclass
class CacheSetting:
label: str
key: str
description: str
directory: str
default: str = ""
disabled: bool = False
type: str = "cache"


Setting = Union[ToggleSetting, DropdownSetting, NumberSetting, CacheSetting]


class SettingsParser:
def __init__(self, raw: SettingsJson) -> None:
self.__settings = raw

def get_bool(self, key: str, default: bool) -> bool:
value = self.__settings.get(key, default)
if isinstance(value, bool):
return value
raise ValueError(f"Invalid bool value for {key}: {value}")

def get_int(self, key: str, default: int, parse_str: bool = False) -> int:
value = self.__settings.get(key, default)
if parse_str and isinstance(value, str):
return int(value)
if isinstance(value, int) and not isinstance(value, bool):
return value
raise ValueError(f"Invalid str value for {key}: {value}")

def get_str(self, key: str, default: str) -> str:
value = self.__settings.get(key, default)
if isinstance(value, str):
return value
raise ValueError(f"Invalid str value for {key}: {value}")

def get_cache_location(self, key: str) -> str | None:
value = self.__settings.get(key)
if isinstance(value, str) or value is None:
return value or None
raise ValueError(f"Invalid cache location value for {key}: {value}")


@dataclass
class Package:
where: str
Expand Down Expand Up @@ -501,9 +418,6 @@ def add_feature(
self.features.append(feature)
return feature

def get_settings(self) -> SettingsParser:
return SettingsParser(get_execution_options().get_package_settings(self.id))


def _iter_py_files(directory: str):
for root, _, files in os.walk(directory):
Expand All @@ -528,6 +442,9 @@ def __init__(self) -> None:
def get_node(self, schema_id: str) -> NodeData:
return self.nodes[schema_id][0]

def get_package(self, schema_id: str) -> Package:
return self.nodes[schema_id][1].category.package

def add(self, package: Package) -> Package:
# assert package.where not in self.packages
self.packages[package.where] = package
Expand Down
22 changes: 19 additions & 3 deletions backend/src/api/node_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import inspect
import os
import pathlib
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, NewType, Tuple, Union, cast, get_args

from .input import BaseInput
from .node_context import NodeContext
from .output import BaseOutput

_Ty = NewType("_Ty", object)
Expand Down Expand Up @@ -190,24 +192,38 @@ def check_schema_types(
wrapped_func: Callable,
inputs: list[BaseInput],
outputs: list[BaseOutput],
node_context: bool,
):
"""
Runtime validation for the number of inputs/outputs compared to the type args
"""

ann = get_type_annotations(wrapped_func)
ann = OrderedDict(get_type_annotations(wrapped_func))

# check return type
if "return" in ann:
validate_return_type(ann.pop("return"), outputs)

# check inputs

# check arguments
arg_spec = inspect.getfullargspec(wrapped_func)
for arg in arg_spec.args:
if arg not in ann:
raise CheckFailedError(f"Missing type annotation for '{arg}'")

if node_context:
first = arg_spec.args[0]
if first != "context":
raise CheckFailedError(
f"Expected the first parameter to be 'context: NodeContext' but found '{first}'."
)
context_type = ann.pop(first)
if context_type != NodeContext: # type: ignore
raise CheckFailedError(
f"Expected type of 'context' to be 'api.NodeContext' but found '{context_type}'"
)

# check inputs

if arg_spec.varargs is not None:
if arg_spec.varargs not in ann:
raise CheckFailedError(f"Missing type annotation for '{arg_spec.varargs}'")
Expand Down
45 changes: 45 additions & 0 deletions backend/src/api/node_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from abc import ABC, abstractmethod

from .settings import SettingsParser


class Aborted(Exception):
pass


class NodeProgress(ABC):
@property
@abstractmethod
def aborted(self) -> bool:
"""
Returns whether the current operation was aborted.
"""

def check_aborted(self) -> None:
"""
Raises an `Aborted` exception if the current operation was aborted. Does nothing otherwise.
"""

if self.aborted:
raise Aborted()

@abstractmethod
def set_progress(self, progress: float) -> None:
"""
Sets the progress of the current node execution. `progress` must be a value between 0 and 1.

Raises an `Aborted` exception if the current operation was aborted.
"""


class NodeContext(NodeProgress, ABC):
"""
The execution context of the current node.
"""

@property
@abstractmethod
def settings(self) -> SettingsParser:
"""
Returns the settings of the current node execution.
"""