Skip to content

Commit

Permalink
implement gql_mutation_wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
kyujin-cho committed May 28, 2024
1 parent a0317be commit 1fef8e9
Showing 1 changed file with 24 additions and 36 deletions.
60 changes: 24 additions & 36 deletions src/ai/backend/manager/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Awaitable,
Callable,
ClassVar,
Coroutine,
Dict,
Generic,
Iterable,
Expand Down Expand Up @@ -1029,6 +1030,27 @@ async def wrapped(cls, root, info: graphene.ResolveInfo, *args, **kwargs) -> Any
ItemType = TypeVar("ItemType", bound=graphene.ObjectType)


async def gql_mutation_wrapper(
result_cls: Type[ResultType], _do_mutate: Callable[[], Coroutine[Any, Any, ResultType]]
) -> ResultType:
try:
return await execute_with_retry(_do_mutate)
except sa.exc.IntegrityError as e:
log.warning("gql_mutation_wrapper(): integrity error ({})", repr(e))
return result_cls(False, f"integrity error: {e}")
except sa.exc.StatementError as e:
log.warning(
"gql_mutation_wrapper(): statement error ({})\n{}", repr(e), e.statement or "(unknown)"
)
orig_exc = e.orig
return result_cls(False, str(orig_exc), None)
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except Exception as e:
log.exception("gql_mutation_wrapper(): other error")
return result_cls(False, f"unexpected error: {e}")


async def simple_db_mutate(
result_cls: Type[ResultType],
graph_ctx: GraphQueryContext,
Expand All @@ -1046,15 +1068,12 @@ async def simple_db_mutate(
See details about the arguments in :func:`simple_db_mutate_returning_item`.
"""
raw_query = "(unknown)"

async def _do_mutate() -> ResultType:
nonlocal raw_query
async with graph_ctx.db.begin() as conn:
if pre_func:
await pre_func(conn)
_query = mutation_query() if callable(mutation_query) else mutation_query
raw_query = str(_query)
result = await conn.execute(_query)
if post_func:
await post_func(conn, result)
Expand All @@ -1063,20 +1082,7 @@ async def _do_mutate() -> ResultType:
else:
return result_cls(False, f"no matching {result_cls.__name__.lower()}")

try:
return await execute_with_retry(_do_mutate)
except sa.exc.IntegrityError as e:
log.warning("simple_db_mutate(): integrity error ({})", repr(e))
return result_cls(False, f"integrity error: {e}")
except sa.exc.StatementError as e:
log.warning("simple_db_mutate(): statement error ({})\n{}", repr(e), raw_query)
orig_exc = e.orig
return result_cls(False, str(orig_exc), None)
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except Exception as e:
log.exception("simple_db_mutate(): other error")
return result_cls(False, f"unexpected error: {e}")
return await gql_mutation_wrapper(result_cls, _do_mutate)


async def simple_db_mutate_returning_item(
Expand Down Expand Up @@ -1111,16 +1117,13 @@ async def simple_db_mutate_returning_item(
from the given mutation result**, because the result object could be fetched only one
time due to its cursor-like nature.
"""
raw_query = "(unknown)"

async def _do_mutate() -> ResultType:
nonlocal raw_query
async with graph_ctx.db.begin() as conn:
if pre_func:
await pre_func(conn)
_query = mutation_query() if callable(mutation_query) else mutation_query
_query = _query.returning(_query.table)
raw_query = str(_query)
result = await conn.execute(_query)
if post_func:
row = await post_func(conn, result)
Expand All @@ -1131,22 +1134,7 @@ async def _do_mutate() -> ResultType:
else:
return result_cls(False, f"no matching {result_cls.__name__.lower()}", None)

try:
return await execute_with_retry(_do_mutate)
except sa.exc.IntegrityError as e:
log.warning("simple_db_mutate_returning_item(): integrity error ({})", repr(e))
return result_cls(False, f"integrity error: {e}", None)
except sa.exc.StatementError as e:
log.warning(
"simple_db_mutate_returning_item(): statement error ({})\n{}", repr(e), raw_query
)
orig_exc = e.orig
return result_cls(False, str(orig_exc), None)
except (asyncio.CancelledError, asyncio.TimeoutError):
raise
except Exception as e:
log.exception("simple_db_mutate_returning_item(): other error")
return result_cls(False, f"unexpected error: {e}", None)
return await gql_mutation_wrapper(result_cls, _do_mutate)


def set_if_set(
Expand Down

0 comments on commit 1fef8e9

Please sign in to comment.