-
Notifications
You must be signed in to change notification settings - Fork 18
handle signals properly #589
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,37 +1,112 @@ | ||
| import logging | ||
| import os | ||
| from multiprocessing.sharedctypes import Value | ||
| import signal | ||
| import sys | ||
|
|
||
| import anyio | ||
| import click | ||
| from anyio import create_task_group, open_signal_receiver | ||
| from jumpstarter_cli_common.config import opt_config | ||
| from jumpstarter_cli_common.exceptions import handle_exceptions, leaf_exceptions | ||
| from jumpstarter_cli_common.exceptions import handle_exceptions | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| def _serve_with_exc_handling(exporter): | ||
|
|
||
| def _handle_child(config): | ||
| """Handle child process with graceful shutdown.""" | ||
| async def serve_with_graceful_shutdown(): | ||
| received_signal = 0 | ||
| signal_handled = False | ||
| exporter = None | ||
|
|
||
| async def signal_handler(): | ||
| nonlocal received_signal, signal_handled | ||
|
|
||
| with open_signal_receiver(signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT) as signals: | ||
| async for sig in signals: | ||
| if signal_handled: | ||
| continue # Ignore duplicate signals | ||
| received_signal = sig | ||
| logger.info("CHILD: Received %d (%s)", received_signal, signal.Signals(received_signal).name) | ||
| if exporter: | ||
| # Terminate exporter. SIGHUP waits until current lease is let go. Later SIGTERM still overrides | ||
| if received_signal != signal.SIGHUP: | ||
| signal_handled = True | ||
| exporter.stop(wait_for_lease_exit=received_signal == signal.SIGHUP) | ||
|
|
||
|
michalskrivanek marked this conversation as resolved.
|
||
| # Start signal handler first, then create exporter | ||
| async with create_task_group() as signal_tg: | ||
|
|
||
| # Start signal handler immediately | ||
| signal_tg.start_soon(signal_handler) | ||
|
|
||
| # Create exporter and run it | ||
| async with config.create_exporter() as exporter: | ||
| try: | ||
| await exporter.serve() | ||
| except* Exception as excgroup: | ||
|
michalskrivanek marked this conversation as resolved.
|
||
| from jumpstarter_cli_common.exceptions import leaf_exceptions | ||
| for exc in leaf_exceptions(excgroup): | ||
| if not isinstance(exc, anyio.get_cancelled_exc_class()): | ||
| click.echo( | ||
| f"Exception while serving on the exporter: {type(exc).__name__}: {exc}", | ||
| err=True, | ||
| ) | ||
|
|
||
| # Cancel the signal handler after exporter completes | ||
| signal_tg.cancel_scope.cancel() | ||
|
|
||
| # Return signal number if received, otherwise 0 for immediate restart | ||
| return received_signal if received_signal else 0 | ||
|
|
||
|
michalskrivanek marked this conversation as resolved.
|
||
| sys.exit(anyio.run(serve_with_graceful_shutdown)) | ||
|
|
||
|
|
||
| def _handle_parent(pid): | ||
| """Handle parent process waiting for child and signal forwarding.""" | ||
| def parent_signal_handler(signum, _): | ||
| logger.info("PARENT: Received %d (%s), forwarding to child PID %d", signum, signal.Signals(signum).name, pid) | ||
| if pid and pid > 0: | ||
| try: | ||
| os.kill(pid, signum) | ||
| except ProcessLookupError: | ||
| pass | ||
|
|
||
| # Set up signal handlers after fork | ||
| for sig in (signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT): | ||
| signal.signal(sig, parent_signal_handler) | ||
|
|
||
| _, status = os.waitpid(pid, 0) | ||
| if os.WIFEXITED(status): | ||
| # Interpret child exit code | ||
| child_exit_code = os.WEXITSTATUS(status) | ||
| if child_exit_code == 0: | ||
| return None # restart child (unexpected exit/exception) | ||
| else: | ||
| # Child indicates termination (signal number) | ||
| return 128 + child_exit_code # Return standard Unix exit code | ||
| else: | ||
| # Child killed by unhandled signal - terminate | ||
| child_exit_signal = os.WTERMSIG(status) if os.WIFSIGNALED(status) else 0 | ||
| click.echo(f"Child killed by unhandled signal: {child_exit_signal}", err=True) | ||
| return 128 + child_exit_signal | ||
|
|
||
|
|
||
| def _serve_with_exc_handling(config): | ||
| while True: | ||
| result = Value("i", 0) | ||
| pid = os.fork() | ||
|
|
||
| if pid > 0: | ||
| os.waitpid(pid, 0) | ||
| if result.value != 0: | ||
| return result.value | ||
| if (exit_code := _handle_parent(pid)) is not None: | ||
| return exit_code | ||
| else: | ||
| try: | ||
| anyio.run(exporter.serve) | ||
| except* Exception as excgroup: | ||
| for exc in leaf_exceptions(excgroup): | ||
| click.echo( | ||
| f"Exception while serving on the exporter: {type(exc).__name__}: {exc}", | ||
| err=True, | ||
| ) | ||
| result.value = 1 | ||
| return | ||
| _handle_child(config) | ||
| sys.exit(1) # should never happen | ||
|
|
||
|
|
||
| @click.command("run") | ||
| @opt_config(client=False) | ||
| @handle_exceptions | ||
| def run(config): | ||
| """Run an exporter locally.""" | ||
|
|
||
| return _serve_with_exc_handling(config) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,26 +156,44 @@ def serve_unix(self): | |
| with portal.wrap_async_context_manager(self.serve_unix_async()) as path: | ||
| yield path | ||
|
|
||
| async def serve(self): | ||
| @asynccontextmanager | ||
| async def create_exporter(self): | ||
| """Create and manage an exporter instance with proper lifecycle.""" | ||
| # dynamic import to avoid circular imports | ||
| from anyio import CancelScope | ||
|
|
||
| from jumpstarter.exporter import Exporter | ||
|
|
||
| async def channel_factory(): | ||
| if self.endpoint is None or self.token is None: | ||
| raise ConfigurationError("endpoint or token not set in exporter config") | ||
|
|
||
| credentials = grpc.composite_channel_credentials( | ||
| await ssl_channel_credentials(self.endpoint, self.tls), | ||
| call_credentials("Exporter", self.metadata, self.token), | ||
| ) | ||
| return aio_secure_channel(self.endpoint, credentials, self.grpcOptions) | ||
|
|
||
| async with Exporter( | ||
| channel_factory=channel_factory, | ||
| device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate, | ||
| tls=self.tls, | ||
| grpc_options=self.grpcOptions, | ||
| ) as exporter: | ||
| exporter = None | ||
| entered = False | ||
| try: | ||
| exporter = Exporter( | ||
| channel_factory=channel_factory, | ||
| device_factory=ExporterConfigV1Alpha1DriverInstance(children=self.export).instantiate, | ||
| tls=self.tls, | ||
| grpc_options=self.grpcOptions, | ||
| ) | ||
| # Initialize the exporter (registration, etc.) | ||
| await exporter.__aenter__() | ||
| entered = True | ||
| yield exporter | ||
| finally: | ||
| # Shield all cleanup operations from abrupt cancellation for clean shutdown | ||
| if exporter and entered: | ||
| with CancelScope(shield=True): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| await exporter.__aexit__(None, None, None) | ||
|
|
||
| async def serve(self): | ||
| async with self.create_exporter() as exporter: | ||
| await exporter.serve() | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,11 @@ | ||
| import logging | ||
| from collections.abc import Callable | ||
| from collections.abc import Awaitable, Callable | ||
| from contextlib import AbstractAsyncContextManager, asynccontextmanager | ||
| from dataclasses import dataclass, field | ||
|
|
||
| import grpc | ||
| from anyio import connect_unix, create_memory_object_stream, create_task_group, sleep | ||
| from anyio.abc import TaskGroup | ||
| from google.protobuf import empty_pb2 | ||
| from jumpstarter_protocol import ( | ||
| jumpstarter_pb2, | ||
|
|
@@ -22,22 +23,57 @@ | |
|
|
||
| @dataclass(kw_only=True) | ||
| class Exporter(AbstractAsyncContextManager, Metadata): | ||
| channel_factory: Callable[[], grpc.aio.Channel] | ||
| channel_factory: Callable[[], Awaitable[grpc.aio.Channel]] | ||
| device_factory: Callable[[], Driver] | ||
| lease_name: str = field(init=False, default="") | ||
| tls: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1) | ||
| grpc_options: dict[str, str] = field(default_factory=dict) | ||
| registered: bool = field(init=False, default=False) | ||
| _stop_requested: bool = field(init=False, default=False) | ||
| _started: bool = field(init=False, default=False) | ||
| _tg: TaskGroup | None = field(init=False, default=None) | ||
|
|
||
| def stop(self, wait_for_lease_exit=False): | ||
| """Signal the exporter to stop. | ||
|
|
||
| Args: | ||
| wait_for_lease_exit (bool): If True, wait for the current lease to exit before stopping. | ||
| """ | ||
|
|
||
| # Stop immediately if not started yet or if immediate stop is requested | ||
| if (not self._started or not wait_for_lease_exit) and self._tg is not None: | ||
| logger.info("Stopping exporter immediately") | ||
| self._tg.cancel_scope.cancel() | ||
| elif not self._stop_requested: | ||
| self._stop_requested = True | ||
| logger.info("Exporter marked for stop upon lease exit") | ||
|
|
||
| async def __aexit__(self, exc_type, exc_value, traceback): | ||
| if self.registered: | ||
| controller = jumpstarter_pb2_grpc.ControllerServiceStub(await self.channel_factory()) | ||
| logger.info("Unregistering exporter with controller") | ||
| await controller.Unregister( | ||
| jumpstarter_pb2.UnregisterRequest( | ||
| reason="TODO", | ||
| ) | ||
| ) | ||
| import anyio | ||
|
|
||
| try: | ||
| if self.registered: | ||
| logger.info("Unregistering exporter with controller") | ||
| try: | ||
| with anyio.move_on_after(10): # 10 second timeout | ||
| channel = await self.channel_factory() | ||
| try: | ||
| controller = jumpstarter_pb2_grpc.ControllerServiceStub(channel) | ||
| await controller.Unregister( | ||
| jumpstarter_pb2.UnregisterRequest( | ||
| reason="Exporter shutdown", | ||
| ) | ||
| ) | ||
| logger.info("Controller unregistration completed successfully") | ||
| finally: | ||
| with anyio.CancelScope(shield=True): | ||
| await channel.close() | ||
| except Exception as e: | ||
| logger.error("Error during controller unregistration: %s", e, exc_info=True) | ||
|
|
||
| except Exception as e: | ||
| logger.error("Error during exporter cleanup: %s", e, exc_info=True) | ||
| # Don't re-raise to avoid masking the original exception | ||
|
|
||
| async def __handle(self, path, endpoint, token, tls_config, grpc_options): | ||
| try: | ||
|
|
@@ -106,10 +142,12 @@ async def listen(retries=5, backoff=3): | |
| ) | ||
|
|
||
| async def serve(self): # noqa: C901 | ||
| """ | ||
| Serve the exporter. | ||
| """ | ||
| # initial registration | ||
| async with self.session(): | ||
| pass | ||
| started = False | ||
| status_tx, status_rx = create_memory_object_stream() | ||
|
|
||
| async def status(retries=5, backoff=3): | ||
|
|
@@ -134,18 +172,23 @@ async def status(retries=5, backoff=3): | |
| retries_left = retries | ||
|
|
||
| async with create_task_group() as tg: | ||
| self._tg = tg | ||
| tg.start_soon(status) | ||
| async for status in status_rx: | ||
| if self.lease_name != "" and self.lease_name != status.lease_name: | ||
| self.lease_name = status.lease_name | ||
| logger.info("Lease status changed, killing existing connections") | ||
| tg.cancel_scope.cancel() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not necessarily wrong, but asking, why do we stop canceling this scope here?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's probably cancelled when you exit the with...
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, that's it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not to trigger "abrupt" cancellation, letting it exit the await (https://github.com/jumpstarter-dev/jumpstarter/pull/589/files#diff-a0c03b6fd1ccd1545c4f26d1e232edb53e3b775a456a96fd78151689fcc8dc5fR43, then in turn line 57) cleanly.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bah...but then the other tasks exporter's tg (status, handle) doesn't really finish, it does indeed need to be cancelled.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i've put it back |
||
| self.stop() | ||
| break | ||
| self.lease_name = status.lease_name | ||
| if not started and self.lease_name != "": | ||
| started = True | ||
| if not self._started and self.lease_name != "": | ||
| self._started = True | ||
| tg.start_soon(self.handle, self.lease_name, tg) | ||
| if status.leased: | ||
| logger.info("Currently leased by %s under %s", status.client_name, status.lease_name) | ||
| else: | ||
| logger.info("Currently not leased") | ||
| if self._stop_requested: | ||
| self.stop() | ||
| break | ||
| self._tg = None | ||
Uh oh!
There was an error while loading. Please reload this page.