From de5ba23e73c2a3b1dfea7b4ad38efc352c0c7692 Mon Sep 17 00:00:00 2001 From: Yaron Haviv Date: Tue, 29 Dec 2020 20:27:17 +0200 Subject: [PATCH] [Serving] Add flow topology support (#621) --- dev-requirements.txt | 2 + dockerfiles/test/Dockerfile | 3 +- mlrun/runtimes/serving.py | 207 ++++-- mlrun/serving/server.py | 259 +++++--- mlrun/serving/states.py | 1013 +++++++++++++++++++++++++++--- mlrun/serving/v2_serving.py | 10 +- mlrun/utils/helpers.py | 46 ++ requirements.txt | 2 +- tests/serving/demo_states.py | 83 +++ tests/serving/test_async_flow.py | 87 +++ tests/serving/test_flow.py | 88 +++ tests/serving/test_serving.py | 64 +- 12 files changed, 1612 insertions(+), 252 deletions(-) create mode 100644 tests/serving/demo_states.py create mode 100644 tests/serving/test_async_flow.py create mode 100644 tests/serving/test_flow.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 9a478d60e6f..981ecb1f7cf 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -9,3 +9,5 @@ pytest-alembic~=0.2 requests-mock~=1.8 # needed for system tests matplotlib~=3.0 +graphviz~=0.16.0 +storey @ git+https://github.com/mlrun/storey.git \ No newline at end of file diff --git a/dockerfiles/test/Dockerfile b/dockerfiles/test/Dockerfile index 91100321393..76e58346a86 100644 --- a/dockerfiles/test/Dockerfile +++ b/dockerfiles/test/Dockerfile @@ -28,7 +28,8 @@ RUN apt-get update && apt-get install -y \ git-core \ gnupg2 \ make \ - software-properties-common + software-properties-common \ + graphviz RUN curl -fsSL https://download.docker.com/linux/debian/gpg | apt-key add - diff --git a/mlrun/runtimes/serving.py b/mlrun/runtimes/serving.py index 03facc55336..81e78d01a7c 100644 --- a/mlrun/runtimes/serving.py +++ b/mlrun/runtimes/serving.py @@ -13,16 +13,22 @@ # limitations under the License. import json +from typing import List, Union +import mlrun import nuclio +from ..model import ObjectList from .function import RemoteRuntime, NuclioSpec -from ..utils import logger -from ..serving.server import create_mock_server +from .function_reference import FunctionReference +from ..utils import logger, get_caller_globals +from ..serving.server import create_graph_server, GraphServer from ..serving.states import ( - ServingRouterState, + RouterState, + StateKinds, + RootFlowState, + graph_root_setter, new_remote_endpoint, new_model_endpoint, - StateKinds, ) serving_subkind = "serving_v2" @@ -91,6 +97,10 @@ def __init__( parameters=None, default_class=None, load_mode=None, + build=None, + function_refs=None, + graph_initializer=None, + error_stream=None, ): super().__init__( @@ -115,22 +125,37 @@ def __init__( function_kind=serving_subkind, service_account=service_account, readiness_timeout=readiness_timeout, + build=build, ) self.models = models or {} self._graph = None - self.graph: ServingRouterState = graph + self.graph: Union[RouterState, RootFlowState] = graph self.parameters = parameters or {} self.default_class = default_class self.load_mode = load_mode + self._function_refs: ObjectList = None + self.function_refs = function_refs or [] + self.graph_initializer = graph_initializer + self.error_stream = error_stream @property - def graph(self) -> ServingRouterState: + def graph(self) -> Union[RouterState, RootFlowState]: + """states graph, holding the serving workflow/DAG topology""" return self._graph @graph.setter def graph(self, graph): - self._graph = self._verify_dict(graph, "graph", ServingRouterState) + graph_root_setter(self, graph) + + @property + def function_refs(self) -> List[FunctionReference]: + """function references, list of optional child function refs""" + return self._function_refs + + @function_refs.setter + def function_refs(self, function_refs: List[FunctionReference]): + self._function_refs = ObjectList.from_list(FunctionReference, function_refs) class ServingRuntime(RemoteRuntime): @@ -145,19 +170,47 @@ def spec(self, spec): self._spec = self._verify_dict(spec, "spec", ServingSpec) def set_topology( - self, topology=None, class_name=None, exist_ok=False, **class_args - ): - """set the serving graph topology (router/flow/endpoint) and root class""" + self, topology=None, class_name=None, engine=None, exist_ok=False, **class_args, + ) -> Union[RootFlowState, RouterState]: + """set the serving graph topology (router/flow) and root class or params + + e.g.: + graph = fn.set_topology("flow", engine="async") + graph.to("MyClass").to(name="to_json", handler="json.dumps").respond() + + topology can be: + router - root router + multiple child route states/models + route is usually determined by the path (route key/name) + can specify special router class and router arguments + + flow - workflow (DAG) with a chain of states + flow support "sync" and "async" engines, branches are not allowed in sync mode + when using async mode calling state.respond() will mark the state as the + one which generates the (REST) call response + + :param topology: - graph topology, router or flow + :param class_name: - optional for router, router class name/path + :param engine: - optional for flow, sync or async engine (default to async) + :param exist_ok: - allow overriding existing topology + :param class_args: - optional, router/flow class init args + + :return graph object (fn.spec.graph) + """ topology = topology or StateKinds.router if self.spec.graph and not exist_ok: - raise ValueError("graph topology is already set") + raise mlrun.errors.MLRunInvalidArgumentError( + "graph topology is already set, cannot be overwritten" + ) - # currently we only support router topology - if topology != StateKinds.router: - raise NotImplementedError("currently only supporting router topology") - self.spec.graph = ServingRouterState( - class_name=class_name, class_args=class_args - ) + if topology == StateKinds.router: + self.spec.graph = RouterState(class_name=class_name, class_args=class_args) + elif topology == StateKinds.flow: + self.spec.graph = RootFlowState(engine=engine) + else: + raise mlrun.errors.MLRunInvalidArgumentError( + f"unsupported topology {topology}, use 'router' or 'flow'" + ) + return self.spec.graph def set_tracking(self, stream_path, batch=None, sample=None): """set tracking log stream parameters""" @@ -176,7 +229,7 @@ def add_model( handler=None, **class_args, ): - """add ml model and/or route to the function + """add ml model and/or route to the function. Example, create a function (from the notebook), add a model class, and deploy: @@ -184,6 +237,9 @@ def add_model( fn.add_model('boost', model_path, model_class='MyClass', my_arg=5) fn.deploy() + only works with router topology, for nested topologies (model under router under flow) + need to add router to flow and use router.add_route() + :param key: model api key (or name:version), will determine the relative url/path :param model_path: path to mlrun model artifact or model directory file/object path :param class_name: V2 Model python class name @@ -193,10 +249,17 @@ def add_model( :param class_args: extra kwargs to pass to the model serving class __init__ (can be read in the model using .get_param(key) method) """ + graph = self.spec.graph + if not graph: + self.set_topology() + + if graph.kind != StateKinds.router: + raise ValueError("models can only be added under router state") + if not model_path and not model_url: raise ValueError("model_path or model_url must be provided") class_name = class_name or self.spec.default_class - if not isinstance(class_name, str): + if class_name and not isinstance(class_name, str): raise ValueError( "class name must be a string (name ot module.submodule.name)" ) @@ -205,58 +268,126 @@ def add_model( if model_path: model_path = str(model_path) - if not self.spec.graph: - self.set_topology() - if model_url: - route = new_remote_endpoint(model_url, **class_args) + state = new_remote_endpoint(model_url, **class_args) else: - route = new_model_endpoint(class_name, model_path, handler, **class_args) - self.spec.graph.add_route(key, route) + state = new_model_endpoint(class_name, model_path, handler, **class_args) + + return graph.add_route(key, state) + + def add_child_function( + self, name, url=None, image=None, requirements=None, kind=None + ): + """in a multi-function pipeline add child function - def remove_models(self, keys: list): - """remove one, multiple, or all models from the spec (blank list for all)""" + example: + fn.add_child_function('enrich', './enrich.ipynb', 'mlrun/mlrun') + + :param name: - child function name + :param url: - function/code url, support .py, .ipynb, .yaml extensions + :param image: - base docker image for the function + :param requirements - py package requirements file path OR list of packages + :param kind: - mlrun function/runtime kind + + :return function object + """ + function_reference = FunctionReference( + url, image, requirements=requirements, kind=kind or "serving" + ) + self._spec.function_refs.update(function_reference, name) + func = function_reference.to_function(self.kind) + func.set_env("SERVING_CURRENT_FUNCTION", function_reference.name) + return func + + def _add_ref_triggers(self): + """add stream trigger to downstream child functions""" + for function_name, stream in self.spec.graph.get_queue_links().items(): + if stream.path: + if function_name not in self._spec.function_refs.keys(): + raise ValueError(f"function reference {function_name} not present") + group = stream.options.get("group", "serving") + + child_function = self._spec.function_refs[function_name] + child_function.function_object().add_stream_trigger( + stream.path, group=group, shards=stream.shards + ) + + def _deploy_function_refs(self): + """set metadata and deploy child functions""" + for function_ref in self._spec.function_refs.values(): + logger.info(f"deploy child function {function_ref.name} ...") + function_object = function_ref.function_object + function_object.metadata.name = function_ref.fullname(self) + function_object.metadata.project = self.metadata.project + function_object.metadata.tag = self.metadata.tag + function_object.spec.graph = self.spec.graph + # todo: may want to copy parent volumes to child functions + function_object.apply(mlrun.v3io_cred()) + function_ref.db_uri = function_object._function_uri() + function_object.verbose = self.verbose + function_object.deploy() + + def remove_states(self, keys: list): + """remove one, multiple, or all states/models from the spec (blank list for all)""" if self.spec.graph: - self.spec.graph.clear_routes(keys) + self.spec.graph.clear_children(keys) - def deploy(self, dashboard="", project="", tag=""): + def deploy(self, dashboard="", project="", tag="", verbose=False): """deploy model serving function to a local/remote cluster :param dashboard: remote nuclio dashboard url (blank for local or auto detection) :param project: optional, overide function specified project name :param tag: specify unique function tag (a different function service is created for every tag) + :param verbose: verbose logging """ load_mode = self.spec.load_mode if load_mode and load_mode not in ["sync", "async"]: raise ValueError(f"illegal model loading mode {load_mode}") if not self.spec.graph: raise ValueError("nothing to deploy, .spec.graph is none, use .add_model()") - return super().deploy(dashboard, project, tag) + + if self.spec.graph.kind != StateKinds.router: + # initialize or create required streams/queues + self.spec.graph.check_and_process_graph() + self.spec.graph.init_queues() + if self._spec.function_refs: + # deploy child functions + self._add_ref_triggers() + self._deploy_function_refs() + logger.info(f"deploy root function {self.metadata.name} ...") + return super().deploy(dashboard, project, tag, verbose=verbose) def _get_runtime_env(self): - # we currently support a minimal topology of one router + multiple child routes/models - # in the future we will extend the support to a full graph, the spec is already built accordingly + + function_name_uri_map = {f.name: f.uri(self) for f in self.spec.function_refs} serving_spec = { "function_uri": self._function_uri(), "version": "v2", "parameters": self.spec.parameters, "graph": self.spec.graph.to_dict(), "load_mode": self.spec.load_mode, - "verbose": self.verbose, + "functions": function_name_uri_map, + "graph_initializer": self.spec.graph_initializer, + "error_stream": self.spec.error_stream, } return {"SERVING_SPEC_ENV": json.dumps(serving_spec)} - def to_mock_server(self, namespace=None, log_level="debug"): + def to_mock_server( + self, namespace=None, current_function=None, **kwargs + ) -> GraphServer: """create mock server object for local testing/emulation :param namespace: classes search namespace, use globals() for current notebook :param log_level: log level (error | info | debug) + :param current_function: specify if you want to simulate a child function """ - return create_mock_server( + server = create_graph_server( parameters=self.spec.parameters, load_mode=self.spec.load_mode, graph=self.spec.graph, - namespace=namespace, - logger=logger, - level=log_level, + verbose=self.verbose, + current_function=current_function, + **kwargs, ) + server.init(None, namespace or get_caller_globals(), logger=logger) + return server diff --git a/mlrun/serving/server.py b/mlrun/serving/server.py index ea7a8468490..4f441d33ed4 100644 --- a/mlrun/serving/server.py +++ b/mlrun/serving/server.py @@ -17,9 +17,19 @@ import sys import traceback import uuid -from copy import deepcopy - -from .states import ServingRouterState, ServingTaskState +from typing import Union + +import mlrun +from mlrun.secrets import SecretsStore +from mlrun.config import config + +from .states import ( + RouterState, + RootFlowState, + get_function, + graph_root_setter, +) +from ..errors import MLRunInvalidArgumentError from ..model import ModelObj from ..platforms.iguazio import OutputStream from ..utils import create_logger, get_caller_globals @@ -37,9 +47,7 @@ def __init__(self, parameters, function_uri): self.output_stream = OutputStream(out_stream) -# Model server host currently support a basic topology of single router + multiple -# routes (models/tasks). it will be enhanced later to support more complex topologies -class ModelServerHost(ModelObj): +class GraphServer(ModelObj): kind = "server" def __init__( @@ -50,119 +58,148 @@ def __init__( load_mode=None, verbose=False, version=None, + functions=None, + graph_initializer=None, + error_stream=None, ): self._graph = None - self.graph: ServingRouterState = graph + self.graph: RouterState = graph self.function_uri = function_uri self.parameters = parameters or {} self.verbose = verbose self.load_mode = load_mode or "sync" self.version = version or "v2" self.context = None - self._namespace = None + self._current_function = None + self.functions = functions or {} + self.graph_initializer = graph_initializer + self.error_stream = error_stream + self._error_stream_object = None + self._secrets = SecretsStore() + self._db_conn = None + self.resource_cache = None + + def set_current_function(self, function): + """set which child function this server is currently running on""" + self._current_function = function @property - def graph(self) -> ServingRouterState: + def graph(self) -> Union[RootFlowState, RouterState]: return self._graph @graph.setter def graph(self, graph): - self._graph = self._verify_dict(graph, "spec", ServingRouterState) + graph_root_setter(self, graph) + + def set_error_stream(self, error_stream): + """set/initialize the error notification stream""" + self.error_stream = error_stream + if error_stream: + self._error_stream_object = OutputStream(error_stream) + else: + self._error_stream_object = None - def merge_root_params(self, params={}): - """for internal use, enrich child states with root params""" - for key, val in self.parameters.items(): - if key not in params: - params[key] = val - return params + def _get_db(self): + return mlrun.get_run_db(secrets=self._secrets) - def init(self, context, namespace): + def init(self, context, namespace, resource_cache=None, logger=None): """for internal use, initialize all states (recursively)""" + + if self.error_stream: + self._error_stream_object = OutputStream(self.error_stream) + self.resource_cache = resource_cache + context = GraphContext(server=self, nuclio_context=context, logger=logger) + + context.stream = _StreamContext(self.parameters, self.function_uri) + context.current_function = self._current_function + context.verbose = self.verbose + context.root = self.graph self.context = context - # enrich the context with classes and methods which will be used when - # initializing classes or handling the event - setattr(context, "stream", _StreamContext(self.parameters, self.function_uri)) - setattr(context, "merge_root_params", self.merge_root_params) - setattr(context, "verbose", self.verbose) + + if self.graph_initializer: + if callable(self.graph_initializer): + handler = self.graph_initializer + else: + handler = get_function(self.graph_initializer, namespace) + handler(self) self.graph.init_object(context, namespace, self.load_mode) - setattr(self.context, "root", self.graph) return v2_serving_handler - def add_model( - self, name, class_name, model_path, handler=None, namespace=None, **class_args - ): - """add child model/route to the server, will register, init and connect the child class - the local or global (module.submodule.class) class specified by the class_name - the context, name, model_path, and **class_args will be used to initialize that class - - every event with "/{router.url_prefix}/{name}/.." or "{name}/.." will be routed to the class. - - keep the handler=None for model server classes, for custom classes you can specify the class handler - which will be invoked when traffic arrives to that route (class.{handler}(event)) - - :param name: name (and url prefix) used for the route/model - :param class_name: class object or name (str) or full path (module.submodule.class) - :param model_path: path to mlrun model artifact or model directory file/object path - :param handler: for advanced users!, override default class handler name (do_event) - :param namespace: class search path when using string_name, for local use py globals() - :param class_args: extra kwargs to pass to the model serving class __init__ - (can be read in the model using .get_param(key) method) - """ - class_args = deepcopy(class_args) - class_args["model_path"] = model_path - route = ServingTaskState(class_name, class_args, handler) - namespace = namespace or get_caller_globals() - self.graph.add_route(name, route).init_object(self.context, namespace) - def test( - self, path, body, method="", content_type=None, silent=False, get_body=True + self, + path="/", + body=None, + method="", + content_type=None, + silent=False, + get_body=True, ): """invoke a test event into the server to simulate/test server behaviour e.g.: - server = create_mock_server() + server = create_graph_server() server.add_model("my", class_name=MyModelClass, model_path="{path}", z=100) print(server.test("my/infer", testdata)) - :param path: relative ({route-name}/..) or absolute (/{router.url_prefix}/{name}/..) path - :param body: message body (dict or json str/bytes) - :param method: optional, GET, POST, .. + :param path: api path, e.g. (/{router.url_prefix}/{model-name}/..) path + :param body: message body (dict or json str/bytes) + :param method: optional, GET, POST, .. :param content_type: optional, http mime type - :param silent: dont raise on error responses (when not 20X) - :param get_body: return the body (vs serialize response into json) + :param silent: dont raise on error responses (when not 20X) + :param get_body: return the body as py object (vs serialize response into json) """ if not self.graph: - raise ValueError("no model or router was added, use .add_model()") - if not path.startswith("/"): - path = self.graph.object.url_prefix + path + raise MLRunInvalidArgumentError( + "no models or steps were set, use function.set_topology() and add steps" + ) event = MockEvent( body=body, path=path, method=method, content_type=content_type ) resp = v2_serving_handler(self.context, event, get_body=get_body) - if hasattr(resp, "status_code") and resp.status_code > 300 and not silent: + if hasattr(resp, "status_code") and resp.status_code >= 300 and not silent: raise RuntimeError(f"failed ({resp.status_code}): {resp.body}") return resp + def wait_for_completion(self): + """wait for async operation to complete""" + self.graph.wait_for_completion() + def v2_serving_init(context, namespace=None): + """hook for nuclio init_context()""" + data = os.environ.get("SERVING_SPEC_ENV", "") if not data: - raise ValueError("failed to find spec env var") + raise MLRunInvalidArgumentError("failed to find spec env var") spec = json.loads(data) - server = ModelServerHost.from_dict(spec) - serving_handler = server.init(context, namespace or globals()) + server = GraphServer.from_dict(spec) + if config.log_level.lower() == "debug": + server.verbose = True + server.set_current_function(os.environ.get("SERVING_CURRENT_FUNCTION", "")) + serving_handler = server.init(context, namespace or get_caller_globals()) # set the handler hook to point to our handler setattr(context, "mlrun_handler", serving_handler) + setattr(context, "root", server.graph) + context.logger.info(f"serving was initialized, verbose={server.verbose}") + if server.verbose: + context.logger.info(server.to_yaml()) def v2_serving_handler(context, event, get_body=False): + """hook for nuclio handler()""" + try: response = context.root.run(event) except Exception as e: + message = str(e) if context.verbose: - context.logger.error(traceback.format_exc()) - return context.Response(body=str(e), content_type="text/plain", status_code=400) + message += "\n" + str(traceback.format_exc()) + context.logger.error(f"run error, {traceback.format_exc()}") + context.push_error(event, message, source="_handler") + return context.Response( + body=message, content_type="text/plain", status_code=400 + ) body = response.body if isinstance(body, context.Response) or get_body: @@ -176,32 +213,26 @@ def v2_serving_handler(context, event, get_body=False): return body -def create_mock_server( - context=None, - router_class=None, - router_args={}, +def create_graph_server( parameters={}, load_mode=None, graph=None, - namespace=None, - logger=None, - level="debug", -): - """create serving emulator/tester for locally testing models and servers + verbose=False, + current_function=None, + **kwargs, +) -> GraphServer: + """create serving host/emulator for local or test runs Usage: - host = create_mock_server() - host.add_model("my", class_name=MyModelClass, model_path="{path}", z=100) - print(host.test("my/infer", testdata)) + server = create_graph_server(graph=RouterState(), parameters={}) + server.init(None, globals()) + server.graph.add_route("my", class_name=MyModelClass, model_path="{path}", z=100) + print(server.test("/v2/models/my/infer", testdata)) """ - if not context: - context = MockContext(level, logger=logger) - - if not graph: - graph = ServingRouterState(class_name=router_class, class_args=router_args) - namespace = namespace or get_caller_globals() - server = ModelServerHost(graph, parameters, load_mode, verbose=level == "debug") - server.init(context, namespace or {}) + server = GraphServer(graph, parameters, load_mode, verbose=verbose, **kwargs) + server.set_current_function( + current_function or os.environ.get("SERVING_CURRENT_FUNCTION", "") + ) return server @@ -222,9 +253,11 @@ def __init__( self.path = path or "/" self.content_type = content_type self.trigger = None + self.error = None def __str__(self): - return f"Event(id={self.id}, body={self.body}, method={self.method}, path={self.path})" + error = f", error={self.error}" if self.error else "" + return f"Event(id={self.id}, body={self.body}, method={self.method}, path={self.path}{error})" class Response(object): @@ -241,11 +274,57 @@ def __repr__(self): return "{}({})".format(cls, ", ".join(args)) -class MockContext: - """mock basic nuclio context object""" +class GraphContext: + """Graph context object""" - def __init__(self, level="debug", logger=None): + def __init__(self, level="debug", logger=None, server=None, nuclio_context=None): self.state = None - self.logger = logger or create_logger(level, "human", "flow", sys.stdout) + self.logger = logger self.worker_id = 0 self.Response = Response + self.verbose = False + self.stream = None + self.root = None + + if nuclio_context: + self.logger = nuclio_context.logger + self.Response = nuclio_context.Response + self.worker_id = nuclio_context.worker_id + elif not logger: + self.logger = create_logger(level, "human", "flow", sys.stdout) + + self._server = server + self.current_function = None + self.get_data_resource = None + self.get_table = None + + def push_error(self, event, message, source=None, **kwargs): + if self.verbose: + self.logger.error( + f"got error from {source} state:\n{event.body}\n{message}" + ) + if self._server and self._server._error_stream_object: + message = format_error(self._server, self, source, event, message, kwargs) + self._server._error_stream_object.push(message) + + def get_param(self, key: str, default=None): + if self._server and self._server.parameters: + return self.parameters.get(key, default) + return default + + def get_secret(self, key: str): + if self._server and self._server._secrets: + return self._secrets.get(key) + return None + + +def format_error(server, context, source, event, message, args): + return { + "function_uri": server.function_uri, + "worker": context.worker_id, + "host": socket.gethostname(), + "source": source, + "event": {"id": event.id, "body": event.body}, + "message": message, + "args": args, + } diff --git a/mlrun/serving/states.py b/mlrun/serving/states.py index e6db21529f8..cf24c29ef31 100644 --- a/mlrun/serving/states.py +++ b/mlrun/serving/states.py @@ -11,20 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import json -from copy import deepcopy +import os +import pathlib +import traceback +from copy import deepcopy, copy +from inspect import getfullargspec +from typing import Union from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry import requests +from ..platforms.iguazio import OutputStream from ..model import ModelObj, ObjectDict -from ..utils import create_class +from ..utils import get_function, get_class +from ..errors import MLRunInvalidArgumentError + +callable_prefix = "_" +path_splitter = "/" + + +class GraphError(Exception): + """error in graph topology or configuration""" + + pass class StateKinds: router = "router" task = "task" + flow = "flow" + queue = "queue" + choice = "choice" + root = "root" _task_state_fields = [ @@ -33,88 +54,248 @@ class StateKinds: "class_args", "handler", "skip_context", - "next", - "resource", + "after", + "function", "comment", - "end", + "shape", + "full_event", + "on_error", ] def new_model_endpoint(class_name, model_path, handler=None, **class_args): class_args = deepcopy(class_args) class_args["model_path"] = model_path - return ServingTaskState(class_name, class_args, handler=handler) + return TaskState(class_name, class_args, handler=handler) def new_remote_endpoint(url, **class_args): class_args = deepcopy(class_args) class_args["url"] = url - return ServingTaskState("$remote", class_args) + return TaskState("$remote", class_args) class BaseState(ModelObj): kind = "BaseState" - _dict_fields = ["kind", "comment", "next", "end", "resource"] + default_shape = "ellipse" + _dict_fields = ["kind", "comment", "after", "on_error"] - def __init__(self, name=None, next=None): + def __init__(self, name: str = None, after: list = None, shape: str = None): self.name = name self._parent = None self.comment = None self.context = None - self.next = next - self.end = None - self.resource = None + self.after = after + self._next = None + self.shape = shape + self.on_error = None + self._on_error_handler = None + + def get_shape(self): + """graphviz shape""" + return self.shape or self.default_shape def set_parent(self, parent): + """set/link the state parent (flow/router)""" self._parent = parent - def set_next(self, key): + @property + def next(self): + return self._next + + @property + def parent(self): + """state parent (flow/router)""" + return self._parent + + def set_next(self, key: str): + """set/insert the key as next after this state, optionally remove other keys""" if not self.next: - self.next = [key] + self._next = [key] elif key not in self.next: - self.next.append(key) - - def init_object(self, context, namespace, mode="sync", **extra_kwargs): + self._next.append(key) + return self + + def after_state(self, after): + """specify the previous state name""" + # most states only accept one source + self.after = [after] if after else [] + return self + + def error_handler(self, state_name: str): + """set error handler state (on failure/raise of this state)""" + self.on_error = state_name + return self + + def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): + """init the state class""" self.context = context + def _is_local_function(self, context): + return True + def get_children(self): + """get child states (for router/flow)""" return [] + def __iter__(self): + yield from [] + @property def fullname(self): - name = self.name - if self._parent: - name = ".".join([self._parent.fullname, name]) - return name + """full path/name (include parents)""" + name = self.name or "" + if self._parent and self._parent.fullname: + name = path_splitter.join([self._parent.fullname, name]) + return name.replace(":", "_") # replace for graphviz escaping def _post_init(self, mode="sync"): pass + def _set_error_handler(self): + """init/link the error handler for this state""" + if self.on_error: + error_state = self.context.root.path_to_state(self.on_error) + self._on_error_handler = error_state.run + + def _log_error(self, event, err, **kwargs): + """on failure log (for sync mode)""" + self.context.logger.error( + f"state {self.name} got error {err} when processing an event:\n {event.body}" + ) + message = traceback.format_exc() + self.context.logger.error(message) + self.context.push_error( + event, f"{err}\n{message}", source=self.fullname, **kwargs + ) + + def _call_error_handler(self, event, err, **kwargs): + """call the error handler if exist""" + if self._on_error_handler: + event.error = str(err) + event.origin_state = self.fullname + return self._on_error_handler(event) + + def path_to_state(self, path: str): + """return state object from state relative/fullname""" + path = path or "" + tree = path.split(path_splitter) + next_level = self + for state in tree: + if state not in next_level: + raise GraphError( + f"step {state} doesnt exist in the graph under {next_level.fullname}" + ) + next_level = next_level[state] + return next_level + + def to( + self, + class_name: Union[str, type] = None, + name: str = None, + handler: str = None, + graph_shape: str = None, + function: str = None, + full_event: bool = None, + **class_args, + ): + """add a state right after this state and return the new state + + example, a 4 step pipeline ending with a stream: + graph.to('URLDownloader')\ + .to('ToParagraphs')\ + .to(name='to_json', handler='json.dumps')\ + .to('>', 'to_v3io', path=stream_path)\ + + :param class_name: class name or state object to build the state from + for router states the class name should start with '*' + for queue/stream state the class should be '>>' or '$queue' + :param name: unique name (and path) for the child state, default is class name + :param handler: class/function handler to invoke on run/event + :param graph_shape: graphviz shape name + :param function: function this state should run in + :param full_event: this step accepts the full event (not just body) + :param class_args: class init arguments + """ + if hasattr(self, "states"): + parent = self + elif self._parent: + parent = self._parent + else: + raise GraphError( + f"state {self.name} parent is not set or its not part of a graph" + ) + + name, state = params_to_state( + class_name, + name, + handler, + graph_shape=graph_shape, + function=function, + full_event=full_event, + class_args=class_args, + ) + state = parent._states.update(name, state) + state.set_parent(parent) + if not hasattr(self, "states"): + # check that its not the root, todo: in future may gave nested flows + state.after_state(self.name) + parent._last_added = state + return state + + +class TaskState(BaseState): + """task execution state, runs a class or handler""" -class ServingTaskState(BaseState): kind = "task" _dict_fields = _task_state_fields _default_class = "" def __init__( - self, class_name=None, class_args=None, handler=None, name=None, next=None + self, + class_name: Union[str, type] = None, + class_args: dict = None, + handler: str = None, + name: str = None, + after: list = None, + full_event: bool = None, + function: str = None, + responder: bool = None, ): - super().__init__(name, next) + super().__init__(name, after) self.class_name = class_name self.class_args = class_args or {} self.handler = handler + self.function = function self._handler = None self._object = None + self._async_object = None self.skip_context = None self.context = None self._class_object = None + self.responder = responder + self.full_event = full_event + self.on_error = None + + def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): + self.context = context + if not self._is_local_function(context): + # skip init of non local functions + return + + if self.handler and not self.class_name: + # link to function + if callable(self.handler): + self._handler = self.handler + self.handler = self.handler.__name__ + else: + self._handler = get_function(self.handler, namespace) + return - def init_object(self, context, namespace, mode="sync", **extra_kwargs): if isinstance(self.class_name, type): self._class_object = self.class_name self.class_name = self.class_name.__name__ - self.context = context if not self._class_object: if self.class_name == "$remote": self._class_object = RemoteHttpHandler @@ -123,74 +304,178 @@ def init_object(self, context, namespace, mode="sync", **extra_kwargs): self.class_name or self._default_class, namespace ) - if not self._object: - class_args = {k: v for k, v in self.class_args.items()} + if not self._object or reset: + # init the state class + args + class_args = {} + for key, arg in self.class_args.items(): + if key.startswith(callable_prefix): + class_args[key[1:]] = get_function(arg, namespace) + else: + class_args[key] = arg class_args.update(extra_kwargs) - if self.skip_context is None or not self.skip_context: - class_args["name"] = self.name + + # add name and context only if target class can accept them + argspec = getfullargspec(self._class_object) + if argspec.varkw or "context" in argspec.args: class_args["context"] = self.context - self._object = self._class_object(**class_args) - self._handler = getattr(self._object, self.handler or "do_event", None) + if argspec.varkw or "name" in argspec.args: + class_args["name"] = self.name + + try: + self._object = self._class_object(**class_args) + except TypeError as e: + raise TypeError( + f"failed to init state {self.name}, {e}\n args={self.class_args}" + ) + # determine the right class handler to use + handler = self.handler + if handler: + if not hasattr(self._object, handler): + raise GraphError( + f"handler ({handler}) specified but doesnt exist in class {self.class_name}" + ) + else: + if hasattr(self._object, "do"): + handler = "do" + elif hasattr(self._object, "do_event"): + handler = "do_event" + self.full_event = True + if handler: + self._handler = getattr(self._object, handler, None) + + self._set_error_handler() if mode != "skip": self._post_init(mode) + def _is_local_function(self, context): + # detect if the class is local (and should be initialized) + current_function = get_current_function(context) + if current_function == "*": + return True + if not self.function and not current_function: + return True + if ( + self.function and self.function == "*" + ) or self.function == current_function: + return True + return False + @property - def object(self): - return self._object + def async_object(self): + """return the sync or async (storey) class instance""" + return self._async_object or self._object + + def clear_object(self): + self._object = None def _post_init(self, mode="sync"): if self._object and hasattr(self._object, "post_init"): self._object.post_init(mode) + def respond(self): + """mark this state as the responder. + + state output will be returned as the flow result, no other state can follow + """ + self.responder = True + return self + def run(self, event, *args, **kwargs): - return self._handler(event, *args, **kwargs) + """run this state, in async flows the run is done through storey""" + if not self._is_local_function(self.context): + # todo invoke remote via REST call + return event + + if self.context.verbose: + self.context.logger.info(f"state {self.name} got event {event.body}") + + try: + if self.full_event: + return self._handler(event, *args, **kwargs) + event.body = self._handler(event.body, *args, **kwargs) + except Exception as e: + self._log_error(event, e) + handled = self._call_error_handler(event, e) + if not handled: + raise e + event.terminated = True + return event -class ServingRouterState(ServingTaskState): +class RouterState(TaskState): + """router state, implement routing logic for running child routes""" + kind = "router" + default_shape = "doubleoctagon" _dict_fields = _task_state_fields + ["routes"] _default_class = "mlrun.serving.ModelRouter" def __init__( - self, class_name=None, class_args=None, handler=None, routes=None, name=None + self, + class_name: Union[str, type] = None, + class_args: dict = None, + handler: str = None, + routes: list = None, + name: str = None, + function: str = None, ): - super().__init__(class_name, class_args, handler, name=name) - self._routes = {} + super().__init__(class_name, class_args, handler, name=name, function=function) + self._routes: ObjectDict = None self.routes = routes def get_children(self): + """get child states (routes)""" return self._routes.values() @property def routes(self): + """child routes/states, traffic is routed to routes based on router logic""" return self._routes @routes.setter def routes(self, routes: dict): self._routes = ObjectDict.from_dict(classes_map, routes, "task") - def add_route(self, key, route): + def add_route(self, key, route=None, class_name=None, handler=None, **class_args): + """add child route state or class to the router + + :param key: unique name (and route path) for the child state + :param route: child state object (Task, ..) + :param class_name: class name to build the route state from (when route is not provided) + :param class_args: class init arguments + :param handler: class handler to invoke on run/event + """ + + if not route and not class_name: + raise MLRunInvalidArgumentError("route or class_name must be specified") + if not route: + route = TaskState(class_name, class_args, handler=handler) route = self._routes.update(key, route) route.set_parent(self) return route - def clear_routes(self, routes: list): + def clear_children(self, routes: list): + """clear child states (routes)""" if not routes: routes = self._routes.keys() for key in routes: del self._routes[key] - def init_object(self, context, namespace, mode="sync", **extra_kwargs): + def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): + if not self._is_local_function(context): + return + self.class_args = self.class_args or {} super().init_object( - context, namespace, "skip", routes=self._routes, **extra_kwargs + context, namespace, "skip", reset=reset, routes=self._routes, **extra_kwargs ) for route in self._routes.values(): route.set_parent(self) - route.init_object(context, namespace, mode) + route.init_object(context, namespace, mode, reset=reset) + self._set_error_handler() self._post_init(mode) def __getitem__(self, name): @@ -205,55 +490,208 @@ def __delitem__(self, key): def __iter__(self): yield from self._routes.keys() + def plot(self, filename=None, format=None, source=None, **kw): + """plot/save a graphviz plot""" + return _generate_graphviz( + self, _add_graphviz_router, filename, format, source=source, **kw + ) + + +class QueueState(BaseState): + """queue state, implement an async queue or represent a stream""" + + kind = "queue" + default_shape = "cds" + _dict_fields = BaseState._dict_fields + [ + "path", + "shards", + "retention_in_hours", + "options", + ] + + def __init__( + self, + name: str = None, + path: str = None, + after: list = None, + shards: int = None, + retention_in_hours: int = None, + **options, + ): + super().__init__(name, after) + self.path = path + self.shards = shards + self.retention_in_hours = retention_in_hours + self.options = options + self._stream = None + self._async_object = None + + def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): + self.context = context + if self.path: + self._stream = OutputStream(self.path, self.shards, self.retention_in_hours) + self._set_error_handler() + + @property + def async_object(self): + return self._async_object + + def after_state(self, after): + # queue states accept multiple sources + if self.after: + if after: + self.after.append(after) + else: + self.after = [after] if after else [] + return self + + def run(self, event, *args, **kwargs): + data = event.body + if not data: + return event + + if self._stream: + self._stream.push({"id": event.id, "body": data, "path": event.path}) + event.terminated = True + event.body = None + return event + + +class FlowState(BaseState): + """flow state, represent a workflow or DAG""" -class ServingFlowState(BaseState): kind = "flow" - _dict_fields = BaseState._dict_fields + ["states", "start_at"] + _dict_fields = BaseState._dict_fields + [ + "states", + "engine", + "default_final_state", + ] - def __init__(self, name=None, states=None, next=None, start_at=None): - super().__init__(name, next) + def __init__( + self, name=None, states=None, after: list = None, engine=None, final_state=None, + ): + super().__init__(name, after) self._states = None self.states = states - self.start_at = start_at - self.from_state = None + self.engine = engine + self.from_state = os.environ.get("START_FROM_STATE", None) + self.final_state = final_state + self._last_added = None + self._controller = None + self._wait_for_result = False + self._source = None + self._start_states = [] def get_children(self): return self._states.values() @property def states(self): + """child (workflow) states""" return self._states + @property + def controller(self): + """async (storey) flow controller""" + return self._controller + @states.setter def states(self, states): self._states = ObjectDict.from_dict(classes_map, states, "task") - def add_state(self, key, state, after=None): + def add_step( + self, + class_name=None, + name=None, + handler=None, + after=None, + before=None, + graph_shape=None, + function=None, + **class_args, + ): + """add task, queue or router state/class to the flow + + use after/before to insert into a specific location + + example: + graph = fn.set_topology("flow", exist_ok=True) + graph.add_step(class_name="Chain", name="s1") + graph.add_step(class_name="Chain", name="s3", after="$prev") + graph.add_step(class_name="Chain", name="s2", after="s1", before="s3") + + :param class_name: class name or state object to build the state from + for router states the class name should start with '*' + for queue/stream state the class should be '>>' or '$queue' + :param name: unique name (and path) for the child state, default is class name + :param handler: class/function handler to invoke on run/event + :param after: the step name this step comes after + can use $prev to indicate the last added state + :param before: string or list of next step names that will run after this step + :param graph_shape: graphviz shape name + :param function: function this state should run in + :param class_args: class init arguments + """ + + name, state = params_to_state( + class_name, + name, + handler, + graph_shape=graph_shape, + function=function, + class_args=class_args, + ) + + self.insert_state(name, state, after, before) + return state + + def insert_state(self, key, state, after, before=None): + """insert state object into the flow, specify before and after""" + state = self._states.update(key, state) state.set_parent(self) - if not self.start_at and len(self._states) <= 1: - self.start_at = key + if after == "$prev" and len(self._states) == 1: + after = None + previous = "" if after: - if isinstance(after, str): + if after == "$prev" and self._last_added: + previous = self._last_added.name + else: if after not in self._states.keys(): - raise ValueError( - f"there is no state named {after}, cant set next state" + raise MLRunInvalidArgumentError( + f"cant set after, there is no state named {after}" ) - after = self._states[after] - after.set_next(key) - elif self._last_added: - self._last_added.set_next(key) + previous = after + state.after_state(previous) + + if before: + if before not in self._states.keys(): + raise MLRunInvalidArgumentError( + f"cant set before, there is no state named {before}" + ) + if before == state.name or before == previous: + raise GraphError( + f"graph loop, state {before} is specified in before and/or after {key}" + ) + self[before].after_state(state.name) self._last_added = state return state + def clear_children(self, states: list = None): + """remove some or all of the states, empty/None for all""" + if not states: + states = self._states.keys() + for key in states: + del self._states[key] + def __getitem__(self, name): return self._states[name] def __setitem__(self, name, state): - self.add_state(name, state) + self.add_step(name, state) def __delitem__(self, key): del self._states[key] @@ -261,35 +699,273 @@ def __delitem__(self, key): def __iter__(self): yield from self._states.keys() - def init_object(self, context, namespace, mode="sync", **extra_kwargs): + def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): self.context = context + self.check_and_process_graph() + for state in self._states.values(): state.set_parent(self) - state.init_object(context, namespace, mode) + state.init_object(context, namespace, mode, reset=reset) + self._set_error_handler() self._post_init(mode) - def get_start_state(self, from_state=None): - from_state = from_state or self.from_state or self.start_at - if not from_state: - raise ValueError( - f"start step {from_state} was not specified in {self.name}" + if self.engine != "sync": + self._build_async_flow() + + def check_and_process_graph(self, allow_empty=False): + """validate correct graph layout and initialize the .next links""" + + if self.is_empty() and allow_empty: + self._start_states = [] + return [], None, [] + + def has_loop(state, previous): + for next_state in state.after or []: + if next_state in previous: + return state.name + downstream = has_loop(self[next_state], previous + [next_state]) + if downstream: + return downstream + return None + + start_states = [] + for state in self._states.values(): + state._next = None + if state.after: + loop_state = has_loop(state, []) + if loop_state: + raise GraphError( + f"Error, loop detected in state {loop_state}, graph must be acyclic (DAG)" + ) + else: + start_states.append(state.name) + + responders = [] + for state in self._states.values(): + if hasattr(state, "responder") and state.responder: + responders.append(state.name) + if state.on_error and state.on_error in start_states: + start_states.remove(state.on_error) + if state.after: + prev_state = state.after[0] + self[prev_state].set_next(state.name) + + if ( + not start_states + ): # for safety, not sure if its possible to get here (since its a loop) + raise GraphError("there are no starting states (ones without .after)") + + if ( + len(responders) > 1 + ): # should not have multiple steps which respond to request + raise GraphError( + f'there are more than one responder states in the graph ({",".join(responders)})' ) - tree = from_state.split(".") - next_obj = self - for state in tree: - if state not in next_obj.keys(): - raise ValueError(f"start step {from_state} doesnt exist in {self.name}") - next_obj = next_obj[state] - return next_obj + + if self.from_state: + if self.from_state not in self.states: + raise GraphError( + f"from_state ({self.from_state}) specified and not found in graph states" + ) + start_states = [self.from_state] + + self._start_states = [self[name] for name in start_states] + + def get_first_function_state(state, current_function): + # find the first state which belongs to the function + if ( + hasattr(state, "function") + and state.function + and state.function == current_function + ): + return state + for item in state.next or []: + next_state = self[item] + returned_state = get_first_function_state(next_state, current_function) + if returned_state: + return returned_state + + current_function = get_current_function(self.context) + if current_function: + new_start_states = [] + for from_state in self._start_states: + state = get_first_function_state(from_state, current_function) + if state: + new_start_states.append(state) + if not new_start_states: + raise GraphError( + f"did not find states pointing to current function ({current_function})" + ) + self._start_states = new_start_states + + if self.engine == "sync" and len(self._start_states) > 1: + raise GraphError( + "sync engine can only have one starting state (without .after)" + ) + + default_final_state = None + if self.final_state: + if self.final_state not in self.states: + raise GraphError( + f"final_state ({self.final_state}) specified and not found in graph states" + ) + default_final_state = self.final_state + + elif len(self._start_states) == 1: + # find the final state in case if a simple sequence of steps + next_obj = self._start_states[0] + while next_obj: + next = next_obj.next + if not next: + default_final_state = next_obj.name + break + next_obj = self[next[0]] if len(next) == 1 else None + + return self._start_states, default_final_state, responders + + def set_flow_source(self, source): + """set the async flow (storey) source""" + self._source = source + + def _build_async_flow(self): + """initialize and build the async/storey DAG""" + try: + import storey + except ImportError: + raise GraphError("storey package is not installed, use pip install storey") + + def process_step(state, step, root): + if not state._is_local_function(self.context): + return + for item in state.next or []: + next_state = root[item] + next_step = step.to(next_state.async_object) + process_step(next_state, next_step, root) + + for state in self._states.values(): + if hasattr(state, "async_object"): + if state.kind == StateKinds.queue: + if state.path: + state._async_object = storey.WriteToV3IOStream( + storey.V3ioDriver(), state.path + ) + else: + state._async_object = storey.Map(lambda x: x) + + elif not state.async_object or not hasattr( + state.async_object, "_outlets" + ): + # if regular class, wrap with storey Map + state._async_object = storey.Map( + state._handler, + full_event=state.full_event, + name=state.name, + context=self.context, + ) + if not state.next and hasattr(state, "responder") and state.responder: + # if responder state (return result), add Complete() + state.async_object.to(storey.Complete(full_event=True)) + self._wait_for_result = True + + # todo: allow source array (e.g. data->json loads..) + source = self._source or storey.Source() + for next_state in self._start_states: + next_step = source.to(next_state.async_object) + process_step(next_state, next_step, self) + + for state in self._states.values(): + # add error handler hooks + if state.on_error and state.async_object: + error_state = self._states[state.on_error] + state.async_object.set_recovery_step(error_state.async_object) + + self._controller = source.run() + + def get_queue_links(self): + """return dict of function and queue its listening on, for building stream triggers""" + links = {} + for state in self.get_children(): + if state.kind == StateKinds.queue: + for item in state.next or []: + next_state = self[item] + if next_state.function: + if next_state.function in links: + raise GraphError( + f"function ({next_state.function}) cannot read from multiple queues" + ) + links[next_state.function] = state + return links + + def init_queues(self): + """init/create the streams used in this flow""" + for state in self.get_children(): + if state.kind == StateKinds.queue: + state.init_object(self.context, None) + + def is_empty(self): + """is the graph empty (no child states)""" + return len(self.states) == 0 def run(self, event, *args, **kwargs): - next_obj = self.get_start_state(kwargs.get("from_state", None)) - return next_obj.run(event, *args, **kwargs) + + if self._controller: + # async flow (using storey) + event._awaitable_result = None + resp = self._controller.emit( + event, return_awaitable_result=self._wait_for_result + ) + if self._wait_for_result: + return resp.await_result() + event = copy(event) + event.body = {"id": event.id} + return event + + next_obj = self._start_states[0] + while next_obj: + try: + event = next_obj.run(event, *args, **kwargs) + except Exception as e: + self._log_error(event, e, failed_state=next_obj.name) + handled = self._call_error_handler(event, e) + if not handled: + raise e + event.terminated = True + return event + + if hasattr(event, "terminated") and event.terminated: + return event + next = next_obj.next + if next and len(next) > 1: + raise GraphError( + f"synchronous flow engine doesnt support branches use async, state={next_obj.name}" + ) + next_obj = self[next[0]] if next else None + return event + + def wait_for_completion(self): + """wait for completion of run in async flows""" + if self._controller: + self._controller.terminate() + self._controller.await_termination() + + def plot(self, filename=None, format=None, source=None, targets=None, **kw): + """plot/save graph using graphviz""" + return _generate_graphviz( + self, + _add_graphviz_flow, + filename, + format, + source=source, + targets=targets, + **kw, + ) -class ServingRootFlowState(ServingFlowState): - kind = "rootFlow" - _dict_fields = ["states", "start_at"] +class RootFlowState(FlowState): + """root flow state""" + + kind = "root" + _dict_fields = ["states", "engine", "final_state", "on_error"] http_adapter = HTTPAdapter( @@ -332,23 +1008,168 @@ def do_event(self, event): return event -def get_class(class_name, namespace): - """return class object from class name string""" - if isinstance(class_name, type): - return class_name - if class_name in namespace: - class_object = namespace[class_name] - return class_object +classes_map = { + "task": TaskState, + "router": RouterState, + "flow": FlowState, + "queue": QueueState, +} + +def get_current_function(context): + if context and hasattr(context, "current_function"): + return context.current_function or "" + return "" + + +def _add_graphviz_router(graph, state, source=None, **kwargs): + if source: + graph.node("_start", source.name, shape=source.shape, style="filled") + graph.edge("_start", state.fullname) + + graph.node(state.fullname, label=state.name, shape=state.get_shape()) + for route in state.get_children(): + graph.node(route.fullname, label=route.name, shape=route.get_shape()) + graph.edge(state.fullname, route.fullname) + + +def _add_graphviz_flow( + graph, state, source=None, targets=None, +): + start_states, default_final_state, responders = state.check_and_process_graph( + allow_empty=True + ) + graph.node("_start", source.name, shape=source.shape, style="filled") + for start_state in start_states: + graph.edge("_start", start_state.fullname) + for child in state.get_children(): + kind = child.kind + if kind == StateKinds.router: + with graph.subgraph(name="cluster_" + child.fullname) as sg: + _add_graphviz_router(sg, child) + else: + graph.node(child.fullname, label=child.name, shape=child.get_shape()) + after = child.after or [] + for item in after: + previous_object = state[item] + kw = ( + {"ltail": "cluster_" + child.fullname} + if child.kind == StateKinds.router + else {} + ) + graph.edge(previous_object.fullname, child.fullname, **kw) + if child.on_error: + graph.edge(child.fullname, child.on_error, style="dashed", **kw) + + # draw targets after the last state (if specified) + if targets: + for target in targets or []: + graph.node(target.fullname, label=target.name, shape=target.get_shape()) + last_state = target.after or default_final_state + if last_state: + graph.edge(last_state, target.fullname) + + +def _generate_graphviz( + state, renderer, filename=None, format=None, source=None, targets=None, **kw, +): try: - class_object = create_class(class_name) - except (ImportError, ValueError) as e: - raise ImportError(f"state init failed, class {class_name} not found, {e}") - return class_object + from graphviz import Digraph + except ImportError: + raise ImportError( + 'graphviz is not installed, run "pip install graphviz" first!' + ) + graph = Digraph("mlrun-flow", format="jpg") + graph.attr(compound="true", **kw) + source = source or BaseState("start", shape="egg") + renderer(graph, state, source=source, targets=targets) + if filename: + suffix = pathlib.Path(filename).suffix + if suffix: + filename = filename[: -len(suffix)] + format = format or suffix[1:] + format = format or "png" + graph.render(filename, format=format) + return graph + + +def graph_root_setter(server, graph): + """set graph root object from class or dict""" + if graph: + if isinstance(graph, dict): + kind = graph.get("kind") + elif hasattr(graph, "kind"): + kind = graph.kind + else: + raise MLRunInvalidArgumentError("graph must be a dict or a valid object") + if kind == StateKinds.router: + server._graph = server._verify_dict(graph, "graph", RouterState) + elif not kind or kind == StateKinds.root: + server._graph = server._verify_dict(graph, "graph", RootFlowState) + else: + raise GraphError(f"illegal root state {kind}") + + +def get_name(name, class_name): + """get task name from provided name or class""" + if name: + return name + if not class_name: + raise MLRunInvalidArgumentError("name or class_name must be provided") + if isinstance(class_name, type): + return class_name.__name__ + return class_name + + +def params_to_state( + class_name, + name, + handler=None, + graph_shape=None, + function=None, + full_event=None, + class_args=None, +): + """return state object from provided params or classes/objects""" + if class_name and hasattr(class_name, "to_dict"): + struct = class_name.to_dict() + kind = struct.get("kind", StateKinds.task) + name = name or struct.get("name", struct.get("class_name")) + cls = classes_map.get(kind, RootFlowState) + state = cls.from_dict(struct) + state.function = function + state.full_event = full_event + + elif class_name and class_name in [">>", "$queue"]: + if "path" not in class_args: + raise MLRunInvalidArgumentError( + "path= must be specified for queues" + ) + if not name: + raise MLRunInvalidArgumentError("queue name must be specified") + state = QueueState(name, **class_args) + + elif class_name and class_name.startswith("*"): + routes = class_args.get("routes", None) + class_name = class_name[1:] + name = get_name(name, class_name or "router") + state = RouterState( + class_name, class_args, handler, name=name, function=function, routes=routes + ) + elif class_name or handler: + name = get_name(name, class_name) + state = TaskState( + class_name, + class_args, + handler, + name=name, + function=function, + full_event=full_event, + ) + else: + raise MLRunInvalidArgumentError("class_name or handler must be provided") -classes_map = { - "task": ServingTaskState, - "router": ServingRouterState, - "flow": ServingFlowState, -} + if graph_shape: + state.shape = graph_shape + return name, state diff --git a/mlrun/serving/v2_serving.py b/mlrun/serving/v2_serving.py index 251cc0a52d8..4cf7417da98 100644 --- a/mlrun/serving/v2_serving.py +++ b/mlrun/serving/v2_serving.py @@ -70,7 +70,7 @@ def __init__( self.protocol = protocol or "v2" self.model_path = model_path self.model_spec: mlrun.artifacts.ModelArtifact = None - self._params = context.merge_root_params(class_args) + self._params = class_args self._model_logger = _ModelLogPusher(self, context) self.metrics = {} @@ -101,7 +101,9 @@ def post_init(self, mode="sync"): def get_param(self, key: str, default=None): """get param by key (specified in the model or the function)""" - return self._params.get(key, default) + if key in self._params: + return self._params.get(key) + return self.context.get_param(key, default=default) def set_metric(self, name: str, value): """set real time metric (for model monitoring)""" @@ -247,11 +249,11 @@ def postprocess(self, request: Dict) -> Dict: def predict(self, request: Dict) -> Dict: """model prediction operation""" - raise NotImplementedError + raise NotImplementedError() def explain(self, request: Dict) -> Dict: """model explain operation""" - raise NotImplementedError + raise NotImplementedError() class _ModelLogPusher: diff --git a/mlrun/utils/helpers.py b/mlrun/utils/helpers.py index 9de8c9d7b74..1b1931b8a24 100644 --- a/mlrun/utils/helpers.py +++ b/mlrun/utils/helpers.py @@ -17,6 +17,7 @@ import re import sys import time +from types import ModuleType from typing import Optional, Tuple from datetime import datetime, timezone from dateutil import parser @@ -724,6 +725,51 @@ def get_caller_globals(level=2): return None +def _module_to_namespace(namespace): + if isinstance(namespace, ModuleType): + members = inspect.getmembers( + namespace, lambda o: inspect.isfunction(o) or isinstance(o, type) + ) + return {key: mod for key, mod in members} + return namespace + + +def get_class(class_name, namespace): + """return class object from class name string""" + if isinstance(class_name, type): + return class_name + namespace = _module_to_namespace(namespace) + if namespace and class_name in namespace: + return namespace[class_name] + + try: + class_object = create_class(class_name) + except (ImportError, ValueError) as e: + raise ImportError(f"state init failed, class {class_name} not found, {e}") + return class_object + + +def get_function(function, namespace): + """return function callable object from function name string""" + if callable(function): + return function + + function = function.strip() + if function.startswith("("): + if not function.endswith(")"): + raise ValueError('function expression must start with "(" and end with ")"') + return eval("lambda event: " + function[1:-1], {}, {}) + namespace = _module_to_namespace(namespace) + if function in namespace: + return namespace[function] + + try: + function_object = create_function(function) + except (ImportError, ValueError) as e: + raise ImportError(f"state init failed, function {function} not found, {e}") + return function_object + + def datetime_from_iso(time_str: str) -> Optional[datetime]: if not time_str: return diff --git a/requirements.txt b/requirements.txt index 1a2d5a92afc..07a92bae800 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ ipython>=5.5, <7.17 nuclio-jupyter>=0.8.9 pandas~=1.0 # used as a the engine for parquet files by pandas -pyarrow~=2.0 +pyarrow~=1.0 pyyaml~=5.1 requests~=2.22 sqlalchemy~=1.3 diff --git a/tests/serving/demo_states.py b/tests/serving/demo_states.py new file mode 100644 index 00000000000..fa4616f74c0 --- /dev/null +++ b/tests/serving/demo_states.py @@ -0,0 +1,83 @@ +from copy import copy +from mlrun.serving import V2ModelServer + + +class BaseClass: + def __init__(self, context, name=None): + self.context = context + self.name = name + + +class Echo(BaseClass): + def __init__(self, name=None): + self.name = name + + def do(self, x): + print("Echo:", self.name, x) + return x + + +class RespName(BaseClass): + def __init__(self, **kwargs): + self.name = kwargs.get("name") + + def do(self, x): + print("Echo:", self.name, x) + return [x, self.name] + + +class EchoError(BaseClass): + def do(self, x): + x.body = {"body": x.body, "origin_state": x.origin_state, "error": x.error} + print("EchoError:", x) + return x + + +class Chain(BaseClass): + def do(self, x): + x = copy(x) + x.append(self.name) + return x + + +class ChainWithContext(BaseClass): + def do(self, x): + visits = self.context.visits.get(self.name, 0) + self.context.visits[self.name] = visits + 1 + x = copy(x) + x.append(self.name) + return x + + +class Message(BaseClass): + def __init__(self, msg="", context=None, name=None): + self.msg = msg + + def do(self, x): + print("Messsage:", self.msg) + return x + + +class Raiser: + def __init__(self, msg="", context=None, name=None): + self.context = context + self.name = name + self.msg = msg + + def do(self, x): + raise ValueError(f" this is an error, {x}") + + +def multiply_input(request): + request["inputs"][0] = request["inputs"][0] * 2 + return request + + +class ModelClass(V2ModelServer): + def load(self): + print("loading") + + def predict(self, request): + print("predict:", request) + resp = request["inputs"][0] * self.get_param("multiplier") + return resp diff --git a/tests/serving/test_async_flow.py b/tests/serving/test_async_flow.py new file mode 100644 index 00000000000..ce47e565d19 --- /dev/null +++ b/tests/serving/test_async_flow.py @@ -0,0 +1,87 @@ +import mlrun +from mlrun.utils import logger +from .demo_states import * # noqa +from tests.conftest import results + + +def test_async_basic(): + function = mlrun.new_function("tests", kind="serving") + flow = function.set_topology("flow", engine="async") + queue = flow.to(name="s1", class_name="ChainWithContext").to( + "$queue", "q1", path="" + ) + + s2 = queue.to(name="s2", class_name="ChainWithContext") + s2.to(name="s4", class_name="ChainWithContext") + s2.to( + name="s5", class_name="ChainWithContext" + ).respond() # this state returns the resp + + queue.to(name="s3", class_name="ChainWithContext") + + # plot the graph for test & debug + flow.plot(f"{results}/serving/async.png") + + server = function.to_mock_server() + server.context.visits = {} + logger.info(f"\nAsync Flow:\n{flow.to_yaml()}") + resp = server.test(body=[]) + + server.wait_for_completion() + assert resp == ["s1", "s2", "s5"], "flow result is incorrect" + assert server.context.visits == { + "s1": 1, + "s2": 1, + "s4": 1, + "s3": 1, + "s5": 1, + }, "flow didnt visit expected states" + + +def test_async_nested(): + function = mlrun.new_function("tests", kind="serving") + graph = function.set_topology("flow", engine="async") + graph.add_step(name="s1", class_name="Echo") + graph.add_step(name="s2", handler="multiply_input", after="s1") + graph.add_step(name="s3", class_name="Echo", after="s2") + + router_step = graph.add_step("*", name="ensemble", after="s2") + router_step.add_route("m1", class_name="ModelClass", model_path=".", multiplier=100) + router_step.add_route("m2", class_name="ModelClass", model_path=".", multiplier=200) + router_step.add_route( + "m3:v1", class_name="ModelClass", model_path=".", multiplier=300 + ) + + graph.add_step(name="final", class_name="Echo", after="ensemble").respond() + + logger.info(graph.to_yaml()) + server = function.to_mock_server() + + # plot the graph for test & debug + graph.plot(f"{results}/serving/nested.png") + resp = server.test("/v2/models/m2/infer", body={"inputs": [5]}) + server.wait_for_completion() + # resp should be input (5) * multiply_input (2) * m2 multiplier (200) + assert resp["outputs"] == 5 * 2 * 200, f"wrong health response {resp}" + + +def test_on_error(): + function = mlrun.new_function("tests", kind="serving") + graph = function.set_topology("flow", engine="async") + chain = graph.to("Chain", name="s1") + chain.to("Raiser").error_handler("catch").to("Chain", name="s3") + + graph.add_step( + name="catch", class_name="EchoError", after="" + ).respond().full_event = True + function.verbose = True + server = function.to_mock_server() + logger.info(graph.to_yaml()) + + # plot the graph for test & debug + graph.plot(f"{results}/serving/on_error.png") + resp = server.test(body=[]) + server.wait_for_completion() + assert ( + resp["error"] and resp["origin_state"] == "Raiser" + ), f"error wasnt caught, resp={resp}" diff --git a/tests/serving/test_flow.py b/tests/serving/test_flow.py new file mode 100644 index 00000000000..264d8559465 --- /dev/null +++ b/tests/serving/test_flow.py @@ -0,0 +1,88 @@ +import mlrun +from mlrun.utils import logger +import pytest +from .demo_states import * # noqa + + +engines = [ + "sync", + "async", +] + + +def test_basic_flow(): + fn = mlrun.new_function("tests", kind="serving") + graph = fn.set_topology("flow", engine="sync") + graph.add_step(name="s1", class_name="Chain") + graph.add_step(name="s2", class_name="Chain", after="$prev") + graph.add_step(name="s3", class_name="Chain", after="$prev") + + server = fn.to_mock_server() + # graph.plot("flow.png") + print("\nFlow1:\n", graph.to_yaml()) + resp = server.test(body=[]) + assert resp == ["s1", "s2", "s3"], "flow1 result is incorrect" + + graph = fn.set_topology("flow", exist_ok=True, engine="sync") + graph.add_step(name="s2", class_name="Chain") + graph.add_step( + name="s1", class_name="Chain", before="s2" + ) # should place s1 first and s2 after it + graph.add_step(name="s3", class_name="Chain", after="s2") + + server = fn.to_mock_server() + logger.info(f"flow: {graph.to_yaml()}") + resp = server.test(body=[]) + assert resp == ["s1", "s2", "s3"], "flow2 result is incorrect" + + graph = fn.set_topology("flow", exist_ok=True, engine="sync") + graph.add_step(name="s1", class_name="Chain") + graph.add_step(name="s3", class_name="Chain", after="$prev") + graph.add_step(name="s2", class_name="Chain", after="s1", before="s3") + + server = fn.to_mock_server() + logger.info(f"flow: {graph.to_yaml()}") + resp = server.test(body=[]) + assert resp == ["s1", "s2", "s3"], "flow3 result is incorrect" + + +@pytest.mark.parametrize("engine", engines) +def test_handler(engine): + fn = mlrun.new_function("tests", kind="serving") + graph = fn.set_topology("flow", engine=engine) + graph.to(name="s1", handler="(event + 1)").to(name="s2", handler="json.dumps") + if engine == "async": + graph["s2"].respond() + + server = fn.to_mock_server() + resp = server.test(body=5) + if engine == "async": + server.wait_for_completion() + # the json.dumps converts the 6 to "6" (string) + assert resp == "6", f"got unexpected result {resp}" + + +def test_init_class(): + fn = mlrun.new_function("tests", kind="serving") + graph = fn.set_topology("flow", engine="sync") + graph.to(name="s1", class_name="Echo").to(name="s2", class_name="RespName") + + server = fn.to_mock_server() + resp = server.test(body=5) + assert resp == [5, "s2"], f"got unexpected result {resp}" + + +def test_on_error(): + fn = mlrun.new_function("tests", kind="serving") + graph = fn.set_topology("flow", engine="sync") + graph.add_step(name="s1", class_name="Chain") + graph.add_step(name="raiser", class_name="Raiser", after="$prev").error_handler( + "catch" + ) + graph.add_step(name="s3", class_name="Chain", after="$prev") + graph.add_step(name="catch", class_name="EchoError").full_event = True + + server = fn.to_mock_server() + logger.info(f"flow: {graph.to_yaml()}") + resp = server.test(body=[]) + assert resp["error"] and resp["origin_state"] == "raiser", "error wasnt caught" diff --git a/tests/serving/test_serving.py b/tests/serving/test_serving.py index f27c9c38edc..66c8e756d65 100644 --- a/tests/serving/test_serving.py +++ b/tests/serving/test_serving.py @@ -1,26 +1,28 @@ import json import os import time +import mlrun +from mlrun.utils import logger from mlrun.runtimes import nuclio_init_hook from mlrun.runtimes.serving import serving_subkind from mlrun.serving import V2ModelServer -from mlrun.serving.server import MockEvent, MockContext, create_mock_server -from mlrun.serving.states import ServingRouterState, ServingTaskState +from mlrun.serving.server import MockEvent, GraphContext, create_graph_server +from mlrun.serving.states import RouterState, TaskState -router_object = ServingRouterState() +router_object = RouterState() router_object.routes = { - "m1": ServingTaskState( - "ModelTestingClass", class_args={"model_path": "", "z": 100} + "m1": TaskState( + "ModelTestingClass", class_args={"model_path": "", "multiplier": 100} ), - "m2": ServingTaskState( - "ModelTestingClass", class_args={"model_path": "", "z": 200} + "m2": TaskState( + "ModelTestingClass", class_args={"model_path": "", "multiplier": 200} ), - "m3:v1": ServingTaskState( - "ModelTestingClass", class_args={"model_path": "", "z": 300} + "m3:v1": TaskState( + "ModelTestingClass", class_args={"model_path": "", "multiplier": 300} ), - "m3:v2": ServingTaskState( - "ModelTestingClass", class_args={"model_path": "", "z": 400} + "m3:v2": TaskState( + "ModelTestingClass", class_args={"model_path": "", "multiplier": 400} ), } @@ -59,7 +61,7 @@ def load(self): def predict(self, request): print("predict:", request) - resp = request["inputs"][0] * self.get_param("z") + resp = request["inputs"][0] * self.get_param("multiplier") return resp def explain(self, request): @@ -85,7 +87,7 @@ def predict(self, request): def init_ctx(): os.environ["SERVING_SPEC_ENV"] = json.dumps(spec) - context = MockContext() + context = GraphContext() nuclio_init_hook(context, globals(), serving_subkind) return context @@ -127,7 +129,7 @@ def test_v2_stream_mode(): '{"model": "m3:v2", "operation": "explain", "inputs": [5]}', path="" ) resp = context.mlrun_handler(context, event) - print(resp.body) + logger.info(f"resp: {resp.body}") data = json.loads(resp.body) assert data["outputs"]["explained"] == 5, f"wrong model response {data}" @@ -135,7 +137,7 @@ def test_v2_stream_mode(): def test_v2_async_mode(): # model loading is async os.environ["SERVING_SPEC_ENV"] = json.dumps(asyncspec) - context = MockContext() + context = GraphContext() nuclio_init_hook(context, globals(), serving_subkind) context.logger.info("model initialized") @@ -149,7 +151,7 @@ def test_v2_async_mode(): event = MockEvent(testdata, path="/v2/models/m5/infer") resp = context.mlrun_handler(context, event) context.logger.info("model responded") - print(resp) + logger.info(resp) assert ( resp.status_code != 200 ), f"expected failure, got {resp.status_code} {resp.body}" @@ -158,7 +160,7 @@ def test_v2_async_mode(): event.trigger = "stream" resp = context.mlrun_handler(context, event) context.logger.info("model responded") - print(resp) + logger.info(resp) data = json.loads(resp.body) assert data["outputs"] == 5, f"wrong model response {data}" @@ -175,7 +177,7 @@ def test_v2_get_modelmeta(): def get_model(name, version, url): event = MockEvent("", path=f"/v2/models/{url}", method="GET") resp = context.mlrun_handler(context, event) - print(resp) + logger.info(f"resp: {resp}") data = json.loads(resp.body) # expected: {"name": "m3", "version": "v2", "inputs": [], "outputs": []} @@ -231,7 +233,25 @@ def test_v2_health(): def test_v2_mock(): - host = create_mock_server() - host.add_model("my", class_name=ModelTestingClass, model_path="", z=100) - print(host.test("my/infer", testdata)) - print(host.to_yaml()) + host = create_graph_server(graph=RouterState()) + host.graph.add_route( + "my", class_name=ModelTestingClass, model_path="", multiplier=100 + ) + host.init(None, globals()) + logger.info(host.to_yaml()) + resp = host.test("/v2/models/my/infer", testdata) + logger.info(f"resp: {resp}") + # expected: source (5) * multiplier (100) + assert resp["outputs"] == 5 * 100, f"wrong health response {resp}" + + +def test_function(): + fn = mlrun.new_function("tests", kind="serving") + graph = fn.set_topology("router") + fn.add_model("my", class_name="ModelTestingClass", model_path=".", multiplier=100) + + server = fn.to_mock_server() + logger.info(f"flow: {graph.to_yaml()}") + resp = server.test("/v2/models/my/infer", testdata) + # expected: source (5) * multiplier (100) + assert resp["outputs"] == 5 * 100, f"wrong health response {resp}"