Skip to content

Commit

Permalink
[Serving] Fix input/result path handling (#1491)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaronha committed Nov 14, 2021
1 parent c946ca2 commit 4e93115
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
4 changes: 2 additions & 2 deletions mlrun/serving/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def to_dict(self, fields=None, exclude=None):
}
if hasattr(self, "_STEP_KIND"):
struct["kind"] = self._STEP_KIND
if hasattr(self, "_input_path"):
if hasattr(self, "_input_path") and self._input_path is not None:
struct["input_path"] = self._input_path
if hasattr(self, "_result_path"):
if hasattr(self, "_result_path") and self._result_path is not None:
struct["result_path"] = self._result_path
if hasattr(self, "_full_event") and self._full_event:
struct["full_event"] = self._full_event
Expand Down
50 changes: 34 additions & 16 deletions mlrun/serving/v2_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from mlrun.utils import logger, now_date, parse_versioned_object_uri
from mlrun.utils.model_monitoring import EndpointType

from .utils import StepToDict
from .utils import StepToDict, _extract_input_data, _update_result_body


class V2ModelServer(StepToDict):
Expand All @@ -41,6 +41,8 @@ def __init__(
model_path: str = None,
model=None,
protocol=None,
input_path: str = None,
result_path: str = None,
**kwargs,
):
"""base model serving class (v2), using similar API to KFServing v2 and Triton
Expand Down Expand Up @@ -86,6 +88,13 @@ def predict(self, request):
:param model_path: model file/dir or artifact path
:param model: model object (for local testing)
:param protocol: serving API protocol (default "v2")
:param input_path: when specified selects the key/path in the event to use as body
this require that the event body will behave like a dict, example:
event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means request body will be 7
:param result_path: selects the key/path in the event to write the results to
this require that the event body will behave like a dict, example:
event: {"x": 5} , result_path="resp" means the returned response will be written
to event["y"] resulting in {"x": 5, "resp": <result>}
:param kwargs: extra arguments (can be accessed using self.get_param(key))
"""
self.name = name
Expand All @@ -98,6 +107,8 @@ def predict(self, request):
self.protocol = protocol or "v2"
self.model_path = model_path
self.model_spec: mlrun.artifacts.ModelArtifact = None
self._input_path = input_path
self._result_path = result_path
self._kwargs = kwargs # for to_dict()
self._params = kwargs
self._model_logger = (
Expand Down Expand Up @@ -208,34 +219,37 @@ def _check_readiness(self, event):
return
raise RuntimeError(f"model {self.name} is not ready {self.error}")

def _pre_event_processing_actions(self, event, op):
def _pre_event_processing_actions(self, event, event_body, op):
self._check_readiness(event)
request = self.preprocess(event.body, op)
if "id" not in request:
request["id"] = event.id
request = self.preprocess(event_body, op)
return self.validate(request, op)

def do_event(self, event, *args, **kwargs):
"""main model event handler method"""
start = now_date()
original_body = event.body
event_body = _extract_input_data(self._input_path, event.body)
event_id = event.id
op = event.path.strip("/")
if not op and event.body and isinstance(event.body, dict):
op = event.body.get("operation")
if event_body and isinstance(event_body, dict):
op = op or event_body.get("operation")
event_id = event_body.get("id", event_id)
if not op and event.method != "GET":
op = "infer"

if op == "predict" or op == "infer":
# predict operation
request = self._pre_event_processing_actions(event, op)
request = self._pre_event_processing_actions(event, event_body, op)
try:
outputs = self.predict(request)
except Exception as exc:
request["id"] = event_id
if self._model_logger:
self._model_logger.push(start, request, op=op, error=exc)
raise exc

response = {
"id": request["id"],
"id": event_id,
"model_name": self.name,
"outputs": outputs,
}
Expand All @@ -256,29 +270,33 @@ def do_event(self, event, *args, **kwargs):
elif op == "" and event.method == "GET":
# get model metadata operation
setattr(event, "terminated", True)
event.body = {
event_body = {
"name": self.name,
"version": self.version,
"inputs": [],
"outputs": [],
}
if self.model_spec:
event.body["inputs"] = self.model_spec.inputs
event.body["outputs"] = self.model_spec.outputs
event_body["inputs"] = self.model_spec.inputs
event_body["outputs"] = self.model_spec.outputs
event.body = _update_result_body(
self._result_path, original_body, event_body
)
return event

elif op == "explain":
# explain operation
request = self._pre_event_processing_actions(event, op)
request = self._pre_event_processing_actions(event, event_body, op)
try:
outputs = self.explain(request)
except Exception as exc:
request["id"] = event_id
if self._model_logger:
self._model_logger.push(start, request, op=op, error=exc)
raise exc

response = {
"id": request["id"],
"id": event_id,
"model_name": self.name,
"outputs": outputs,
}
Expand All @@ -288,7 +306,7 @@ def do_event(self, event, *args, **kwargs):
elif hasattr(self, "op_" + op):
# custom operation (child methods starting with "op_")
response = getattr(self, "op_" + op)(event)
event.body = response
event.body = _update_result_body(self._result_path, original_body, response)
return event

else:
Expand All @@ -297,7 +315,7 @@ def do_event(self, event, *args, **kwargs):
response = self.postprocess(response)
if self._model_logger:
self._model_logger.push(start, request, response, op)
event.body = response
event.body = _update_result_body(self._result_path, original_body, response)
return event

def validate(self, request, operation):
Expand Down
22 changes: 22 additions & 0 deletions tests/serving/test_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,25 @@ def test_serving_no_router():
resp = server.test("/", testdata)
# expected: source (5) * multiplier (100)
assert resp["outputs"] == 5 * 100, f"wrong data response {resp}"


def test_model_chained():
fn = mlrun.new_function("demo", kind="serving")
graph = fn.set_topology("flow", engine="async")
graph.to(
ModelTestingClass(name="m1", model_path=".", multiplier=2),
result_path="m1",
input_path="req",
).to(
ModelTestingClass(
name="m2", model_path=".", result_path="m2", multiplier=3, input_path="req"
)
).respond()
server = fn.to_mock_server()

resp = server.test(body={"req": {"inputs": [5]}})
server.wait_for_completion()
assert list(resp.keys()) == ["req", "m1", "m2"], "unexpected keys in resp"
assert (
resp["m1"]["outputs"] == 5 * 2 and resp["m2"]["outputs"] == 5 * 3
), "unexpected model results"

0 comments on commit 4e93115

Please sign in to comment.