Skip to content

Commit

Permalink
[Serving] Handle content-type header + add default_content_type attri…
Browse files Browse the repository at this point in the history
…bute (#1108)

(cherry picked from commit d4db7b3)
  • Loading branch information
yaronha authored and Hedingber committed Jul 12, 2021
1 parent a58dc9d commit c4ce9f9
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
4 changes: 4 additions & 0 deletions mlrun/runtimes/serving.py
Expand Up @@ -106,6 +106,7 @@ def __init__(
error_stream=None,
track_models=None,
secret_sources=None,
default_content_type=None,
):

super().__init__(
Expand Down Expand Up @@ -145,6 +146,7 @@ def __init__(
self.error_stream = error_stream
self.track_models = track_models
self.secret_sources = secret_sources or []
self.default_content_type = default_content_type

@property
def graph(self) -> Union[RouterStep, RootFlowStep]:
Expand Down Expand Up @@ -463,6 +465,7 @@ def _get_runtime_env(self):
"graph_initializer": self.spec.graph_initializer,
"error_stream": self.spec.error_stream,
"track_models": self.spec.track_models,
"default_content_type": self.spec.default_content_type,
}

if self.spec.secret_sources:
Expand Down Expand Up @@ -492,6 +495,7 @@ def to_mock_server(
track_models=self.spec.track_models,
function_uri=self._function_uri(),
secret_sources=self.spec.secret_sources,
default_content_type=self.spec.default_content_type,
**kwargs,
)
server.init_states(
Expand Down
10 changes: 10 additions & 0 deletions mlrun/serving/server.py
Expand Up @@ -81,6 +81,7 @@ def __init__(
error_stream=None,
track_models=None,
secret_sources=None,
default_content_type=None,
):
self._graph = None
self.graph: Union[RouterStep, RootFlowStep] = graph
Expand All @@ -100,6 +101,7 @@ def __init__(
self._secrets = SecretsStore.from_list(secret_sources)
self._db_conn = None
self.resource_cache = None
self.default_content_type = default_content_type

def set_current_function(self, function):
"""set which child function this server is currently running on"""
Expand Down Expand Up @@ -213,6 +215,14 @@ def run(self, event, context=None, get_body=False, extra_args=None):
server_context = self.context
context = context or server_context
try:
if not event.content_type and self.default_content_type:
event.content_type = self.default_content_type
if (
isinstance(event.body, (str, bytes))
and event.content_type
and event.content_type in ["json", "application/json"]
):
event.body = json.loads(event.body)
response = self.graph.run(event, **(extra_args or {}))
except Exception as exc:
message = str(exc)
Expand Down
6 changes: 5 additions & 1 deletion mlrun/serving/states.py
Expand Up @@ -1157,7 +1157,11 @@ def do_event(self, event):
raise RuntimeError(f"bad function response {resp.text}")

data = resp.content
if self.format == "json" or resp.headers["content-type"] == "application/json":
if (
self.format == "json"
or resp.headers["content-type"] == "application/json"
and isinstance(data, (str, bytes))
):
data = json.loads(data)
event.body = data
return event
Expand Down
31 changes: 31 additions & 0 deletions tests/serving/test_flow.py
Expand Up @@ -109,3 +109,34 @@ def test_on_error():
logger.info(f"flow: {graph.to_yaml()}")
resp = server.test(body=[])
assert resp["error"] and resp["origin_state"] == "raiser", "error wasnt caught"


def return_type(event):
return event.__class__.__name__


def test_content_type():
fn = mlrun.new_function("tests", kind="serving")
graph = fn.set_topology("flow", engine="sync")
graph.to(name="totype", handler=return_type)
server = fn.to_mock_server()

# test that we json.load() when the content type is json
resp = server.test(body={"a": 1})
assert resp == "dict", "invalid type"
resp = server.test(body="[1,2]")
assert resp == "str", "invalid type"
resp = server.test(body={"a": 1}, content_type="application/json")
assert resp == "dict", "invalid type"
resp = server.test(body="[1,2]", content_type="application/json")
assert resp == "list", "did not load json"

# test the use of default content type
fn = mlrun.new_function("tests", kind="serving")
fn.spec.default_content_type = "application/json"
graph = fn.set_topology("flow", engine="sync")
graph.to(name="totype", handler=return_type)

server = fn.to_mock_server()
resp = server.test(body="[1,2]")
assert resp == "list", "did not load json"

0 comments on commit c4ce9f9

Please sign in to comment.