diff --git a/mlrun/common/constants.py b/mlrun/common/constants.py index 1441c78c80d..3cf1ac9c944 100644 --- a/mlrun/common/constants.py +++ b/mlrun/common/constants.py @@ -20,6 +20,7 @@ MLRUN_SERVING_SPEC_PATH = ( f"{MLRUN_SERVING_SPEC_MOUNT_PATH}/{MLRUN_SERVING_SPEC_FILENAME}" ) +MLRUN_FUNCTIONS_ANNOTATION = "mlrun/mlrun-functions" MYSQL_MEDIUMBLOB_SIZE_BYTES = 16 * 1024 * 1024 MLRUN_LABEL_PREFIX = "mlrun/" DASK_LABEL_PREFIX = "dask.org/" diff --git a/mlrun/common/schemas/api_gateway.py b/mlrun/common/schemas/api_gateway.py index 3e78573a191..971674bacaf 100644 --- a/mlrun/common/schemas/api_gateway.py +++ b/mlrun/common/schemas/api_gateway.py @@ -18,6 +18,7 @@ import pydantic import mlrun.common.types +from mlrun.common.constants import MLRUN_FUNCTIONS_ANNOTATION class APIGatewayAuthenticationMode(mlrun.common.types.StrEnum): @@ -55,6 +56,7 @@ class APIGatewayMetadata(_APIGatewayBaseModel): name: str namespace: Optional[str] labels: Optional[dict] = {} + annotations: Optional[dict] = {} class APIGatewayBasicAuth(_APIGatewayBaseModel): @@ -91,6 +93,56 @@ class APIGateway(_APIGatewayBaseModel): spec: APIGatewaySpec status: Optional[APIGatewayStatus] + def get_function_names(self): + return [ + upstream.nucliofunction.get("name") + for upstream in self.spec.upstreams + if upstream.nucliofunction.get("name") + ] + + def enrich_mlrun_function_names(self): + upstream_with_nuclio_names = [] + mlrun_function_uris = [] + for upstream in self.spec.upstreams: + uri = upstream.nucliofunction.get("name") + project, function_name, tag, _ = ( + mlrun.common.helpers.parse_versioned_object_uri(uri) + ) + upstream.nucliofunction["name"] = ( + mlrun.runtimes.nuclio.function.get_fullname(function_name, project, tag) + ) + + upstream_with_nuclio_names.append(upstream) + mlrun_function_uris.append(uri) + + self.spec.upstreams = upstream_with_nuclio_names + if len(mlrun_function_uris) == 1: + self.metadata.annotations[MLRUN_FUNCTIONS_ANNOTATION] = mlrun_function_uris[ + 0 + ] + elif len(mlrun_function_uris) == 2: + self.metadata.annotations[MLRUN_FUNCTIONS_ANNOTATION] = "&".join( + mlrun_function_uris + ) + return self + + def replace_nuclio_names_with_mlrun_uri(self): + mlrun_functions = self.metadata.annotations.get(MLRUN_FUNCTIONS_ANNOTATION) + if mlrun_functions: + mlrun_function_uris = ( + mlrun_functions.split("&") + if "&" in mlrun_functions + else [mlrun_functions] + ) + if len(mlrun_function_uris) != len(self.spec.upstreams): + raise mlrun.errors.MLRunValueError( + "Error when translating nuclio names to mlrun names in api gateway:" + " number of functions doesn't match the mlrun functions in annotation" + ) + for i in range(len(mlrun_function_uris)): + self.spec.upstreams[i].nucliofunction["name"] = mlrun_function_uris[i] + return self + class APIGatewaysOutput(_APIGatewayBaseModel): api_gateways: typing.Optional[dict[str, APIGateway]] = {} diff --git a/mlrun/runtimes/nuclio/api_gateway.py b/mlrun/runtimes/nuclio/api_gateway.py index 945106c3c97..2d68ba52c2a 100644 --- a/mlrun/runtimes/nuclio/api_gateway.py +++ b/mlrun/runtimes/nuclio/api_gateway.py @@ -28,7 +28,7 @@ from mlrun.platforms.iguazio import min_iguazio_versions from mlrun.utils import logger -from .function import get_fullname, min_nuclio_versions +from .function import min_nuclio_versions class APIGatewayAuthenticator(typing.Protocol): @@ -283,7 +283,21 @@ def _validate_functions( function_names = [] for func in functions: if isinstance(func, str): - function_names.append(func) + # check whether the function was passed as a URI or just a name + parsed_project, function_name, _, _ = ( + mlrun.common.helpers.parse_versioned_object_uri(func) + ) + + if parsed_project and function_name: + # check that parsed project and passed project are the same + if parsed_project != project: + raise mlrun.errors.MLRunInvalidArgumentError( + "Function doesn't belong to passed project" + ) + function_uri = func + else: + function_uri = mlrun.utils.generate_object_uri(project, func) + function_names.append(function_uri) continue function_name = ( @@ -298,8 +312,13 @@ def _validate_functions( f"input function {function_name} " f"does not belong to this project" ) - nuclio_name = get_fullname(function_name, project, func.metadata.tag) - function_names.append(nuclio_name) + function_uri = mlrun.utils.generate_object_uri( + project, + function_name, + func.metadata.tag, + func.metadata.hash, + ) + function_names.append(function_uri) return function_names diff --git a/server/api/utils/clients/async_nuclio.py b/server/api/utils/clients/async_nuclio.py index 2c8ca32a7d0..e2335a21fbe 100644 --- a/server/api/utils/clients/async_nuclio.py +++ b/server/api/utils/clients/async_nuclio.py @@ -60,7 +60,9 @@ async def list_api_gateways( ) parsed_api_gateways = {} for name, gw in api_gateways.items(): - parsed_api_gateways[name] = mlrun.common.schemas.APIGateway.parse_obj(gw) + parsed_api_gateways[name] = mlrun.common.schemas.APIGateway.parse_obj( + gw + ).replace_nuclio_names_with_mlrun_uri() return parsed_api_gateways async def api_gateway_exists(self, name: str, project_name: str = None): @@ -77,7 +79,9 @@ async def get_api_gateway(self, name: str, project_name: str = None): path=NUCLIO_API_GATEWAYS_ENDPOINT_TEMPLATE.format(api_gateway=name), headers=headers, ) - return mlrun.common.schemas.APIGateway.parse_obj(api_gateway) + return mlrun.common.schemas.APIGateway.parse_obj( + api_gateway + ).replace_nuclio_names_with_mlrun_uri() async def store_api_gateway( self, @@ -220,4 +224,5 @@ def _enrich_nuclio_api_gateway( api_gateway: mlrun.common.schemas.APIGateway, ) -> mlrun.common.schemas.APIGateway: self._set_iguazio_labels(api_gateway, project_name) + api_gateway.enrich_mlrun_function_names() return api_gateway diff --git a/tests/api/api/test_nuclio.py b/tests/api/api/test_nuclio.py index f5b17917432..b7319b4c783 100644 --- a/tests/api/api/test_nuclio.py +++ b/tests/api/api/test_nuclio.py @@ -16,11 +16,14 @@ from unittest.mock import patch import fastapi +import pytest import mlrun import mlrun.common.schemas +import mlrun.runtimes.nuclio import server.api.utils.clients.async_nuclio import server.api.utils.clients.iguazio +from mlrun.common.constants import MLRUN_FUNCTIONS_ANNOTATION PROJECT = "project-name" @@ -70,7 +73,7 @@ def test_list_api_gateways( assert response.json() == { "api_gateways": { "new-gw": { - "metadata": {"name": "new-gw", "labels": {}}, + "metadata": {"name": "new-gw", "labels": {}, "annotations": {}}, "spec": { "name": "new-gw", "path": "/", @@ -153,3 +156,58 @@ def test_store_api_gateway( json=api_gateway.dict(), ) assert response.status_code == 200 + + +@pytest.mark.parametrize( + "functions, expected_nuclio_function_names, expected_mlrun_functions_label", + [ + ( + ["test-func"], + ["test-project-test-func"], + "test-project/test-func", + ), + ( + ["test-func1", "test-func2"], + ["test-project-test-func1", "test-project-test-func2"], + "test-project/test-func1&test-project/test-func2", + ), + ( + ["test-func1:latest", "test-func2:latest"], + ["test-project-test-func1", "test-project-test-func2"], + "test-project/test-func1:latest&test-project/test-func2:latest", + ), + ( + ["test-func1:tag1", "test-func2:tag2"], + ["test-project-test-func1-tag1", "test-project-test-func2-tag2"], + "test-project/test-func1:tag1&test-project/test-func2:tag2", + ), + ], +) +def test_mlrun_function_translation_to_nuclio( + functions, expected_nuclio_function_names, expected_mlrun_functions_label +): + project_name = "test-project" + api_gateway_client_side = mlrun.runtimes.APIGateway( + metadata=mlrun.runtimes.nuclio.api_gateway.APIGatewayMetadata(name="new-gw"), + spec=mlrun.runtimes.nuclio.api_gateway.APIGatewaySpec( + functions=functions, project=project_name + ), + ) + api_gateway_server_side = ( + api_gateway_client_side.to_scheme().enrich_mlrun_function_names() + ) + assert ( + api_gateway_server_side.get_function_names() == expected_nuclio_function_names + ) + + assert ( + api_gateway_server_side.metadata.annotations[MLRUN_FUNCTIONS_ANNOTATION] + == expected_mlrun_functions_label + ) + api_gateway_with_replaced_nuclio_names_to_mlrun = ( + api_gateway_server_side.replace_nuclio_names_with_mlrun_uri() + ) + assert ( + api_gateway_with_replaced_nuclio_names_to_mlrun.get_function_names() + == api_gateway_client_side.spec.functions + ) diff --git a/tests/api/utils/clients/test_async_nuclio.py b/tests/api/utils/clients/test_async_nuclio.py index f99788e7fba..7c0fed0a2e5 100644 --- a/tests/api/utils/clients/test_async_nuclio.py +++ b/tests/api/utils/clients/test_async_nuclio.py @@ -17,6 +17,7 @@ import pytest from aioresponses import aioresponses as aioresponses_ +import mlrun.common.constants import mlrun.common.schemas import mlrun.config import mlrun.errors @@ -66,9 +67,14 @@ async def test_nuclio_get_api_gateway( api_gateway.with_canary(["test", "test2"], [20, 80]) request_url = f"{api_url}/api/api_gateways/test-basic" + + expected_payload = api_gateway.to_scheme() + expected_payload.metadata.labels = { + mlrun.common.constants.MLRunInternalLabels.nuclio_project_name: "default-project" + } mock_aioresponse.get( request_url, - payload=api_gateway.to_scheme().dict(), + payload=expected_payload.dict(), status=http.HTTPStatus.ACCEPTED, ) r = await nuclio_client.get_api_gateway("test-basic", "default") @@ -79,7 +85,10 @@ async def test_nuclio_get_api_gateway( received_api_gateway.authentication.authentication_mode == api_gateway.spec.authentication.authentication_mode ) - assert received_api_gateway.spec.functions == ["test", "test2"] + assert received_api_gateway.spec.functions == [ + "default-project/test", + "default-project/test2", + ] assert received_api_gateway.spec.canary == [20, 80] diff --git a/tests/projects/test_project.py b/tests/projects/test_project.py index 5db57c4b37d..7b1b82cd69e 100644 --- a/tests/projects/test_project.py +++ b/tests/projects/test_project.py @@ -1983,7 +1983,7 @@ def test_list_api_gateways(patched_list_api_gateways, context): assert gateways[0].name == "test" assert gateways[0].host == "http://gateway-f1-f2-project-name.some-domain.com" - assert gateways[0].spec.functions == ["my-func1"] + assert gateways[0].spec.functions == ["project-name/my-func1"] assert gateways[1].invoke_url == "http://test-basic-default.domain.com/"