Skip to content

Commit

Permalink
[APIGateway] Keep mlrun function names in API gateway object on clien…
Browse files Browse the repository at this point in the history
…t side (#5613)
  • Loading branch information
rokatyy committed May 29, 2024
1 parent 9263dc2 commit 39bbdf8
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 10 deletions.
1 change: 1 addition & 0 deletions mlrun/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
MLRUN_SERVING_SPEC_PATH = (
f"{MLRUN_SERVING_SPEC_MOUNT_PATH}/{MLRUN_SERVING_SPEC_FILENAME}"
)
MLRUN_FUNCTIONS_ANNOTATION = "mlrun/mlrun-functions"
MYSQL_MEDIUMBLOB_SIZE_BYTES = 16 * 1024 * 1024
MLRUN_LABEL_PREFIX = "mlrun/"
DASK_LABEL_PREFIX = "dask.org/"
Expand Down
52 changes: 52 additions & 0 deletions mlrun/common/schemas/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pydantic

import mlrun.common.types
from mlrun.common.constants import MLRUN_FUNCTIONS_ANNOTATION


class APIGatewayAuthenticationMode(mlrun.common.types.StrEnum):
Expand Down Expand Up @@ -55,6 +56,7 @@ class APIGatewayMetadata(_APIGatewayBaseModel):
name: str
namespace: Optional[str]
labels: Optional[dict] = {}
annotations: Optional[dict] = {}


class APIGatewayBasicAuth(_APIGatewayBaseModel):
Expand Down Expand Up @@ -91,6 +93,56 @@ class APIGateway(_APIGatewayBaseModel):
spec: APIGatewaySpec
status: Optional[APIGatewayStatus]

def get_function_names(self):
return [
upstream.nucliofunction.get("name")
for upstream in self.spec.upstreams
if upstream.nucliofunction.get("name")
]

def enrich_mlrun_function_names(self):
upstream_with_nuclio_names = []
mlrun_function_uris = []
for upstream in self.spec.upstreams:
uri = upstream.nucliofunction.get("name")
project, function_name, tag, _ = (
mlrun.common.helpers.parse_versioned_object_uri(uri)
)
upstream.nucliofunction["name"] = (
mlrun.runtimes.nuclio.function.get_fullname(function_name, project, tag)
)

upstream_with_nuclio_names.append(upstream)
mlrun_function_uris.append(uri)

self.spec.upstreams = upstream_with_nuclio_names
if len(mlrun_function_uris) == 1:
self.metadata.annotations[MLRUN_FUNCTIONS_ANNOTATION] = mlrun_function_uris[
0
]
elif len(mlrun_function_uris) == 2:
self.metadata.annotations[MLRUN_FUNCTIONS_ANNOTATION] = "&".join(
mlrun_function_uris
)
return self

def replace_nuclio_names_with_mlrun_uri(self):
mlrun_functions = self.metadata.annotations.get(MLRUN_FUNCTIONS_ANNOTATION)
if mlrun_functions:
mlrun_function_uris = (
mlrun_functions.split("&")
if "&" in mlrun_functions
else [mlrun_functions]
)
if len(mlrun_function_uris) != len(self.spec.upstreams):
raise mlrun.errors.MLRunValueError(
"Error when translating nuclio names to mlrun names in api gateway:"
" number of functions doesn't match the mlrun functions in annotation"
)
for i in range(len(mlrun_function_uris)):
self.spec.upstreams[i].nucliofunction["name"] = mlrun_function_uris[i]
return self


class APIGatewaysOutput(_APIGatewayBaseModel):
api_gateways: typing.Optional[dict[str, APIGateway]] = {}
27 changes: 23 additions & 4 deletions mlrun/runtimes/nuclio/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from mlrun.platforms.iguazio import min_iguazio_versions
from mlrun.utils import logger

from .function import get_fullname, min_nuclio_versions
from .function import min_nuclio_versions


class APIGatewayAuthenticator(typing.Protocol):
Expand Down Expand Up @@ -283,7 +283,21 @@ def _validate_functions(
function_names = []
for func in functions:
if isinstance(func, str):
function_names.append(func)
# check whether the function was passed as a URI or just a name
parsed_project, function_name, _, _ = (
mlrun.common.helpers.parse_versioned_object_uri(func)
)

if parsed_project and function_name:
# check that parsed project and passed project are the same
if parsed_project != project:
raise mlrun.errors.MLRunInvalidArgumentError(
"Function doesn't belong to passed project"
)
function_uri = func
else:
function_uri = mlrun.utils.generate_object_uri(project, func)
function_names.append(function_uri)
continue

function_name = (
Expand All @@ -298,8 +312,13 @@ def _validate_functions(
f"input function {function_name} "
f"does not belong to this project"
)
nuclio_name = get_fullname(function_name, project, func.metadata.tag)
function_names.append(nuclio_name)
function_uri = mlrun.utils.generate_object_uri(
project,
function_name,
func.metadata.tag,
func.metadata.hash,
)
function_names.append(function_uri)
return function_names


Expand Down
9 changes: 7 additions & 2 deletions server/api/utils/clients/async_nuclio.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ async def list_api_gateways(
)
parsed_api_gateways = {}
for name, gw in api_gateways.items():
parsed_api_gateways[name] = mlrun.common.schemas.APIGateway.parse_obj(gw)
parsed_api_gateways[name] = mlrun.common.schemas.APIGateway.parse_obj(
gw
).replace_nuclio_names_with_mlrun_uri()
return parsed_api_gateways

async def api_gateway_exists(self, name: str, project_name: str = None):
Expand All @@ -77,7 +79,9 @@ async def get_api_gateway(self, name: str, project_name: str = None):
path=NUCLIO_API_GATEWAYS_ENDPOINT_TEMPLATE.format(api_gateway=name),
headers=headers,
)
return mlrun.common.schemas.APIGateway.parse_obj(api_gateway)
return mlrun.common.schemas.APIGateway.parse_obj(
api_gateway
).replace_nuclio_names_with_mlrun_uri()

async def store_api_gateway(
self,
Expand Down Expand Up @@ -220,4 +224,5 @@ def _enrich_nuclio_api_gateway(
api_gateway: mlrun.common.schemas.APIGateway,
) -> mlrun.common.schemas.APIGateway:
self._set_iguazio_labels(api_gateway, project_name)
api_gateway.enrich_mlrun_function_names()
return api_gateway
60 changes: 59 additions & 1 deletion tests/api/api/test_nuclio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
from unittest.mock import patch

import fastapi
import pytest

import mlrun
import mlrun.common.schemas
import mlrun.runtimes.nuclio
import server.api.utils.clients.async_nuclio
import server.api.utils.clients.iguazio
from mlrun.common.constants import MLRUN_FUNCTIONS_ANNOTATION

PROJECT = "project-name"

Expand Down Expand Up @@ -70,7 +73,7 @@ def test_list_api_gateways(
assert response.json() == {
"api_gateways": {
"new-gw": {
"metadata": {"name": "new-gw", "labels": {}},
"metadata": {"name": "new-gw", "labels": {}, "annotations": {}},
"spec": {
"name": "new-gw",
"path": "/",
Expand Down Expand Up @@ -153,3 +156,58 @@ def test_store_api_gateway(
json=api_gateway.dict(),
)
assert response.status_code == 200


@pytest.mark.parametrize(
"functions, expected_nuclio_function_names, expected_mlrun_functions_label",
[
(
["test-func"],
["test-project-test-func"],
"test-project/test-func",
),
(
["test-func1", "test-func2"],
["test-project-test-func1", "test-project-test-func2"],
"test-project/test-func1&test-project/test-func2",
),
(
["test-func1:latest", "test-func2:latest"],
["test-project-test-func1", "test-project-test-func2"],
"test-project/test-func1:latest&test-project/test-func2:latest",
),
(
["test-func1:tag1", "test-func2:tag2"],
["test-project-test-func1-tag1", "test-project-test-func2-tag2"],
"test-project/test-func1:tag1&test-project/test-func2:tag2",
),
],
)
def test_mlrun_function_translation_to_nuclio(
functions, expected_nuclio_function_names, expected_mlrun_functions_label
):
project_name = "test-project"
api_gateway_client_side = mlrun.runtimes.APIGateway(
metadata=mlrun.runtimes.nuclio.api_gateway.APIGatewayMetadata(name="new-gw"),
spec=mlrun.runtimes.nuclio.api_gateway.APIGatewaySpec(
functions=functions, project=project_name
),
)
api_gateway_server_side = (
api_gateway_client_side.to_scheme().enrich_mlrun_function_names()
)
assert (
api_gateway_server_side.get_function_names() == expected_nuclio_function_names
)

assert (
api_gateway_server_side.metadata.annotations[MLRUN_FUNCTIONS_ANNOTATION]
== expected_mlrun_functions_label
)
api_gateway_with_replaced_nuclio_names_to_mlrun = (
api_gateway_server_side.replace_nuclio_names_with_mlrun_uri()
)
assert (
api_gateway_with_replaced_nuclio_names_to_mlrun.get_function_names()
== api_gateway_client_side.spec.functions
)
13 changes: 11 additions & 2 deletions tests/api/utils/clients/test_async_nuclio.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
from aioresponses import aioresponses as aioresponses_

import mlrun.common.constants
import mlrun.common.schemas
import mlrun.config
import mlrun.errors
Expand Down Expand Up @@ -66,9 +67,14 @@ async def test_nuclio_get_api_gateway(
api_gateway.with_canary(["test", "test2"], [20, 80])

request_url = f"{api_url}/api/api_gateways/test-basic"

expected_payload = api_gateway.to_scheme()
expected_payload.metadata.labels = {
mlrun.common.constants.MLRunInternalLabels.nuclio_project_name: "default-project"
}
mock_aioresponse.get(
request_url,
payload=api_gateway.to_scheme().dict(),
payload=expected_payload.dict(),
status=http.HTTPStatus.ACCEPTED,
)
r = await nuclio_client.get_api_gateway("test-basic", "default")
Expand All @@ -79,7 +85,10 @@ async def test_nuclio_get_api_gateway(
received_api_gateway.authentication.authentication_mode
== api_gateway.spec.authentication.authentication_mode
)
assert received_api_gateway.spec.functions == ["test", "test2"]
assert received_api_gateway.spec.functions == [
"default-project/test",
"default-project/test2",
]
assert received_api_gateway.spec.canary == [20, 80]


Expand Down
2 changes: 1 addition & 1 deletion tests/projects/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,7 +1983,7 @@ def test_list_api_gateways(patched_list_api_gateways, context):

assert gateways[0].name == "test"
assert gateways[0].host == "http://gateway-f1-f2-project-name.some-domain.com"
assert gateways[0].spec.functions == ["my-func1"]
assert gateways[0].spec.functions == ["project-name/my-func1"]

assert gateways[1].invoke_url == "http://test-basic-default.domain.com/"

Expand Down

0 comments on commit 39bbdf8

Please sign in to comment.