From deede45db4fa740f6d9198271301197e4b87d211 Mon Sep 17 00:00:00 2001 From: Teagan Glenn Date: Sat, 11 Oct 2025 17:39:02 -0600 Subject: [PATCH] refactor(docs): describe reranker model overrides --- README.md | 7 +++++++ mcp_plex/server/__init__.py | 2 +- mcp_plex/server/cli.py | 6 ++++++ mcp_plex/server/config.py | 4 ++++ tests/test_server.py | 6 +++++- tests/test_server_config_additional.py | 11 +++++++++++ 6 files changed, 34 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 19d93bd..1111d16 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,13 @@ uv run load-data --dense-model my-dense --sparse-model my-sparse uv run mcp-server --dense-model my-dense --sparse-model my-sparse ``` +Cross-encoder reranking defaults to `cross-encoder/ms-marco-MiniLM-L-6-v2`. +Set `RERANKER_MODEL` or pass `--reranker-model` to point at a different model: + +```bash +uv run mcp-server --reranker-model sentence-transformers/ms-marco-mini +``` + ## Docker A Dockerfile builds a GPU-enabled image based on `nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04` using `uv` for dependency diff --git a/mcp_plex/server/__init__.py b/mcp_plex/server/__init__.py index 461e033..70c2208 100644 --- a/mcp_plex/server/__init__.py +++ b/mcp_plex/server/__init__.py @@ -170,7 +170,7 @@ async def ensure_reranker(self) -> CrossEncoder | None: return self._reranker try: reranker = await asyncio.to_thread( - CrossEncoder, "cross-encoder/ms-marco-MiniLM-L-6-v2" + CrossEncoder, self.settings.reranker_model ) except Exception as exc: logger.warning( diff --git a/mcp_plex/server/cli.py b/mcp_plex/server/cli.py index 9ec13b6..d543982 100644 --- a/mcp_plex/server/cli.py +++ b/mcp_plex/server/cli.py @@ -56,6 +56,11 @@ def main(argv: list[str] | None = None) -> None: default=settings.sparse_model, help="Sparse embedding model name (env: SPARSE_MODEL)", ) + parser.add_argument( + "--reranker-model", + default=settings.reranker_model, + help="Cross-encoder reranker model name (env: RERANKER_MODEL)", + ) args = parser.parse_args(argv) env_transport = os.getenv("MCP_TRANSPORT") @@ -105,6 +110,7 @@ def main(argv: list[str] | None = None) -> None: settings.dense_model = args.dense_model settings.sparse_model = args.sparse_model + settings.reranker_model = args.reranker_model plex_server.run(transport=transport, **run_config.to_kwargs()) diff --git a/mcp_plex/server/config.py b/mcp_plex/server/config.py index 3bc6cfc..d930833 100644 --- a/mcp_plex/server/config.py +++ b/mcp_plex/server/config.py @@ -35,6 +35,10 @@ class Settings(BaseSettings): ) cache_size: int = Field(default=128, validation_alias="CACHE_SIZE") use_reranker: bool = Field(default=True, validation_alias="USE_RERANKER") + reranker_model: str = Field( + default="cross-encoder/ms-marco-MiniLM-L-6-v2", + validation_alias="RERANKER_MODEL", + ) plex_url: AnyHttpUrl | None = Field(default=None, validation_alias="PLEX_URL") plex_token: str | None = Field(default=None, validation_alias="PLEX_TOKEN") plex_player_aliases: PlexPlayerAliasMap = Field( diff --git a/tests/test_server.py b/tests/test_server.py index 7a3f2cd..1cdcf26 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -541,6 +541,10 @@ def __init__(self, *args, **kwargs): def test_ensure_reranker_uses_thread_executor(monkeypatch): + monkeypatch.setenv( + "RERANKER_MODEL", + "sentence-transformers/test-cross-encoder", + ) module, Dummy = _reload_server_with_dummy_reranker(monkeypatch) calls: list[tuple[object, tuple[object, ...], dict[str, object]]] = [] @@ -560,7 +564,7 @@ async def exercise(): assert len(calls) == 1 fn, args, _ = calls[0] assert fn is Dummy - assert args == ("cross-encoder/ms-marco-MiniLM-L-6-v2",) + assert args == ("sentence-transformers/test-cross-encoder",) asyncio.run(module.server.close()) diff --git a/tests/test_server_config_additional.py b/tests/test_server_config_additional.py index 0648644..9075843 100644 --- a/tests/test_server_config_additional.py +++ b/tests/test_server_config_additional.py @@ -5,6 +5,17 @@ from mcp_plex.server.config import Settings +def test_settings_defaults_reranker_model(): + settings = Settings() + assert settings.reranker_model == "cross-encoder/ms-marco-MiniLM-L-6-v2" + + +def test_settings_env_alias_for_reranker_model(monkeypatch): + monkeypatch.setenv("RERANKER_MODEL", "sentence-transformers/custom") + settings = Settings() + assert settings.reranker_model == "sentence-transformers/custom" + + def test_parse_aliases_rejects_invalid_json(): with pytest.raises(ValueError): Settings._parse_aliases("not json")