From 9e522c55aa411d2924db7ed8dc6b687189430b74 Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Sun, 31 Aug 2025 04:52:04 -0600 Subject: [PATCH] feat: allow configuring qdrant connection --- README.md | 10 ++++ mcp_plex/loader.py | 78 +++++++++++++++++++++++++++++++- mcp_plex/server.py | 21 ++++++++- tests/test_loader_integration.py | 42 +++++++++++++++-- tests/test_server.py | 26 +++++++++++ 5 files changed, 171 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index d409701..24b54ac 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,16 @@ The included `docker-compose.yml` launches both Qdrant and the MCP server. The server will connect to the `qdrant` service at `http://qdrant:6333` and expose an SSE endpoint at `http://localhost:8000/mcp`. +### Qdrant Configuration + +Connection settings can be provided via environment variables: + +- `QDRANT_URL` – full URL or SQLite path. +- `QDRANT_HOST`/`QDRANT_PORT` – HTTP host and port. +- `QDRANT_GRPC_PORT` – gRPC port. +- `QDRANT_HTTPS` – set to `1` to enable HTTPS. +- `QDRANT_PREFER_GRPC` – set to `1` to prefer gRPC. + ## Development Run linting and tests through `uv`: ```bash diff --git a/mcp_plex/loader.py b/mcp_plex/loader.py index 6a7275b..0ec6883 100644 --- a/mcp_plex/loader.py +++ b/mcp_plex/loader.py @@ -285,6 +285,11 @@ async def run( sample_dir: Optional[Path], qdrant_url: Optional[str], qdrant_api_key: Optional[str], + qdrant_host: Optional[str] = None, + qdrant_port: int = 6333, + qdrant_grpc_port: int = 6334, + qdrant_https: bool = False, + qdrant_prefer_grpc: bool = False, ) -> None: """Core execution logic for the CLI.""" @@ -323,7 +328,17 @@ async def run( dense_vectors = list(dense_model.embed(texts)) sparse_vectors = list(sparse_model.passage_embed(texts)) - client = AsyncQdrantClient(qdrant_url or ":memory:", api_key=qdrant_api_key) + if qdrant_url is None and qdrant_host is None: + qdrant_url = ":memory:" + client = AsyncQdrantClient( + location=qdrant_url, + api_key=qdrant_api_key, + host=qdrant_host, + port=qdrant_port, + grpc_port=qdrant_grpc_port, + https=qdrant_https, + prefer_grpc=qdrant_prefer_grpc, + ) collection_name = "media-items" vectors_config = { "dense": models.VectorParams( @@ -456,6 +471,47 @@ async def run( required=False, help="Qdrant API key", ) +@click.option( + "--qdrant-host", + envvar="QDRANT_HOST", + show_envvar=True, + required=False, + help="Qdrant host", +) +@click.option( + "--qdrant-port", + envvar="QDRANT_PORT", + show_envvar=True, + type=int, + default=6333, + show_default=True, + required=False, + help="Qdrant HTTP port", +) +@click.option( + "--qdrant-grpc-port", + envvar="QDRANT_GRPC_PORT", + show_envvar=True, + type=int, + default=6334, + show_default=True, + required=False, + help="Qdrant gRPC port", +) +@click.option( + "--qdrant-https/--no-qdrant-https", + envvar="QDRANT_HTTPS", + show_envvar=True, + default=False, + help="Use HTTPS when connecting to Qdrant", +) +@click.option( + "--qdrant-prefer-grpc/--no-qdrant-prefer-grpc", + envvar="QDRANT_PREFER_GRPC", + show_envvar=True, + default=False, + help="Prefer gRPC when connecting to Qdrant", +) @click.option( "--continuous", is_flag=True, @@ -479,6 +535,11 @@ def main( sample_dir: Optional[Path], qdrant_url: Optional[str], qdrant_api_key: Optional[str], + qdrant_host: Optional[str], + qdrant_port: int, + qdrant_grpc_port: int, + qdrant_https: bool, + qdrant_prefer_grpc: bool, continuous: bool, delay: float, ) -> None: @@ -492,6 +553,11 @@ def main( sample_dir, qdrant_url, qdrant_api_key, + qdrant_host, + qdrant_port, + qdrant_grpc_port, + qdrant_https, + qdrant_prefer_grpc, continuous, delay, ) @@ -505,6 +571,11 @@ async def load_media( sample_dir: Optional[Path], qdrant_url: Optional[str], qdrant_api_key: Optional[str], + qdrant_host: Optional[str], + qdrant_port: int, + qdrant_grpc_port: int, + qdrant_https: bool, + qdrant_prefer_grpc: bool, continuous: bool, delay: float, ) -> None: @@ -518,6 +589,11 @@ async def load_media( sample_dir, qdrant_url, qdrant_api_key, + qdrant_host, + qdrant_port, + qdrant_grpc_port, + qdrant_https, + qdrant_prefer_grpc, ) if not continuous: break diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 375be10..2b40de3 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -25,11 +25,28 @@ CrossEncoder = None # Environment configuration for Qdrant -_QDRANT_URL = os.getenv("QDRANT_URL", ":memory:") +_QDRANT_URL = os.getenv("QDRANT_URL") _QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") +_QDRANT_HOST = os.getenv("QDRANT_HOST") +_QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) +_QDRANT_GRPC_PORT = int(os.getenv("QDRANT_GRPC_PORT", "6334")) +_QDRANT_PREFER_GRPC = os.getenv("QDRANT_PREFER_GRPC", "0") == "1" +_https_env = os.getenv("QDRANT_HTTPS") +_QDRANT_HTTPS = None if _https_env is None else _https_env == "1" + +if _QDRANT_URL is None and _QDRANT_HOST is None: + _QDRANT_URL = ":memory:" # Instantiate global client and embedding models -_client = AsyncQdrantClient(_QDRANT_URL, api_key=_QDRANT_API_KEY) +_client = AsyncQdrantClient( + location=_QDRANT_URL, + api_key=_QDRANT_API_KEY, + host=_QDRANT_HOST, + port=_QDRANT_PORT, + grpc_port=_QDRANT_GRPC_PORT, + prefer_grpc=_QDRANT_PREFER_GRPC, + https=_QDRANT_HTTPS, +) _dense_model = TextEmbedding("BAAI/bge-small-en-v1.5") _sparse_model = SparseTextEmbedding("Qdrant/bm42-all-minilm-l6-v2-attentions") diff --git a/tests/test_loader_integration.py b/tests/test_loader_integration.py index c2b97ee..cda92c8 100644 --- a/tests/test_loader_integration.py +++ b/tests/test_loader_integration.py @@ -39,9 +39,10 @@ def passage_embed(self, texts): class DummyQdrantClient: instance = None - def __init__(self, url: str, api_key: str | None = None): + def __init__(self, url: str | None = None, api_key: str | None = None, **kwargs): self.collections = {} self.upserted = [] + self.kwargs = kwargs DummyQdrantClient.instance = self async def collection_exists(self, name: str) -> bool: @@ -68,8 +69,8 @@ async def upsert(self, collection_name: str, points): class TrackingQdrantClient(DummyQdrantClient): """Qdrant client that starts with a mismatched collection size.""" - def __init__(self, url: str, api_key: str | None = None): - super().__init__(url, api_key) + def __init__(self, url: str | None = None, api_key: str | None = None, **kwargs): + super().__init__(url, api_key, **kwargs) # Pre-create a collection with the wrong vector size to force recreation wrong_params = SimpleNamespace( vectors={ @@ -117,3 +118,38 @@ def test_run_recreates_mismatched_collection(monkeypatch): client.collections["media-items"].config.params.vectors["dense"].size == 3 ) + + +def test_run_uses_connection_options(monkeypatch): + monkeypatch.setattr(loader, "TextEmbedding", DummyTextEmbedding) + monkeypatch.setattr(loader, "SparseTextEmbedding", DummySparseEmbedding) + + captured = {} + + class CaptureClient(DummyQdrantClient): + def __init__(self, url: str | None = None, api_key: str | None = None, **kwargs): + super().__init__(url, api_key, **kwargs) + captured.update(kwargs) + + monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient) + sample_dir = Path(__file__).resolve().parents[1] / "sample-data" + asyncio.run( + loader.run( + None, + None, + None, + sample_dir, + None, + None, + qdrant_host="example", + qdrant_port=1111, + qdrant_grpc_port=2222, + qdrant_https=True, + qdrant_prefer_grpc=True, + ) + ) + assert captured["host"] == "example" + assert captured["port"] == 1111 + assert captured["grpc_port"] == 2222 + assert captured["https"] is True + assert captured["prefer_grpc"] is True diff --git a/tests/test_server.py b/tests/test_server.py index f93e717..d4b7819 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -144,6 +144,32 @@ async def _setup_db(tmp_path: Path) -> str: return "dummy" +def test_qdrant_env_config(monkeypatch): + from qdrant_client import async_qdrant_client + + captured = {} + + class CaptureClient: + def __init__(self, *args, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr(async_qdrant_client, "AsyncQdrantClient", CaptureClient) + monkeypatch.setenv("QDRANT_HOST", "example.com") + monkeypatch.setenv("QDRANT_PORT", "1234") + monkeypatch.setenv("QDRANT_GRPC_PORT", "5678") + monkeypatch.setenv("QDRANT_PREFER_GRPC", "1") + monkeypatch.setenv("QDRANT_HTTPS", "1") + import importlib + import mcp_plex.server as server + importlib.reload(server) + + assert captured["host"] == "example.com" + assert captured["port"] == 1234 + assert captured["grpc_port"] == 5678 + assert captured["prefer_grpc"] is True + assert captured["https"] is True + + def test_server_tools(tmp_path, monkeypatch): # Patch embeddings and Qdrant client to use dummy implementations monkeypatch.setattr(loader, "TextEmbedding", DummyTextEmbedding)