From 11e54a82a3f372a06f8559c97c0a011caeecd06b Mon Sep 17 00:00:00 2001 From: Adam Date: Thu, 30 May 2024 15:41:15 +0300 Subject: [PATCH] [API] Support Object Formatting on List Requests & Specifically Minimal Function Format (#5659) * function format * function format * function format * fix defaults * split formatters * big oops * less diff --------- Co-authored-by: quaark --- mlrun/common/formatters/__init__.py | 16 ++++++++ mlrun/common/formatters/base.py | 59 +++++++++++++++++++++++++++ mlrun/common/formatters/function.py | 41 +++++++++++++++++++ server/api/api/endpoints/functions.py | 5 +++ server/api/crud/functions.py | 5 ++- server/api/db/base.py | 11 ++++- server/api/db/sqldb/db.py | 36 +++++++++++++--- tests/api/db/test_functions.py | 28 +++++++++++++ 8 files changed, 193 insertions(+), 8 deletions(-) create mode 100644 mlrun/common/formatters/__init__.py create mode 100644 mlrun/common/formatters/base.py create mode 100644 mlrun/common/formatters/function.py diff --git a/mlrun/common/formatters/__init__.py b/mlrun/common/formatters/__init__.py new file mode 100644 index 00000000000..69f1fa0af5b --- /dev/null +++ b/mlrun/common/formatters/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .function import FunctionFormat # noqa diff --git a/mlrun/common/formatters/base.py b/mlrun/common/formatters/base.py new file mode 100644 index 00000000000..d8d00d40353 --- /dev/null +++ b/mlrun/common/formatters/base.py @@ -0,0 +1,59 @@ +# Copyright 2024 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing + + +class ObjectFormat: + full = "full" + + @staticmethod + def format_method(_format: str) -> typing.Optional[typing.Callable]: + return { + ObjectFormat.full: None, + }[_format] + + @classmethod + def format_obj(cls, obj: typing.Any, _format: str) -> typing.Any: + _format = _format or cls.full + format_method = cls.format_method(_format) + if not format_method: + return obj + + return format_method(obj) + + @staticmethod + def filter_obj_method(_filter: list[list[str]]) -> typing.Callable: + def _filter_method(obj: dict) -> dict: + formatted_obj = {} + for key_list in _filter: + obj_recursive_iterator = obj + formatted_obj_recursive_iterator = formatted_obj + for idx, key in enumerate(key_list): + if key not in obj_recursive_iterator: + break + value = ( + {} if idx < len(key_list) - 1 else obj_recursive_iterator[key] + ) + formatted_obj_recursive_iterator.setdefault(key, value) + + obj_recursive_iterator = obj_recursive_iterator[key] + formatted_obj_recursive_iterator = formatted_obj_recursive_iterator[ + key + ] + + return formatted_obj + + return _filter_method diff --git a/mlrun/common/formatters/function.py b/mlrun/common/formatters/function.py new file mode 100644 index 00000000000..12bdeecf04e --- /dev/null +++ b/mlrun/common/formatters/function.py @@ -0,0 +1,41 @@ +# Copyright 2024 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing + +import mlrun.common.types + +from .base import ObjectFormat + + +class FunctionFormat(ObjectFormat, mlrun.common.types.StrEnum): + minimal = "minimal" + + @staticmethod + def format_method(_format: str) -> typing.Optional[typing.Callable]: + return { + FunctionFormat.full: None, + FunctionFormat.minimal: FunctionFormat.filter_obj_method( + [ + ["kind"], + ["metadata"], + ["status"], + ["spec", "description"], + ["spec", "image"], + ["spec", "default_handler"], + ["spec", "entry_points"], + ] + ), + }[_format] diff --git a/server/api/api/endpoints/functions.py b/server/api/api/endpoints/functions.py index b6431d7d7a9..1150171764b 100644 --- a/server/api/api/endpoints/functions.py +++ b/server/api/api/endpoints/functions.py @@ -30,6 +30,7 @@ from kubernetes.client.rest import ApiException from sqlalchemy.orm import Session +import mlrun.common.formatters import mlrun.common.model_monitoring import mlrun.common.model_monitoring.helpers import mlrun.common.schemas @@ -113,6 +114,7 @@ async def get_function( name: str, tag: str = "", hash_key="", + _format: str = Query(mlrun.common.formatters.FunctionFormat.full, alias="format"), auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): @@ -123,6 +125,7 @@ async def get_function( project, tag, hash_key, + _format, ) await server.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( mlrun.common.schemas.AuthorizationResourceTypes.function, @@ -206,6 +209,7 @@ async def list_functions( page: int = Query(None, gt=0), page_size: int = Query(None, alias="page-size", gt=0), page_token: str = Query(None, alias="page-token"), + _format: str = Query(mlrun.common.formatters.FunctionFormat.full, alias="format"), auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): @@ -245,6 +249,7 @@ async def _filter_functions_by_permissions(_functions): tag=tag, labels=labels, hash_key=hash_key, + _format=_format, ) return { diff --git a/server/api/crud/functions.py b/server/api/crud/functions.py index 8898933fecb..decfe3ddbe7 100644 --- a/server/api/crud/functions.py +++ b/server/api/crud/functions.py @@ -67,10 +67,11 @@ def get_function( project: str = mlrun.mlconf.default_project, tag: str = "", hash_key: str = "", + _format: str = None, ) -> dict: project = project or mlrun.mlconf.default_project return server.api.utils.singletons.db.get_db().get_function( - db_session, name, project, tag, hash_key + db_session, name, project, tag, hash_key, _format ) def delete_function( @@ -93,6 +94,7 @@ def list_functions( hash_key: str = None, page: int = None, page_size: int = None, + _format: str = None, ) -> list: project = project or mlrun.mlconf.default_project if labels is None: @@ -104,6 +106,7 @@ def list_functions( tag=tag, labels=labels, hash_key=hash_key, + _format=_format, page=page, page_size=page_size, ) diff --git a/server/api/db/base.py b/server/api/db/base.py index e7d97d7e79d..b0c7e0f6ea5 100644 --- a/server/api/db/base.py +++ b/server/api/db/base.py @@ -290,7 +290,15 @@ def store_function( pass @abstractmethod - def get_function(self, session, name, project="", tag="", hash_key=""): + def get_function( + self, + session, + name: str = None, + project: str = None, + tag: str = None, + hash_key: str = None, + _format: str = None, + ): pass @abstractmethod @@ -306,6 +314,7 @@ def list_functions( tag: str = None, labels: list[str] = None, hash_key: str = None, + _format: str = None, page: int = None, page_size: int = None, ): diff --git a/server/api/db/sqldb/db.py b/server/api/db/sqldb/db.py index af584301d69..00f4ed8aec5 100644 --- a/server/api/db/sqldb/db.py +++ b/server/api/db/sqldb/db.py @@ -33,6 +33,7 @@ import mlrun import mlrun.common.constants as mlrun_constants +import mlrun.common.formatters import mlrun.common.runtimes.constants import mlrun.common.schemas import mlrun.errors @@ -1632,6 +1633,7 @@ def list_functions( tag: typing.Optional[str] = None, labels: list[str] = None, hash_key: typing.Optional[str] = None, + _format: str = mlrun.common.formatters.FunctionFormat.full, page: typing.Optional[int] = None, page_size: typing.Optional[int] = None, ) -> list[dict]: @@ -1660,10 +1662,22 @@ def list_functions( else: function_dict["metadata"]["tag"] = function_tag - functions.append(function_dict) + functions.append( + mlrun.common.formatters.FunctionFormat.format_obj( + function_dict, _format + ) + ) return functions - def get_function(self, session, name, project="", tag="", hash_key="") -> dict: + def get_function( + self, + session, + name: str = None, + project: str = None, + tag: str = None, + hash_key: str = None, + _format: str = None, + ) -> dict: """ In version 1.4.0 we added a normalization to the function name before storing. To be backwards compatible and allow users to query old non-normalized functions, @@ -1675,7 +1689,7 @@ def get_function(self, session, name, project="", tag="", hash_key="") -> dict: normalized_function_name = mlrun.utils.normalize_name(name) try: return self._get_function( - session, normalized_function_name, project, tag, hash_key + session, normalized_function_name, project, tag, hash_key, _format ) except mlrun.errors.MLRunNotFoundError as exc: if "_" in name: @@ -1683,7 +1697,9 @@ def get_function(self, session, name, project="", tag="", hash_key="") -> dict: "Failed to get underscore-named function, trying without normalization", function_name=name, ) - return self._get_function(session, name, project, tag, hash_key) + return self._get_function( + session, name, project, tag, hash_key, _format + ) else: raise exc @@ -1722,7 +1738,15 @@ def update_function( self._upsert(session, [function]) return function.struct - def _get_function(self, session, name, project="", tag="", hash_key=""): + def _get_function( + self, + session, + name: str = None, + project: str = None, + tag: str = None, + hash_key: str = None, + _format: str = mlrun.common.formatters.FunctionFormat.full, + ): project = project or config.default_project query = self._query(session, Function, name=name, project=project) computed_tag = tag or "latest" @@ -1747,7 +1771,7 @@ def _get_function(self, session, name, project="", tag="", hash_key=""): # If connected to a tag add it to metadata if tag_function_uid: function["metadata"]["tag"] = computed_tag - return function + return mlrun.common.formatters.FunctionFormat.format_obj(function, _format) else: function_uri = generate_object_uri(project, name, tag, hash_key) raise mlrun.errors.MLRunNotFoundError(f"Function not found {function_uri}") diff --git a/tests/api/db/test_functions.py b/tests/api/db/test_functions.py index 45e775aeac8..ca2fad33137 100644 --- a/tests/api/db/test_functions.py +++ b/tests/api/db/test_functions.py @@ -255,6 +255,34 @@ def test_list_functions_filtering_unversioned_untagged( assert functions[0]["metadata"]["hash"] == tagged_function_hash_key +def test_list_functions_with_format(db: DBInterface, db_session: Session): + name = "function_name_1" + tag = "some_tag" + function_body = { + "metadata": {"name": name}, + "kind": "remote", + "status": {"state": "online"}, + "spec": { + "description": "some_description", + "image": "some_image", + "default_handler": "some_handler", + "entry_points": "some_entry_points", + "extra_field": "extra_field", + }, + } + db.store_function(db_session, function_body, name, tag=tag, versioned=True) + functions = db.list_functions(db_session, tag=tag, _format="full") + assert len(functions) == 1 + function = functions[0] + assert function["spec"] == function_body["spec"] + + functions = db.list_functions(db_session, tag=tag, _format="minimal") + assert len(functions) == 1 + function = functions[0] + del function_body["spec"]["extra_field"] + assert function["spec"] == function_body["spec"] + + def test_delete_function(db: DBInterface, db_session: Session): labels = { "name": "value",