Skip to content

Commit

Permalink
Terminate model server quicker on receiving term signal (#404)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajroark committed Jun 28, 2023
1 parent 4723acb commit ead203d
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.4.9"
version = "0.4.10rc1"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
64 changes: 64 additions & 0 deletions truss/templates/server/common/termination_handler_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import asyncio
import signal
from typing import Callable

from fastapi import Request

# This is to allow the last request's response to finish handling. There may be more
# middlewares that the response goes through, and then there's the time for the bytes
# to be pushed to the caller.
DEFAULT_TERM_DELAY_SECS = 5.0


class TerminationHandlerMiddleware:
"""
This middleware allows for swiftly and safely terminating the server. It
listens to a set of termination signals. On receiving such a signal, it
first informs on the on_stop callback, then waits for currently executing
requests to finish, before informing on the on_term callback.
Stop means that the process to stop the server has started. As soon as
outstading requests go to zero after this, on_term will be called.
Term means that this is the right time to terminate the server process, no
outstanding requests at this point.
The caller would typically handle on_stop by stop sending more requests to
the FastApi server. And on_term by exiting the server process.
"""

def __init__(
self,
on_stop: Callable[[], None],
on_term: Callable[[], None],
termination_delay_secs: float = DEFAULT_TERM_DELAY_SECS,
):
self._outstanding_request_count = 0
self._on_stop = on_stop
self._on_term = on_term
self._termination_delay_secs = termination_delay_secs
self._stopped = False
for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]:
signal.signal(sig, self._stop)

async def __call__(self, request: Request, call_next):
self._outstanding_request_count += 1
try:
response = await call_next(request)
finally:
self._outstanding_request_count -= 1
if self._outstanding_request_count == 0 and self._stopped:
# There's a delay in term to allow some time for current
# response flow to finish.
asyncio.create_task(self._term())
return response

def _stop(self, sig, frame):
self._on_stop()
self._stopped = True
if self._outstanding_request_count == 0:
self._on_term()

async def _term(self):
await asyncio.sleep(self._termination_delay_secs)
self._on_term()
50 changes: 48 additions & 2 deletions truss/templates/server/common/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import os
import signal
import socket
import sys
import time
from pathlib import Path
from typing import Dict, List, Optional, Union

Expand All @@ -18,12 +20,22 @@
truss_msgpack_deserialize,
truss_msgpack_serialize,
)
from common.termination_handler_middleware import TerminationHandlerMiddleware
from fastapi import Depends, FastAPI, Request
from fastapi.responses import ORJSONResponse
from fastapi.routing import APIRoute as FastAPIRoute
from model_wrapper import ModelWrapper
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response

# [IMPORTANT] A lot of things depend on this currently.
# Please consider the following when increasing this:
# 1. Self-termination on model load fail.
# 2. Graceful termination.
NUM_WORKERS = 1
WORKER_TERMINATION_TIMEOUT_SECS = 120.0
WORKER_TERMINATION_CHECK_INTERVAL_SECS = 0.5


async def parse_body(request: Request) -> bytes:
"""
Expand Down Expand Up @@ -171,7 +183,7 @@ def on_startup(self):
self._model.start_load()

def create_application(self):
return FastAPI(
app = FastAPI(
title="Baseten Inference Server",
docs_url=None,
redoc_url=None,
Expand Down Expand Up @@ -213,12 +225,24 @@ def create_application(self):
},
)

def exit_self():
# Note that this kills the current process, the worker process, not
# the main truss_server process.
sys.exit()

termination_handler_middleware = TerminationHandlerMiddleware(
on_stop=lambda: None,
on_term=exit_self,
)
app.add_middleware(BaseHTTPMiddleware, dispatch=termination_handler_middleware)
return app

def start(self):
cfg = uvicorn.Config(
self.create_application(),
host="0.0.0.0",
port=self.http_port,
workers=1,
workers=NUM_WORKERS,
log_config={
"version": 1,
"formatters": {
Expand Down Expand Up @@ -275,9 +299,31 @@ async def serve():
serversocket.listen(5)

logging.info(f"starting uvicorn with {cfg.workers} workers")
servers: List[UvicornCustomServer] = []
for _ in range(cfg.workers):
server = UvicornCustomServer(config=cfg, sockets=[serversocket])
server.start()
servers.append(server)

def stop_servers():
# Send stop signal, then wait for all to exit
for server in servers:
# Sends term signal to the process, which should be handled
# by the termination handler.
server.stop()

termination_check_attempts = int(
WORKER_TERMINATION_TIMEOUT_SECS
/ WORKER_TERMINATION_CHECK_INTERVAL_SECS
)
for _ in range(termination_check_attempts):
time.sleep(WORKER_TERMINATION_CHECK_INTERVAL_SECS)
if utils.all_processes_dead(servers):
# Exit main process
sys.exit()

for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]:
signal.signal(sig, lambda sig, frame: stop_servers())

async def servers_task():
servers = [serve()]
Expand Down
9 changes: 9 additions & 0 deletions truss/templates/server/common/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import multiprocessing
import os
import sys
from typing import List

import psutil

Expand Down Expand Up @@ -49,3 +51,10 @@ def cpu_count():
pass

return count


def all_processes_dead(procs: List[multiprocessing.Process]) -> bool:
for proc in procs:
if proc.is_alive():
return False
return True
53 changes: 53 additions & 0 deletions truss/tests/templates/core/server/common/test_truss_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os
import signal
import socket
import sys
import tempfile
import time
from multiprocessing import Process
from pathlib import Path

import pytest
import yaml


@pytest.mark.integration
def test_truss_server_termination(truss_container_fs):
port = 10123

def start_truss_server(stdout_capture_file_path):
sys.stdout = open(stdout_capture_file_path, "w")
app_path = truss_container_fs / "app"
sys.path.append(str(app_path))

from common.truss_server import TrussServer

config = yaml.safe_load((app_path / "config.yaml").read_text())
server = TrussServer(http_port=port, config=config)
server.start()

stdout_capture_file = tempfile.NamedTemporaryFile()
subproc = Process(target=start_truss_server, args=(stdout_capture_file.name,))
subproc.start()
proc_id = subproc.pid
time.sleep(2.0)
# Port should have been taken up by truss server
assert not _is_port_available(port)
os.kill(proc_id, signal.SIGTERM)
time.sleep(2.0)
# Print on purpose for help with debugging, otherwise hard to know what's going on
print(Path(stdout_capture_file.name).read_text())
assert not subproc.is_alive()
# Port should be free now
assert _is_port_available(port)


def _is_port_available(port):
try:
# Try to bind to the given port
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("localhost", port))
return True
except socket.error:
# Port is already in use
return False
1 change: 1 addition & 0 deletions truss/tests/templates/core/server/common/test_util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# This file doesn't test anything, but provides utilities for testing.
from unittest import mock


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import multiprocessing
import tempfile
import time
from pathlib import Path
from typing import Awaitable, Callable, List

import pytest
from truss.templates.server.common.termination_handler_middleware import (
TerminationHandlerMiddleware,
)


async def noop(*args, **kwargs):
return


@pytest.mark.integration
def test_termination_sequence_no_pending_requests(tmp_path):
# Create middleware in separate process, on sending term signal to process,
# it should print the right messages.
def main_coro_gen(middleware: TerminationHandlerMiddleware):
import asyncio

async def main(*args, **kwargs):
await middleware(1, call_next=noop)
await asyncio.sleep(1)
print("should not print due to termination")

return main()

_verify_term(main_coro_gen, ["stopped", "terminated"])


@pytest.mark.integration
def test_termination_sequence_with_pending_requests(tmp_path):
def main_coro_gen(middleware: TerminationHandlerMiddleware):
import asyncio

async def main(*args, **kwargs):
async def call_next(req):
await asyncio.sleep(1.0)
return "call_next_called"

resp = await middleware(1, call_next=call_next)
print(f"call_next response: {resp}")
await asyncio.sleep(1)
print("should not print due to termination")

return main()

_verify_term(
main_coro_gen,
[
"stopped",
"call_next response: call_next_called",
"terminated",
],
)


def _verify_term(
main_coro_gen: Callable[[TerminationHandlerMiddleware], Awaitable],
expected_lines: List[str],
):
def run(stdout_capture_file_path):
import asyncio
import os
import signal
import sys

sys.stdout = open(stdout_capture_file_path, "w")

def term():
print("terminated", flush=True)
os.kill(os.getpid(), signal.SIGKILL)

middleware = TerminationHandlerMiddleware(
on_stop=lambda: print("stopped", flush=True),
on_term=term,
termination_delay_secs=0.1,
)
asyncio.run(main_coro_gen(middleware))

stdout_capture_file = tempfile.NamedTemporaryFile()
proc = multiprocessing.Process(target=run, args=(stdout_capture_file.name,))
proc.start()
time.sleep(1)
proc.terminate()
proc.join(timeout=6.0)
with Path(stdout_capture_file.name).open() as file:
lines = [line.strip() for line in file]

assert lines == expected_lines

0 comments on commit ead203d

Please sign in to comment.