Skip to content

Commit

Permalink
server_host.py refactor (#2647)
Browse files Browse the repository at this point in the history
* Refactor server_host.py to use helper class

* Break from the while loop when see error

* rename some things
  • Loading branch information
joeyballentine committed Mar 4, 2024
1 parent 2bf4570 commit f3ec0ff
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 160 deletions.
193 changes: 33 additions & 160 deletions backend/src/server_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,26 @@

import asyncio
import logging
import os
import socket
import subprocess
import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict, dataclass
from json import dumps as stringify

import aiohttp
import psutil
from sanic import Sanic
from sanic.log import access_logger, logger
from sanic.request import Request
from sanic.response import HTTPResponse, json
from sanic.response import json
from sanic_cors import CORS

import api
from api import (
Package,
)
from custom_types import UpdateProgressFn
from dependencies.store import DependencyInfo, install_dependencies, installed_packages
from events import EventQueue
from gpu import get_nvidia_helper
from server_config import ServerConfig
from server_process_helper import ExecutorServer


def find_free_port():
Expand All @@ -37,11 +31,6 @@ def find_free_port():
return s.getsockname()[1] # Return the port number assigned.


port = find_free_port()
base_url = f"http://127.0.0.1:{port}"
session = None


class AppContext:
def __init__(self):
self.config: ServerConfig = None # type: ignore
Expand Down Expand Up @@ -70,155 +59,49 @@ def filter(self, record): # noqa: ANN001
)


class ExecutorServerProcess:
def __init__(self, flags: list[str] | None = None):
self.process = None
self.stop_event = threading.Event()
self.finished_starting = False

self.flags = flags or []

def start_process(self):
server_file = os.path.join(os.path.dirname(__file__), "server.py")
python_location = sys.executable
self.process = subprocess.Popen(
[python_location, server_file, str(port), *self.flags],
shell=False,
stdin=None,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Create a separate thread to read and print the output of the subprocess
threading.Thread(
target=self._read_output, daemon=True, name="output reader"
).start()

def stop_process(self):
if self.process:
self.stop_event.set()
self.process.terminate()
self.process.kill()

def _read_output(self):
if self.process is None or self.process.stdout is None:
return
for line in self.process.stdout:
if self.stop_event.is_set():
break
if not self.finished_starting:
if "Starting worker" in line.decode():
self.finished_starting = True
print(line.decode().strip())


server_process: ExecutorServerProcess = ExecutorServerProcess()
server_process.start_process()


def start_executor_server(flags: list[str] | None = None):
global server_process
del server_process
server_process = ExecutorServerProcess(flags)
server_process.start_process()
return server_process


def stop_executor_server():
server_process.stop_process()


def restart_executor_server(flags: list[str] | None = None):
stop_executor_server()
start_executor_server(flags)


async def wait_for_server_start():
while server_process.finished_starting is False:
await asyncio.sleep(0.1)


backend_ready = False


async def wait_for_backend_ready():
while not backend_ready:
await asyncio.sleep(0.1)

port = find_free_port()
executor_server: ExecutorServer = ExecutorServer(port)

setup_task = None

access_logger.addFilter(SSEFilter())


async def proxy_request(request: Request, timeout: int | None = 300):
assert session is not None
await wait_for_server_start()
await wait_for_backend_ready()
if request.route is None:
raise ValueError("Route not found")
async with session.request(
request.method,
f"/{request.route.path}",
headers=request.headers,
data=request.body,
timeout=timeout,
) as resp:
headers = resp.headers
status = resp.status
body = await resp.read()
return HTTPResponse(
body,
status=status,
headers=dict(headers),
content_type=request.content_type,
)


async def get_packages_req():
await wait_for_server_start()
assert session is not None
logger.info("Fetching packages...")
packages_resp = await session.get("/packages", params={"hideInternal": "false"})
packages_json = await packages_resp.json()
packages = [Package.from_dict(p) for p in packages_json]
return packages


@app.route("/nodes")
async def nodes(request: Request):
resp = await proxy_request(request)
resp = await executor_server.proxy_request(request)
return resp


@app.route("/run", methods=["POST"])
async def run(request: Request):
return await proxy_request(request, timeout=None)
return await executor_server.proxy_request(request, timeout=None)


@app.route("/run/individual", methods=["POST"])
async def run_individual(request: Request):
logger.info("Running individual")
return await proxy_request(request)
return await executor_server.proxy_request(request)


@app.route("/clear-cache/individual", methods=["POST"])
async def clear_cache_individual(request: Request):
return await proxy_request(request)
return await executor_server.proxy_request(request)


@app.route("/pause", methods=["POST"])
async def pause(request: Request):
return await proxy_request(request)
return await executor_server.proxy_request(request)


@app.route("/resume", methods=["POST"])
async def resume(request: Request):
return await proxy_request(request, timeout=None)
return await executor_server.proxy_request(request, timeout=None)


@app.route("/kill", methods=["POST"])
async def kill(request: Request):
return await proxy_request(request)
return await executor_server.proxy_request(request)


@app.route("/python-info", methods=["GET"])
Expand Down Expand Up @@ -257,13 +140,13 @@ async def system_usage(_request: Request):

@app.route("/packages", methods=["GET"])
async def get_packages(request: Request):
return await proxy_request(request)
return await executor_server.proxy_request(request)


@app.route("/installed-dependencies", methods=["GET"])
async def get_installed_dependencies(request: Request):
installed_deps: dict[str, str] = {}
packages = await get_packages_req()
packages = await executor_server.get_packages()
for package in packages:
for pkg_dep in package.dependencies:
installed_version = installed_packages.get(pkg_dep.pypi_name, None)
Expand All @@ -275,23 +158,20 @@ async def get_installed_dependencies(request: Request):

@app.route("/features")
async def get_features(request: Request):
return await proxy_request(request)
return await executor_server.proxy_request(request)


@app.get("/sse")
async def sse(request: Request):
assert session is not None
headers = {"Cache-Control": "no-cache"}
response = await request.respond(headers=headers, content_type="text/event-stream")
async with session.request(
request.method, "/sse", headers=request.headers, data=request.body, timeout=None
) as resp:
while True:
try:
async for data, _ in resp.content.iter_chunks():
async for data in executor_server.get_sse(request):
if response is not None:
await response.send(data)
except Exception as ex:
logger.error(f"Error in sse: {ex}")
except Exception:
break


@app.get("/setup-sse")
Expand All @@ -300,10 +180,13 @@ async def setup_sse(request: Request):
headers = {"Cache-Control": "no-cache"}
response = await request.respond(headers=headers, content_type="text/event-stream")
while True:
message = await ctx.setup_queue.get()
if response is not None:
await response.send(f"event: {message['event']}\n")
await response.send(f"data: {stringify(message['data'])}\n\n")
try:
message = await ctx.setup_queue.get()
if response is not None:
await response.send(f"event: {message['event']}\n")
await response.send(f"data: {stringify(message['data'])}\n\n")
except Exception:
break


async def import_packages(
Expand All @@ -322,7 +205,7 @@ async def install_deps(dependencies: list[api.Dependency]):
]
await install_dependencies(dep_info, update_progress_cb, logger)

packages = await get_packages_req()
packages = await executor_server.get_packages()

logger.info("Checking dependencies...")

Expand Down Expand Up @@ -354,7 +237,7 @@ async def install_deps(dependencies: list[api.Dependency]):
if config.close_after_start:
flags.append("--close-after-start")

restart_executor_server(flags)
await executor_server.restart(flags)
except Exception as ex:
logger.error(f"Error installing dependencies: {ex}", exc_info=True)
if config.close_after_start:
Expand All @@ -364,8 +247,6 @@ async def install_deps(dependencies: list[api.Dependency]):


async def setup(sanic_app: Sanic, loop: asyncio.AbstractEventLoop):
global backend_ready

setup_queue = AppContext.get(sanic_app).setup_queue

async def update_progress(
Expand Down Expand Up @@ -405,11 +286,7 @@ async def update_progress(
await update_progress("Loading Nodes...", 1.0, None)

# Wait to send backend-ready until nodes are loaded
await wait_for_server_start()
assert session is not None
await session.get("/nodes", timeout=None)

backend_ready = True
await executor_server.wait_for_backend_ready()

await setup_queue.put_and_wait(
{
Expand Down Expand Up @@ -437,30 +314,26 @@ async def close_server(sanic_app: Sanic):
except Exception as ex:
logger.error(f"Error waiting for server to start: {ex}")

stop_executor_server()
if session is not None:
await session.close()
await executor_server.stop()
sanic_app.stop()


@app.after_server_stop
async def after_server_stop(_sanic_app: Sanic, _loop: asyncio.AbstractEventLoop):
server_process.stop_process()
if session is not None:
await session.close()
await executor_server.stop()
logger.info("Server closed.")


@app.after_server_start
async def after_server_start(sanic_app: Sanic, loop: asyncio.AbstractEventLoop):
global session, setup_task
session = aiohttp.ClientSession(base_url=base_url)
global setup_task
await executor_server.start()

# initialize the queues
ctx = AppContext.get(sanic_app)
ctx.setup_queue = EventQueue()

await wait_for_server_start()
await executor_server.wait_for_server_start()

# start the setup task
setup_task = loop.create_task(setup(sanic_app, loop))
Expand Down

0 comments on commit f3ec0ff

Please sign in to comment.