Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: CI

on:
# Triggers the workflow on push or pull request events but only for the "main" branch
pull_request:
branches: [ "main" ]

jobs:
build:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ["3.11", "3.12"]

steps:
- name: Checkout repository
uses: actions/checkout@v6

# Sets up a specific version of Python
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pre-commit uv
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi

- name: Run pre-commit hooks
run: |
echo "Running pre-commit hooks..."
pre-commit run --all-files --verbose || {
echo "❌ Pre-commit hooks failed!"
echo ""
echo "Files modified by hooks:"
git diff --name-only
echo ""
echo "Detailed changes:"
git diff --stat
exit 1
}

# Future: testing
# - name: Test
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.8
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
143 changes: 80 additions & 63 deletions src/server.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
# Translated from Rust WASM filter with help of generative model
# Standard
import asyncio
import re
import logging
from typing import AsyncIterator
from concurrent import futures

import json
import os
import grpc

from envoy.service.ext_proc.v3 import external_processor_pb2 as ep
from envoy.service.ext_proc.v3 import external_processor_pb2_grpc as ep_grpc
from envoy.config.core.v3 import base_pb2 as core
from envoy.type.v3 import http_status_pb2 as http_status_pb2


# plugin manager
import sys
import json
# sys.path.append("/app/apex")
import os
# First-Party
#from apex.mcp.entities.models import HookType, Message, PromptResult, Role, TextContent, PromptPosthookPayload, PromptPrehookPayload
# from apex.mcp.entities.models import HookType, Message, PromptResult, Role, TextContent, PromptPosthookPayload, PromptPrehookPayload
# import apex.mcp.entities.models as apex
# import mcpgateway.plugins.tools.models as apex
from mcpgateway.plugins.framework import PromptHookType, ToolHookType, HttpHeaderPayload, PromptPosthookPayload, PromptPrehookPayload, ToolPostInvokePayload, ToolPreInvokePayload
from mcpgateway.plugins.framework import (
ToolHookType,
PromptPrehookPayload,
ToolPostInvokePayload,
ToolPreInvokePayload,
)
from mcpgateway.plugins.framework import PluginManager
from mcpgateway.plugins.framework.models import GlobalContext
# from apex.framework.manager import PluginManager
#from apex.framework.models import GlobalContext
# from apex.framework.models import GlobalContext
# from plugins.regex_filter.search_replace import SearchReplaceConfig

log_level = os.environ.get('LOGLEVEL', 'INFO').upper()
log_level = os.environ.get("LOGLEVEL", "INFO").upper()

logging.basicConfig(level=log_level)
logger = logging.getLogger("ext-proc-PM")
Expand All @@ -42,23 +42,25 @@


async def getToolPostInvokeResponse(body):
#FIXME: size of content array is expected to be 1
#for content in body["result"]["content"]:
# FIXME: size of content array is expected to be 1
# for content in body["result"]["content"]:

logger.debug("**** Tool Post Invoke ****")
payload = ToolPostInvokePayload(name="replaceme", result = body)
#TODO: hard-coded ids
payload = ToolPostInvokePayload(name="replaceme", result=body)
# TODO: hard-coded ids
logger.debug("**** Tool Post Invoke result ****")
logger.deub(payload)
global_context = GlobalContext(request_id="1", server_id="2")
result, contexts = await manager.invoke_hook(ToolHookType.TOOL_POST_INVOKE, payload, global_context=global_context)
result, contexts = await manager.invoke_hook(
ToolHookType.TOOL_POST_INVOKE, payload, global_context=global_context
)
logger.info(result)
if not result.continue_processing:
body_resp = ep.ProcessingResponse(
immediate_response=ep.ImmediateResponse(
#TODO: hard-coded error reason
status=http_status_pb2.HttpStatus(code = http_status_pb2.Forbidden),
details="No go"
# TODO: hard-coded error reason
status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden),
details="No go",
)
)
else:
Expand All @@ -70,105 +72,114 @@ async def getToolPostInvokeResponse(body):
body_resp = ep.ProcessingResponse(
request_body=ep.BodyResponse(
response=ep.CommonResponse(
body_mutation=ep.BodyMutation(
body=json.dumps(body).encode("utf-8")
)
body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8"))
)
)
)
return body_resp


async def getToolPreInvokeResponse(body):
logger.debug(body)
payload_args = { "tool_name": body["params"]['name'],
"tool_args": body["params"]["arguments"],
"session_id": "replaceme"
}
payload = ToolPreInvokePayload(name=body["params"]["name"], args = payload_args)
#TODO: hard-coded ids
payload_args = {
"tool_name": body["params"]["name"],
"tool_args": body["params"]["arguments"],
"session_id": "replaceme",
}
payload = ToolPreInvokePayload(name=body["params"]["name"], args=payload_args)
# TODO: hard-coded ids
global_context = GlobalContext(request_id="1", server_id="2")
logger.debug("**** Invoking Tool Pre Invoke with payload ****")
logger.debug(payload)
result, contexts = await manager.invoke_hook(ToolHookType.TOOL_PRE_INVOKE, payload, global_context=global_context)
result, contexts = await manager.invoke_hook(
ToolHookType.TOOL_PRE_INVOKE, payload, global_context=global_context
)
logger.debug("**** Tool Pre Invoke Result ****")
logger.info(result)
if not result.continue_processing:
body_resp = ep.ProcessingResponse(
immediate_response=ep.ImmediateResponse(
status=http_status_pb2.HttpStatus(code = http_status_pb2.Forbidden),
details="No go"
status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden),
details="No go",
)
)
else:
logger.debug(result)
print(result)
result_payload = result.modified_payload
if result_payload is not None and result_payload.args is not None:
if result_payload is not None and result_payload.args is not None:
body["params"]["arguments"] = result_payload.args
# else:
# body["params"]["arguments"] = None

body_resp = ep.ProcessingResponse(
request_body=ep.BodyResponse(
response=ep.CommonResponse(
body_mutation=ep.BodyMutation(
body=json.dumps(body).encode("utf-8")
)
body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8"))
)
)
)
logger.info("****Tool Pre Invoke Return body****")
logger.info(body_resp)
return body_resp


async def getPromptPreFetchResponse(body):
prompt = PromptPrehookPayload(name=body["params"]["name"], args = body["params"]["arguments"])
#TODO: hard-coded ids
prompt = PromptPrehookPayload(
name=body["params"]["name"], args=body["params"]["arguments"]
)
# TODO: hard-coded ids
global_context = GlobalContext(request_id="1", server_id="2")
result, contexts = await manager.invoke_hook(ToolHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context)
result, contexts = await manager.invoke_hook(
ToolHookType.PROMPT_PRE_FETCH, prompt, global_context=global_context
)
logger.info(result)
if not result.continue_processing:
body_resp = ep.ProcessingResponse(
immediate_response=ep.ImmediateResponse(
status=http_status_pb2.HttpStatus(code = http_status_pb2.Forbidden),
details="No go"
status=http_status_pb2.HttpStatus(code=http_status_pb2.Forbidden),
details="No go",
)
)
else:
body["params"]["arguments"] = result.modified_payload.args
body_resp = ep.ProcessingResponse(
request_body=ep.BodyResponse(
response=ep.CommonResponse(
body_mutation=ep.BodyMutation(
body=json.dumps(body).encode("utf-8")
)
body_mutation=ep.BodyMutation(body=json.dumps(body).encode("utf-8"))
)
)
)
logger.info("****body ")
logger.info(body_resp)
return body_resp


class ExtProcServicer(ep_grpc.ExternalProcessorServicer):
async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], context) -> AsyncIterator[ep.ProcessingResponse]:
async def Process(
self, request_iterator: AsyncIterator[ep.ProcessingRequest], context
) -> AsyncIterator[ep.ProcessingResponse]:
req_body_buf = bytearray()
resp_body_buf = bytearray()

async for request in request_iterator:
# logger.info(request)
if request.HasField("request_headers"):
# Modify request headers
headers = request.request_headers.headers
_headers = request.request_headers.headers
yield ep.ProcessingResponse(
request_headers=ep.HeadersResponse(
response=ep.CommonResponse(
header_mutation=ep.HeaderMutation(
set_headers=[
core.HeaderValueOption(
header=core.HeaderValue(
key="x-ext-proc-header", raw_value="hello-from-ext-proc".encode('utf-8')
key="x-ext-proc-header",
raw_value="hello-from-ext-proc".encode(
"utf-8"
),
),
append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD
append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD,
)
]
)
Expand All @@ -177,17 +188,20 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c
)
elif request.HasField("response_headers"):
# Modify response headers
headers = request.response_headers.headers
_headers = request.response_headers.headers
yield ep.ProcessingResponse(
response_headers=ep.HeadersResponse(
response=ep.CommonResponse(
header_mutation=ep.HeaderMutation(
set_headers=[
core.HeaderValueOption(
header=core.HeaderValue(
key="x-ext-proc-response-header", raw_value="processed-by-ext-proc".encode('utf-8')
key="x-ext-proc-response-header",
raw_value="processed-by-ext-proc".encode(
"utf-8"
),
),
append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD
append_action=core.HeaderValueOption.APPEND_IF_EXISTS_OR_ADD,
)
]
)
Expand All @@ -207,9 +221,9 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c
else:
logger.info(json.loads(text))
body = json.loads(text)
if 'method' in body and body['method'] == "tools/call":
if "method" in body and body["method"] == "tools/call":
body_resp = await getToolPreInvokeResponse(body)
elif 'method' in body and body['method'] == "prompts/get":
elif "method" in body and body["method"] == "prompts/get":
body_resp = await getPromptPreFetchResponse(body)
else:
body_resp = ep.ProcessingResponse(
Expand All @@ -233,13 +247,13 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c
logger.debug("Response body not UTF-8; skipping")
else:
logger.info(text.split("\n"))
#find data key
# find data key
data = [d for d in text.split("\n") if d.startswith("data:")]
#logger.info(json.loads(data[0].strip("data:")))
if data: #List can be empty
# logger.info(json.loads(data[0].strip("data:")))
if data: # List can be empty
data = json.loads(data[0].strip("data:"))
#TODO: check for tool call
if 'result' in data:
# TODO: check for tool call
if "result" in data:
body_resp = await getToolPostInvokeResponse(data)
else:
body_resp = ep.ProcessingResponse(
Expand All @@ -254,12 +268,13 @@ async def Process(self, request_iterator: AsyncIterator[ep.ProcessingRequest], c
else:
logger.warn("Not processed")


async def serve(host: str = "0.0.0.0", port: int = 50052):
await manager.initialize()
logger.info(manager.config)

server = grpc.aio.server()
#server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
# server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
ep_grpc.add_ExternalProcessorServicer_to_server(ExtProcServicer(), server)
listen_addr = f"{host}:{port}"
server.add_insecure_port(listen_addr)
Expand All @@ -268,15 +283,17 @@ async def serve(host: str = "0.0.0.0", port: int = 50052):
# wait forever
await server.wait_for_termination()


if __name__ == "__main__":
try:
logging.getLogger("mcpgateway.config").setLevel(logging.DEBUG)
logging.getLogger("mcpgateway.observability").setLevel(logging.DEBUG)
logger.info("Manager main")
pm_config = os.environ.get('PLUGIN_MANAGER_CONFIG', './resources/config/config.yaml')
pm_config = os.environ.get(
"PLUGIN_MANAGER_CONFIG", "./resources/config/config.yaml"
)
manager = PluginManager(pm_config)
asyncio.run(serve())
#serve()
# serve()
except KeyboardInterrupt:
logger.info("Shutting down")