Skip to content

Commit

Permalink
Add option to delete function (#446)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hedingber committed Sep 24, 2020
1 parent 9bc8760 commit fb12230
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 1 deletion.
10 changes: 10 additions & 0 deletions mlrun/api/api/endpoints/functions.py
Expand Up @@ -67,6 +67,16 @@ def get_function(
}


@router.delete(
"/projects/{project}/functions/{name}", status_code=HTTPStatus.NO_CONTENT.value
)
def delete_function(
project: str, name: str, db_session: Session = Depends(deps.get_db_session),
):
get_db().delete_function(db_session, name, project)
return Response(status_code=HTTPStatus.NO_CONTENT.value)


# curl http://localhost:8080/funcs?project=p1&name=x&label=l1&label=l2
@router.get("/funcs")
def list_functions(
Expand Down
4 changes: 4 additions & 0 deletions mlrun/api/db/base.py
Expand Up @@ -114,6 +114,10 @@ def store_function(
def get_function(self, session, name, project="", tag="", hash_key=""):
pass

@abstractmethod
def delete_function(self, session, project: str, name: str):
pass

@abstractmethod
def list_functions(self, session, name, project="", tag="", labels=None):
pass
Expand Down
3 changes: 3 additions & 0 deletions mlrun/api/db/filedb/db.py
Expand Up @@ -98,6 +98,9 @@ def get_function(self, session, name, project="", tag="", hash_key=""):
self.db.get_function, name, project, tag, hash_key
)

def delete_function(self, session, project: str, name: str):
raise NotImplementedError()

def list_functions(self, session, name, project="", tag="", labels=None):
return self._transform_run_db_error(
self.db.list_functions, name, project, tag, labels
Expand Down
28 changes: 27 additions & 1 deletion mlrun/api/db/sqldb/db.py
Expand Up @@ -339,6 +339,12 @@ def get_function(self, session, name, project="", tag="", hash_key=""):
function_uri = generate_function_uri(project, name, tag, hash_key)
raise mlrun.errors.MLRunNotFoundError(f"Function not found {function_uri}")

def delete_function(self, session: Session, project: str, name: str):
logger.debug("Removing function from db", project=project, name=name)
self._delete_function_tags(session, project, name, commit=False)
self._delete_function_labels(session, project, name, commit=False)
self._delete(session, Function, project=project, name=name)

def list_functions(self, session, name, project=None, tag=None, labels=None):
project = project or config.default_project
uid = None
Expand Down Expand Up @@ -367,6 +373,26 @@ def list_functions(self, session, name, project=None, tag=None, labels=None):
funcs.append(function_dict)
return funcs

def _delete_function_tags(self, session, project, function_name, commit=True):
query = session.query(Function.Tag).filter(
Function.Tag.project == project, Function.Tag.obj_name == function_name
)
for obj in query:
session.delete(obj)
if commit:
session.commit()

def _delete_function_labels(self, session, project, function_name, commit=True):
labels = (
session.query(Function.Label)
.join(Function)
.filter(Function.project == project, Function.name == function_name)
)
for label in labels:
session.delete(label)
if commit:
session.commit()

def _list_function_tags(self, session, project, function_id):
query = (
session.query(Function.Tag.name)
Expand Down Expand Up @@ -695,7 +721,7 @@ def _find_artifacts(self, session, project, uids, labels, since, until):

return query

def _find_functions(self, session, name, project, uid, labels):
def _find_functions(self, session, name, project, uid=None, labels=None):
query = self._query(session, Function, name=name, project=project)
if uid:
query = query.filter(Function.uid == uid)
Expand Down
4 changes: 4 additions & 0 deletions mlrun/db/base.py
Expand Up @@ -106,6 +106,10 @@ def store_function(self, function, name, project="", tag="", versioned=False):
def get_function(self, name, project="", tag="", hash_key=""):
pass

@abstractmethod
def delete_function(self, name: str, project: str = ""):
pass

@abstractmethod
def list_functions(self, name, project="", tag="", labels=None):
pass
Expand Down
3 changes: 3 additions & 0 deletions mlrun/db/filedb.py
Expand Up @@ -319,6 +319,9 @@ def get_function(self, name, project="", tag="", hash_key=""):
parsed_data["metadata"]["tag"] = "" if hash_key else tag
return parsed_data

def delete_function(self, name: str, project: str = ""):
raise NotImplementedError()

def list_functions(self, name, project="", tag="", labels=None):
labels = labels or []
logger.info(f"reading functions in {project} name/mask: {name} tag: {tag} ...")
Expand Down
7 changes: 7 additions & 0 deletions mlrun/db/httpdb.py
Expand Up @@ -355,6 +355,13 @@ def get_function(self, name, project="", tag=None, hash_key=""):
resp = self.api_call("GET", path, error, params=params)
return resp.json()["func"]

def delete_function(self, name: str, project: str = ""):
raise NotImplementedError()
project = project or default_project
path = f"projects/{project}/functions/{name}"
error_message = f"Failed deleting function {project}/{name}"
self.api_call("DELETE", path, error_message)

def list_functions(self, name=None, project=None, tag=None, labels=None):
params = {
"project": project or default_project,
Expand Down
5 changes: 5 additions & 0 deletions mlrun/db/sqldb.py
Expand Up @@ -148,6 +148,11 @@ def get_function(self, name, project="", tag="", hash_key=""):
self.db.get_function, self.session, name, project, tag, hash_key
)

def delete_function(self, name: str, project: str = ""):
return self._transform_db_error(
self.db.delete_function, self.session, project, name
)

def list_functions(self, name, project=None, tag=None, labels=None):
return self._transform_db_error(
self.db.list_functions, self.session, name, project, tag, labels
Expand Down
60 changes: 60 additions & 0 deletions tests/api/db/test_functions.py
Expand Up @@ -3,6 +3,7 @@

import mlrun.errors
from mlrun.api.db.base import DBInterface
from mlrun.api.db.sqldb.models import Function
from tests.api.db.conftest import dbs


Expand Down Expand Up @@ -198,3 +199,62 @@ def test_list_functions_multiple_tags(db: DBInterface, db_session: Session):
function_tag = function["metadata"]["tag"]
tags.remove(function_tag)
assert len(tags) == 0


# running only on sqldb cause filedb is not really a thing anymore, will be removed soon
@pytest.mark.parametrize(
"db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"]
)
def test_delete_function(db: DBInterface, db_session: Session):
labels = {
"name": "value",
"name2": "value2",
}
function = {
"bla": "blabla",
"metadata": {"labels": labels},
"status": {"bla": "blabla"},
}
function_name = "function_name_1"
project = "bla"
tags = ["some_tag", "some_tag2", "some_tag3"]
function_hash_key = None
for tag in tags:
function_hash_key = db.store_function(
db_session, function, function_name, project, tag=tag, versioned=True
)

# if not exploding then function exists
for tag in tags:
db.get_function(db_session, function_name, project, tag=tag)
db.get_function(db_session, function_name, project, hash_key=function_hash_key)
assert len(tags) == len(db.list_functions(db_session, function_name, project))
number_of_tags = (
db_session.query(Function.Tag)
.filter_by(project=project, obj_name=function_name)
.count()
)
number_of_labels = db_session.query(Function.Label).count()

assert len(tags) == number_of_tags
assert len(labels) == number_of_labels

db.delete_function(db_session, project, function_name)

for tag in tags:
with pytest.raises(mlrun.errors.MLRunNotFoundError):
db.get_function(db_session, function_name, project, tag=tag)
with pytest.raises(mlrun.errors.MLRunNotFoundError):
db.get_function(db_session, function_name, project, hash_key=function_hash_key)
assert 0 == len(db.list_functions(db_session, function_name, project))

# verifying tags and labels (different table) records were removed
number_of_tags = (
db_session.query(Function.Tag)
.filter_by(project=project, obj_name=function_name)
.count()
)
number_of_labels = db_session.query(Function.Label).count()

assert number_of_tags == 0
assert number_of_labels == 0

0 comments on commit fb12230

Please sign in to comment.