Skip to content

Commit

Permalink
feat: Fetch container logs of a specific kernel (#2364)
Browse files Browse the repository at this point in the history
Co-authored-by: Sanghun Lee <sanghun@lablup.com>
  • Loading branch information
lizable and fregataa committed Jul 5, 2024
1 parent 4cd4204 commit f16b76e
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 26 deletions.
1 change: 1 addition & 0 deletions changes/2364.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for fetching container logs of a specific kernel.
15 changes: 12 additions & 3 deletions src/ai/backend/client/cli/session/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,18 +748,27 @@ def ls(session_id, path):

@session.command()
@click.argument("session_id", metavar="SESSID")
def logs(session_id):
@click.option(
"-k",
"--kernel",
"--kernel-id",
type=str,
default=None,
help="The target kernel id of logs. Default value is None, in which case logs of a main kernel are fetched.",
)
def logs(session_id, kernel: str | None):
"""
Shows the full console log of a compute session.
\b
SESSID: Session ID or its alias given when creating the session.
"""
_kernel_id = uuid.UUID(kernel) if kernel is not None else None
with Session() as session:
try:
print_wait("Retrieving live container logs...")
kernel = session.ComputeSession(session_id)
result = kernel.get_logs().get("result")
_session = session.ComputeSession(session_id)
result = _session.get_logs(_kernel_id).get("result")
logs = result.get("logs") if "logs" in result else ""
print(logs)
print_done("End of logs.")
Expand Down
4 changes: 3 additions & 1 deletion src/ai/backend/client/func/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,13 +699,15 @@ async def get_info(self):
return await resp.json()

@api_function
async def get_logs(self):
async def get_logs(self, kernel_id: UUID | None = None):
"""
Retrieves the console log of the compute session container.
"""
params = {}
if self.owner_access_key:
params["owner_access_key"] = self.owner_access_key
if kernel_id is not None:
params["kernel_id"] = str(kernel_id)
prefix = get_naming(api_session.get().api_version, "path")
rqst = Request(
"GET",
Expand Down
4 changes: 4 additions & 0 deletions src/ai/backend/manager/api/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ class MainKernelNotFound(ObjectNotFound):
object_name = "main kernel"


class KernelNotFound(ObjectNotFound):
object_name = "kernel"


class EndpointNotFound(ObjectNotFound):
object_name = "endpoint"

Expand Down
69 changes: 50 additions & 19 deletions src/ai/backend/manager/api/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import trafaret as t
from aiohttp import hdrs, web
from dateutil.tz import tzutc
from pydantic import BaseModel, Field
from pydantic import AliasChoices, BaseModel, Field
from redis.asyncio import Redis
from sqlalchemy.orm import noload, selectinload
from sqlalchemy.sql.expression import null, true
Expand Down Expand Up @@ -73,6 +73,7 @@
AgentId,
ClusterMode,
ImageRegistry,
KernelId,
MountPermission,
MountTypes,
SessionTypes,
Expand Down Expand Up @@ -2109,19 +2110,36 @@ async def list_files(request: web.Request) -> web.Response:
return web.json_response(resp, status=200)


class ContainerLogRequestModel(BaseModel):
owner_access_key: str | None = Field(
validation_alias=AliasChoices("owner_access_key", "ownerAccessKey"),
default=None,
)
kernel_id: uuid.UUID | None = Field(
validation_alias=AliasChoices("kernel_id", "kernelId"),
description="Target kernel to get container logs.",
default=None,
)


@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
t.Dict({
t.Key("owner_access_key", default=None): t.Null | t.String,
})
)
async def get_container_logs(request: web.Request, params: Any) -> web.Response:
@pydantic_params_api_handler(ContainerLogRequestModel)
async def get_container_logs(
request: web.Request, params: ContainerLogRequestModel
) -> web.Response:
root_ctx: RootContext = request.app["_root.context"]
session_name: str = request.match_info["session_name"]
requester_access_key, owner_access_key = await get_access_key_scopes(request, params)
requester_access_key, owner_access_key = await get_access_key_scopes(
request, {"owner_access_key": params.owner_access_key}
)
kernel_id = KernelId(params.kernel_id) if params.kernel_id is not None else None
log.info(
"GET_CONTAINER_LOG (ak:{}/{}, s:{})", requester_access_key, owner_access_key, session_name
"GET_CONTAINER_LOG (ak:{}/{}, s:{}, k:{})",
requester_access_key,
owner_access_key,
session_name,
kernel_id,
)
resp = {"result": {"logs": ""}}
async with root_ctx.db.begin_readonly_session() as db_sess:
Expand All @@ -2130,25 +2148,38 @@ async def get_container_logs(request: web.Request, params: Any) -> web.Response:
session_name,
owner_access_key,
allow_stale=True,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY,
kernel_loading_strategy=KernelLoadingStrategy.MAIN_KERNEL_ONLY
if kernel_id is None
else KernelLoadingStrategy.ALL_KERNELS,
)
if (
compute_session.status in DEAD_SESSION_STATUSES
and compute_session.main_kernel.container_log is not None
):
log.debug("returning log from database record")
resp["result"]["logs"] = compute_session.main_kernel.container_log.decode("utf-8")
return web.json_response(resp, status=200)

if compute_session.status in DEAD_SESSION_STATUSES:
if kernel_id is not None:
# Get logs from the specific kernel
kernel_row = compute_session.get_kernel_by_id(kernel_id)
kernel_log = kernel_row.container_log
else:
# Get logs from the main kernel
kernel_log = compute_session.main_kernel.container_log
if kernel_log is not None:
# Get logs from database record
log.debug("returning log from database record")
resp["result"]["logs"] = kernel_log.decode("utf-8")
return web.json_response(resp, status=200)

try:
registry = root_ctx.registry
await registry.increment_session_usage(compute_session)
resp["result"]["logs"] = await registry.get_logs_from_agent(compute_session)
resp["result"]["logs"] = await registry.get_logs_from_agent(
session=compute_session, kernel_id=kernel_id
)
log.debug("returning log from agent")
except BackendError:
log.exception(
"GET_CONTAINER_LOG(ak:{}/{}, s:{}): unexpected error",
"GET_CONTAINER_LOG(ak:{}/{}, kernel_id: {}, s:{}): unexpected error",
requester_access_key,
owner_access_key,
kernel_id,
session_name,
)
raise
Expand Down
10 changes: 10 additions & 0 deletions src/ai/backend/manager/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
KernelCreationFailed,
KernelDestructionFailed,
KernelExecutionFailed,
KernelNotFound,
KernelRestartFailed,
MainKernelNotFound,
SessionNotFound,
Expand Down Expand Up @@ -79,6 +80,7 @@

from .gql import GraphQueryContext

log = BraceStyleAdapter(logging.getLogger(__spec__.name)) # type: ignore[name-defined]

__all__ = (
"determine_session_status",
Expand Down Expand Up @@ -730,6 +732,14 @@ def resource_opts(self) -> dict[str, Any]:
def is_private(self) -> bool:
return any([kernel.is_private for kernel in self.kernels])

def get_kernel_by_id(self, kernel_id: KernelId) -> KernelRow:
kerns = tuple(kern for kern in self.kernels if kern.id == kernel_id)
if len(kerns) > 1:
raise TooManyKernelsFound(f"Multiple kernels found (id:{kernel_id}).")
if len(kerns) == 0:
raise KernelNotFound(f"Session has no such kernel (sid:{self.id}, kid:{kernel_id}))")
return kerns[0]

def get_kernel_by_cluster_name(self, cluster_name: str) -> KernelRow:
kerns = tuple(kern for kern in self.kernels if kern.cluster_name == cluster_name)
if len(kerns) > 1:
Expand Down
12 changes: 9 additions & 3 deletions src/ai/backend/manager/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2810,14 +2810,20 @@ async def list_files(
async def get_logs_from_agent(
self,
session: SessionRow,
kernel_id: KernelId | None = None,
) -> str:
async with handle_session_exception(self.db, "get_logs_from_agent", session.id):
kernel = (
session.get_kernel_by_id(kernel_id)
if kernel_id is not None
else session.main_kernel
)
async with self.agent_cache.rpc_context(
session.main_kernel.agent,
agent_id=kernel.agent,
invoke_timeout=30,
order_key=session.main_kernel.id,
order_key=kernel.id,
) as rpc:
reply = await rpc.call.get_logs(str(session.main_kernel.id))
reply = await rpc.call.get_logs(str(kernel.id))
return reply["logs"]

async def increment_session_usage(
Expand Down

0 comments on commit f16b76e

Please sign in to comment.