-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Terminate model server quicker on receiving term signal (#404)
- Loading branch information
1 parent
4723acb
commit ead203d
Showing
7 changed files
with
269 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
truss/templates/server/common/termination_handler_middleware.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import asyncio | ||
import signal | ||
from typing import Callable | ||
|
||
from fastapi import Request | ||
|
||
# This is to allow the last request's response to finish handling. There may be more | ||
# middlewares that the response goes through, and then there's the time for the bytes | ||
# to be pushed to the caller. | ||
DEFAULT_TERM_DELAY_SECS = 5.0 | ||
|
||
|
||
class TerminationHandlerMiddleware: | ||
""" | ||
This middleware allows for swiftly and safely terminating the server. It | ||
listens to a set of termination signals. On receiving such a signal, it | ||
first informs on the on_stop callback, then waits for currently executing | ||
requests to finish, before informing on the on_term callback. | ||
Stop means that the process to stop the server has started. As soon as | ||
outstading requests go to zero after this, on_term will be called. | ||
Term means that this is the right time to terminate the server process, no | ||
outstanding requests at this point. | ||
The caller would typically handle on_stop by stop sending more requests to | ||
the FastApi server. And on_term by exiting the server process. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
on_stop: Callable[[], None], | ||
on_term: Callable[[], None], | ||
termination_delay_secs: float = DEFAULT_TERM_DELAY_SECS, | ||
): | ||
self._outstanding_request_count = 0 | ||
self._on_stop = on_stop | ||
self._on_term = on_term | ||
self._termination_delay_secs = termination_delay_secs | ||
self._stopped = False | ||
for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]: | ||
signal.signal(sig, self._stop) | ||
|
||
async def __call__(self, request: Request, call_next): | ||
self._outstanding_request_count += 1 | ||
try: | ||
response = await call_next(request) | ||
finally: | ||
self._outstanding_request_count -= 1 | ||
if self._outstanding_request_count == 0 and self._stopped: | ||
# There's a delay in term to allow some time for current | ||
# response flow to finish. | ||
asyncio.create_task(self._term()) | ||
return response | ||
|
||
def _stop(self, sig, frame): | ||
self._on_stop() | ||
self._stopped = True | ||
if self._outstanding_request_count == 0: | ||
self._on_term() | ||
|
||
async def _term(self): | ||
await asyncio.sleep(self._termination_delay_secs) | ||
self._on_term() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
truss/tests/templates/core/server/common/test_truss_server.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
import signal | ||
import socket | ||
import sys | ||
import tempfile | ||
import time | ||
from multiprocessing import Process | ||
from pathlib import Path | ||
|
||
import pytest | ||
import yaml | ||
|
||
|
||
@pytest.mark.integration | ||
def test_truss_server_termination(truss_container_fs): | ||
port = 10123 | ||
|
||
def start_truss_server(stdout_capture_file_path): | ||
sys.stdout = open(stdout_capture_file_path, "w") | ||
app_path = truss_container_fs / "app" | ||
sys.path.append(str(app_path)) | ||
|
||
from common.truss_server import TrussServer | ||
|
||
config = yaml.safe_load((app_path / "config.yaml").read_text()) | ||
server = TrussServer(http_port=port, config=config) | ||
server.start() | ||
|
||
stdout_capture_file = tempfile.NamedTemporaryFile() | ||
subproc = Process(target=start_truss_server, args=(stdout_capture_file.name,)) | ||
subproc.start() | ||
proc_id = subproc.pid | ||
time.sleep(2.0) | ||
# Port should have been taken up by truss server | ||
assert not _is_port_available(port) | ||
os.kill(proc_id, signal.SIGTERM) | ||
time.sleep(2.0) | ||
# Print on purpose for help with debugging, otherwise hard to know what's going on | ||
print(Path(stdout_capture_file.name).read_text()) | ||
assert not subproc.is_alive() | ||
# Port should be free now | ||
assert _is_port_available(port) | ||
|
||
|
||
def _is_port_available(port): | ||
try: | ||
# Try to bind to the given port | ||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | ||
s.bind(("localhost", port)) | ||
return True | ||
except socket.error: | ||
# Port is already in use | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# This file doesn't test anything, but provides utilities for testing. | ||
from unittest import mock | ||
|
||
|
||
|
93 changes: 93 additions & 0 deletions
93
truss/tests/templates/server/common/test_termination_handler_middleware.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import multiprocessing | ||
import tempfile | ||
import time | ||
from pathlib import Path | ||
from typing import Awaitable, Callable, List | ||
|
||
import pytest | ||
from truss.templates.server.common.termination_handler_middleware import ( | ||
TerminationHandlerMiddleware, | ||
) | ||
|
||
|
||
async def noop(*args, **kwargs): | ||
return | ||
|
||
|
||
@pytest.mark.integration | ||
def test_termination_sequence_no_pending_requests(tmp_path): | ||
# Create middleware in separate process, on sending term signal to process, | ||
# it should print the right messages. | ||
def main_coro_gen(middleware: TerminationHandlerMiddleware): | ||
import asyncio | ||
|
||
async def main(*args, **kwargs): | ||
await middleware(1, call_next=noop) | ||
await asyncio.sleep(1) | ||
print("should not print due to termination") | ||
|
||
return main() | ||
|
||
_verify_term(main_coro_gen, ["stopped", "terminated"]) | ||
|
||
|
||
@pytest.mark.integration | ||
def test_termination_sequence_with_pending_requests(tmp_path): | ||
def main_coro_gen(middleware: TerminationHandlerMiddleware): | ||
import asyncio | ||
|
||
async def main(*args, **kwargs): | ||
async def call_next(req): | ||
await asyncio.sleep(1.0) | ||
return "call_next_called" | ||
|
||
resp = await middleware(1, call_next=call_next) | ||
print(f"call_next response: {resp}") | ||
await asyncio.sleep(1) | ||
print("should not print due to termination") | ||
|
||
return main() | ||
|
||
_verify_term( | ||
main_coro_gen, | ||
[ | ||
"stopped", | ||
"call_next response: call_next_called", | ||
"terminated", | ||
], | ||
) | ||
|
||
|
||
def _verify_term( | ||
main_coro_gen: Callable[[TerminationHandlerMiddleware], Awaitable], | ||
expected_lines: List[str], | ||
): | ||
def run(stdout_capture_file_path): | ||
import asyncio | ||
import os | ||
import signal | ||
import sys | ||
|
||
sys.stdout = open(stdout_capture_file_path, "w") | ||
|
||
def term(): | ||
print("terminated", flush=True) | ||
os.kill(os.getpid(), signal.SIGKILL) | ||
|
||
middleware = TerminationHandlerMiddleware( | ||
on_stop=lambda: print("stopped", flush=True), | ||
on_term=term, | ||
termination_delay_secs=0.1, | ||
) | ||
asyncio.run(main_coro_gen(middleware)) | ||
|
||
stdout_capture_file = tempfile.NamedTemporaryFile() | ||
proc = multiprocessing.Process(target=run, args=(stdout_capture_file.name,)) | ||
proc.start() | ||
time.sleep(1) | ||
proc.terminate() | ||
proc.join(timeout=6.0) | ||
with Path(stdout_capture_file.name).open() as file: | ||
lines = [line.strip() for line in file] | ||
|
||
assert lines == expected_lines |