From f3ec0ff3357ba3260085a1c3ad834d6a7473a911 Mon Sep 17 00:00:00 2001 From: Joey Ballentine <34788790+joeyballentine@users.noreply.github.com> Date: Mon, 4 Mar 2024 13:37:01 -0500 Subject: [PATCH] server_host.py refactor (#2647) * Refactor server_host.py to use helper class * Break from the while loop when see error * rename some things --- backend/src/server_host.py | 193 +++++---------------------- backend/src/server_process_helper.py | 148 ++++++++++++++++++++ 2 files changed, 181 insertions(+), 160 deletions(-) create mode 100644 backend/src/server_process_helper.py diff --git a/backend/src/server_host.py b/backend/src/server_host.py index 3da6f1286..7e07f5a3f 100644 --- a/backend/src/server_host.py +++ b/backend/src/server_host.py @@ -2,32 +2,26 @@ import asyncio import logging -import os import socket -import subprocess import sys -import threading from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass from json import dumps as stringify -import aiohttp import psutil from sanic import Sanic from sanic.log import access_logger, logger from sanic.request import Request -from sanic.response import HTTPResponse, json +from sanic.response import json from sanic_cors import CORS import api -from api import ( - Package, -) from custom_types import UpdateProgressFn from dependencies.store import DependencyInfo, install_dependencies, installed_packages from events import EventQueue from gpu import get_nvidia_helper from server_config import ServerConfig +from server_process_helper import ExecutorServer def find_free_port(): @@ -37,11 +31,6 @@ def find_free_port(): return s.getsockname()[1] # Return the port number assigned. -port = find_free_port() -base_url = f"http://127.0.0.1:{port}" -session = None - - class AppContext: def __init__(self): self.config: ServerConfig = None # type: ignore @@ -70,155 +59,49 @@ def filter(self, record): # noqa: ANN001 ) -class ExecutorServerProcess: - def __init__(self, flags: list[str] | None = None): - self.process = None - self.stop_event = threading.Event() - self.finished_starting = False - - self.flags = flags or [] - - def start_process(self): - server_file = os.path.join(os.path.dirname(__file__), "server.py") - python_location = sys.executable - self.process = subprocess.Popen( - [python_location, server_file, str(port), *self.flags], - shell=False, - stdin=None, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - # Create a separate thread to read and print the output of the subprocess - threading.Thread( - target=self._read_output, daemon=True, name="output reader" - ).start() - - def stop_process(self): - if self.process: - self.stop_event.set() - self.process.terminate() - self.process.kill() - - def _read_output(self): - if self.process is None or self.process.stdout is None: - return - for line in self.process.stdout: - if self.stop_event.is_set(): - break - if not self.finished_starting: - if "Starting worker" in line.decode(): - self.finished_starting = True - print(line.decode().strip()) - - -server_process: ExecutorServerProcess = ExecutorServerProcess() -server_process.start_process() - - -def start_executor_server(flags: list[str] | None = None): - global server_process - del server_process - server_process = ExecutorServerProcess(flags) - server_process.start_process() - return server_process - - -def stop_executor_server(): - server_process.stop_process() - - -def restart_executor_server(flags: list[str] | None = None): - stop_executor_server() - start_executor_server(flags) - - -async def wait_for_server_start(): - while server_process.finished_starting is False: - await asyncio.sleep(0.1) - - -backend_ready = False - - -async def wait_for_backend_ready(): - while not backend_ready: - await asyncio.sleep(0.1) - +port = find_free_port() +executor_server: ExecutorServer = ExecutorServer(port) setup_task = None access_logger.addFilter(SSEFilter()) -async def proxy_request(request: Request, timeout: int | None = 300): - assert session is not None - await wait_for_server_start() - await wait_for_backend_ready() - if request.route is None: - raise ValueError("Route not found") - async with session.request( - request.method, - f"/{request.route.path}", - headers=request.headers, - data=request.body, - timeout=timeout, - ) as resp: - headers = resp.headers - status = resp.status - body = await resp.read() - return HTTPResponse( - body, - status=status, - headers=dict(headers), - content_type=request.content_type, - ) - - -async def get_packages_req(): - await wait_for_server_start() - assert session is not None - logger.info("Fetching packages...") - packages_resp = await session.get("/packages", params={"hideInternal": "false"}) - packages_json = await packages_resp.json() - packages = [Package.from_dict(p) for p in packages_json] - return packages - - @app.route("/nodes") async def nodes(request: Request): - resp = await proxy_request(request) + resp = await executor_server.proxy_request(request) return resp @app.route("/run", methods=["POST"]) async def run(request: Request): - return await proxy_request(request, timeout=None) + return await executor_server.proxy_request(request, timeout=None) @app.route("/run/individual", methods=["POST"]) async def run_individual(request: Request): logger.info("Running individual") - return await proxy_request(request) + return await executor_server.proxy_request(request) @app.route("/clear-cache/individual", methods=["POST"]) async def clear_cache_individual(request: Request): - return await proxy_request(request) + return await executor_server.proxy_request(request) @app.route("/pause", methods=["POST"]) async def pause(request: Request): - return await proxy_request(request) + return await executor_server.proxy_request(request) @app.route("/resume", methods=["POST"]) async def resume(request: Request): - return await proxy_request(request, timeout=None) + return await executor_server.proxy_request(request, timeout=None) @app.route("/kill", methods=["POST"]) async def kill(request: Request): - return await proxy_request(request) + return await executor_server.proxy_request(request) @app.route("/python-info", methods=["GET"]) @@ -257,13 +140,13 @@ async def system_usage(_request: Request): @app.route("/packages", methods=["GET"]) async def get_packages(request: Request): - return await proxy_request(request) + return await executor_server.proxy_request(request) @app.route("/installed-dependencies", methods=["GET"]) async def get_installed_dependencies(request: Request): installed_deps: dict[str, str] = {} - packages = await get_packages_req() + packages = await executor_server.get_packages() for package in packages: for pkg_dep in package.dependencies: installed_version = installed_packages.get(pkg_dep.pypi_name, None) @@ -275,23 +158,20 @@ async def get_installed_dependencies(request: Request): @app.route("/features") async def get_features(request: Request): - return await proxy_request(request) + return await executor_server.proxy_request(request) @app.get("/sse") async def sse(request: Request): - assert session is not None headers = {"Cache-Control": "no-cache"} response = await request.respond(headers=headers, content_type="text/event-stream") - async with session.request( - request.method, "/sse", headers=request.headers, data=request.body, timeout=None - ) as resp: + while True: try: - async for data, _ in resp.content.iter_chunks(): + async for data in executor_server.get_sse(request): if response is not None: await response.send(data) - except Exception as ex: - logger.error(f"Error in sse: {ex}") + except Exception: + break @app.get("/setup-sse") @@ -300,10 +180,13 @@ async def setup_sse(request: Request): headers = {"Cache-Control": "no-cache"} response = await request.respond(headers=headers, content_type="text/event-stream") while True: - message = await ctx.setup_queue.get() - if response is not None: - await response.send(f"event: {message['event']}\n") - await response.send(f"data: {stringify(message['data'])}\n\n") + try: + message = await ctx.setup_queue.get() + if response is not None: + await response.send(f"event: {message['event']}\n") + await response.send(f"data: {stringify(message['data'])}\n\n") + except Exception: + break async def import_packages( @@ -322,7 +205,7 @@ async def install_deps(dependencies: list[api.Dependency]): ] await install_dependencies(dep_info, update_progress_cb, logger) - packages = await get_packages_req() + packages = await executor_server.get_packages() logger.info("Checking dependencies...") @@ -354,7 +237,7 @@ async def install_deps(dependencies: list[api.Dependency]): if config.close_after_start: flags.append("--close-after-start") - restart_executor_server(flags) + await executor_server.restart(flags) except Exception as ex: logger.error(f"Error installing dependencies: {ex}", exc_info=True) if config.close_after_start: @@ -364,8 +247,6 @@ async def install_deps(dependencies: list[api.Dependency]): async def setup(sanic_app: Sanic, loop: asyncio.AbstractEventLoop): - global backend_ready - setup_queue = AppContext.get(sanic_app).setup_queue async def update_progress( @@ -405,11 +286,7 @@ async def update_progress( await update_progress("Loading Nodes...", 1.0, None) # Wait to send backend-ready until nodes are loaded - await wait_for_server_start() - assert session is not None - await session.get("/nodes", timeout=None) - - backend_ready = True + await executor_server.wait_for_backend_ready() await setup_queue.put_and_wait( { @@ -437,30 +314,26 @@ async def close_server(sanic_app: Sanic): except Exception as ex: logger.error(f"Error waiting for server to start: {ex}") - stop_executor_server() - if session is not None: - await session.close() + await executor_server.stop() sanic_app.stop() @app.after_server_stop async def after_server_stop(_sanic_app: Sanic, _loop: asyncio.AbstractEventLoop): - server_process.stop_process() - if session is not None: - await session.close() + await executor_server.stop() logger.info("Server closed.") @app.after_server_start async def after_server_start(sanic_app: Sanic, loop: asyncio.AbstractEventLoop): - global session, setup_task - session = aiohttp.ClientSession(base_url=base_url) + global setup_task + await executor_server.start() # initialize the queues ctx = AppContext.get(sanic_app) ctx.setup_queue = EventQueue() - await wait_for_server_start() + await executor_server.wait_for_server_start() # start the setup task setup_task = loop.create_task(setup(sanic_app, loop)) diff --git a/backend/src/server_process_helper.py b/backend/src/server_process_helper.py new file mode 100644 index 000000000..54e17f9b0 --- /dev/null +++ b/backend/src/server_process_helper.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import asyncio +import os +import subprocess +import sys +import threading + +import aiohttp +from sanic import HTTPResponse, Request +from sanic.log import logger + +from api import Package + + +class ExecutorServerWorker: + def __init__(self, port: int, flags: list[str] | None = None): + self.process = None + self.stop_event = threading.Event() + self.finished_starting = False + + self.port = port + self.flags = flags or [] + + def start_process(self): + server_file = os.path.join(os.path.dirname(__file__), "server.py") + python_location = sys.executable + self.process = subprocess.Popen( + [python_location, server_file, str(self.port), *self.flags], + shell=False, + stdin=None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + # Create a separate thread to read and print the output of the subprocess + threading.Thread( + target=self._read_output, daemon=True, name="output reader" + ).start() + + def stop_process(self): + if self.process: + self.stop_event.set() + self.process.terminate() + self.process.kill() + + def _read_output(self): + if self.process is None or self.process.stdout is None: + return + for line in self.process.stdout: + if self.stop_event.is_set(): + break + if not self.finished_starting: + if "Starting worker" in line.decode(): + self.finished_starting = True + print(line.decode().strip()) + + +class ExecutorServer: + def __init__(self, port: int, flags: list[str] | None = None): + self.port = port + self.flags = flags + + self.server_process = None + + self.base_url = f"http://127.0.0.1:{port}" + self.session = None + + self.backend_ready = False + + async def start(self, flags: list[str] | None = None): + del self.server_process + self.server_process = ExecutorServerWorker(self.port, flags or self.flags) + self.server_process.start_process() + self.session = aiohttp.ClientSession(base_url=self.base_url) + await self.wait_for_server_start() + await self.session.get("/nodes", timeout=None) + self.backend_ready = True + return self + + async def stop(self): + if self.server_process: + self.server_process.stop_process() + if self.session: + await self.session.close() + + async def restart(self, flags: list[str] | None = None): + await self.stop() + await self.start(flags) + + async def wait_for_server_start(self): + while ( + self.server_process is None + or self.server_process.finished_starting is False + ): + await asyncio.sleep(0.1) + + async def wait_for_backend_ready(self): + while not self.backend_ready: + await asyncio.sleep(0.1) + + async def proxy_request(self, request: Request, timeout: int | None = 300): + assert self.session is not None + await self.wait_for_server_start() + await self.wait_for_backend_ready() + if request.route is None: + raise ValueError("Route not found") + async with self.session.request( + request.method, + f"/{request.route.path}", + headers=request.headers, + data=request.body, + timeout=timeout, + ) as resp: + headers = resp.headers + status = resp.status + body = await resp.read() + return HTTPResponse( + body, + status=status, + headers=dict(headers), + content_type=request.content_type, + ) + + async def get_sse(self, request: Request): + assert self.session is not None + await self.wait_for_server_start() + await self.wait_for_backend_ready() + async with self.session.request( + request.method, + "/sse", + headers=request.headers, + data=request.body, + timeout=None, + ) as resp: + async for data, _ in resp.content.iter_chunks(): + yield data + + async def get_packages(self): + await self.wait_for_server_start() + await self.wait_for_backend_ready() + assert self.session is not None + logger.debug("Fetching packages...") + packages_resp = await self.session.get( + "/packages", params={"hideInternal": "false"} + ) + packages_json = await packages_resp.json() + packages = [Package.from_dict(p) for p in packages_json] + return packages