diff --git a/mlrun/runtimes/serving.py b/mlrun/runtimes/serving.py index cf6ca2f02a0..2b71d6df2e1 100644 --- a/mlrun/runtimes/serving.py +++ b/mlrun/runtimes/serving.py @@ -106,6 +106,7 @@ def __init__( error_stream=None, track_models=None, secret_sources=None, + default_content_type=None, ): super().__init__( @@ -145,6 +146,7 @@ def __init__( self.error_stream = error_stream self.track_models = track_models self.secret_sources = secret_sources or [] + self.default_content_type = default_content_type @property def graph(self) -> Union[RouterStep, RootFlowStep]: @@ -463,6 +465,7 @@ def _get_runtime_env(self): "graph_initializer": self.spec.graph_initializer, "error_stream": self.spec.error_stream, "track_models": self.spec.track_models, + "default_content_type": self.spec.default_content_type, } if self.spec.secret_sources: @@ -492,6 +495,7 @@ def to_mock_server( track_models=self.spec.track_models, function_uri=self._function_uri(), secret_sources=self.spec.secret_sources, + default_content_type=self.spec.default_content_type, **kwargs, ) server.init_states( diff --git a/mlrun/serving/server.py b/mlrun/serving/server.py index 5e087a787f6..3b94acc1a05 100644 --- a/mlrun/serving/server.py +++ b/mlrun/serving/server.py @@ -81,6 +81,7 @@ def __init__( error_stream=None, track_models=None, secret_sources=None, + default_content_type=None, ): self._graph = None self.graph: Union[RouterStep, RootFlowStep] = graph @@ -100,6 +101,7 @@ def __init__( self._secrets = SecretsStore.from_list(secret_sources) self._db_conn = None self.resource_cache = None + self.default_content_type = default_content_type def set_current_function(self, function): """set which child function this server is currently running on""" @@ -213,6 +215,14 @@ def run(self, event, context=None, get_body=False, extra_args=None): server_context = self.context context = context or server_context try: + if not event.content_type and self.default_content_type: + event.content_type = self.default_content_type + if ( + isinstance(event.body, (str, bytes)) + and event.content_type + and event.content_type in ["json", "application/json"] + ): + event.body = json.loads(event.body) response = self.graph.run(event, **(extra_args or {})) except Exception as exc: message = str(exc) diff --git a/mlrun/serving/states.py b/mlrun/serving/states.py index c256152c148..fe005ba34fd 100644 --- a/mlrun/serving/states.py +++ b/mlrun/serving/states.py @@ -1157,7 +1157,11 @@ def do_event(self, event): raise RuntimeError(f"bad function response {resp.text}") data = resp.content - if self.format == "json" or resp.headers["content-type"] == "application/json": + if ( + self.format == "json" + or resp.headers["content-type"] == "application/json" + and isinstance(data, (str, bytes)) + ): data = json.loads(data) event.body = data return event diff --git a/tests/serving/test_flow.py b/tests/serving/test_flow.py index 11d59357720..80958a26451 100644 --- a/tests/serving/test_flow.py +++ b/tests/serving/test_flow.py @@ -109,3 +109,34 @@ def test_on_error(): logger.info(f"flow: {graph.to_yaml()}") resp = server.test(body=[]) assert resp["error"] and resp["origin_state"] == "raiser", "error wasnt caught" + + +def return_type(event): + return event.__class__.__name__ + + +def test_content_type(): + fn = mlrun.new_function("tests", kind="serving") + graph = fn.set_topology("flow", engine="sync") + graph.to(name="totype", handler=return_type) + server = fn.to_mock_server() + + # test that we json.load() when the content type is json + resp = server.test(body={"a": 1}) + assert resp == "dict", "invalid type" + resp = server.test(body="[1,2]") + assert resp == "str", "invalid type" + resp = server.test(body={"a": 1}, content_type="application/json") + assert resp == "dict", "invalid type" + resp = server.test(body="[1,2]", content_type="application/json") + assert resp == "list", "did not load json" + + # test the use of default content type + fn = mlrun.new_function("tests", kind="serving") + fn.spec.default_content_type = "application/json" + graph = fn.set_topology("flow", engine="sync") + graph.to(name="totype", handler=return_type) + + server = fn.to_mock_server() + resp = server.test(body="[1,2]") + assert resp == "list", "did not load json"