Skip to content

Commit

Permalink
Refactor worker class (#2651)
Browse files Browse the repository at this point in the history
* Refactor worker class

* typo
  • Loading branch information
RunDevelopment committed Mar 5, 2024
1 parent f984e78 commit 6a312a8
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 87 deletions.
40 changes: 20 additions & 20 deletions backend/src/server_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from events import EventQueue
from gpu import get_nvidia_helper
from server_config import ServerConfig
from server_process_helper import ExecutorServer
from server_process_helper import WorkerServer


class AppContext:
Expand Down Expand Up @@ -51,7 +51,7 @@ def filter(self, record): # noqa: ANN001
)


executor_server: ExecutorServer = ExecutorServer()
worker: WorkerServer = WorkerServer()

setup_task = None

Expand All @@ -60,39 +60,39 @@ def filter(self, record): # noqa: ANN001

@app.route("/nodes")
async def nodes(request: Request):
resp = await executor_server.proxy_request(request)
resp = await worker.proxy_request(request)
return resp


@app.route("/run", methods=["POST"])
async def run(request: Request):
return await executor_server.proxy_request(request, timeout=None)
return await worker.proxy_request(request, timeout=None)


@app.route("/run/individual", methods=["POST"])
async def run_individual(request: Request):
logger.info("Running individual")
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/clear-cache/individual", methods=["POST"])
async def clear_cache_individual(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/pause", methods=["POST"])
async def pause(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/resume", methods=["POST"])
async def resume(request: Request):
return await executor_server.proxy_request(request, timeout=None)
return await worker.proxy_request(request, timeout=None)


@app.route("/kill", methods=["POST"])
async def kill(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/python-info", methods=["GET"])
Expand Down Expand Up @@ -131,13 +131,13 @@ async def system_usage(_request: Request):

@app.route("/packages", methods=["GET"])
async def get_packages(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.route("/installed-dependencies", methods=["GET"])
async def get_installed_dependencies(request: Request):
installed_deps: dict[str, str] = {}
packages = await executor_server.get_packages()
packages = await worker.get_packages()
for package in packages:
for pkg_dep in package.dependencies:
installed_version = installed_packages.get(pkg_dep.pypi_name, None)
Expand All @@ -149,7 +149,7 @@ async def get_installed_dependencies(request: Request):

@app.route("/features")
async def get_features(request: Request):
return await executor_server.proxy_request(request)
return await worker.proxy_request(request)


@app.get("/sse")
Expand All @@ -158,7 +158,7 @@ async def sse(request: Request):
response = await request.respond(headers=headers, content_type="text/event-stream")
while True:
try:
async for data in executor_server.get_sse(request):
async for data in worker.get_sse(request):
if response is not None:
await response.send(data)
except Exception:
Expand Down Expand Up @@ -196,7 +196,7 @@ async def install_deps(dependencies: list[api.Dependency]):
]
await install_dependencies(dep_info, update_progress_cb, logger)

packages = await executor_server.get_packages()
packages = await worker.get_packages()

logger.info("Checking dependencies...")

Expand Down Expand Up @@ -228,7 +228,7 @@ async def install_deps(dependencies: list[api.Dependency]):
if config.close_after_start:
flags.append("--close-after-start")

await executor_server.restart(flags)
await worker.restart(flags)
except Exception as ex:
logger.error(f"Error installing dependencies: {ex}", exc_info=True)
if config.close_after_start:
Expand Down Expand Up @@ -277,7 +277,7 @@ async def update_progress(
await update_progress("Loading Nodes...", 1.0, None)

# Wait to send backend-ready until nodes are loaded
await executor_server.wait_for_backend_ready()
await worker.wait_for_ready()

await setup_queue.put_and_wait(
{
Expand Down Expand Up @@ -305,26 +305,26 @@ async def close_server(sanic_app: Sanic):
except Exception as ex:
logger.error(f"Error waiting for server to start: {ex}")

await executor_server.stop()
await worker.stop()
sanic_app.stop()


@app.after_server_stop
async def after_server_stop(_sanic_app: Sanic, _loop: asyncio.AbstractEventLoop):
await executor_server.stop()
await worker.stop()
logger.info("Server closed.")


@app.after_server_start
async def after_server_start(sanic_app: Sanic, loop: asyncio.AbstractEventLoop):
global setup_task
await executor_server.start()
await worker.start()

# initialize the queues
ctx = AppContext.get(sanic_app)
ctx.setup_queue = EventQueue()

await executor_server.wait_for_server_start()
await worker.wait_for_ready()

# start the setup task
setup_task = loop.create_task(setup(sanic_app, loop))
Expand Down
132 changes: 65 additions & 67 deletions backend/src/server_process_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import subprocess
import sys
import threading
import time
from typing import Iterable

import aiohttp
from sanic import HTTPResponse, Request
Expand All @@ -14,104 +16,102 @@
from api import Package


def find_free_port():
def _find_free_port():
with socket.socket() as s:
s.bind(("", 0)) # Bind to a free port provided by the host.
return s.getsockname()[1] # Return the port number assigned.


class ExecutorServerWorker:
def __init__(self, port: int, flags: list[str] | None = None):
self.process = None
self.stop_event = threading.Event()
self.finished_starting = False
def _port_in_use(port: int):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("127.0.0.1", port)) == 0

self.port = port
self.flags = flags or []

def start_process(self):
class _WorkerProcess:
def __init__(self, flags: list[str]):
server_file = os.path.join(os.path.dirname(__file__), "server.py")
python_location = sys.executable
self.process = subprocess.Popen(
[python_location, server_file, str(self.port), *self.flags],

self._process = subprocess.Popen(
[python_location, server_file, *flags],
shell=False,
stdin=None,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
self._stop_event = threading.Event()

# Create a separate thread to read and print the output of the subprocess
threading.Thread(
target=self._read_output, daemon=True, name="output reader"
target=self._read_output,
daemon=True,
name="output reader",
).start()

def stop_process(self):
if self.process:
self.stop_event.set()
self.process.terminate()
self.process.kill()
def close(self):
self._stop_event.set()
self._process.terminate()
self._process.kill()

def _read_output(self):
if self.process is None or self.process.stdout is None:
if self._process.stdout is None:
return
for line in self.process.stdout:
if self.stop_event.is_set():
for line in self._process.stdout:
if self._stop_event.is_set():
break
if not self.finished_starting:
if "Starting worker" in line.decode():
self.finished_starting = True
print(line.decode().strip())


class ExecutorServer:
def __init__(self, flags: list[str] | None = None):
self.flags = flags

self.server_process = None
class WorkerServer:
def __init__(self):
self._process = None

self.port = find_free_port()
self.base_url = f"http://127.0.0.1:{self.port}"
self.session = None
self._port = _find_free_port()
self._base_url = f"http://127.0.0.1:{self._port}"
self._session = None

self.backend_ready = False

async def start(self, flags: list[str] | None = None):
del self.server_process
self.server_process = ExecutorServerWorker(self.port, flags or self.flags)
self.server_process.start_process()
self.session = aiohttp.ClientSession(base_url=self.base_url)
await self.wait_for_server_start()
await self.session.get("/nodes", timeout=None)
self.backend_ready = True
return self
async def start(self, flags: Iterable[str] = []):
logger.info("Starting worker process...")
self._process = _WorkerProcess([str(self._port), *flags])
self._session = aiohttp.ClientSession(base_url=self._base_url)
await self.wait_for_ready()
logger.info("Worker process started")

async def stop(self):
if self.server_process:
self.server_process.stop_process()
if self.session:
await self.session.close()
if self._process:
self._process.close()
if self._session:
await self._session.close()
logger.info("Worker process stopped")

async def restart(self, flags: list[str] | None = None):
async def restart(self, flags: Iterable[str] = []):
await self.stop()
await self.start(flags)

async def wait_for_server_start(self):
while (
self.server_process is None
or self.server_process.finished_starting is False
):
await asyncio.sleep(0.1)
async def wait_for_ready(self, timeout: float = 300):
start = time.time()
while time.time() - start < timeout:
if (
self._process is not None
and self._session is not None
and _port_in_use(self._port)
):
try:
await self._session.get("/nodes", timeout=5)
return
except Exception:
pass

async def wait_for_backend_ready(self):
while not self.backend_ready:
await asyncio.sleep(0.1)

raise TimeoutError("Server did not start in time")

async def proxy_request(self, request: Request, timeout: int | None = 300):
assert self.session is not None
await self.wait_for_server_start()
await self.wait_for_backend_ready()
await self.wait_for_ready()
assert self._session is not None
if request.route is None:
raise ValueError("Route not found")
async with self.session.request(
async with self._session.request(
request.method,
f"/{request.route.path}",
headers=request.headers,
Expand All @@ -129,10 +129,9 @@ async def proxy_request(self, request: Request, timeout: int | None = 300):
)

async def get_sse(self, request: Request):
assert self.session is not None
await self.wait_for_server_start()
await self.wait_for_backend_ready()
async with self.session.request(
await self.wait_for_ready()
assert self._session is not None
async with self._session.request(
request.method,
"/sse",
headers=request.headers,
Expand All @@ -143,11 +142,10 @@ async def get_sse(self, request: Request):
yield data

async def get_packages(self):
await self.wait_for_server_start()
await self.wait_for_backend_ready()
assert self.session is not None
await self.wait_for_ready()
assert self._session is not None
logger.debug("Fetching packages...")
packages_resp = await self.session.get(
packages_resp = await self._session.get(
"/packages", params={"hideInternal": "false"}
)
packages_json = await packages_resp.json()
Expand Down

0 comments on commit 6a312a8

Please sign in to comment.