Skip to content

Commit

Permalink
fix: function default argument from a mutable object to None (#1986)
Browse files Browse the repository at this point in the history
Backported-from: main
Backported-to: 23.09
  • Loading branch information
fregataa committed Apr 1, 2024
1 parent 80cf820 commit 2d39a0a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
1 change: 1 addition & 0 deletions changes/1986.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Change function default arguments from mutable object to `None`.
18 changes: 10 additions & 8 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ async def get_session(
allow_stale: bool = False,
for_update: bool = False,
kernel_loading_strategy: KernelLoadingStrategy = KernelLoadingStrategy.NONE,
eager_loading_op: list[Any] = [],
eager_loading_op: list[Any] | None = None,
) -> SessionRow:
"""
Retrieve the session information by session's UUID,
Expand All @@ -935,9 +935,10 @@ async def get_session(
:param kernel_loading_strategy: Determines JOIN strategy of `kernels` relation when fetching session rows.
:param eager_loading_op: Extra loading operators to be passed directly to `match_sessions()` API.
"""
_eager_loading_op = eager_loading_op or []
match kernel_loading_strategy:
case KernelLoadingStrategy.ALL_KERNELS:
eager_loading_op.extend([
_eager_loading_op.extend([
noload("*"),
selectinload(SessionRow.kernels).options(
noload("*"),
Expand All @@ -947,7 +948,7 @@ async def get_session(
case KernelLoadingStrategy.MAIN_KERNEL_ONLY:
kernel_rel = SessionRow.kernels
kernel_rel.and_(KernelRow.cluster_role == DEFAULT_ROLE)
eager_loading_op.extend([
_eager_loading_op.extend([
noload("*"),
selectinload(kernel_rel).options(
noload("*"),
Expand All @@ -961,7 +962,7 @@ async def get_session(
access_key,
allow_stale=allow_stale,
for_update=for_update,
eager_loading_op=eager_loading_op,
eager_loading_op=_eager_loading_op,
)
if not session_list:
raise SessionNotFound(f"Session (id={session_name_or_id}) does not exist.")
Expand All @@ -988,11 +989,12 @@ async def list_sessions(
allow_stale: bool = False,
for_update: bool = False,
kernel_loading_strategy=KernelLoadingStrategy.NONE,
eager_loading_op: list[Any] = [],
eager_loading_op: list[Any] | None = None,
) -> Iterable[SessionRow]:
_eager_loading_op = eager_loading_op or []
match kernel_loading_strategy:
case KernelLoadingStrategy.ALL_KERNELS:
eager_loading_op.extend([
_eager_loading_op.extend([
noload("*"),
selectinload(SessionRow.kernels).options(
noload("*"),
Expand All @@ -1002,7 +1004,7 @@ async def list_sessions(
case KernelLoadingStrategy.MAIN_KERNEL_ONLY:
kernel_rel = SessionRow.kernels
kernel_rel.and_(KernelRow.cluster_role == DEFAULT_ROLE)
eager_loading_op.extend([
_eager_loading_op.extend([
noload("*"),
selectinload(kernel_rel).options(
noload("*"),
Expand All @@ -1016,7 +1018,7 @@ async def list_sessions(
access_key,
allow_stale=allow_stale,
for_update=for_update,
eager_loading_op=eager_loading_op,
eager_loading_op=_eager_loading_op,
)
try:
return session_list
Expand Down

0 comments on commit 2d39a0a

Please sign in to comment.