Skip to content

Commit

Permalink
Enable querying related flexible attributes
Browse files Browse the repository at this point in the history
Unify query creation logic from
- queryparse.py:construct_query_part,
- Model.field_query,
- DefaultTemplateFunctions._tmpl_unique

to a single implementation under `LibModel.field_query` class method.
This method should be used for query resolution for model (flex)fields.

Allow filtering item attributes in album queries and vice versa by
merging `flex_attrs` from Album and Item together as `all_flex_attrs`.
This field is only used for filtering and is discarded after.
  • Loading branch information
snejus committed May 8, 2024
1 parent 64ac3bd commit de85dd1
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 78 deletions.
49 changes: 19 additions & 30 deletions beets/dbcore/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from ..util import cached_classproperty, functemplate, py3_path
from . import types
from .query import (
AndQuery,
FieldQuery,
FieldSort,
MatchQuery,
Expand Down Expand Up @@ -372,6 +371,10 @@ def table_with_flex_attrs(cls) -> str:
) {cls._table}
"""

@cached_classproperty
def all_model_db_fields(cls) -> Set[str]:
return set()

@classmethod
def _getters(cls: Type["Model"]):
"""Return a mapping from field names to getter functions."""
Expand Down Expand Up @@ -748,33 +751,6 @@ def set_parse(self, key, string: str):
"""Set the object's key to a value represented by a string."""
self[key] = self._parse(key, string)

# Convenient queries.

@classmethod
def field_query(
cls,
field,
pattern,
query_cls: Type[FieldQuery] = MatchQuery,
) -> FieldQuery:
"""Get a `FieldQuery` for this model."""
return query_cls(field, pattern, field in cls._fields)

@classmethod
def all_fields_query(
cls: Type["Model"],
pats: Mapping,
query_cls: Type[FieldQuery] = MatchQuery,
):
"""Get a query that matches many fields with different patterns.
`pats` should be a mapping from field names to patterns. The
resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
subqueries = [cls.field_query(k, v, query_cls) for k, v in pats.items()]
return AndQuery(subqueries)


# Database controller and supporting interfaces.

Expand Down Expand Up @@ -1247,13 +1223,26 @@ def _fetch(
where, subvals = query.clause()
order_by = sort.order_clause()

this_table = model_cls._table
select_fields = [f"{this_table}.*"]
_from = model_cls.table_with_flex_attrs

required_fields = query.field_names
if required_fields - model_cls._fields.keys():
_from += f" {model_cls.relation_join}"

table = model_cls._table
sql = f"SELECT {table}.* FROM {_from} WHERE {where or 1} GROUP BY {table}.id"
if required_fields - model_cls.all_model_db_fields:
# merge all flexible attribute into a single JSON field
select_fields.append(
f"""
json_patch(
COALESCE({this_table}."flex_attrs [json_str]", '{{}}'),
COALESCE({model_cls._relation._table}."flex_attrs [json_str]", '{{}}')
) AS all_flex_attrs
"""
)

sql = f"SELECT {', '.join(select_fields)} FROM {_from} WHERE {where or 1} GROUP BY {this_table}.id"

if order_by:
# the sort field may exist in both 'items' and 'albums' tables
Expand Down
2 changes: 1 addition & 1 deletion beets/dbcore/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def field_names(self) -> Set[str]:
@property
def col_name(self) -> str:
if not self.fast:
return f'json_extract("flex_attrs [json_str]", "$.{self.field}")'
return f'json_extract(all_flex_attrs, "$.{self.field}")'

return f"{self.table}.{self.field}" if self.table else self.field

Expand Down
29 changes: 15 additions & 14 deletions beets/dbcore/queryparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,23 @@

import itertools
import re
from typing import Collection, Dict, List, Optional, Sequence, Tuple, Type
from typing import (
TYPE_CHECKING,
Collection,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
)

from .. import library
from . import Model, query
from .query import Sort

if TYPE_CHECKING:
from ..library import LibModel

PARSE_QUERY_PART_REGEX = re.compile(
# Non-capturing optional segment for the keyword.
r"(-|\^)?" # Negation prefixes.
Expand Down Expand Up @@ -105,7 +116,7 @@ def parse_query_part(


def construct_query_part(
model_cls: Type[Model],
model_cls: Type["LibModel"],
prefixes: Dict,
query_part: str,
) -> query.Query:
Expand Down Expand Up @@ -153,17 +164,7 @@ def construct_query_part(
# Field queries get constructed according to the name of the field
# they are querying.
else:
key = key.lower()
album_fields = library.Album._fields.keys()
item_fields = library.Item._fields.keys()
fast = key in album_fields | item_fields

if key in album_fields & item_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError. Using an explicit table name resolves this.
key = f"{model_cls._table}.{key}"

out_query = query_class(key, pattern, fast)
out_query = model_cls.field_query(key.lower(), pattern, query_class)

# Apply negation.
if negate:
Expand Down
55 changes: 49 additions & 6 deletions beets/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""The core data store and collection logic for beets.
"""
from __future__ import annotations

import os
import re
Expand All @@ -22,7 +23,7 @@
import sys
import time
import unicodedata
from typing import Type
from typing import Mapping, Set, Type

from mediafile import MediaFile, UnreadableFileError

Expand Down Expand Up @@ -388,6 +389,14 @@ class LibModel(dbcore.Model):
# Config key that specifies how an instance should be formatted.
_format_config_key = None

@cached_classproperty
def all_model_db_fields(cls) -> Set[str]:
return cls._fields.keys() | cls._relation._fields.keys()

@cached_classproperty
def shared_model_db_fields(cls) -> Set[str]:
return cls._fields.keys() & cls._relation._fields.keys()

def _template_funcs(self):
funcs = DefaultTemplateFunctions(self, self._db).functions()
funcs.update(plugins.template_funcs())
Expand Down Expand Up @@ -417,6 +426,39 @@ def __str__(self):
def __bytes__(self):
return self.__str__().encode("utf-8")

# Convenient queries.

@classmethod
def field_query(
cls, field: str, pattern: str, query_cls: Type[dbcore.FieldQuery]
) -> dbcore.Query:
"""Get a `FieldQuery` for this model."""
fast = field in cls.all_model_db_fields
if field in cls.shared_model_db_fields:
# This field exists in both tables, so SQLite will encounter
# an OperationalError if we try to use it in a query.
# Using an explicit table name resolves this.
field = f"{cls._table}.{field}"

return query_cls(field, pattern, fast)

@classmethod
def all_fields_query(
cls, pattern_by_field: Mapping[str, str]
) -> dbcore.AndQuery:
"""Get a query that matches many fields with different patterns.
`pattern_by_field` should be a mapping from field names to patterns.
The resulting query is a conjunction ("and") of per-field queries
for all of these field/pattern pairs.
"""
return dbcore.AndQuery(
[
cls.field_query(f, p, dbcore.MatchQuery)
for f, p in pattern_by_field.items()
]
)


class FormattedItemMapping(dbcore.db.FormattedMapping):
"""Add lookup for album-level fields.
Expand Down Expand Up @@ -653,7 +695,7 @@ def relation_join(cls) -> str:
We need to use a LEFT JOIN here, otherwise items that are not part of
an album (e.g. singletons) would be left out.
"""
return f"LEFT JOIN {cls._relation._table} ON {cls._table}.album_id = {cls._relation._table}.id"
return f"LEFT JOIN {cls._relation.table_with_flex_attrs} ON {cls._table}.album_id = {cls._relation._table}.id"

@property
def _cached_album(self):
Expand Down Expand Up @@ -1266,7 +1308,7 @@ def relation_join(cls) -> str:
Here we can use INNER JOIN (which is more performant than LEFT JOIN),
since we only want to see albums that
"""
return f"INNER JOIN {cls._relation._table} ON {cls._table}.id = {cls._relation._table}.album_id"
return f"INNER JOIN {cls._relation.table_with_flex_attrs} ON {cls._table}.id = {cls._relation._table}.album_id"

@classmethod
def _getters(cls):
Expand Down Expand Up @@ -1956,9 +1998,10 @@ def _tmpl_unique(
subqueries.extend(initial_subqueries)
for key in keys:
value = db_item.get(key, "")
# Use slow queries for flexible attributes.
fast = key in item_keys
subqueries.append(dbcore.MatchQuery(key, value, fast))
subqueries.append(
db_item.field_query(key, value, dbcore.MatchQuery)
)

query = dbcore.AndQuery(subqueries)
ambigous_items = (
self.lib.items(query)
Expand Down
55 changes: 29 additions & 26 deletions test/plugins/test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import platform
import shutil
import unittest
from pathlib import Path

from beets import logging
from beets.library import Album, Item
Expand All @@ -29,36 +30,38 @@ def setUp(self):
# Add library elements. Note that self.lib.add overrides any "id=<n>"
# and assigns the next free id number.
# The following adds will create items #1, #2 and #3
path1 = (
self.path_prefix + os.sep + os.path.join(b"path_1").decode("utf-8")
)
self.lib.add(
Item(title="title", path=path1, album_id=2, artist="AAA Singers")
)
path2 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere", b"a").decode("utf-8")
)
self.lib.add(
Item(title="another title", path=path2, artist="AAA Singers")
base_path = Path(self.path_prefix + os.sep)
album2_item1 = Item(
title="title",
path=str(base_path / "path_1"),
album_id=2,
artist="AAA Singers",
)
path3 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere", b"abc").decode("utf-8")
album1_item = Item(
title="another title",
path=str(base_path / "somewhere" / "a"),
artist="AAA Singers",
)
self.lib.add(
Item(title="and a third", testattr="ABC", path=path3, album_id=2)
album2_item2 = Item(
title="and a third",
testattr="ABC",
path=str(base_path / "somewhere" / "abc"),
album_id=2,
)
self.lib.add(album2_item1)
self.lib.add(album1_item)
self.lib.add(album2_item2)

# The following adds will create albums #1 and #2
self.lib.add(Album(album="album", albumtest="xyz"))
path4 = (
self.path_prefix
+ os.sep
+ os.path.join(b"somewhere2", b"art_path_2").decode("utf-8")
)
self.lib.add(Album(album="other album", artpath=path4))
album1 = self.lib.add_album([album1_item])
album1.album = "album"
album1.albumtest = "xyz"
album1.store()

album2 = self.lib.add_album([album2_item1, album2_item2])
album2.album = "other album"
album2.artpath = str(base_path / "somewhere2" / "art_path_2")
album2.store()

web.app.config["TESTING"] = True
web.app.config["lib"] = self.lib
Expand Down
3 changes: 2 additions & 1 deletion test/test_dbcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tempfile import mkstemp

from beets import dbcore
from beets.library import LibModel
from beets.test import _common

# Fixture: concrete database and model classes. For migration tests, we
Expand All @@ -42,7 +43,7 @@ def match(self):
return True


class ModelFixture1(dbcore.Model):
class ModelFixture1(LibModel):
_table = "test"
_flex_table = "testflex"
_fields = {
Expand Down
20 changes: 20 additions & 0 deletions test/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,26 @@ def test_get_albums_filter_by_album_flex(self):
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])

def test_get_albums_filter_by_track_flex(self):
q = "item_flex1:Album1"
results = self.lib.albums(q)
self.assert_albums_matched(results, ["Album1"])

def test_get_items_filter_by_album_flex(self):
q = "album_flex:Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album1 Item2"])

def test_filter_by_flex(self):
q = "item_flex1:'Item1 Flex1'"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1", "Album2 Item1"])

def test_filter_by_many_flex(self):
q = "item_flex1:'Item1 Flex1' item_flex2:Album1"
results = self.lib.items(q)
self.assert_items_matched(results, ["Album1 Item1"])


def suite():
return unittest.TestLoader().loadTestsFromName(__name__)
Expand Down

0 comments on commit de85dd1

Please sign in to comment.