Skip to content

Commit

Permalink
Simplify reloader and remove --workers (#141)
Browse files Browse the repository at this point in the history
* Simplify reloader and remove --workers

* Remove commented-out code
  • Loading branch information
tomchristie committed Jul 20, 2018
1 parent 84ecdbd commit 63e3a17
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 128 deletions.
5 changes: 4 additions & 1 deletion uvicorn/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ async def __call__(self, receive, send):
accept = get_accept_header(self.scope)
if "text/html" in accept:
exc_html = html.escape(traceback.format_exc())
content = "<html><body><h1>500 Server Error</h1><pre>%s</pre></body></html>" % exc_html
content = (
"<html><body><h1>500 Server Error</h1><pre>%s</pre></body></html>"
% exc_html
)
response = HTMLResponse(content, status_code=500)
else:
content = traceback.format_exc()
Expand Down
155 changes: 44 additions & 111 deletions uvicorn/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from uvicorn.debug import DebugMiddleware
from uvicorn.importer import import_from_string, ImportFromStringError
from uvicorn.reloaders.noreload import NoReload
from uvicorn.reloaders.statreload import StatReload
import asyncio
import click
Expand Down Expand Up @@ -45,7 +44,6 @@ def get_socket(host, port):
sock = socket.socket()
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
sock.set_inheritable(True)
return sock


Expand All @@ -61,21 +59,9 @@ def get_logger(log_level):
@click.option("--port", type=int, default=8000, help="Port")
@click.option("--loop", type=LOOP_CHOICES, default="auto", help="Event loop")
@click.option("--http", type=HTTP_CHOICES, default="auto", help="HTTP Handler")
@click.option("--workers", type=int, default=1, help="Number of worker processes")
@click.option(
"--debug", type=bool, is_flag=True, default=False, help="Enable debug mode"
)
@click.option("--debug", is_flag=True, default=False, help="Enable debug mode")
@click.option("--log-level", type=LEVEL_CHOICES, default="info", help="Log level")
def main(
app,
host: str,
port: int,
loop: str,
http: str,
workers: int,
debug: bool,
log_level: str,
):
def main(app, host: str, port: int, loop: str, http: str, debug: bool, log_level: str):
sys.path.insert(0, ".")

kwargs = {
Expand All @@ -85,10 +71,15 @@ def main(
"loop": loop,
"http": http,
"log_level": log_level,
"workers": workers,
"debug": debug,
}
run(**kwargs)

if debug:
logger = get_logger(log_level)
reloader = StatReload(logger)
reloader.run(run, kwargs)
else:
run(**kwargs)


def run(
Expand All @@ -98,74 +89,8 @@ def run(
loop="auto",
http="auto",
log_level="info",
workers=1,
debug=False,
):
sock = get_socket(host, port)
logger = get_logger(log_level)
pid = os.getpid()

message = "* Uvicorn running on http://%s:%d 🦄 (Press CTRL+C to quit)"
click.echo(message % (host, port))
logger.info("Started parent [{}]".format(pid))

processes = []
seen_shutdown = False
seen_restart = False

if debug:
reloader = StatReload(logger)
else:
reloader = NoReload()

def shutdown(sig, frame):
nonlocal seen_shutdown

seen_shutdown = True

logger.warning("Got signal %s. Shutting down.", signal.Signals(sig).name)

for process, event in processes:
event.set()

for sig in HANDLED_SIGNALS:
signal.signal(sig, shutdown)

while not seen_shutdown:
for _ in range(workers):
event = multiprocessing.Event()
kwargs = {
"app": app,
"sock": sock,
"event": event,
"logger": logger,
"loop": loop,
"http": http,
"debug": debug,
}
process = multiprocessing.Process(target=run_one, kwargs=kwargs)
process.start()
processes.append((process, event))

while not (seen_shutdown or seen_restart):
if not any([process.is_alive() for process, event in processes]):
seen_shutdown = True
time.sleep(0.2)
seen_restart = reloader.should_restart()

if seen_restart:
for process, event in processes:
event.set()
seen_restart = False
reloader.clear()

for process, event in processes:
process.join()

logger.info("Stopping parent [{}]".format(pid))


def run_one(app, sock, event, logger, debug=False, loop="auto", http="auto"):
try:
app = import_from_string(app)
except ImportFromStringError as exc:
Expand All @@ -175,61 +100,69 @@ def run_one(app, sock, event, logger, debug=False, loop="auto", http="auto"):
if debug:
app = DebugMiddleware(app)

logger = get_logger(log_level)
loop_setup = import_from_string(LOOP_SETUPS[loop])
protocol_class = import_from_string(HTTP_PROTOCOLS[http])

loop = loop_setup()

# Ignore signals, instead allowing the parent process to handle them.
# Communication with subprocesses is via the 'multiprocessing.Event' instance.
def ignore(sig, frame):
pass

for sig in HANDLED_SIGNALS:
signal.signal(sig, ignore)

server = Server(app, sock, event, logger, loop, protocol_class)
server = Server(
app=app,
host=host,
port=port,
logger=logger,
loop=loop,
protocol_class=protocol_class,
)
server.run()


class Server:
def __init__(self, app, sock, event, logger, loop, protocol_class):
def __init__(self, app, host, port, logger, loop, protocol_class):
self.app = app
self.sock = sock
self.event = event
self.host = host
self.port = port
self.logger = logger
self.loop = loop
self.protocol_class = protocol_class
self.should_exit = False
self.pid = os.getpid()

def set_signal_handlers(self):
try:
for sig in HANDLED_SIGNALS:
self.loop.add_signal_handler(sig, self.handle_exit, sig, None)
except NotImplementedError:
# Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)

def handle_exit(self, sig, frame):
self.should_exit = True

def run(self):
self.logger.info("Started worker [{}]".format(self.pid))
self.logger.info("Started server process [{}]".format(self.pid))
self.set_signal_handlers()
self.loop.run_until_complete(self.create_server())
self.loop.create_task(self.tick())
self.loop.run_forever()

def create_protocol(self):
try:
return self.protocol_class(app=self.app, loop=self.loop, logger=self.logger)
except Exception as exc:
self.logger.error(exc)
self.event.set()
return self.protocol_class(app=self.app, loop=self.loop, logger=self.logger)

async def create_server(self):
try:
self.server = await self.loop.create_server(
self.create_protocol, sock=self.sock
)
except Exception as exc:
self.logger.error(exc)
self.event.set()
self.server = await self.loop.create_server(
self.create_protocol, self.host, self.port
)
message = "* Uvicorn running on http://%s:%d 🦄 (Press CTRL+C to quit)"
click.echo(message % (self.host, self.port))

async def tick(self):
while not self.event.is_set():
while not self.should_exit:
self.protocol_class.tick()
await asyncio.sleep(1)

self.logger.info("Stopping worker [{}]".format(self.pid))
self.logger.info("Stopping server process [{}]".format(self.pid))
self.server.close()
await self.server.wait_closed()
self.loop.stop()
Expand Down
9 changes: 0 additions & 9 deletions uvicorn/reloaders/noreload.py

This file was deleted.

50 changes: 43 additions & 7 deletions uvicorn/reloaders/statreload.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,54 @@
import os
import signal
import sys
import time
import multiprocessing


def _iter_py_files():
for subdir, dirs, files in os.walk("."):
for file in files:
filepath = subdir + os.sep + file
if filepath.endswith(".py"):
yield filepath
HANDLED_SIGNALS = (
signal.SIGINT, # Unix signal 2. Sent by Ctrl+C.
signal.SIGTERM, # Unix signal 15. Sent by `kill <pid>`.
)


class StatReload:
def __init__(self, logger):
self.logger = logger
self.should_exit = False
self.mtimes = {}

def handle_exit(self, sig, frame):
self.should_exit = True

def run(self, target, kwargs):
pid = os.getpid()

self.logger.info("Started reloader process [{}]".format(pid))

for sig in HANDLED_SIGNALS:
signal.signal(sig, self.handle_exit)

process = multiprocessing.Process(target=target, kwargs=kwargs)
process.start()

while process.is_alive() and not self.should_exit:
time.sleep(0.2)
if self.should_restart():
self.clear()
os.kill(process.pid, signal.SIGTERM)
process.join()
process = multiprocessing.Process(target=target, kwargs=kwargs)
process.start()

self.logger.info("Stopping reloader process [{}]".format(pid))

sys.exit(process.exitcode)

def clear(self):
self.mtimes = {}

def should_restart(self):
for filename in _iter_py_files():
for filename in self.iter_py_files():
try:
mtime = os.stat(filename).st_mtime
except OSError:
Expand All @@ -34,3 +63,10 @@ def should_restart(self):
self.logger.warning(message, filename)
return True
return False

def iter_py_files(self):
for subdir, dirs, files in os.walk("."):
for file in files:
filepath = subdir + os.sep + file
if filepath.endswith(".py"):
yield filepath

0 comments on commit 63e3a17

Please sign in to comment.