Skip to content

Commit

Permalink
fix: Allow admins to restart other's session (#1635)
Browse files Browse the repository at this point in the history
Backported-from: main
Backported-to: 23.09
  • Loading branch information
rapsealk authored and achimnol committed Oct 23, 2023
1 parent 107eb6d commit e07a4b6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
1 change: 1 addition & 0 deletions changes/1635.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow admins to restart other's session by setting an optional parameter `owner_access_key`.
11 changes: 9 additions & 2 deletions src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,14 @@ def destroy(session_names, forced, owner, stats, recursive):

def _restart_cmd(docs: str = None):
@click.argument("session_refs", metavar="SESSION_REFS", nargs=-1)
def restart(session_refs):
@click.option(
"-o",
"--owner",
"--owner-access-key",
metavar="ACCESS_KEY",
help="Specify the owner of the target session explicitly.",
)
def restart(session_refs, owner):
"""
Restart the compute session.
Expand All @@ -559,7 +566,7 @@ def restart(session_refs):
has_failure = False
for session_ref in session_refs:
try:
compute_session = session.ComputeSession(session_ref)
compute_session = session.ComputeSession(session_ref, owner)
compute_session.restart()
except BackendAPIError as e:
print_error(e)
Expand Down
11 changes: 9 additions & 2 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,10 +1348,17 @@ async def get_info(request: web.Request) -> web.Response:

@server_status_required(READ_ALLOWED)
@auth_required
async def restart(request: web.Request) -> web.Response:
@check_api_params(
t.Dict(
{
t.Key("owner_access_key", default=None): t.Null | t.String,
}
)
)
async def restart(request: web.Request, params: Any) -> web.Response:
root_ctx: RootContext = request.app["_root.context"]
session_name = request.match_info["session_name"]
requester_access_key, owner_access_key = await get_access_key_scopes(request)
requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
log.info("RESTART (ak:{0}/{1}, s:{2})", requester_access_key, owner_access_key, session_name)
async with root_ctx.db.begin_session() as db_sess:
session = await SessionRow.get_session(
Expand Down

0 comments on commit e07a4b6

Please sign in to comment.