diff --git a/docker/pyproject.deps.toml b/docker/pyproject.deps.toml index 8347a4c..17b6cdb 100644 --- a/docker/pyproject.deps.toml +++ b/docker/pyproject.deps.toml @@ -1,6 +1,6 @@ [project] name = "mcp-plex" -version = "0.26.35" +version = "0.26.38" requires-python = ">=3.11,<3.13" dependencies = [ "fastmcp>=2.11.2", diff --git a/mcp_plex/loader.py b/mcp_plex/loader.py index bf17b2c..3a560a5 100644 --- a/mcp_plex/loader.py +++ b/mcp_plex/loader.py @@ -425,11 +425,32 @@ def _build_plex_item(item: PlexPartialObject) -> PlexItem: ) for a in getattr(item, "actors", []) or getattr(item, "roles", []) or [] ] + genres = [ + str(getattr(g, "tag", "")) + for g in getattr(item, "genres", []) or [] + if getattr(g, "tag", None) + ] + collections = [ + str(getattr(c, "tag", "")) + for c in getattr(item, "collections", []) or [] + if getattr(c, "tag", None) + ] + season_number = getattr(item, "parentIndex", None) + if isinstance(season_number, str): + season_number = int(season_number) if season_number.isdigit() else None + episode_number = getattr(item, "index", None) + if isinstance(episode_number, str): + episode_number = int(episode_number) if episode_number.isdigit() else None + return PlexItem( rating_key=str(getattr(item, "ratingKey", "")), guid=str(getattr(item, "guid", "")), type=str(getattr(item, "type", "")), title=str(getattr(item, "title", "")), + show_title=getattr(item, "grandparentTitle", None), + season_title=getattr(item, "parentTitle", None), + season_number=season_number, + episode_number=episode_number, summary=getattr(item, "summary", None), year=getattr(item, "year", None), added_at=getattr(item, "addedAt", None), @@ -441,6 +462,8 @@ def _build_plex_item(item: PlexPartialObject) -> PlexItem: directors=directors, writers=writers, actors=actors, + genres=genres, + collections=collections, ) @@ -556,6 +579,13 @@ def _load_from_sample(sample_dir: Path) -> List[AggregatedItem]: ) for a in movie_data.get("Role", []) ], + genres=[g.get("tag", "") for g in movie_data.get("Genre", []) if g.get("tag")], + collections=[ + c.get("tag", "") + for key in ("Collection", "Collections") + for c in movie_data.get(key, []) or [] + if c.get("tag") + ], ) with (movie_dir / "imdb.json").open("r", encoding="utf-8") as f: imdb_movie = IMDbTitle.model_validate(json.load(f)) @@ -571,6 +601,10 @@ def _load_from_sample(sample_dir: Path) -> List[AggregatedItem]: guid=str(episode_data.get("guid", "")), type=episode_data.get("type", "episode"), title=episode_data.get("title", ""), + show_title=episode_data.get("grandparentTitle"), + season_title=episode_data.get("parentTitle"), + season_number=episode_data.get("parentIndex"), + episode_number=episode_data.get("index"), summary=episode_data.get("summary"), year=episode_data.get("year"), added_at=episode_data.get("addedAt"), @@ -596,6 +630,13 @@ def _load_from_sample(sample_dir: Path) -> List[AggregatedItem]: ) for a in episode_data.get("Role", []) ], + genres=[g.get("tag", "") for g in episode_data.get("Genre", []) if g.get("tag")], + collections=[ + c.get("tag", "") + for key in ("Collection", "Collections") + for c in episode_data.get(key, []) or [] + if c.get("tag") + ], ) with (episode_dir / "imdb.tv.json").open("r", encoding="utf-8") as f: imdb_episode = IMDbTitle.model_validate(json.load(f)) @@ -657,15 +698,43 @@ async def run( # Assemble points with server-side embeddings points: List[models.PointStruct] = [] for item in items: + primary_title = item.plex.title + if item.plex.type == "episode": + title_bits: list[str] = [] + if item.plex.show_title: + title_bits.append(item.plex.show_title) + se_parts: list[str] = [] + if item.plex.season_number is not None: + se_parts.append(f"S{item.plex.season_number:02d}") + if item.plex.episode_number is not None: + se_parts.append(f"E{item.plex.episode_number:02d}") + if se_parts: + title_bits.append("".join(se_parts)) + if item.plex.title: + title_bits.append(item.plex.title) + if title_bits: + primary_title = " - ".join(title_bits) parts = [ - item.plex.title, + primary_title, item.plex.summary or "", item.tmdb.overview if item.tmdb and hasattr(item.tmdb, "overview") else "", item.imdb.plot if item.imdb else "", - " ".join(p.tag for p in item.plex.directors), - " ".join(p.tag for p in item.plex.writers), - " ".join(p.tag for p in item.plex.actors), ] + directors_text = ", ".join(p.tag for p in item.plex.directors if p.tag) + writers_text = ", ".join(p.tag for p in item.plex.writers if p.tag) + actors_text = ", ".join(p.tag for p in item.plex.actors if p.tag) + if directors_text: + parts.append(f"Directed by {directors_text}") + if writers_text: + parts.append(f"Written by {writers_text}") + if actors_text: + parts.append(f"Starring {actors_text}") + if item.plex.tagline: + parts.append(item.plex.tagline) + if item.tmdb and hasattr(item.tmdb, "tagline"): + tagline = getattr(item.tmdb, "tagline", None) + if tagline: + parts.append(tagline) if item.tmdb and hasattr(item.tmdb, "reviews"): parts.extend(r.get("content", "") for r in getattr(item.tmdb, "reviews", [])) text = "\n".join(p for p in parts if p) @@ -674,8 +743,45 @@ async def run( "title": item.plex.title, "type": item.plex.type, } + if item.plex.type == "episode": + if item.plex.show_title: + payload["show_title"] = item.plex.show_title + if item.plex.season_title: + payload["season_title"] = item.plex.season_title + if item.plex.season_number is not None: + payload["season_number"] = item.plex.season_number + if item.plex.episode_number is not None: + payload["episode_number"] = item.plex.episode_number if item.plex.actors: - payload["actors"] = [p.tag for p in item.plex.actors] + payload["actors"] = [p.tag for p in item.plex.actors if p.tag] + if item.plex.directors: + payload["directors"] = [p.tag for p in item.plex.directors if p.tag] + if item.plex.writers: + payload["writers"] = [p.tag for p in item.plex.writers if p.tag] + if item.plex.genres: + payload["genres"] = item.plex.genres + if item.plex.collections: + payload["collections"] = item.plex.collections + summary = item.plex.summary + if summary: + payload["summary"] = summary + overview = getattr(item.tmdb, "overview", None) if item.tmdb else None + if overview: + payload["overview"] = overview + plot = item.imdb.plot if item.imdb else None + if plot: + payload["plot"] = plot + taglines = [item.plex.tagline] + if item.tmdb and hasattr(item.tmdb, "tagline"): + taglines.append(getattr(item.tmdb, "tagline", None)) + taglines = [t for t in taglines if t] + if taglines: + payload["tagline"] = "\n".join(dict.fromkeys(taglines)) + if item.tmdb and hasattr(item.tmdb, "reviews"): + review_texts = [r.get("content", "") for r in getattr(item.tmdb, "reviews", [])] + review_texts = [r for r in review_texts if r] + if review_texts: + payload["reviews"] = review_texts if item.plex.year is not None: payload["year"] = item.plex.year if item.plex.added_at is not None: @@ -719,15 +825,16 @@ async def run( created_collection = True if created_collection: + text_index = models.TextIndexParams( + type=models.PayloadSchemaType.TEXT, + tokenizer=models.TokenizerType.WORD, + min_token_len=2, + lowercase=True, + ) await client.create_payload_index( collection_name=collection_name, field_name="title", - field_schema=models.TextIndexParams( - type=models.PayloadSchemaType.TEXT, - tokenizer=models.TokenizerType.WORD, - min_token_len=2, - lowercase=True, - ), + field_schema=text_index, ) await client.create_payload_index( collection_name=collection_name, @@ -749,6 +856,66 @@ async def run( field_name="actors", field_schema=models.PayloadSchemaType.KEYWORD, ) + await client.create_payload_index( + collection_name=collection_name, + field_name="directors", + field_schema=models.PayloadSchemaType.KEYWORD, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="writers", + field_schema=models.PayloadSchemaType.KEYWORD, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="genres", + field_schema=models.PayloadSchemaType.KEYWORD, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="show_title", + field_schema=models.PayloadSchemaType.KEYWORD, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="season_number", + field_schema=models.PayloadSchemaType.INTEGER, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="episode_number", + field_schema=models.PayloadSchemaType.INTEGER, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="collections", + field_schema=models.PayloadSchemaType.KEYWORD, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="summary", + field_schema=text_index, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="overview", + field_schema=text_index, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="plot", + field_schema=text_index, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="tagline", + field_schema=text_index, + ) + await client.create_payload_index( + collection_name=collection_name, + field_name="reviews", + field_schema=text_index, + ) await client.create_payload_index( collection_name=collection_name, field_name="data.plex.rating_key", diff --git a/mcp_plex/server.py b/mcp_plex/server.py index 69513d2..46df48a 100644 --- a/mcp_plex/server.py +++ b/mcp_plex/server.py @@ -6,7 +6,7 @@ import inspect import json import os -from typing import Annotated, Any, Callable +from typing import Annotated, Any, Callable, Sequence from fastapi import FastAPI from fastapi.openapi.docs import get_swagger_ui_html @@ -166,6 +166,17 @@ async def _find_records(identifier: str, limit: int = 5) -> list[models.Record]: return points +def _flatten_payload(payload: dict[str, Any]) -> dict[str, Any]: + """Merge top-level payload fields with the nested data block.""" + + data = dict(payload.get("data", {})) + for key, value in payload.items(): + if key == "data": + continue + data[key] = value + return data + + async def _get_media_data(identifier: str) -> dict[str, Any]: """Return the first matching media record's payload.""" cached = server.cache.get_payload(identifier) @@ -174,17 +185,18 @@ async def _get_media_data(identifier: str) -> dict[str, Any]: records = await _find_records(identifier, limit=1) if not records: raise ValueError("Media item not found") - data = records[0].payload["data"] + payload = _flatten_payload(records[0].payload) + data = payload rating_key = str(data.get("plex", {}).get("rating_key")) if rating_key: - server.cache.set_payload(rating_key, data) + server.cache.set_payload(rating_key, payload) thumb = data.get("plex", {}).get("thumb") if thumb: server.cache.set_poster(rating_key, thumb) art = data.get("plex", {}).get("art") if art: server.cache.set_background(rating_key, art) - return data + return payload @server.tool("get-media") @@ -199,7 +211,7 @@ async def get_media( ) -> list[dict[str, Any]]: """Retrieve media items by rating key, IMDb/TMDb ID or title.""" records = await _find_records(identifier, limit=10) - return [r.payload["data"] for r in records] + return [_flatten_payload(r.payload) for r in records] @server.tool("search-media") @@ -248,7 +260,7 @@ async def search_media( hits = res.points async def _prefetch(hit: models.ScoredPoint) -> None: - data = hit.payload["data"] + data = _flatten_payload(hit.payload) rating_key = str(data.get("plex", {}).get("rating_key")) if rating_key: server.cache.set_payload(rating_key, data) @@ -266,7 +278,7 @@ def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: return hits docs: list[str] = [] for h in hits: - data = h.payload["data"] + data = _flatten_payload(h.payload) parts = [ data.get("title"), data.get("summary"), @@ -274,6 +286,40 @@ def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: data.get("plex", {}).get("summary"), data.get("tmdb", {}).get("overview"), ] + directors = data.get("directors") or data.get("plex", {}).get("directors") + writers = data.get("writers") or data.get("plex", {}).get("writers") + actors = data.get("actors") or data.get("plex", {}).get("actors") + + def _join_people(values: Any) -> str: + if isinstance(values, list): + names = [] + for val in values: + if isinstance(val, str) and val: + names.append(val) + elif isinstance(val, dict): + tag = val.get("tag") or val.get("name") + if tag: + names.append(str(tag)) + return ", ".join(names) + if isinstance(values, str): + return values + return "" + + director_names = _join_people(directors) + writer_names = _join_people(writers) + actor_names = _join_people(actors) + if director_names: + parts.append(f"Directed by {director_names}") + if writer_names: + parts.append(f"Written by {writer_names}") + if actor_names: + parts.append(f"Starring {actor_names}") + tagline = data.get("tagline") or data.get("plex", {}).get("tagline") + if tagline: + parts.append(tagline if isinstance(tagline, str) else "\n".join(tagline)) + reviews = data.get("reviews") + if isinstance(reviews, list): + parts.extend(str(r) for r in reviews if r) docs.append(" ".join(p for p in parts if p)) pairs = [(query, d) for d in docs] scores = reranker.predict(pairs) @@ -284,7 +330,299 @@ def _rerank(hits: list[models.ScoredPoint]) -> list[models.ScoredPoint]: reranked = await asyncio.to_thread(_rerank, hits) await prefetch_task - return [h.payload["data"] for h in reranked[:limit]] + return [_flatten_payload(h.payload) for h in reranked[:limit]] + + +@server.tool("query-media") +async def query_media( + dense_query: Annotated[ + str | None, + Field( + description="Text used to generate a dense vector query", + examples=["british crime comedy"], + ), + ] = None, + sparse_query: Annotated[ + str | None, + Field( + description="Text used to generate a sparse vector query", + examples=["british crime comedy"], + ), + ] = None, + title: Annotated[ + str | None, + Field(description="Full-text title match", examples=["The Gentlemen"]), + ] = None, + type: Annotated[ + str | None, + Field( + description="Filter by media type", + examples=["movie"], + ), + ] = None, + year: Annotated[ + int | None, + Field(description="Exact release year", examples=[2020]), + ] = None, + year_from: Annotated[ + int | None, + Field(description="Minimum release year", examples=[2018]), + ] = None, + year_to: Annotated[ + int | None, + Field(description="Maximum release year", examples=[2024]), + ] = None, + added_after: Annotated[ + int | None, + Field( + description="Minimum added_at timestamp (seconds since epoch)", + examples=[1_700_000_000], + ), + ] = None, + added_before: Annotated[ + int | None, + Field( + description="Maximum added_at timestamp (seconds since epoch)", + examples=[1_760_000_000], + ), + ] = None, + actors: Annotated[ + Sequence[str] | None, + Field(description="Match actors by name", examples=[["Matthew McConaughey"]]), + ] = None, + directors: Annotated[ + Sequence[str] | None, + Field(description="Match directors by name", examples=[["Guy Ritchie"]]), + ] = None, + writers: Annotated[ + Sequence[str] | None, + Field(description="Match writers by name", examples=[["Guy Ritchie"]]), + ] = None, + genres: Annotated[ + Sequence[str] | None, + Field(description="Match genre tags", examples=[["Action", "Comedy"]]), + ] = None, + collections: Annotated[ + Sequence[str] | None, + Field(description="Match Plex collection names", examples=[["John Wick Collection"]]), + ] = None, + show_title: Annotated[ + str | None, + Field(description="Match the parent show title", examples=["Alien: Earth"]), + ] = None, + season_number: Annotated[ + int | None, + Field(description="Match the season number", examples=[1]), + ] = None, + episode_number: Annotated[ + int | None, + Field(description="Match the episode number", examples=[4]), + ] = None, + summary: Annotated[ + str | None, + Field(description="Full-text search within Plex summaries", examples=["marijuana empire"]), + ] = None, + overview: Annotated[ + str | None, + Field(description="Full-text search within TMDb overviews", examples=["criminal underworld"]), + ] = None, + plot: Annotated[ + str | None, + Field(description="Full-text search within IMDb plots", examples=["drug lord"]), + ] = None, + tagline: Annotated[ + str | None, + Field(description="Full-text search within taglines", examples=["criminal class"]), + ] = None, + reviews: Annotated[ + str | None, + Field(description="Full-text search within review content", examples=["hilarious"]), + ] = None, + plex_rating_key: Annotated[ + str | None, + Field( + description="Match a specific Plex rating key", + examples=["49915"], + ), + ] = None, + imdb_id: Annotated[ + str | None, + Field(description="Match an IMDb identifier", examples=["tt8367814"]), + ] = None, + tmdb_id: Annotated[ + int | None, + Field(description="Match a TMDb identifier", examples=[568467]), + ] = None, + limit: Annotated[ + int, + Field(description="Maximum number of results to return", ge=1, le=50, examples=[5]), + ] = 5, +) -> list[dict[str, Any]]: + """Run a structured query against indexed payload fields and optional vector searches.""" + + def _listify(value: Sequence[str] | str | None) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + return [value] + return [v for v in value if isinstance(v, str) and v] + + vector_queries: list[tuple[str, models.Document]] = [] + if dense_query: + vector_queries.append( + ( + "dense", + models.Document(text=dense_query, model=server.settings.dense_model), + ) + ) + if sparse_query: + vector_queries.append( + ( + "sparse", + models.Document(text=sparse_query, model=server.settings.sparse_model), + ) + ) + + prefetch_entries: list[models.Prefetch] = [] + for name, doc in vector_queries: + prefetch_entries.append( + models.Prefetch( + query=models.NearestQuery(nearest=doc), + using=name, + limit=limit, + ) + ) + + if len(prefetch_entries) > 1: + candidate_limit = limit * 3 + prefetch_entries = [ + models.Prefetch(query=p.query, using=p.using, limit=candidate_limit) + for p in prefetch_entries + ] + query_obj: models.Query = models.FusionQuery(fusion=models.Fusion.RRF) + using_param = None + prefetch_param: Sequence[models.Prefetch] | None = prefetch_entries + elif prefetch_entries: + query_obj = prefetch_entries[0].query + using_param = prefetch_entries[0].using + prefetch_param = None + else: + query_obj = None + using_param = None + prefetch_param = None + + must: list[models.FieldCondition] = [] + + if title: + must.append(models.FieldCondition(key="title", match=models.MatchText(text=title))) + media_type = type + if media_type: + must.append( + models.FieldCondition(key="type", match=models.MatchValue(value=media_type)) + ) + if year is not None: + must.append(models.FieldCondition(key="year", match=models.MatchValue(value=year))) + if year_from is not None or year_to is not None: + rng: dict[str, int] = {} + if year_from is not None: + rng["gte"] = year_from + if year_to is not None: + rng["lte"] = year_to + must.append(models.FieldCondition(key="year", range=models.Range(**rng))) + if added_after is not None or added_before is not None: + rng_at: dict[str, int] = {} + if added_after is not None: + rng_at["gte"] = added_after + if added_before is not None: + rng_at["lte"] = added_before + must.append(models.FieldCondition(key="added_at", range=models.Range(**rng_at))) + + for actor in _listify(actors): + must.append(models.FieldCondition(key="actors", match=models.MatchValue(value=actor))) + for director in _listify(directors): + must.append( + models.FieldCondition(key="directors", match=models.MatchValue(value=director)) + ) + for writer in _listify(writers): + must.append( + models.FieldCondition(key="writers", match=models.MatchValue(value=writer)) + ) + for genre in _listify(genres): + must.append(models.FieldCondition(key="genres", match=models.MatchValue(value=genre))) + for collection in _listify(collections): + must.append( + models.FieldCondition( + key="collections", match=models.MatchValue(value=collection) + ) + ) + + if show_title: + must.append( + models.FieldCondition( + key="show_title", match=models.MatchValue(value=show_title) + ) + ) + if season_number is not None: + must.append( + models.FieldCondition( + key="season_number", match=models.MatchValue(value=season_number) + ) + ) + if episode_number is not None: + must.append( + models.FieldCondition( + key="episode_number", match=models.MatchValue(value=episode_number) + ) + ) + + if summary: + must.append(models.FieldCondition(key="summary", match=models.MatchText(text=summary))) + if overview: + must.append(models.FieldCondition(key="overview", match=models.MatchText(text=overview))) + if plot: + must.append(models.FieldCondition(key="plot", match=models.MatchText(text=plot))) + if tagline: + must.append(models.FieldCondition(key="tagline", match=models.MatchText(text=tagline))) + if reviews: + must.append(models.FieldCondition(key="reviews", match=models.MatchText(text=reviews))) + + if plex_rating_key: + must.append( + models.FieldCondition( + key="data.plex.rating_key", + match=models.MatchValue(value=plex_rating_key), + ) + ) + if imdb_id: + must.append( + models.FieldCondition( + key="data.imdb.id", match=models.MatchValue(value=imdb_id) + ) + ) + if tmdb_id is not None: + must.append( + models.FieldCondition( + key="data.tmdb.id", match=models.MatchValue(value=tmdb_id) + ) + ) + + filter_obj: models.Filter | None = None + if must: + filter_obj = models.Filter(must=must) + + if query_obj is None: + query_obj = models.SampleQuery(sample=models.Sample.RANDOM) + + res = await server.qdrant_client.query_points( + collection_name="media-items", + query=query_obj, + using=using_param, + prefetch=prefetch_param, + query_filter=filter_obj, + limit=limit, + with_payload=True, + ) + return [_flatten_payload(p.payload) for p in res.points] @server.tool("recommend-media") @@ -320,7 +658,7 @@ async def recommend_media( with_payload=True, using="dense", ) - return [r.payload["data"] for r in recs] + return [_flatten_payload(r.payload) for r in recs] @server.tool("new-movies") @@ -353,7 +691,7 @@ async def new_movies( limit=limit, with_payload=True, ) - return [p.payload["data"] for p in res.points] + return [_flatten_payload(p.payload) for p in res.points] @server.tool("new-shows") @@ -386,7 +724,7 @@ async def new_shows( limit=limit, with_payload=True, ) - return [p.payload["data"] for p in res.points] + return [_flatten_payload(p.payload) for p in res.points] @server.tool("actor-movies") @@ -439,7 +777,7 @@ async def actor_movies( limit=limit, with_payload=True, ) - return [p.payload["data"] for p in res.points] + return [_flatten_payload(p.payload) for p in res.points] @server.resource("resource://media-item/{identifier}") diff --git a/mcp_plex/types.py b/mcp_plex/types.py index a5cf4fd..a344a93 100644 --- a/mcp_plex/types.py +++ b/mcp_plex/types.py @@ -127,6 +127,10 @@ class PlexItem(BaseModel): guid: str type: Literal["movie", "episode"] title: str + show_title: Optional[str] = None + season_title: Optional[str] = None + season_number: Optional[int] = None + episode_number: Optional[int] = None summary: Optional[str] = None year: Optional[int] = None added_at: Optional[datetime] = None @@ -138,6 +142,8 @@ class PlexItem(BaseModel): directors: List[PlexPerson] = Field(default_factory=list) writers: List[PlexPerson] = Field(default_factory=list) actors: List[PlexPerson] = Field(default_factory=list) + genres: List[str] = Field(default_factory=list) + collections: List[str] = Field(default_factory=list) class AggregatedItem(BaseModel): diff --git a/pyproject.toml b/pyproject.toml index 74b6053..5c993eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mcp-plex" -version = "0.26.35" +version = "0.26.38" description = "Plex-Oriented Model Context Protocol Server" requires-python = ">=3.11,<3.13" diff --git a/tests/test_loader_integration.py b/tests/test_loader_integration.py index 5a10d8d..f352066 100644 --- a/tests/test_loader_integration.py +++ b/tests/test_loader_integration.py @@ -3,6 +3,7 @@ import asyncio import json from pathlib import Path +from typing import Any from qdrant_client.async_qdrant_client import AsyncQdrantClient from qdrant_client import models @@ -15,6 +16,7 @@ class CaptureClient(AsyncQdrantClient): instance: "CaptureClient" | None = None captured_points: list[models.PointStruct] = [] upsert_calls: int = 0 + created_indexes: list[tuple[str, Any]] = [] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -27,6 +29,21 @@ async def upsert(self, collection_name: str, points, **kwargs): collection_name=collection_name, points=points, **kwargs ) + async def create_payload_index( + self, + collection_name: str, + field_name: str, + field_schema: models.PayloadSchemaType | models.TextIndexParams, + wait: bool | None = None, + ) -> models.UpdateResult: + CaptureClient.created_indexes.append((field_name, field_schema)) + return await super().create_payload_index( + collection_name=collection_name, + field_name=field_name, + field_schema=field_schema, + wait=wait, + ) + async def _run_loader(sample_dir: Path) -> None: await loader.run( @@ -43,10 +60,15 @@ def test_run_writes_points(monkeypatch): monkeypatch.setattr(loader, "AsyncQdrantClient", CaptureClient) CaptureClient.captured_points = [] CaptureClient.upsert_calls = 0 + CaptureClient.created_indexes = [] sample_dir = Path(__file__).resolve().parents[1] / "sample-data" asyncio.run(_run_loader(sample_dir)) client = CaptureClient.instance assert client is not None + index_map = {name: schema for name, schema in CaptureClient.created_indexes} + assert index_map.get("show_title") == models.PayloadSchemaType.KEYWORD + assert index_map.get("season_number") == models.PayloadSchemaType.INTEGER + assert index_map.get("episode_number") == models.PayloadSchemaType.INTEGER points, _ = asyncio.run(client.scroll("media-items", limit=10, with_payload=True)) assert len(points) == 2 assert all("title" in p.payload and "type" in p.payload for p in points) @@ -59,6 +81,28 @@ def test_run_writes_points(monkeypatch): p.vector["sparse"].model == "Qdrant/bm42-all-minilm-l6-v2-attentions" for p in captured ) + texts = [p.vector["dense"].text for p in captured] + assert any("Directed by" in t for t in texts) + assert any("Starring" in t for t in texts) + movie_point = next(p for p in points if p.payload["type"] == "movie") + assert "directors" in movie_point.payload and "Guy Ritchie" in movie_point.payload["directors"] + assert "writers" in movie_point.payload and movie_point.payload["writers"] + assert "genres" in movie_point.payload and movie_point.payload["genres"] + assert movie_point.payload.get("summary") + assert movie_point.payload.get("overview") + assert movie_point.payload.get("plot") + assert movie_point.payload.get("tagline") + assert movie_point.payload.get("reviews") + episode_point = next(p for p in points if p.payload["type"] == "episode") + assert episode_point.payload.get("show_title") == "Alien: Earth" + assert episode_point.payload.get("season_title") == "Season 1" + assert episode_point.payload.get("season_number") == 1 + assert episode_point.payload.get("episode_number") == 4 + episode_vector = next( + p for p in captured if p.payload.get("type") == "episode" + ).vector["dense"].text + assert "Alien: Earth" in episode_vector + assert "S01E04" in episode_vector def test_run_processes_imdb_queue(monkeypatch, tmp_path): diff --git a/tests/test_server.py b/tests/test_server.py index 580b104..c41d8e1 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -82,6 +82,11 @@ def test_server_tools(monkeypatch): res = asyncio.run(server.get_media.fn(identifier="tt8367814")) assert res and res[0]["plex"]["rating_key"] == movie_id + episode = asyncio.run(server.get_media.fn(identifier="61960")) + assert episode and episode[0]["show_title"] == "Alien: Earth" + assert episode[0]["season_number"] == 1 + assert episode[0]["episode_number"] == 4 + poster = asyncio.run(server.media_poster.fn(identifier=movie_id)) assert isinstance(poster, str) and "thumb" in poster assert server.server.cache.get_poster(movie_id) == poster @@ -105,6 +110,30 @@ def test_server_tools(monkeypatch): ) assert res and res[0]["plex"]["title"] == "The Gentlemen" + structured = asyncio.run( + server.query_media.fn( + dense_query="crime comedy", + title="Gentlemen", + type="movie", + directors=["Guy Ritchie"], + limit=1, + ) + ) + assert structured and structured[0]["plex"]["title"] == "The Gentlemen" + assert "directors" in structured[0] + + episode_structured = asyncio.run( + server.query_media.fn( + type="episode", + show_title="Alien: Earth", + season_number=1, + episode_number=4, + limit=1, + ) + ) + assert episode_structured and episode_structured[0]["plex"]["rating_key"] == "61960" + assert episode_structured[0]["show_title"] == "Alien: Earth" + rec = asyncio.run(server.recommend_media.fn(identifier=movie_id, limit=1)) assert rec and rec[0]["plex"]["rating_key"] == "61960" @@ -129,6 +158,9 @@ def test_new_media_tools(monkeypatch): shows = asyncio.run(server.new_shows.fn(limit=1)) assert shows and shows[0]["plex"]["type"] == "episode" assert shows[0]["plex"]["added_at"] is not None + assert shows[0]["show_title"] == "Alien: Earth" + assert shows[0]["season_number"] == 1 + assert shows[0]["episode_number"] == 4 def test_actor_movies(monkeypatch): diff --git a/uv.lock b/uv.lock index dd5d266..0b2989a 100644 --- a/uv.lock +++ b/uv.lock @@ -690,7 +690,7 @@ wheels = [ [[package]] name = "mcp-plex" -version = "0.26.35" +version = "0.26.38" source = { editable = "." } dependencies = [ { name = "fastapi" },