Skip to content

Commit

Permalink
Tag radio endpoint enhancements (#2840)
Browse files Browse the repository at this point in the history
Flatten source list into one and rename query parameter.
  • Loading branch information
amCap1712 committed Apr 17, 2024
1 parent ae07ad2 commit 9374d6b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 89 deletions.
102 changes: 44 additions & 58 deletions listenbrainz/db/tags.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from flask import current_app
from psycopg2.sql import Literal, SQL
from sqlalchemy import text

Expand Down Expand Up @@ -44,36 +45,32 @@ def get_inserts(self, message):
TagsDataset = _TagsDataset()


def get_once(connection, query, count_query, params, results, counts):
def get_once(connection, query, params, results):
""" One pass over the lb radio tags dataset to retrieve matching recordings """
rows = connection.execute(text(query), params)
for row in rows:
source = row.source
results[source].extend(row.recordings)

rows = connection.execute(text(count_query), params)
for row in rows:
source = row.source
counts[source] += row.total_count


def get(query, count_query, more_query, more_count_query, params):
def get(query, more_query, params):
""" Retrieve recordings and tags for the given query. First it tries to retrieve recordings matching the
specified criteria. If there are not enough recordings matching the given criteria, it relaxes the
percentage bounds and to gather and return more recordings.
"""
results = {"artist": [], "recording": [], "release-group": []}
counts = {"artist": 0, "recording": 0, "release-group": 0}

recordings = []
count = params["count"]

with timescale.engine.connect() as connection:
get_once(connection, query, count_query, params, results, counts)
if len(results["artist"]) < count or len(results["recording"]) < count or len(results["release-group"]) < count:
get_once(connection, more_query, more_count_query, params, results, counts)
result = connection.execute(text(query), params).first()
if result is not None and result.recordings is not None:
recordings.extend(result.recordings)
if len(recordings) < count:
result = connection.execute(text(more_query), params).first()
if result is not None and result.recordings is not None:
recordings.extend(result.recordings)

results["count"] = counts
return results
return recordings


def get_partial_clauses(expanded):
Expand All @@ -86,16 +83,16 @@ def get_partial_clauses(expanded):
if expanded:
order_clause = """
ORDER BY CASE
WHEN :begin_percent > percent THEN percent - :begin_percent
WHEN :end_percent < percent THEN :end_percent - percent
WHEN percent < :begin_percent THEN :begin_percent - percent
WHEN percent > :end_percent THEN percent - :end_percent
ELSE 1
END
, RANDOM()
"""
percent_clause = ":begin_percent <= percent AND percent < :end_percent"
percent_clause = "percent < :begin_percent OR percent > :end_percent"
else:
order_clause = "ORDER BY RANDOM()"
percent_clause = "percent < :begin_percent OR percent >= :end_percent"
percent_clause = ":begin_percent <= percent AND percent <= :end_percent"

return order_clause, percent_clause

Expand All @@ -113,7 +110,6 @@ def build_and_query(tags, expanded):
clauses.append(f"""
SELECT recording_mbid
, source
, percent
FROM tags.lb_tag_radio
WHERE tag = :{param}
AND ({percent_clause})
Expand All @@ -123,44 +119,44 @@ def build_and_query(tags, expanded):
query = f"""
WITH all_recs AS (
{clause}
), randomize_recs AS (
), add_percent_to_recs AS (
SELECT ar.recording_mbid
, ar.source
, ltr.percent
FROM tags.lb_tag_radio ltr
JOIN all_recs ar
ON ar.recording_mbid = ltr.recording_mbid
AND ar.source = ltr.source
AND ltr.tag = :tag_0
), randomize_recs AS (
SELECT recording_mbid
, source
, row_number() OVER (PARTITION BY source {order_clause}) AS rnum
FROM all_recs
), selected_recs AS (
FROM add_percent_to_recs
), selected_recs AS (
SELECT recording_mbid
, source
FROM randomize_recs
WHERE rnum <= :count
) SELECT source
, jsonb_agg(
) SELECT jsonb_agg(
jsonb_build_object(
'recording_mbid'
, recording_mbid
, 'tag_count'
, tag_count
, 'percent'
, percent
, percent * 100
, 'source'
, source
)
ORDER BY tag_count DESC
) AS recordings
FROM selected_recs
JOIN tags.lb_tag_radio
USING (recording_mbid, source)
WHERE tag = :tag_0
GROUP BY source
"""
count_query = f"""
WITH all_recs AS (
{clause}
) SELECT source
, count(*) AS total_count
FROM all_recs
GROUP BY source
"""

return query, count_query, params
return query, params


def build_or_query(expanded=True):
Expand All @@ -178,33 +174,23 @@ def build_or_query(expanded=True):
FROM tags.lb_tag_radio
WHERE tag IN :tags
AND ({percent_clause})
) SELECT source
, jsonb_agg(
) SELECT jsonb_agg(
jsonb_build_object(
'recording_mbid'
, recording_mbid
, 'tag_count'
, tag_count
, 'percent'
, percent
, percent * 100
, 'source'
, source
)
ORDER BY tag_count DESC
) AS recordings
FROM all_tags
WHERE rnum <= :count
GROUP BY source
"""

count_query = f"""
SELECT source
, count(*) AS total_count
FROM tags.lb_tag_radio
WHERE tag IN :tags
AND ({percent_clause})
GROUP BY source
"""

return query, count_query
return query


def get_and(tags, begin_percent, end_percent, count):
Expand All @@ -213,10 +199,10 @@ def get_and(tags, begin_percent, end_percent, count):
outside the percent bounds may also be returned.)
"""
params = {"count": count, "begin_percent": begin_percent, "end_percent": end_percent}
query, count_query, _params = build_and_query(tags, False)
more_query, more_count_query, _ = build_and_query(tags, True)
query, _params = build_and_query(tags, False)
more_query, _ = build_and_query(tags, True)
params.update(_params)
return get(query, count_query, more_query, more_count_query, params)
return get(query, more_query, params)


def get_or(tags, begin_percent, end_percent, count):
Expand All @@ -225,6 +211,6 @@ def get_or(tags, begin_percent, end_percent, count):
outside the percent bounds may also be returned.)
"""
params = {"tags": tuple(tags), "begin_percent": begin_percent, "end_percent": end_percent, "count": count}
query, count_query = build_or_query(False)
more_query, more_count_query = build_or_query(True)
return get(query, count_query, more_query, more_count_query, params)
query = build_or_query(False)
more_query = build_or_query(True)
return get(query, more_query, params)
62 changes: 31 additions & 31 deletions listenbrainz/webserver/views/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,12 +707,12 @@ def get_tags_dataset():
:param tag: the MusicBrainz tag to fetch recordings for, this parameter can be specified multiple times. if more
than one tag is specified, the condition param should also be specified.
:param condition: specify AND to retrieve recordings that have all the tags, otherwise specify OR to retrieve
than one tag is specified, the operator param should also be specified.
:param operator: specify AND to retrieve recordings that have all the tags, otherwise specify OR to retrieve
recordings that have any one of the tags.
:param begin_percent: percent is a measure of the recording's popularity, begin_percent denotes a preferred
:param pop_begin: percent is a measure of the recording's popularity, pop_begin denotes a preferred
lower bound on the popularity of recordings to be returned.
:param end_percent: percent is a measure of the recording's popularity, end_percent denotes a preferred
:param pop_end: percent is a measure of the recording's popularity, pop_end denotes a preferred
upper bound on the popularity of recordings to be returned.
:param count: number of recordings to return for the
:resheader Content-Type: *application/json*
Expand All @@ -723,37 +723,37 @@ def get_tags_dataset():
if tag is None:
raise APIBadRequest("tag param is missing")

condition = request.args.get("condition")
operator = request.args.get("operator")

# if there is only one tag, then we can use any of the condition's query to retrieve data
if len(tag) == 1 and condition is None:
condition = "OR"
# if there is only one tag, then we can use any of the operator's query to retrieve data
if len(tag) == 1 and operator is None:
operator = "OR"

if condition is None:
raise APIBadRequest("multiple tags are specified but the condition param is missing")
condition = condition.upper()
if condition != "AND" and condition != "OR":
raise APIBadRequest("condition param should be either 'AND' or 'OR'")
if operator is None:
raise APIBadRequest("multiple tags are specified but the operator param is missing")
operator = operator.upper()
if operator != "AND" and operator != "OR":
raise APIBadRequest("operator param should be either 'AND' or 'OR'")

try:
begin_percent = request.args.get("begin_percent")
if begin_percent is None:
raise APIBadRequest("begin_percent param is missing")
begin_percent = float(begin_percent) / 100
if begin_percent < 0 or begin_percent > 1:
raise APIBadRequest("begin_percent should be between the range: 0 to 100")
pop_begin = request.args.get("pop_begin")
if pop_begin is None:
raise APIBadRequest("pop_begin param is missing")
pop_begin = float(pop_begin) / 100
if pop_begin < 0 or pop_begin > 1:
raise APIBadRequest("pop_begin should be between the range: 0 to 100")
except ValueError:
raise APIBadRequest(f"begin_percent: '{begin_percent}' is not a valid number")
raise APIBadRequest(f"pop_begin: '{pop_begin}' is not a valid number")

try:
end_percent = request.args.get("end_percent")
if end_percent is None:
raise APIBadRequest("end_percent param is missing")
end_percent = float(end_percent) / 100
if end_percent < 0 or end_percent > 1:
raise APIBadRequest("end_percent should be between the range: 0 to 100")
pop_end = request.args.get("pop_end")
if pop_end is None:
raise APIBadRequest("pop_end param is missing")
pop_end = float(pop_end) / 100
if pop_end < 0 or pop_end > 1:
raise APIBadRequest("pop_end should be between the range: 0 to 100")
except ValueError:
raise APIBadRequest(f"end_percent: '{end_percent}' is not a valid number")
raise APIBadRequest(f"pop_end: '{pop_end}' is not a valid number")

try:
count = request.args.get("count")
Expand All @@ -765,11 +765,11 @@ def get_tags_dataset():
except ValueError:
raise APIBadRequest(f"count: '{count}' is not a valid positive number")

if condition == "AND":
results = tags.get_and(tag, begin_percent, end_percent, count)
if operator == "AND":
recordings = tags.get_and(tag, pop_begin, pop_end, count)
else:
results = tags.get_or(tag, begin_percent, end_percent, count)
return jsonify(results)
recordings = tags.get_or(tag, pop_begin, pop_end, count)
return jsonify(recordings[:count])


def _get_listen_type(listen_type):
Expand Down

0 comments on commit 9374d6b

Please sign in to comment.