Skip to content

Commit

Permalink
fix: Improve error handling of initialization failure in kernel runner (
Browse files Browse the repository at this point in the history
  • Loading branch information
achimnol authored Jul 16, 2024
1 parent 735f737 commit 395e6b2
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 11 deletions.
1 change: 1 addition & 0 deletions changes/2478.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve error handling of initialization failures in the kernel runner
3 changes: 0 additions & 3 deletions src/ai/backend/kernel/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import sys
from pathlib import Path

import uvloop

from . import lang_map
from .compat import asyncio_run_forever

Expand Down Expand Up @@ -61,5 +59,4 @@ def main(args) -> None:


args = parse_args()
uvloop.install()
main(args)
53 changes: 46 additions & 7 deletions src/ai/backend/kernel/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MutableMapping,
Optional,
Sequence,
TypeVar,
Union,
)

Expand All @@ -44,9 +45,20 @@
from .jupyter_client import aexecute_interactive
from .logging import BraceStyleAdapter, setup_logger
from .service import ServiceParser
from .utils import scan_proc_stats, wait_local_port_open
from .utils import TracebackSourceFilter, scan_proc_stats, wait_local_port_open

log = BraceStyleAdapter(logging.getLogger())
logger = logging.getLogger()
logger.addFilter(TracebackSourceFilter(str(Path(__file__).parent)))
log = BraceStyleAdapter(logger)

TReturn = TypeVar("TReturn")


class FailureSentinel(enum.Enum):
TOKEN = 0


FAILURE = FailureSentinel.TOKEN


class HealthStatus(enum.Enum):
Expand Down Expand Up @@ -325,15 +337,42 @@ async def _shutdown_jupyter_kernel(self):
self.kernel_client.stop_channels()
assert not await self.kernel_mgr.is_alive(), "ipykernel failed to shutdown"

async def _handle_exception(
self, coro: Awaitable[TReturn], help_text: str | None = None
) -> TReturn | FailureSentinel:
try:
return await coro
except Exception as e:
match e:
case FileNotFoundError():
msg = "File not found: {!r}"
if help_text:
msg += f" ({help_text})"
log.exception(msg, e.filename)
case _:
msg = "Unexpected error!"
if help_text:
msg += f" ({help_text})"
log.exception(msg)
return FAILURE

async def _init_with_loop(self) -> None:
if self.init_done is not None:
self.init_done.clear()
try:
await self.init_with_loop()
await init_sshd_service(self.child_env)
except Exception:
log.exception("Unexpected error!")
log.warning("We are skipping the error but the container may not work as expected.")
ret = await self._handle_exception(
self.init_with_loop(),
"Check the image configs/labels like `ai.backend.runtime-path`",
)
if ret is FAILURE:
log.warning(
"We are skipping the runtime-specific initialization failure, "
"and the container may not work as expected."
)
await self._handle_exception(
init_sshd_service(self.child_env),
"Verify agent installation with the embedded prebuilt binaries",
)
finally:
if self.init_done is not None:
self.init_done.set()
Expand Down
2 changes: 1 addition & 1 deletion src/ai/backend/kernel/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def setup_logger(log_queue, log_prefix, debug):
loghandlers.append(LogQHandler(log_queue))
logging.basicConfig(
level=logging.DEBUG if debug else logging.INFO,
format=log_prefix + ": {message}",
format=log_prefix + ": [{levelname}] {message}",
style="{",
handlers=loghandlers,
)
24 changes: 24 additions & 0 deletions src/ai/backend/kernel/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
import logging
import os
import site
import traceback
from pathlib import Path
from typing import Final

Expand Down Expand Up @@ -34,6 +37,27 @@ def find_executable(*paths):
return None


class TracebackSourceFilter(logging.Filter):
def __init__(self, path_prefix: str) -> None:
super().__init__()
self.path_prefix = path_prefix
self.site_prefix = site.getsitepackages()[0]

def filter(self, record: logging.LogRecord) -> bool:
if record.exc_info:
_, _, exc_tb = record.exc_info
filtered_traceback: list[traceback.FrameSummary] = []
for tb in traceback.extract_tb(exc_tb):
if tb.filename.startswith(self.path_prefix) and tb.name != "_handle_exception":
filtered_traceback.append(tb)
lines = [" Traceback:"]
for tb in filtered_traceback:
short_path = tb.filename.removeprefix(self.site_prefix).removeprefix("/")
lines.append(f" {short_path} (L{tb.lineno}): {tb.name}()")
record.exc_text = "\n".join(lines)
return True


async def safe_close_task(task):
if task is not None and not task.done():
task.cancel()
Expand Down

0 comments on commit 395e6b2

Please sign in to comment.