Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement gql_mutation_wrapper() function #2164

Merged
merged 1 commit into from
May 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 24 additions & 36 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Awaitable,
Callable,
ClassVar,
Coroutine,
Dict,
Generic,
Iterable,
Expand Down Expand Up @@ -1029,6 +1030,27 @@ async def wrapped(cls, root, info: graphene.ResolveInfo, *args, **kwargs) -> Any
ItemType = TypeVar("ItemType", bound=graphene.ObjectType)


async def gql_mutation_wrapper(
result_cls: Type[ResultType], _do_mutate: Callable[[], Coroutine[Any, Any, ResultType]]
) -> ResultType:
try:
return await execute_with_retry(_do_mutate)
except sa.exc.IntegrityError as e:
log.warning("gql_mutation_wrapper(): integrity error ({})", repr(e))
return result_cls(False, f"integrity error: {e}")
except sa.exc.StatementError as e:
log.warning(
"gql_mutation_wrapper(): statement error ({})\n{}", repr(e), e.statement or "(unknown)"
)
orig_exc = e.orig
return result_cls(False, str(orig_exc), None)
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except Exception as e:
log.exception("gql_mutation_wrapper(): other error")
return result_cls(False, f"unexpected error: {e}")


async def simple_db_mutate(
result_cls: Type[ResultType],
graph_ctx: GraphQueryContext,
Expand All @@ -1046,15 +1068,12 @@ async def simple_db_mutate(

See details about the arguments in :func:`simple_db_mutate_returning_item`.
"""
raw_query = "(unknown)"

async def _do_mutate() -> ResultType:
nonlocal raw_query
async with graph_ctx.db.begin() as conn:
if pre_func:
await pre_func(conn)
_query = mutation_query() if callable(mutation_query) else mutation_query
raw_query = str(_query)
result = await conn.execute(_query)
if post_func:
await post_func(conn, result)
Expand All @@ -1063,20 +1082,7 @@ async def _do_mutate() -> ResultType:
else:
return result_cls(False, f"no matching {result_cls.__name__.lower()}")

try:
return await execute_with_retry(_do_mutate)
except sa.exc.IntegrityError as e:
log.warning("simple_db_mutate(): integrity error ({})", repr(e))
return result_cls(False, f"integrity error: {e}")
except sa.exc.StatementError as e:
log.warning("simple_db_mutate(): statement error ({})\n{}", repr(e), raw_query)
orig_exc = e.orig
return result_cls(False, str(orig_exc), None)
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except Exception as e:
log.exception("simple_db_mutate(): other error")
return result_cls(False, f"unexpected error: {e}")
return await gql_mutation_wrapper(result_cls, _do_mutate)


async def simple_db_mutate_returning_item(
Expand Down Expand Up @@ -1111,16 +1117,13 @@ async def simple_db_mutate_returning_item(
from the given mutation result**, because the result object could be fetched only one
time due to its cursor-like nature.
"""
raw_query = "(unknown)"

async def _do_mutate() -> ResultType:
nonlocal raw_query
async with graph_ctx.db.begin() as conn:
if pre_func:
await pre_func(conn)
_query = mutation_query() if callable(mutation_query) else mutation_query
_query = _query.returning(_query.table)
raw_query = str(_query)
result = await conn.execute(_query)
if post_func:
row = await post_func(conn, result)
Expand All @@ -1131,22 +1134,7 @@ async def _do_mutate() -> ResultType:
else:
return result_cls(False, f"no matching {result_cls.__name__.lower()}", None)

try:
return await execute_with_retry(_do_mutate)
except sa.exc.IntegrityError as e:
log.warning("simple_db_mutate_returning_item(): integrity error ({})", repr(e))
return result_cls(False, f"integrity error: {e}", None)
except sa.exc.StatementError as e:
log.warning(
"simple_db_mutate_returning_item(): statement error ({})\n{}", repr(e), raw_query
)
orig_exc = e.orig
return result_cls(False, str(orig_exc), None)
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except Exception as e:
log.exception("simple_db_mutate_returning_item(): other error")
return result_cls(False, f"unexpected error: {e}", None)
return await gql_mutation_wrapper(result_cls, _do_mutate)


def set_if_set(
Expand Down
Loading