From fb12230d5a313137f4c3ab34415aa2a38f15a8ae Mon Sep 17 00:00:00 2001 From: Hedingber Date: Thu, 24 Sep 2020 04:13:15 +0300 Subject: [PATCH] Add option to delete function (#446) --- mlrun/api/api/endpoints/functions.py | 10 +++++ mlrun/api/db/base.py | 4 ++ mlrun/api/db/filedb/db.py | 3 ++ mlrun/api/db/sqldb/db.py | 28 ++++++++++++- mlrun/db/base.py | 4 ++ mlrun/db/filedb.py | 3 ++ mlrun/db/httpdb.py | 7 ++++ mlrun/db/sqldb.py | 5 +++ tests/api/db/test_functions.py | 60 ++++++++++++++++++++++++++++ 9 files changed, 123 insertions(+), 1 deletion(-) diff --git a/mlrun/api/api/endpoints/functions.py b/mlrun/api/api/endpoints/functions.py index a2953997fd7..0e0e0a164b0 100644 --- a/mlrun/api/api/endpoints/functions.py +++ b/mlrun/api/api/endpoints/functions.py @@ -67,6 +67,16 @@ def get_function( } +@router.delete( + "/projects/{project}/functions/{name}", status_code=HTTPStatus.NO_CONTENT.value +) +def delete_function( + project: str, name: str, db_session: Session = Depends(deps.get_db_session), +): + get_db().delete_function(db_session, name, project) + return Response(status_code=HTTPStatus.NO_CONTENT.value) + + # curl http://localhost:8080/funcs?project=p1&name=x&label=l1&label=l2 @router.get("/funcs") def list_functions( diff --git a/mlrun/api/db/base.py b/mlrun/api/db/base.py index 6945637aedf..1fa18b48732 100644 --- a/mlrun/api/db/base.py +++ b/mlrun/api/db/base.py @@ -114,6 +114,10 @@ def store_function( def get_function(self, session, name, project="", tag="", hash_key=""): pass + @abstractmethod + def delete_function(self, session, project: str, name: str): + pass + @abstractmethod def list_functions(self, session, name, project="", tag="", labels=None): pass diff --git a/mlrun/api/db/filedb/db.py b/mlrun/api/db/filedb/db.py index ff84ee73c82..b7e17921e12 100644 --- a/mlrun/api/db/filedb/db.py +++ b/mlrun/api/db/filedb/db.py @@ -98,6 +98,9 @@ def get_function(self, session, name, project="", tag="", hash_key=""): self.db.get_function, name, project, tag, hash_key ) + def delete_function(self, session, project: str, name: str): + raise NotImplementedError() + def list_functions(self, session, name, project="", tag="", labels=None): return self._transform_run_db_error( self.db.list_functions, name, project, tag, labels diff --git a/mlrun/api/db/sqldb/db.py b/mlrun/api/db/sqldb/db.py index f7cf1ba0daf..9e507ac093f 100644 --- a/mlrun/api/db/sqldb/db.py +++ b/mlrun/api/db/sqldb/db.py @@ -339,6 +339,12 @@ def get_function(self, session, name, project="", tag="", hash_key=""): function_uri = generate_function_uri(project, name, tag, hash_key) raise mlrun.errors.MLRunNotFoundError(f"Function not found {function_uri}") + def delete_function(self, session: Session, project: str, name: str): + logger.debug("Removing function from db", project=project, name=name) + self._delete_function_tags(session, project, name, commit=False) + self._delete_function_labels(session, project, name, commit=False) + self._delete(session, Function, project=project, name=name) + def list_functions(self, session, name, project=None, tag=None, labels=None): project = project or config.default_project uid = None @@ -367,6 +373,26 @@ def list_functions(self, session, name, project=None, tag=None, labels=None): funcs.append(function_dict) return funcs + def _delete_function_tags(self, session, project, function_name, commit=True): + query = session.query(Function.Tag).filter( + Function.Tag.project == project, Function.Tag.obj_name == function_name + ) + for obj in query: + session.delete(obj) + if commit: + session.commit() + + def _delete_function_labels(self, session, project, function_name, commit=True): + labels = ( + session.query(Function.Label) + .join(Function) + .filter(Function.project == project, Function.name == function_name) + ) + for label in labels: + session.delete(label) + if commit: + session.commit() + def _list_function_tags(self, session, project, function_id): query = ( session.query(Function.Tag.name) @@ -695,7 +721,7 @@ def _find_artifacts(self, session, project, uids, labels, since, until): return query - def _find_functions(self, session, name, project, uid, labels): + def _find_functions(self, session, name, project, uid=None, labels=None): query = self._query(session, Function, name=name, project=project) if uid: query = query.filter(Function.uid == uid) diff --git a/mlrun/db/base.py b/mlrun/db/base.py index 618e8428258..e32bc8fafec 100644 --- a/mlrun/db/base.py +++ b/mlrun/db/base.py @@ -106,6 +106,10 @@ def store_function(self, function, name, project="", tag="", versioned=False): def get_function(self, name, project="", tag="", hash_key=""): pass + @abstractmethod + def delete_function(self, name: str, project: str = ""): + pass + @abstractmethod def list_functions(self, name, project="", tag="", labels=None): pass diff --git a/mlrun/db/filedb.py b/mlrun/db/filedb.py index 5562173f370..9e6a335bb58 100644 --- a/mlrun/db/filedb.py +++ b/mlrun/db/filedb.py @@ -319,6 +319,9 @@ def get_function(self, name, project="", tag="", hash_key=""): parsed_data["metadata"]["tag"] = "" if hash_key else tag return parsed_data + def delete_function(self, name: str, project: str = ""): + raise NotImplementedError() + def list_functions(self, name, project="", tag="", labels=None): labels = labels or [] logger.info(f"reading functions in {project} name/mask: {name} tag: {tag} ...") diff --git a/mlrun/db/httpdb.py b/mlrun/db/httpdb.py index b8635b11271..6a874af050e 100644 --- a/mlrun/db/httpdb.py +++ b/mlrun/db/httpdb.py @@ -355,6 +355,13 @@ def get_function(self, name, project="", tag=None, hash_key=""): resp = self.api_call("GET", path, error, params=params) return resp.json()["func"] + def delete_function(self, name: str, project: str = ""): + raise NotImplementedError() + project = project or default_project + path = f"projects/{project}/functions/{name}" + error_message = f"Failed deleting function {project}/{name}" + self.api_call("DELETE", path, error_message) + def list_functions(self, name=None, project=None, tag=None, labels=None): params = { "project": project or default_project, diff --git a/mlrun/db/sqldb.py b/mlrun/db/sqldb.py index 40991c7e1f0..c0ae36005e7 100644 --- a/mlrun/db/sqldb.py +++ b/mlrun/db/sqldb.py @@ -148,6 +148,11 @@ def get_function(self, name, project="", tag="", hash_key=""): self.db.get_function, self.session, name, project, tag, hash_key ) + def delete_function(self, name: str, project: str = ""): + return self._transform_db_error( + self.db.delete_function, self.session, project, name + ) + def list_functions(self, name, project=None, tag=None, labels=None): return self._transform_db_error( self.db.list_functions, self.session, name, project, tag, labels diff --git a/tests/api/db/test_functions.py b/tests/api/db/test_functions.py index 08eb522de2a..a6802cb9694 100644 --- a/tests/api/db/test_functions.py +++ b/tests/api/db/test_functions.py @@ -3,6 +3,7 @@ import mlrun.errors from mlrun.api.db.base import DBInterface +from mlrun.api.db.sqldb.models import Function from tests.api.db.conftest import dbs @@ -198,3 +199,62 @@ def test_list_functions_multiple_tags(db: DBInterface, db_session: Session): function_tag = function["metadata"]["tag"] tags.remove(function_tag) assert len(tags) == 0 + + +# running only on sqldb cause filedb is not really a thing anymore, will be removed soon +@pytest.mark.parametrize( + "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] +) +def test_delete_function(db: DBInterface, db_session: Session): + labels = { + "name": "value", + "name2": "value2", + } + function = { + "bla": "blabla", + "metadata": {"labels": labels}, + "status": {"bla": "blabla"}, + } + function_name = "function_name_1" + project = "bla" + tags = ["some_tag", "some_tag2", "some_tag3"] + function_hash_key = None + for tag in tags: + function_hash_key = db.store_function( + db_session, function, function_name, project, tag=tag, versioned=True + ) + + # if not exploding then function exists + for tag in tags: + db.get_function(db_session, function_name, project, tag=tag) + db.get_function(db_session, function_name, project, hash_key=function_hash_key) + assert len(tags) == len(db.list_functions(db_session, function_name, project)) + number_of_tags = ( + db_session.query(Function.Tag) + .filter_by(project=project, obj_name=function_name) + .count() + ) + number_of_labels = db_session.query(Function.Label).count() + + assert len(tags) == number_of_tags + assert len(labels) == number_of_labels + + db.delete_function(db_session, project, function_name) + + for tag in tags: + with pytest.raises(mlrun.errors.MLRunNotFoundError): + db.get_function(db_session, function_name, project, tag=tag) + with pytest.raises(mlrun.errors.MLRunNotFoundError): + db.get_function(db_session, function_name, project, hash_key=function_hash_key) + assert 0 == len(db.list_functions(db_session, function_name, project)) + + # verifying tags and labels (different table) records were removed + number_of_tags = ( + db_session.query(Function.Tag) + .filter_by(project=project, obj_name=function_name) + .count() + ) + number_of_labels = db_session.query(Function.Label).count() + + assert number_of_tags == 0 + assert number_of_labels == 0