Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise when function not found on get function call #428

Merged
merged 8 commits into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlrun/api/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi import HTTPException, Request
from sqlalchemy.orm import Session

import mlrun.errors
from mlrun.api import schemas
from mlrun.api.db.sqldb.db import SQLDB
from mlrun.api.utils.singletons.db import get_db
Expand Down Expand Up @@ -142,6 +143,8 @@ def submit(db_session: Session, data):
except HTTPException:
logger.error(traceback.format_exc())
raise
except mlrun.errors.MLRunHTTPStatusError:
raise
except Exception as err:
logger.error(traceback.format_exc())
log_and_raise(
Expand Down
18 changes: 16 additions & 2 deletions mlrun/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytz
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from typing import Any, List

import pytz
from sqlalchemy import and_, func
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
Expand All @@ -29,7 +29,13 @@
)
from mlrun.config import config
from mlrun.lists import ArtifactList, FunctionList, RunList
from mlrun.utils import get_in, update_in, logger, fill_function_hash
from mlrun.utils import (
get_in,
update_in,
logger,
fill_function_hash,
generate_function_uri,
)

NULL = None # Avoid flake8 issuing warnings when comparing in filter
run_time_fmt = "%Y-%m-%dT%H:%M:%S.%fZ"
Expand Down Expand Up @@ -309,6 +315,11 @@ def get_function(self, session, name, project="", tag="", hash_key=""):
tag_function_uid = self._resolve_tag_function_uid(
session, Function, project, name, computed_tag
)
if tag_function_uid is None:
function_uri = generate_function_uri(project, name, tag)
raise mlrun.errors.MLRunNotFoundError(
f"Function tag not found {function_uri}"
)
uid = tag_function_uid
if uid:
query = query.filter(Function.uid == uid)
Expand All @@ -324,6 +335,9 @@ def get_function(self, session, name, project="", tag="", hash_key=""):
if tag_function_uid:
function["metadata"]["tag"] = computed_tag
return function
else:
function_uri = generate_function_uri(project, name, tag, hash_key)
raise mlrun.errors.MLRunNotFoundError(f"Function not found {function_uri}")

def list_functions(self, session, name, project=None, tag=None, labels=None):
project = project or config.default_project
Expand Down
4 changes: 3 additions & 1 deletion mlrun/db/filedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
match_value,
update_in,
fill_function_hash,
generate_function_uri,
)

run_logs = "runs"
Expand Down Expand Up @@ -309,7 +310,8 @@ def get_function(self, name, project="", tag="", hash_key=""):
+ self.format
)
if not pathlib.Path(filepath).is_file():
return None
function_uri = generate_function_uri(project, name, tag, hash_key)
raise mlrun.errors.MLRunNotFoundError(f"Function not found {function_uri}")
data = self._datastore.get(filepath)
parsed_data = self._loads(data)

Expand Down
16 changes: 7 additions & 9 deletions mlrun/runtimes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from mlrun.api import schemas
from mlrun.api.constants import LogSources
from mlrun.api.db.base import DBInterface
from mlrun.utils.helpers import verify_field_regex
from mlrun.utils.helpers import verify_field_regex, generate_function_uri
from .constants import PodPhases, RunStates
from .generators import get_generator
from .utils import calc_hash, RunError, results_to_iter
Expand Down Expand Up @@ -198,14 +198,12 @@ def _use_remote_api(self):
return False

def _function_uri(self, tag=None, hash_key=None):
url = "{}/{}".format(self.metadata.project, self.metadata.name)

# prioritize hash key over tag
if hash_key:
url += "@{}".format(hash_key)
elif tag or self.metadata.tag:
url += ":{}".format(tag or self.metadata.tag)
return url
return generate_function_uri(
self.metadata.project,
self.metadata.name,
tag=tag or self.metadata.tag,
hash_key=hash_key,
)

def _get_db(self):
if not self._db_conn:
Expand Down
11 changes: 11 additions & 0 deletions mlrun/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,17 @@ def parse_function_uri(uri):
return project, uri, tag, hash_key


def generate_function_uri(project, name, tag=None, hash_key=None):
uri = "{}/{}".format(project, name)

# prioritize hash key over tag
if hash_key:
uri += "@{}".format(hash_key)
elif tag:
uri += ":{}".format(tag)
return uri


def extend_hub_uri(uri):
if not uri.startswith(hub_prefix):
return uri
Expand Down
5 changes: 1 addition & 4 deletions tests/api/api/test_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,4 @@ def test_submit_job_failure_function_not_found(db: Session, client: TestClient)
}
resp = client.post("/api/submit_job", json=body)
assert resp.status_code == HTTPStatus.NOT_FOUND.value
assert (
resp.json()["detail"]["reason"]
== f"runtime error: function {function_reference} not found"
)
assert f"Function not found {function_reference}" in resp.json()["detail"]
22 changes: 18 additions & 4 deletions tests/api/db/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from sqlalchemy.orm import Session

import mlrun.errors
from mlrun.api.db.base import DBInterface
from tests.api.db.conftest import dbs

Expand Down Expand Up @@ -81,10 +82,8 @@ def test_store_function_not_versioned(db: DBInterface, db_session: Session):
assert function_result_1["metadata"]["tag"] == "latest"

# not versioned so not queryable by hash key
function_result_2 = db.get_function(
db_session, function_name_1, hash_key=function_hash_key
)
assert function_result_2 is None
with pytest.raises(mlrun.errors.MLRunNotFoundError):
db.get_function(db_session, function_name_1, hash_key=function_hash_key)

function_2 = {"bla": "blabla", "bla2": "blabla2"}
db.store_function(db_session, function_2, function_name_1, versioned=False)
Expand Down Expand Up @@ -138,6 +137,21 @@ def test_get_function_by_tag(db: DBInterface, db_session: Session):
assert function_queried_by_hash_key["status"] is None


@pytest.mark.parametrize(
"db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"]
)
def test_get_function_not_found(db: DBInterface, db_session: Session):
function_1 = {"bla": "blabla", "status": {"bla": "blabla"}}
function_name_1 = "function_name_1"
db.store_function(db_session, function_1, function_name_1, versioned=True)

with pytest.raises(mlrun.errors.MLRunNotFoundError):
db.get_function(db_session, function_name_1, tag="inexistent_tag")

with pytest.raises(mlrun.errors.MLRunNotFoundError):
db.get_function(db_session, function_name_1, hash_key="inexistent_hash_key")


@pytest.mark.parametrize(
"db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"]
)
Expand Down