Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ uv run load-data --dense-model my-dense --sparse-model my-sparse
uv run mcp-server --dense-model my-dense --sparse-model my-sparse
```

Cross-encoder reranking defaults to `cross-encoder/ms-marco-MiniLM-L-6-v2`.
Set `RERANKER_MODEL` or pass `--reranker-model` to point at a different model:

```bash
uv run mcp-server --reranker-model sentence-transformers/ms-marco-mini
```

## Docker
A Dockerfile builds a GPU-enabled image based on
`nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04` using `uv` for dependency
Expand Down
2 changes: 1 addition & 1 deletion mcp_plex/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def ensure_reranker(self) -> CrossEncoder | None:
return self._reranker
try:
reranker = await asyncio.to_thread(
CrossEncoder, "cross-encoder/ms-marco-MiniLM-L-6-v2"
CrossEncoder, self.settings.reranker_model
)
except Exception as exc:
logger.warning(
Expand Down
6 changes: 6 additions & 0 deletions mcp_plex/server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def main(argv: list[str] | None = None) -> None:
default=settings.sparse_model,
help="Sparse embedding model name (env: SPARSE_MODEL)",
)
parser.add_argument(
"--reranker-model",
default=settings.reranker_model,
help="Cross-encoder reranker model name (env: RERANKER_MODEL)",
)
args = parser.parse_args(argv)

env_transport = os.getenv("MCP_TRANSPORT")
Expand Down Expand Up @@ -105,6 +110,7 @@ def main(argv: list[str] | None = None) -> None:

settings.dense_model = args.dense_model
settings.sparse_model = args.sparse_model
settings.reranker_model = args.reranker_model

plex_server.run(transport=transport, **run_config.to_kwargs())

Expand Down
4 changes: 4 additions & 0 deletions mcp_plex/server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class Settings(BaseSettings):
)
cache_size: int = Field(default=128, validation_alias="CACHE_SIZE")
use_reranker: bool = Field(default=True, validation_alias="USE_RERANKER")
reranker_model: str = Field(
default="cross-encoder/ms-marco-MiniLM-L-6-v2",
validation_alias="RERANKER_MODEL",
)
plex_url: AnyHttpUrl | None = Field(default=None, validation_alias="PLEX_URL")
plex_token: str | None = Field(default=None, validation_alias="PLEX_TOKEN")
plex_player_aliases: PlexPlayerAliasMap = Field(
Expand Down
6 changes: 5 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,10 @@ def __init__(self, *args, **kwargs):


def test_ensure_reranker_uses_thread_executor(monkeypatch):
monkeypatch.setenv(
"RERANKER_MODEL",
"sentence-transformers/test-cross-encoder",
)
module, Dummy = _reload_server_with_dummy_reranker(monkeypatch)
calls: list[tuple[object, tuple[object, ...], dict[str, object]]] = []

Expand All @@ -560,7 +564,7 @@ async def exercise():
assert len(calls) == 1
fn, args, _ = calls[0]
assert fn is Dummy
assert args == ("cross-encoder/ms-marco-MiniLM-L-6-v2",)
assert args == ("sentence-transformers/test-cross-encoder",)
asyncio.run(module.server.close())


Expand Down
11 changes: 11 additions & 0 deletions tests/test_server_config_additional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
from mcp_plex.server.config import Settings


def test_settings_defaults_reranker_model():
settings = Settings()
assert settings.reranker_model == "cross-encoder/ms-marco-MiniLM-L-6-v2"


def test_settings_env_alias_for_reranker_model(monkeypatch):
monkeypatch.setenv("RERANKER_MODEL", "sentence-transformers/custom")
settings = Settings()
assert settings.reranker_model == "sentence-transformers/custom"


def test_parse_aliases_rejects_invalid_json():
with pytest.raises(ValueError):
Settings._parse_aliases("not json")
Expand Down