Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
uv sync --extra dev
```

## Versioning
- Bump the version in `pyproject.toml` for any user-facing change.
- Update `uv.lock` after version or dependency changes by running `uv lock`.

## Checks
- Run linting with `uv run ruff check .`.
- Run the test suite with `uv run pytest` and ensure it passes before committing.
Expand Down
12 changes: 4 additions & 8 deletions mcp_plex/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]:
points, _ = await _client.scroll(
collection_name="media-items",
limit=limit,
filter=flt,
scroll_filter=flt,
with_payload=True,
)
return points
Expand Down Expand Up @@ -174,18 +174,13 @@ async def search_media(
) -> list[dict[str, Any]]:
"""Hybrid similarity search across media items using dense and sparse vectors."""
dense_task = asyncio.to_thread(lambda: list(_dense_model.embed([query]))[0])
sparse_task = asyncio.to_thread(lambda: next(_sparse_model.query_embed(query)))
dense_vec, sparse_vec = await asyncio.gather(dense_task, sparse_task)
dense_vec = await dense_task
named_dense = models.NamedVector(name="dense", vector=dense_vec)
sv = models.SparseVector(
indices=sparse_vec.indices.tolist(), values=sparse_vec.values.tolist()
)
named_sparse = models.NamedSparseVector(name="sparse", vector=sv)
candidate_limit = limit * 3 if _reranker is not None else limit
hits = await _client.search(
collection_name="media-items",
query_vector=named_dense,
query_sparse_vector=named_sparse,
query_filter=None,
limit=candidate_limit,
with_payload=True,
)
Expand Down Expand Up @@ -261,6 +256,7 @@ async def recommend_media(
positive=[record.id],
limit=limit,
with_payload=True,
using="dense",
)
return [r.payload["data"] for r in recs]

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
dev = [
"ruff>=0.12.0",
"pytest>=8.4.1",
"sentence-transformers>=2.7.0",
]

[project.scripts]
Expand Down
159 changes: 22 additions & 137 deletions tests/test_loader_integration.py
Original file line number Diff line number Diff line change
@@ -1,155 +1,40 @@
from __future__ import annotations

import asyncio
from pathlib import Path
from types import SimpleNamespace

from qdrant_client import models
from qdrant_client.async_qdrant_client import AsyncQdrantClient

from mcp_plex import loader


class DummyTextEmbedding:
def __init__(self, name: str):
self.embedding_size = 3

def embed(self, texts):
for _ in texts:
yield [0.1, 0.2, 0.3]


class DummyArray(list):
def tolist(self):
return list(self)


class DummySparseVector:
def __init__(self, indices, values):
self.indices = DummyArray(indices)
self.values = DummyArray(values)


class DummySparseEmbedding:
def __init__(self, name: str):
pass

def passage_embed(self, texts):
for i, _ in enumerate(texts):
yield DummySparseVector([i], [1.0])


class DummyQdrantClient:
instance = None

def __init__(self, url: str | None = None, api_key: str | None = None, **kwargs):
self.collections = {}
self.upserted = []
self.kwargs = kwargs
DummyQdrantClient.instance = self

async def collection_exists(self, name: str) -> bool:
return name in self.collections

async def get_collection(self, name: str):
return self.collections[name]

async def delete_collection(self, name: str):
self.collections.pop(name, None)

async def create_collection(self, collection_name: str, vectors_config, sparse_vectors_config):
size = vectors_config["dense"].size
params = SimpleNamespace(vectors={"dense": models.VectorParams(size=size, distance=models.Distance.COSINE)})
self.collections[collection_name] = SimpleNamespace(config=SimpleNamespace(params=params))

async def create_payload_index(self, **kwargs):
return None

async def upsert(self, collection_name: str, points):
self.upserted.extend(points)
class CaptureClient(AsyncQdrantClient):
instance: "CaptureClient" | None = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
CaptureClient.instance = self

class TrackingQdrantClient(DummyQdrantClient):
"""Qdrant client that starts with a mismatched collection size."""

def __init__(self, url: str | None = None, api_key: str | None = None, **kwargs):
super().__init__(url, api_key, **kwargs)
# Pre-create a collection with the wrong vector size to force recreation
wrong_params = SimpleNamespace(
vectors={
"dense": models.VectorParams(size=99, distance=models.Distance.COSINE)
}
)
self.collections["media-items"] = SimpleNamespace(
config=SimpleNamespace(params=wrong_params)
)
self.deleted = False

async def delete_collection(self, name: str):
self.deleted = True
await super().delete_collection(name)


async def _run_loader(sample_dir: Path):
await loader.run(None, None, None, sample_dir, None, None)
async def _run_loader(sample_dir: Path) -> None:
await loader.run(
None,
None,
None,
sample_dir,
None,
None,
)


def test_run_writes_points(monkeypatch):
monkeypatch.setattr(loader, "TextEmbedding", DummyTextEmbedding)
monkeypatch.setattr(loader, "SparseTextEmbedding", DummySparseEmbedding)
monkeypatch.setattr(loader, "AsyncQdrantClient", DummyQdrantClient)
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
asyncio.run(_run_loader(sample_dir))
client = DummyQdrantClient.instance
assert client is not None
assert len(client.upserted) == 2
payloads = [p.payload for p in client.upserted]
assert all("title" in p and "type" in p for p in payloads)


def test_run_recreates_mismatched_collection(monkeypatch):
monkeypatch.setattr(loader, "TextEmbedding", DummyTextEmbedding)
monkeypatch.setattr(loader, "SparseTextEmbedding", DummySparseEmbedding)
monkeypatch.setattr(loader, "AsyncQdrantClient", TrackingQdrantClient)
monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient)
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
asyncio.run(_run_loader(sample_dir))
client = TrackingQdrantClient.instance
client = CaptureClient.instance
assert client is not None
# The pre-created collection should have been deleted and recreated
assert client.deleted is True
assert (
client.collections["media-items"].config.params.vectors["dense"].size
== 3
)


def test_run_uses_connection_options(monkeypatch):
monkeypatch.setattr(loader, "TextEmbedding", DummyTextEmbedding)
monkeypatch.setattr(loader, "SparseTextEmbedding", DummySparseEmbedding)
points, _ = asyncio.run(client.scroll("media-items", limit=10, with_payload=True))
assert len(points) == 2
assert all("title" in p.payload and "type" in p.payload for p in points)

captured = {}

class CaptureClient(DummyQdrantClient):
def __init__(self, url: str | None = None, api_key: str | None = None, **kwargs):
super().__init__(url, api_key, **kwargs)
captured.update(kwargs)

monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient)
sample_dir = Path(__file__).resolve().parents[1] / "sample-data"
asyncio.run(
loader.run(
None,
None,
None,
sample_dir,
None,
None,
qdrant_host="example",
qdrant_port=1111,
qdrant_grpc_port=2222,
qdrant_https=True,
qdrant_prefer_grpc=True,
)
)
assert captured["host"] == "example"
assert captured["port"] == 1111
assert captured["grpc_port"] == 2222
assert captured["https"] is True
assert captured["prefer_grpc"] is True
Loading