Skip to content

Commit

Permalink
[API] Enrich function object before build (#3688)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonmr committed Jun 1, 2023
1 parent ece1d53 commit 3a0d322
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 11 deletions.
2 changes: 2 additions & 0 deletions mlrun/api/api/endpoints/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import mlrun.api.crud
import mlrun.api.crud.runtimes.nuclio.function
import mlrun.api.db.session
import mlrun.api.launcher
import mlrun.api.utils.auth.verifier
import mlrun.api.utils.background_tasks
import mlrun.api.utils.clients.chief
Expand Down Expand Up @@ -662,6 +663,7 @@ def _build_function(
ready = None
try:
fn = new_function(runtime=function)
mlrun.api.launcher.ServerSideLauncher.enrich_runtime(runtime=fn)
except Exception as err:
logger.error(traceback.format_exc())
log_and_raise(
Expand Down
4 changes: 2 additions & 2 deletions mlrun/api/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def launch(
notifications: Optional[List[mlrun.model.Notification]] = None,
returns: Optional[List[Union[str, Dict[str, str]]]] = None,
) -> mlrun.run.RunObject:
self._enrich_runtime(runtime, project)
self.enrich_runtime(runtime, project)

run = self._create_run_object(task)

Expand Down Expand Up @@ -146,7 +146,7 @@ def launch(
return self._wrap_run_result(runtime, result, run, err=last_err)

@staticmethod
def _enrich_runtime(
def enrich_runtime(
runtime: "mlrun.runtimes.base.BaseRuntime", project: Optional[str] = ""
):
"""
Expand Down
2 changes: 1 addition & 1 deletion mlrun/api/utils/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def build_image(

else:
raise mlrun.errors.MLRunInvalidArgumentError(
f"Load of relative source ({source}) is not supported at build time"
f"Load of relative source ({source}) is not supported at build time "
"see 'mlrun.runtimes.kubejob.KubejobRuntime.with_source_archive' or "
"'mlrun.projects.project.MlrunProject.set_source' for more details"
)
Expand Down
2 changes: 1 addition & 1 deletion mlrun/launcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def prepare_image_for_deploy(runtime: "mlrun.runtimes.BaseRuntime"):

@staticmethod
@abc.abstractmethod
def _enrich_runtime(
def enrich_runtime(
runtime: "mlrun.runtimes.base.BaseRuntime",
project: Optional[str] = "",
):
Expand Down
2 changes: 1 addition & 1 deletion mlrun/launcher/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ClientBaseLauncher(mlrun.launcher.base.BaseLauncher, abc.ABC):
"""

@staticmethod
def _enrich_runtime(
def enrich_runtime(
runtime: "mlrun.runtimes.base.BaseRuntime", project: Optional[str] = ""
):
runtime.try_auto_mount_based_on_config()
Expand Down
2 changes: 1 addition & 1 deletion mlrun/launcher/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def launch(
"local and schedule cannot be used together"
)

self._enrich_runtime(runtime)
self.enrich_runtime(runtime)
run = self._create_run_object(task)

if self._is_run_local:
Expand Down
2 changes: 1 addition & 1 deletion mlrun/launcher/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def launch(
notifications: Optional[List[mlrun.model.Notification]] = None,
returns: Optional[List[Union[str, Dict[str, str]]]] = None,
) -> "mlrun.run.RunObject":
self._enrich_runtime(runtime)
self.enrich_runtime(runtime)
run = self._create_run_object(task)

run = self._enrich_run(
Expand Down
47 changes: 47 additions & 0 deletions tests/api/api/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import mlrun.api.api.utils
import mlrun.api.crud
import mlrun.api.main
import mlrun.api.utils.builder
import mlrun.api.utils.clients.chief
import mlrun.api.utils.singletons.db
import mlrun.api.utils.singletons.k8s
Expand Down Expand Up @@ -443,6 +444,52 @@ def test_build_function_with_mlrun_bool(
mlrun.api.api.endpoints.functions._build_function = original_build_function


@pytest.mark.parametrize(
"source, load_source_on_run",
[
("./", False),
(".", False),
("./", True),
(".", True),
],
)
def test_build_function_with_project_repo(
db: sqlalchemy.orm.Session,
client: fastapi.testclient.TestClient,
source,
load_source_on_run,
):
git_repo = "git://github.com/mlrun/test.git"
tests.api.api.utils.create_project(
client, PROJECT, source=git_repo, load_source_on_run=load_source_on_run
)
function_dict = {
"kind": "job",
"metadata": {
"name": "function-name",
"project": "project-name",
"tag": "latest",
},
"spec": {
"build": {
"source": source,
},
},
}
original_build_runtime = mlrun.api.utils.builder.build_image
mlrun.api.utils.builder.build_image = unittest.mock.Mock(return_value="success")
response = client.post(
"build/function",
json={"function": function_dict},
)
assert response.status_code == HTTPStatus.OK.value
function = mlrun.new_function(runtime=response.json()["data"])
assert function.spec.build.source == git_repo
assert function.spec.build.load_source_on_run == load_source_on_run

mlrun.api.utils.builder.build_image = original_build_runtime


def test_start_function_succeeded(
db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient, monkeypatch
):
Expand Down
19 changes: 15 additions & 4 deletions tests/api/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@
PROJECT = "project-name"


def create_project(client: TestClient, project_name: str = PROJECT, artifact_path=None):
project = _create_project_obj(project_name, artifact_path)
def create_project(
client: TestClient,
project_name: str = PROJECT,
artifact_path=None,
source="source",
load_source_on_run=False,
):
project = _create_project_obj(
project_name, artifact_path, source, load_source_on_run
)
resp = client.post("projects", json=project.dict())
assert resp.status_code == HTTPStatus.CREATED.value
return resp
Expand Down Expand Up @@ -69,12 +77,15 @@ async def create_project_async(
return resp


def _create_project_obj(project_name, artifact_path) -> mlrun.common.schemas.Project:
def _create_project_obj(
project_name, artifact_path, source, load_source_on_run=False
) -> mlrun.common.schemas.Project:
return mlrun.common.schemas.Project(
metadata=mlrun.common.schemas.ProjectMetadata(name=project_name),
spec=mlrun.common.schemas.ProjectSpec(
description="banana",
source="source",
source=source,
load_source_on_run=load_source_on_run,
goals="some goals",
artifact_path=artifact_path,
),
Expand Down
1 change: 1 addition & 0 deletions tests/api/runtimes/test_kubejob.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def custom_setup(self):
def _generate_runtime(self) -> mlrun.runtimes.KubejobRuntime:
runtime = mlrun.runtimes.KubejobRuntime()
runtime.spec.image = self.image_name
runtime.metadata.project = self.project
return runtime

def test_run_without_runspec(self, db: Session, client: TestClient):
Expand Down

0 comments on commit 3a0d322

Please sign in to comment.