Skip to content

Commit

Permalink
[Serving] Add new set_flow() API (#5437)
Browse files Browse the repository at this point in the history
* [Serving] Add new `set_flow()` API

[ML-6220](https://iguazio.atlassian.net/browse/ML-6220)

* Add to docstring

* Remove redundant comma

* Improve docstring

* Minor change to docstring

* Refactoring
  • Loading branch information
gtopper committed Apr 30, 2024
1 parent 75fcc2d commit 03aecff
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 2 deletions.
55 changes: 53 additions & 2 deletions mlrun/serving/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import traceback
from copy import copy, deepcopy
from inspect import getfullargspec, signature
from typing import Union
from typing import Any, Union

import mlrun

Expand Down Expand Up @@ -327,7 +327,7 @@ def to(
parent = self._parent
else:
raise GraphError(
f"step {self.name} parent is not set or its not part of a graph"
f"step {self.name} parent is not set or it's not part of a graph"
)

name, step = params_to_step(
Expand All @@ -349,6 +349,36 @@ def to(
parent._last_added = step
return step

def set_flow(
self,
steps: list[Union[str, StepToDict, dict[str, Any]]],
force: bool = False,
):
"""set list of steps as downstream from this step, in the order specified. This will overwrite any existing
downstream steps.
:param steps: list of steps to follow this one
:param force: whether to overwrite existing downstream steps. If False, this method will fail if any downstream
steps have already been defined. Defaults to False.
:return: the last step added to the flow
example:
The below code sets the downstream nodes of step1 by using a list of steps (provided to `set_flow()`) and a
single step (provided to `to()`), resulting in the graph (step1 -> step2 -> step3 -> step4).
Notice that using `force=True` is required in case step1 already had downstream nodes (e.g. if the existing
graph is step1 -> step2_old) and that following the execution of this code the existing downstream steps
are removed. If the intention is to split the graph (and not to overwrite), please use `to()`.
step1.set_flow(
[
dict(name="step2", handler="step2_handler"),
dict(name="step3", class_name="Step3Class"),
],
force=True,
).to(dict(name="step4", class_name="Step4Class"))
"""
raise NotImplementedError("set_flow() can only be called on a FlowStep")


class TaskStep(BaseStep):
"""task execution step, runs a class or handler"""
Expand Down Expand Up @@ -1258,6 +1288,27 @@ def _insert_error_step(self, name, step):
)
self[step_name].after_step(name)

def set_flow(
self,
steps: list[Union[str, StepToDict, dict[str, Any]]],
force: bool = False,
):
if not force and self.steps:
raise mlrun.errors.MLRunInvalidArgumentError(
"set_flow() called on a step that already has downstream steps. "
"If you want to overwrite existing steps, set force=True."
)

self.steps = None
step = self
for next_step in steps:
if isinstance(next_step, dict):
step = step.to(**next_step)
else:
step = step.to(next_step)

return step


class RootFlowStep(FlowStep):
"""root flow step"""
Expand Down
33 changes: 33 additions & 0 deletions tests/serving/test_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,36 @@ def test_add_aggregate_as_insert():

assert graph_2["s2"].after == ["Aggregates"]
assert graph_2["Aggregates"].after == ["s1"]


def test_set_flow_error():
fn = mlrun.new_function("tests", kind="serving")
graph = fn.set_topology("flow", engine="sync")
s1 = dict(name="s1", handler="(event + 1)")
s2 = dict(name="s2", handler="json.dumps")
graph.to(**s1).to(**s2)

r1 = dict(name="r1", handler="(event + 10)")
r2 = dict(name="r2", handler="json.dumps")
with pytest.raises(
mlrun.errors.MLRunInvalidArgumentError,
match=r"set_flow\(\) called on a step that already has downstream steps. "
"If you want to overwrite existing steps, set force=True.",
):
graph.set_flow(steps=[r1, r2])


def test_set_flow():
fn = mlrun.new_function("tests", kind="serving")
graph = fn.set_topology("flow", engine="sync")
s1 = dict(name="s1", handler="(event + 1)")
s2 = dict(name="s2", handler="json.dumps")
graph.to(**s1).to(**s2)

r1 = dict(name="r1", handler="(event + 10)")
r2 = dict(name="r2", handler="json.dumps")
graph.set_flow(steps=[r1, r2], force=True)

server = fn.to_mock_server()
resp = server.test(body=5)
assert resp == "15"

0 comments on commit 03aecff

Please sign in to comment.