Skip to content

Commit

Permalink
fix: timeout on compute_session_list GQL Query (#2084)
Browse files Browse the repository at this point in the history
This PR fixes querying `compute_session_list` with `commit_status` field raising timeout when there are large enough amount of kernels residing in the query scope. On current implementation every single `commit_status` resolver performed single followed by a database query, resulting in bursting request to both Redis and PostgreSQL. To solve these bottlenecks this PR refactored implementations of both `AgentRegistry.get_commit_status()` and `ComputeSession.resolve_commit_status()` methods so that it can benefit from redis' pipeline feature.

**Checklist:** (if applicable)

- [x] Milestone metadata specifying the target backport version

<!-- readthedocs-preview sorna start -->
----
📚 Documentation preview 📚: https://sorna--2084.org.readthedocs.build/en/2084/

<!-- readthedocs-preview sorna end -->

<!-- readthedocs-preview sorna-ko start -->
----
📚 Documentation preview 📚: https://sorna-ko--2084.org.readthedocs.build/ko/2084/

<!-- readthedocs-preview sorna-ko end -->
  • Loading branch information
kyujin-cho committed Apr 29, 2024
1 parent b69fa68 commit d41aab0
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 20 deletions.
1 change: 1 addition & 0 deletions changes/2084.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix `compute_session_list` GQL query not responding on an abundant amount of sessions
4 changes: 2 additions & 2 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,11 +904,11 @@ async def get_commit_status(request: web.Request, params: Mapping[str, Any]) ->
owner_access_key,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
status_info = await root_ctx.registry.get_commit_status(session)
statuses = await root_ctx.registry.get_commit_status([session.main_kernel.id])
except BackendError:
log.exception("GET_COMMIT_STATUS: exception")
raise
resp = {"status": status_info["status"], "kernel": status_info["kernel"]}
resp = {"status": statuses[session.main_kernel.id], "kernel": str(session.main_kernel.id)}
return web.json_response(resp, status=200)


Expand Down
21 changes: 13 additions & 8 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,14 +1332,10 @@ async def resolve_dependencies(

async def resolve_commit_status(self, info: graphene.ResolveInfo) -> str:
graph_ctx: GraphQueryContext = info.context
async with graph_ctx.db.begin_readonly_session() as db_sess:
session: SessionRow = await SessionRow.get_session(
db_sess,
self.id,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
)
commit_status = await graph_ctx.registry.get_commit_status(session)
return commit_status["status"]
loader = graph_ctx.dataloader_manager.get_loader(
graph_ctx, "ComputeSession.commit_statuses"
)
return await loader.load(self.main_kernel_id)

async def resolve_resource_opts(self, info: graphene.ResolveInfo) -> dict[str, Any]:
containers = self.containers
Expand Down Expand Up @@ -1584,6 +1580,15 @@ async def batch_load_by_dependency(
lambda row: row.SessionRow.id,
)

@classmethod
async def batch_load_commit_statuses(
cls,
ctx: GraphQueryContext,
kernel_ids: Sequence[KernelId],
) -> Sequence[str]:
commit_statuses = await ctx.registry.get_commit_status(kernel_ids)
return [commit_statuses[kernel_id] for kernel_id in kernel_ids]


class ComputeSessionList(graphene.ObjectType):
class Meta:
Expand Down
22 changes: 12 additions & 10 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3166,17 +3166,19 @@ async def _get_user_email(

async def get_commit_status(
self,
session: SessionRow,
) -> Mapping[str, str]:
kern_id = str(session.main_kernel.id)
key = f"kernel.{kern_id}.commit"
result: Optional[bytes] = await redis_helper.execute(
self.redis_stat,
lambda r: r.get(key),
)
kernel_ids: Sequence[KernelId],
) -> Mapping[KernelId, str]:
async def _pipe_builder(r: Redis):
pipe = r.pipeline()
for kernel_id in kernel_ids:
await pipe.get(f"kernel.{kernel_id}.commit")
return pipe

commit_statuses = await redis_helper.execute(self.redis_stat, _pipe_builder)

return {
"kernel": kern_id,
"status": str(result, "utf-8") if result is not None else CommitStatus.READY.value,
kernel_id: str(result, "utf-8") if result is not None else CommitStatus.READY.value
for kernel_id, result in zip(kernel_ids, commit_statuses)
}

async def commit_session(
Expand Down

0 comments on commit d41aab0

Please sign in to comment.