diff --git a/backend/src/server_host.py b/backend/src/server_host.py index c37c7692c..f20cfa927 100644 --- a/backend/src/server_host.py +++ b/backend/src/server_host.py @@ -20,7 +20,7 @@ from events import EventQueue from gpu import get_nvidia_helper from server_config import ServerConfig -from server_process_helper import ExecutorServer +from server_process_helper import WorkerServer class AppContext: @@ -51,7 +51,7 @@ def filter(self, record): # noqa: ANN001 ) -executor_server: ExecutorServer = ExecutorServer() +worker: WorkerServer = WorkerServer() setup_task = None @@ -60,39 +60,39 @@ def filter(self, record): # noqa: ANN001 @app.route("/nodes") async def nodes(request: Request): - resp = await executor_server.proxy_request(request) + resp = await worker.proxy_request(request) return resp @app.route("/run", methods=["POST"]) async def run(request: Request): - return await executor_server.proxy_request(request, timeout=None) + return await worker.proxy_request(request, timeout=None) @app.route("/run/individual", methods=["POST"]) async def run_individual(request: Request): logger.info("Running individual") - return await executor_server.proxy_request(request) + return await worker.proxy_request(request) @app.route("/clear-cache/individual", methods=["POST"]) async def clear_cache_individual(request: Request): - return await executor_server.proxy_request(request) + return await worker.proxy_request(request) @app.route("/pause", methods=["POST"]) async def pause(request: Request): - return await executor_server.proxy_request(request) + return await worker.proxy_request(request) @app.route("/resume", methods=["POST"]) async def resume(request: Request): - return await executor_server.proxy_request(request, timeout=None) + return await worker.proxy_request(request, timeout=None) @app.route("/kill", methods=["POST"]) async def kill(request: Request): - return await executor_server.proxy_request(request) + return await worker.proxy_request(request) @app.route("/python-info", methods=["GET"]) @@ -131,13 +131,13 @@ async def system_usage(_request: Request): @app.route("/packages", methods=["GET"]) async def get_packages(request: Request): - return await executor_server.proxy_request(request) + return await worker.proxy_request(request) @app.route("/installed-dependencies", methods=["GET"]) async def get_installed_dependencies(request: Request): installed_deps: dict[str, str] = {} - packages = await executor_server.get_packages() + packages = await worker.get_packages() for package in packages: for pkg_dep in package.dependencies: installed_version = installed_packages.get(pkg_dep.pypi_name, None) @@ -149,7 +149,7 @@ async def get_installed_dependencies(request: Request): @app.route("/features") async def get_features(request: Request): - return await executor_server.proxy_request(request) + return await worker.proxy_request(request) @app.get("/sse") @@ -158,7 +158,7 @@ async def sse(request: Request): response = await request.respond(headers=headers, content_type="text/event-stream") while True: try: - async for data in executor_server.get_sse(request): + async for data in worker.get_sse(request): if response is not None: await response.send(data) except Exception: @@ -196,7 +196,7 @@ async def install_deps(dependencies: list[api.Dependency]): ] await install_dependencies(dep_info, update_progress_cb, logger) - packages = await executor_server.get_packages() + packages = await worker.get_packages() logger.info("Checking dependencies...") @@ -228,7 +228,7 @@ async def install_deps(dependencies: list[api.Dependency]): if config.close_after_start: flags.append("--close-after-start") - await executor_server.restart(flags) + await worker.restart(flags) except Exception as ex: logger.error(f"Error installing dependencies: {ex}", exc_info=True) if config.close_after_start: @@ -277,7 +277,7 @@ async def update_progress( await update_progress("Loading Nodes...", 1.0, None) # Wait to send backend-ready until nodes are loaded - await executor_server.wait_for_backend_ready() + await worker.wait_for_ready() await setup_queue.put_and_wait( { @@ -305,26 +305,26 @@ async def close_server(sanic_app: Sanic): except Exception as ex: logger.error(f"Error waiting for server to start: {ex}") - await executor_server.stop() + await worker.stop() sanic_app.stop() @app.after_server_stop async def after_server_stop(_sanic_app: Sanic, _loop: asyncio.AbstractEventLoop): - await executor_server.stop() + await worker.stop() logger.info("Server closed.") @app.after_server_start async def after_server_start(sanic_app: Sanic, loop: asyncio.AbstractEventLoop): global setup_task - await executor_server.start() + await worker.start() # initialize the queues ctx = AppContext.get(sanic_app) ctx.setup_queue = EventQueue() - await executor_server.wait_for_server_start() + await worker.wait_for_ready() # start the setup task setup_task = loop.create_task(setup(sanic_app, loop)) diff --git a/backend/src/server_process_helper.py b/backend/src/server_process_helper.py index bac12fe49..7ee799405 100644 --- a/backend/src/server_process_helper.py +++ b/backend/src/server_process_helper.py @@ -6,6 +6,8 @@ import subprocess import sys import threading +import time +from typing import Iterable import aiohttp from sanic import HTTPResponse, Request @@ -14,104 +16,102 @@ from api import Package -def find_free_port(): +def _find_free_port(): with socket.socket() as s: s.bind(("", 0)) # Bind to a free port provided by the host. return s.getsockname()[1] # Return the port number assigned. -class ExecutorServerWorker: - def __init__(self, port: int, flags: list[str] | None = None): - self.process = None - self.stop_event = threading.Event() - self.finished_starting = False +def _port_in_use(port: int): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("127.0.0.1", port)) == 0 - self.port = port - self.flags = flags or [] - def start_process(self): +class _WorkerProcess: + def __init__(self, flags: list[str]): server_file = os.path.join(os.path.dirname(__file__), "server.py") python_location = sys.executable - self.process = subprocess.Popen( - [python_location, server_file, str(self.port), *self.flags], + + self._process = subprocess.Popen( + [python_location, server_file, *flags], shell=False, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) + self._stop_event = threading.Event() + # Create a separate thread to read and print the output of the subprocess threading.Thread( - target=self._read_output, daemon=True, name="output reader" + target=self._read_output, + daemon=True, + name="output reader", ).start() - def stop_process(self): - if self.process: - self.stop_event.set() - self.process.terminate() - self.process.kill() + def close(self): + self._stop_event.set() + self._process.terminate() + self._process.kill() def _read_output(self): - if self.process is None or self.process.stdout is None: + if self._process.stdout is None: return - for line in self.process.stdout: - if self.stop_event.is_set(): + for line in self._process.stdout: + if self._stop_event.is_set(): break - if not self.finished_starting: - if "Starting worker" in line.decode(): - self.finished_starting = True print(line.decode().strip()) -class ExecutorServer: - def __init__(self, flags: list[str] | None = None): - self.flags = flags - - self.server_process = None +class WorkerServer: + def __init__(self): + self._process = None - self.port = find_free_port() - self.base_url = f"http://127.0.0.1:{self.port}" - self.session = None + self._port = _find_free_port() + self._base_url = f"http://127.0.0.1:{self._port}" + self._session = None - self.backend_ready = False - - async def start(self, flags: list[str] | None = None): - del self.server_process - self.server_process = ExecutorServerWorker(self.port, flags or self.flags) - self.server_process.start_process() - self.session = aiohttp.ClientSession(base_url=self.base_url) - await self.wait_for_server_start() - await self.session.get("/nodes", timeout=None) - self.backend_ready = True - return self + async def start(self, flags: Iterable[str] = []): + logger.info("Starting worker process...") + self._process = _WorkerProcess([str(self._port), *flags]) + self._session = aiohttp.ClientSession(base_url=self._base_url) + await self.wait_for_ready() + logger.info("Worker process started") async def stop(self): - if self.server_process: - self.server_process.stop_process() - if self.session: - await self.session.close() + if self._process: + self._process.close() + if self._session: + await self._session.close() + logger.info("Worker process stopped") - async def restart(self, flags: list[str] | None = None): + async def restart(self, flags: Iterable[str] = []): await self.stop() await self.start(flags) - async def wait_for_server_start(self): - while ( - self.server_process is None - or self.server_process.finished_starting is False - ): - await asyncio.sleep(0.1) + async def wait_for_ready(self, timeout: float = 300): + start = time.time() + while time.time() - start < timeout: + if ( + self._process is not None + and self._session is not None + and _port_in_use(self._port) + ): + try: + await self._session.get("/nodes", timeout=5) + return + except Exception: + pass - async def wait_for_backend_ready(self): - while not self.backend_ready: await asyncio.sleep(0.1) + raise TimeoutError("Server did not start in time") + async def proxy_request(self, request: Request, timeout: int | None = 300): - assert self.session is not None - await self.wait_for_server_start() - await self.wait_for_backend_ready() + await self.wait_for_ready() + assert self._session is not None if request.route is None: raise ValueError("Route not found") - async with self.session.request( + async with self._session.request( request.method, f"/{request.route.path}", headers=request.headers, @@ -129,10 +129,9 @@ async def proxy_request(self, request: Request, timeout: int | None = 300): ) async def get_sse(self, request: Request): - assert self.session is not None - await self.wait_for_server_start() - await self.wait_for_backend_ready() - async with self.session.request( + await self.wait_for_ready() + assert self._session is not None + async with self._session.request( request.method, "/sse", headers=request.headers, @@ -143,11 +142,10 @@ async def get_sse(self, request: Request): yield data async def get_packages(self): - await self.wait_for_server_start() - await self.wait_for_backend_ready() - assert self.session is not None + await self.wait_for_ready() + assert self._session is not None logger.debug("Fetching packages...") - packages_resp = await self.session.get( + packages_resp = await self._session.get( "/packages", params={"hideInternal": "false"} ) packages_json = await packages_resp.json()