Skip to content

Commit

Permalink
[Serving] Fix async input_path bug in RemoteStep (#1355)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaronha committed Sep 24, 2021
1 parent c4f6b74 commit d07e57f
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 26 deletions.
61 changes: 35 additions & 26 deletions mlrun/serving/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

import mlrun

from .utils import _extract_input_data, _update_result_body

http_adapter = HTTPAdapter(
Expand All @@ -25,6 +23,7 @@ def __init__(
method: str = None,
headers: dict = None,
url_expression: str = None,
body_expression: str = None,
return_json: bool = True,
input_path: str = None,
result_path: str = None,
Expand All @@ -49,6 +48,7 @@ def __init__(
:param method: HTTP method (GET, POST, ..), default to POST
:param headers: dictionary with http header values
:param url_expression: an expression for getting the url from the event, e.g. "event['url']"
:param body_expression: an expression for getting the request body from the event, e.g. "event['data']"
:param return_json: indicate the returned value is json, and convert it to a py object
:param input_path: when specified selects the key/path in the event to use as body
this require that the event body will behave like a dict, example:
Expand All @@ -58,16 +58,13 @@ def __init__(
event: {"x": 5} , result_path="resp" means the returned response will be written
to event["y"] resulting in {"x": 5, "resp": <result>}
"""
if url and url_expression:
raise mlrun.errors.MLRunInvalidArgumentError(
"cannot set both url and url_expression"
)
self.url = url
self.url_expression = url_expression
self.body_expression = body_expression
self.headers = headers
self.method = method
self.return_json = return_json
self.subpath = subpath or ""
self.subpath = subpath
super().__init__(
None, None, input_path=input_path, result_path=result_path, **kwargs
)
Expand All @@ -76,22 +73,32 @@ def __init__(
self._endpoint = ""
self._session = None
self._url_function_handler = None
self._body_function_handler = None
self._full_event = False

def post_init(self, mode="sync"):
self._endpoint = self.url
if self.url and self.context:
self._endpoint = self.context.get_remote_endpoint(self.url).strip("/")
if self.body_expression:
# init lambda function for calculating url from event
self._body_function_handler = eval(
"lambda event: " + self.body_expression, {}, {}
)
if self.url_expression:
# init lambda function for calculating url from event
self._url_function_handler = eval(
"lambda event: " + self.url_expression, {}, {}
"lambda event: " + self.url_expression, {"endpoint": self._endpoint}, {}
)
else:
elif self.subpath:
self._append_event_path = self.subpath == "$path"
self._endpoint = self.context.get_remote_endpoint(self.url).strip("/")
if self.subpath and not self._append_event_path:
if not self._append_event_path:
self._endpoint = self._endpoint + "/" + self.subpath.lstrip("/")

async def _process_event(self, event):
# async implementation (with storey)
method, url, headers, body = self._generate_request(event)
body = self._get_event_or_body(event)
method, url, headers, body = self._generate_request(event, body)
return await self._client_session.request(
method, url, headers=headers, data=body, ssl=False
)
Expand All @@ -111,9 +118,8 @@ def do_event(self, event):
self._session.mount("http://", http_adapter)
self._session.mount("https://", http_adapter)

original_body = event.body
event.body = _extract_input_data(self._input_path, event.body)
method, url, headers, body = self._generate_request(event)
body = _extract_input_data(self._input_path, event.body)
method, url, headers, body = self._generate_request(event, body)
try:
resp = self._session.request(
method, url, verify=False, headers=headers, data=body
Expand All @@ -124,27 +130,28 @@ def do_event(self, event):
raise RuntimeError(f"bad http response {resp.text}")

result = self._get_data(resp.content, resp.headers)
event.body = _update_result_body(self._result_path, original_body, result)
event.body = _update_result_body(self._result_path, event.body, result)
return event

def _generate_request(self, event):
def _generate_request(self, event, body):
method = self.method or event.method or "POST"
headers = self.headers or event.headers or {}

body = None
if method != "GET" and event.body is not None:
if isinstance(event.body, (str, bytes)):
body = event.body
else:
body = json.dumps(event.body)
headers["Content-Type"] = "application/json"

if self._url_function_handler:
url = self._url_function_handler(event.body)
url = self._url_function_handler(body)
else:
url = self._endpoint
if self._append_event_path:
url = url + "/" + event.path.lstrip("/")

if method == "GET":
body = None
elif body is not None and not isinstance(body, (str, bytes)):
if self._body_function_handler:
body = self._body_function_handler(body)
body = json.dumps(body)
headers["Content-Type"] = "application/json"

return method, url, headers, body

def _get_data(self, data, headers):
Expand All @@ -165,7 +172,9 @@ def to_dict(self):
"headers",
"return_json",
"url_expression",
"body_expression",
]
if getattr(self, key) is not None
}
return {
"class_name": f"{__name__}.{self.__class__.__name__}",
Expand Down
28 changes: 28 additions & 0 deletions tests/serving/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,31 @@ def test_remote_class(httpserver, engine):
resp = server.test(body={"req": {"x": 5}})
server.wait_for_completion()
assert resp == {"req": {"x": 5}, "resp": {"cat": "ok"}}


@pytest.mark.parametrize("engine", ["sync", "async"])
def test_remote_advance(httpserver, engine):
from mlrun.serving.remote import RemoteStep

httpserver.expect_request("/dog", method="POST", json={"x": 5}).respond_with_json(
{"post": "ok"}
)

function = mlrun.new_function("test2", kind="serving")
flow = function.set_topology("flow", engine=engine)
flow.to(name="s1", handler="echo").to(
RemoteStep(
name="remote_echo",
url=httpserver.url_for("/"),
url_expression="endpoint + event['url']",
body_expression="event['data']",
input_path="req",
result_path="resp",
)
).to(name="s3", handler="echo").respond()

server = function.to_mock_server()
resp = server.test(body={"req": {"url": "/dog", "data": {"x": 5}}})
server.wait_for_completion()
print(resp)
assert resp == {"req": {"url": "/dog", "data": {"x": 5}}, "resp": {"post": "ok"}}

0 comments on commit d07e57f

Please sign in to comment.