Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions mcp_plex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from collections import OrderedDict
from typing import Annotated, Any

from fastembed import SparseTextEmbedding, TextEmbedding
from fastmcp.server import FastMCP
from pydantic import Field
from qdrant_client import models
from qdrant_client.async_qdrant_client import AsyncQdrantClient
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion

try:
from sentence_transformers import CrossEncoder
Expand Down Expand Up @@ -43,7 +43,7 @@
if _QDRANT_URL is None and _QDRANT_HOST is None:
_QDRANT_URL = ":memory:"

# Instantiate global client and embedding models
# Instantiate global client
_client = AsyncQdrantClient(
location=_QDRANT_URL,
api_key=_QDRANT_API_KEY,
Expand All @@ -53,8 +53,6 @@
prefer_grpc=_QDRANT_PREFER_GRPC,
https=_QDRANT_HTTPS,
)
_dense_model = TextEmbedding(_DENSE_MODEL_NAME)
_sparse_model = SparseTextEmbedding(_SPARSE_MODEL_NAME)

_USE_RERANKER = os.getenv("USE_RERANKER", "1") == "1"
_reranker = None
Expand Down Expand Up @@ -179,17 +177,26 @@ async def search_media(
] = 5,
) -> list[dict[str, Any]]:
"""Hybrid similarity search across media items using dense and sparse vectors."""
dense_task = asyncio.to_thread(lambda: list(_dense_model.embed([query]))[0])
dense_vec = await dense_task
named_dense = models.NamedVector(name="dense", vector=dense_vec)
dense_doc = models.Document(text=query, model=_DENSE_MODEL_NAME)
sparse_doc = models.Document(text=query, model=_SPARSE_MODEL_NAME)
candidate_limit = limit * 3 if _reranker is not None else limit
hits = await _client.search(
collection_name="media-items",
query_vector=named_dense,
query_filter=None,
limit=candidate_limit,
with_payload=True,
dense_resp, sparse_resp = await asyncio.gather(
_client.query_points(
collection_name="media-items",
query=dense_doc,
using="dense",
limit=candidate_limit,
with_payload=True,
),
_client.query_points(
collection_name="media-items",
query=sparse_doc,
using="sparse",
limit=candidate_limit,
with_payload=True,
),
)
hits = reciprocal_rank_fusion([dense_resp.points, sparse_resp.points], limit=candidate_limit)

async def _prefetch(hit: models.ScoredPoint) -> None:
data = hit.payload["data"]
Expand Down Expand Up @@ -468,6 +475,7 @@ async def media_background(

def main(argv: list[str] | None = None) -> None:
"""CLI entrypoint for running the MCP server."""
global _DENSE_MODEL_NAME, _SPARSE_MODEL_NAME
parser = argparse.ArgumentParser(description="Run the MCP server")
parser.add_argument("--bind", help="Host address to bind to")
parser.add_argument("--port", type=int, help="Port to listen on")
Expand Down Expand Up @@ -502,9 +510,8 @@ def main(argv: list[str] | None = None) -> None:
if args.mount:
run_kwargs["path"] = args.mount

global _dense_model, _sparse_model
_dense_model = TextEmbedding(args.dense_model)
_sparse_model = SparseTextEmbedding(args.sparse_model)
_DENSE_MODEL_NAME = args.dense_model
_SPARSE_MODEL_NAME = args.sparse_model

server.run(transport=args.transport, **run_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "mcp-plex"
version = "0.26.0"
version = "0.26.1"

description = "Plex-Oriented Model Context Protocol Server"
requires-python = ">=3.11,<4"
Expand Down
55 changes: 14 additions & 41 deletions tests/test_server_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,7 @@

import pytest

class _StubDense:
def __init__(self, *args, **kwargs) -> None:
pass

@staticmethod
def list_supported_models() -> list[str]:
return ["stub-dense"]


class _StubSparse:
def __init__(self, *args, **kwargs) -> None:
pass

@staticmethod
def list_supported_models() -> list[str]:
return ["stub"]


with patch("fastembed.TextEmbedding", _StubDense), patch(
"fastembed.SparseTextEmbedding", _StubSparse
):
from mcp_plex import server
from mcp_plex import server


def test_main_stdio_runs():
Expand Down Expand Up @@ -51,32 +30,26 @@ def test_main_http_with_mount_runs():


def test_main_model_overrides():
with patch("mcp_plex.server.TextEmbedding") as mock_dense, patch(
"mcp_plex.server.SparseTextEmbedding"
) as mock_sparse, patch.object(server.server, "run") as mock_run:
with patch.object(server.server, "run") as mock_run:
server.main([
"--dense-model",
"foo",
"--sparse-model",
"bar",
])
mock_dense.assert_called_with("foo")
mock_sparse.assert_called_with("bar")
assert server._DENSE_MODEL_NAME == "foo"
assert server._SPARSE_MODEL_NAME == "bar"
mock_run.assert_called_once_with(transport="stdio")


def test_env_model_overrides(monkeypatch):
with patch("fastembed.TextEmbedding") as mock_dense, patch(
"fastembed.SparseTextEmbedding"
) as mock_sparse:
monkeypatch.setenv("DENSE_MODEL", "foo")
monkeypatch.setenv("SPARSE_MODEL", "bar")
import importlib

importlib.reload(server)
mock_dense.assert_called_with("foo")
mock_sparse.assert_called_with("bar")
with patch("fastembed.TextEmbedding"), patch("fastembed.SparseTextEmbedding"):
import importlib

importlib.reload(server)
monkeypatch.setenv("DENSE_MODEL", "foo")
monkeypatch.setenv("SPARSE_MODEL", "bar")
import importlib

importlib.reload(server)
assert server._DENSE_MODEL_NAME == "foo"
assert server._SPARSE_MODEL_NAME == "bar"

# reload to reset globals
importlib.reload(server)
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.