From a3cfccc454e7b841e143c56c93efa6548a59de2b Mon Sep 17 00:00:00 2001 From: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> Date: Wed, 10 Dec 2025 09:51:48 -0700 Subject: [PATCH] :construction_worker::art: Format and lint server Signed-off-by: Evaline Ju <69598118+evaline-ju@users.noreply.github.com> --- .github/workflows/ci.yaml | 47 +++++++++++++ .pre-commit-config.yaml | 9 +++ src/server.py | 143 +++++++++++++++++++++----------------- 3 files changed, 136 insertions(+), 63 deletions(-) create mode 100644 .github/workflows/ci.yaml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..6c01f01 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,47 @@ +name: CI + +on: + # Triggers the workflow on push or pull request events but only for the "main" branch + pull_request: + branches: [ "main" ] + +jobs: + build: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.11", "3.12"] + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + # Sets up a specific version of Python + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pre-commit uv + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Run pre-commit hooks + run: | + echo "Running pre-commit hooks..." + pre-commit run --all-files --verbose || { + echo "❌ Pre-commit hooks failed!" + echo "" + echo "Files modified by hooks:" + git diff --name-only + echo "" + echo "Detailed changes:" + git diff --stat + exit 1 + } + + # Future: testing + # - name: Test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a2cd02a --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,9 @@ +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.14.8 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format diff --git a/src/server.py b/src/server.py index b33e4b6..674a05f 100644 --- a/src/server.py +++ b/src/server.py @@ -1,34 +1,34 @@ -# Translated from Rust WASM filter with help of generative model +# Standard import asyncio -import re import logging from typing import AsyncIterator -from concurrent import futures - +import json +import os import grpc + from envoy.service.ext_proc.v3 import external_processor_pb2 as ep from envoy.service.ext_proc.v3 import external_processor_pb2_grpc as ep_grpc from envoy.config.core.v3 import base_pb2 as core from envoy.type.v3 import http_status_pb2 as http_status_pb2 - # plugin manager -import sys -import json -# sys.path.append("/app/apex") -import os # First-Party -#from apex.mcp.entities.models import HookType, Message, PromptResult, Role, TextContent, PromptPosthookPayload, PromptPrehookPayload +# from apex.mcp.entities.models import HookType, Message, PromptResult, Role, TextContent, PromptPosthookPayload, PromptPrehookPayload # import apex.mcp.entities.models as apex # import mcpgateway.plugins.tools.models as apex -from mcpgateway.plugins.framework import PromptHookType, ToolHookType, HttpHeaderPayload, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload +from mcpgateway.plugins.framework import ( + ToolHookType, + PromptPrehookPayload, + ToolPostInvokePayload, + ToolPreInvokePayload, +) from mcpgateway.plugins.framework import PluginManager from mcpgateway.plugins.framework.models import GlobalContext # from apex.framework.manager import PluginManager -#from apex.framework.models import GlobalContext +# from apex.framework.models import GlobalContext # from plugins.regex_filter.search_replace import SearchReplaceConfig -log_level = os.environ.get('LOGLEVEL', 'INFO').upper() +log_level = os.environ.get("LOGLEVEL", "INFO").upper() logging.basicConfig(level=log_level) logger = logging.getLogger("ext-proc-PM") @@ -42,23 +42,25 @@ async def getToolPostInvokeResponse(body): - #FIXME: size of content array is expected to be 1 - #for content in body["result"]["content"]: + # FIXME: size of content array is expected to be 1 + # for content in body["result"]["content"]: logger.debug("**** Tool Post Invoke ****") - payload = ToolPostInvokePayload(name="replaceme", result = body) - #TODO: hard-coded ids + payload = ToolPostInvokePayload(name="replaceme", result=body) + # TODO: hard-coded ids logger.debug("**** Tool Post Invoke result ****") logger.deub(payload) global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, global_context=global_context) + result, contexts = await manager.invoke_hook( + ToolHookType.TOOL_POST_INVOKE, payload, global_context=global_context + ) logger.info(result) if not result.continue_processing: body_resp = ep.ProcessingResponse( immediate_response=ep.ImmediateResponse( - #TODO: hard-coded error reason - status=http_status_pb2.HttpStatus(code = http_status_pb2.Forbidden), - details="No go" + # TODO: hard-coded error reason + status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden), + details="No go", ) ) else: @@ -70,40 +72,42 @@ async def getToolPostInvokeResponse(body): body_resp = ep.ProcessingResponse( request_body=ep.BodyResponse( response=ep.CommonResponse( - body_mutation=ep.BodyMutation( - body=json.dumps(body).encode("utf-8") - ) + body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8")) ) ) ) return body_resp + async def getToolPreInvokeResponse(body): logger.debug(body) - payload_args = { "tool_name": body["params"]['name'], - "tool_args": body["params"]["arguments"], - "session_id": "replaceme" - } - payload = ToolPreInvokePayload(name=body["params"]["name"], args = payload_args) - #TODO: hard-coded ids + payload_args = { + "tool_name": body["params"]["name"], + "tool_args": body["params"]["arguments"], + "session_id": "replaceme", + } + payload = ToolPreInvokePayload(name=body["params"]["name"], args=payload_args) + # TODO: hard-coded ids global_context = GlobalContext(request_id="1", server_id="2") logger.debug("**** Invoking Tool Pre Invoke with payload ****") logger.debug(payload) - result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, payload, global_context=global_context) + result, contexts = await manager.invoke_hook( + ToolHookType.TOOL_PRE_INVOKE, payload, global_context=global_context + ) logger.debug("**** Tool Pre Invoke Result ****") logger.info(result) if not result.continue_processing: body_resp = ep.ProcessingResponse( immediate_response=ep.ImmediateResponse( - status=http_status_pb2.HttpStatus(code = http_status_pb2.Forbidden), - details="No go" + status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden), + details="No go", ) ) else: logger.debug(result) print(result) result_payload = result.modified_payload - if result_payload is not None and result_payload.args is not None: + if result_payload is not None and result_payload.args is not None: body["params"]["arguments"] = result_payload.args # else: # body["params"]["arguments"] = None @@ -111,9 +115,7 @@ async def getToolPreInvokeResponse(body): body_resp = ep.ProcessingResponse( request_body=ep.BodyResponse( response=ep.CommonResponse( - body_mutation=ep.BodyMutation( - body=json.dumps(body).encode("utf-8") - ) + body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8")) ) ) ) @@ -121,17 +123,22 @@ async def getToolPreInvokeResponse(body): logger.info(body_resp) return body_resp + async def getPromptPreFetchResponse(body): - prompt = PromptPrehookPayload(name=body["params"]["name"], args = body["params"]["arguments"]) - #TODO: hard-coded ids + prompt = PromptPrehookPayload( + name=body["params"]["name"], args=body["params"]["arguments"] + ) + # TODO: hard-coded ids global_context = GlobalContext(request_id="1", server_id="2") - result, contexts = await manager.invoke_hook(ToolHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context) + result, contexts = await manager.invoke_hook( + ToolHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context + ) logger.info(result) if not result.continue_processing: body_resp = ep.ProcessingResponse( immediate_response=ep.ImmediateResponse( - status=http_status_pb2.HttpStatus(code = http_status_pb2.Forbidden), - details="No go" + status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden), + details="No go", ) ) else: @@ -139,9 +146,7 @@ async def getPromptPreFetchResponse(body): body_resp = ep.ProcessingResponse( request_body=ep.BodyResponse( response=ep.CommonResponse( - body_mutation=ep.BodyMutation( - body=json.dumps(body).encode("utf-8") - ) + body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8")) ) ) ) @@ -149,8 +154,11 @@ async def getPromptPreFetchResponse(body): logger.info(body_resp) return body_resp + class ExtProcServicer(ep_grpc.ExternalProcessorServicer): - async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], context) -> AsyncIterator[ep.ProcessingResponse]: + async def Process( + self, request_iterator: AsyncIterator[ep.ProcessingRequest], context + ) -> AsyncIterator[ep.ProcessingResponse]: req_body_buf = bytearray() resp_body_buf = bytearray() @@ -158,7 +166,7 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c # logger.info(request) if request.HasField("request_headers"): # Modify request headers - headers = request.request_headers.headers + _headers = request.request_headers.headers yield ep.ProcessingResponse( request_headers=ep.HeadersResponse( response=ep.CommonResponse( @@ -166,9 +174,12 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c set_headers=[ core.HeaderValueOption( header=core.HeaderValue( - key="x-ext-proc-header", raw_value="hello-from-ext-proc".encode('utf-8') + key="x-ext-proc-header", + raw_value="hello-from-ext-proc".encode( + "utf-8" + ), ), - append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD + append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD, ) ] ) @@ -177,7 +188,7 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c ) elif request.HasField("response_headers"): # Modify response headers - headers = request.response_headers.headers + _headers = request.response_headers.headers yield ep.ProcessingResponse( response_headers=ep.HeadersResponse( response=ep.CommonResponse( @@ -185,9 +196,12 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c set_headers=[ core.HeaderValueOption( header=core.HeaderValue( - key="x-ext-proc-response-header", raw_value="processed-by-ext-proc".encode('utf-8') + key="x-ext-proc-response-header", + raw_value="processed-by-ext-proc".encode( + "utf-8" + ), ), - append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD + append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD, ) ] ) @@ -207,9 +221,9 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c else: logger.info(json.loads(text)) body = json.loads(text) - if 'method' in body and body['method'] == "tools/call": + if "method" in body and body["method"] == "tools/call": body_resp = await getToolPreInvokeResponse(body) - elif 'method' in body and body['method'] == "prompts/get": + elif "method" in body and body["method"] == "prompts/get": body_resp = await getPromptPreFetchResponse(body) else: body_resp = ep.ProcessingResponse( @@ -233,13 +247,13 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c logger.debug("Response body not UTF-8; skipping") else: logger.info(text.split("\n")) - #find data key + # find data key data = [d for d in text.split("\n") if d.startswith("data:")] - #logger.info(json.loads(data[0].strip("data:"))) - if data: #List can be empty + # logger.info(json.loads(data[0].strip("data:"))) + if data: # List can be empty data = json.loads(data[0].strip("data:")) - #TODO: check for tool call - if 'result' in data: + # TODO: check for tool call + if "result" in data: body_resp = await getToolPostInvokeResponse(data) else: body_resp = ep.ProcessingResponse( @@ -254,12 +268,13 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c else: logger.warn("Not processed") + async def serve(host: str = "0.0.0.0", port: int = 50052): await manager.initialize() logger.info(manager.config) server = grpc.aio.server() - #server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + # server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) ep_grpc.add_ExternalProcessorServicer_to_server(ExtProcServicer(), server) listen_addr = f"{host}:{port}" server.add_insecure_port(listen_addr) @@ -268,15 +283,17 @@ async def serve(host: str = "0.0.0.0", port: int = 50052): # wait forever await server.wait_for_termination() + if __name__ == "__main__": try: logging.getLogger("mcpgateway.config").setLevel(logging.DEBUG) logging.getLogger("mcpgateway.observability").setLevel(logging.DEBUG) logger.info("Manager main") - pm_config = os.environ.get('PLUGIN_MANAGER_CONFIG', './resources/config/config.yaml') + pm_config = os.environ.get( + "PLUGIN_MANAGER_CONFIG", "./resources/config/config.yaml" + ) manager = PluginManager(pm_config) asyncio.run(serve()) - #serve() + # serve() except KeyboardInterrupt: logger.info("Shutting down") -