Skip to content

Commit

Permalink
[API] Fix session usage when deleting functions (#5584)
Browse files Browse the repository at this point in the history
  • Loading branch information
rokatyy committed May 17, 2024
1 parent 3e16327 commit b7982e7
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 24 deletions.
21 changes: 11 additions & 10 deletions server/api/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,8 +1295,8 @@ async def _delete_function(
)

# update functions with deletion task id
await _update_functions_with_deletion_task_ids(
db_session, functions, project, background_task_name
await _update_functions_with_deletion_info(
functions, project, {"status.deletion_task_id": background_task_name}
)

# Since we request functions by a specific name and project,
Expand All @@ -1316,6 +1316,9 @@ async def _delete_function(
)
if failed_requests:
error_message = f"Failed to delete function {function_name}. Errors: {' '.join(failed_requests)}"
await _update_functions_with_deletion_info(
functions, project, {"status.deletion_error": error_message}
)
raise mlrun.errors.MLRunInternalServerError(error_message)

# delete the function from the database
Expand All @@ -1327,24 +1330,22 @@ async def _delete_function(
)


async def _update_functions_with_deletion_task_ids(
db_session, functions, project, background_task_name
):
async def _update_functions_with_deletion_info(functions, project, updates: dict):
semaphore = asyncio.Semaphore(
mlrun.mlconf.background_tasks.function_deletion_batch_size
)

async def update_function_with_task_id(function):
async def update_function(function):
async with semaphore:
await run_in_threadpool(
server.api.crud.Functions().set_function_deletion_task_id,
db_session,
server.api.db.session.run_function_with_new_db_session,
server.api.crud.Functions().update_function,
function,
project,
background_task_name,
updates,
)

tasks = [update_function_with_task_id(function) for function in functions]
tasks = [update_function(function) for function in functions]
await asyncio.gather(*tasks)


Expand Down
14 changes: 8 additions & 6 deletions server/api/crud/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,18 @@ def start_function(self, function, client_version=None, client_python_version=No
)
function.save(versioned=False)

def set_function_deletion_task_id(
self, db_session: sqlalchemy.orm.Session, function, project, deletion_task_id
def update_function(
self,
db_session: sqlalchemy.orm.Session,
function,
project,
updates: dict,
):
deleting_updates = {
"status.deletion_task_id": deletion_task_id,
}
return server.api.utils.singletons.db.get_db().update_function(
session=db_session,
name=function["metadata"]["name"],
tag=function["metadata"]["tag"],
hash_key=function.get("metadata", {}).get("hash"),
project=project,
updates=deleting_updates,
updates=updates,
)
4 changes: 2 additions & 2 deletions server/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,8 +1692,8 @@ def update_function(
name,
updates: dict,
project: str = None,
tag: str = None,
hash_key: str = None,
tag: str = "",
hash_key: str = "",
):
project = project or config.default_project
query = self._query(session, Function, name=name, project=project)
Expand Down
12 changes: 8 additions & 4 deletions tests/api/api/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
_generate_function_and_task_from_submit_run_body,
_mask_v3io_access_key_env_var,
_mask_v3io_volume_credentials,
_update_functions_with_deletion_task_ids,
_update_functions_with_deletion_info,
ensure_function_has_auth_set,
ensure_function_security_context,
get_scheduler,
Expand Down Expand Up @@ -1695,7 +1695,7 @@ async def test_delete_function_calls_k8s_helper_methods():


@pytest.mark.asyncio
async def test_update_functions_with_deletion_task_ids(db: sqlalchemy.orm.Session):
async def test_update_functions_with_deletion_info(db: sqlalchemy.orm.Session):
project = "my_project"
deletion_task_id = "12345"
function_name = "test_function"
Expand All @@ -1708,8 +1708,12 @@ async def test_update_functions_with_deletion_task_ids(db: sqlalchemy.orm.Sessio
)
functions = [function]

await _update_functions_with_deletion_task_ids(
db, functions, project, deletion_task_id
await _update_functions_with_deletion_info(
functions,
project,
updates={
"status.deletion_task_id": deletion_task_id,
},
)
function = server.api.crud.Functions().get_function(
db, name=function_name, project=project, tag=function_tag
Expand Down
6 changes: 4 additions & 2 deletions tests/api/crud/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ def test_set_function_deletion_task_id_updates_correctly(db: sqlalchemy.orm.Sess
db, name=function_name, project=project, tag=function_tag
)

result = server.api.crud.Functions().set_function_deletion_task_id(
result = server.api.crud.Functions().update_function(
db_session=db,
function=function,
project=project,
deletion_task_id=deletion_task_id,
updates={
"status.deletion_task_id": deletion_task_id,
},
)

assert result["status"]["deletion_task_id"] == deletion_task_id

0 comments on commit b7982e7

Please sign in to comment.