Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Remove more usages of cursor_to_dict #16551

Merged
merged 12 commits into from
Oct 26, 2023
56 changes: 34 additions & 22 deletions synapse/storage/databases/main/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Set,
Tuple,
Union,
cast,
)

import attr
Expand Down Expand Up @@ -506,16 +507,18 @@ async def search_msgs(
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"

results = await self.db_pool.execute(
"search_msgs", self.db_pool.cursor_to_dict, sql, *args
# List of tuples of (rank, room_id, event_id).
results = cast(
List[Tuple[int, str, str]],
clokep marked this conversation as resolved.
Show resolved Hide resolved
await self.db_pool.execute("search_msgs", None, sql, *args),
)

results = list(filter(lambda row: row["room_id"] in room_ids, results))
results = list(filter(lambda row: row[1] in room_ids, results))

# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
[r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)

Expand All @@ -527,16 +530,20 @@ async def search_msgs(

count_sql += " GROUP BY room_id"

count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
# List of tuples of (room_id, count).
count_results = cast(
List[Tuple[str, int]],
await self.db_pool.execute(
"search_rooms_count", None, count_sql, *count_args
),
)

count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
count = sum(row[1] for row in count_results if row[0] in room_ids)
return {
"results": [
{"event": event_map[r["event_id"]], "rank": r["rank"]}
{"event": event_map[r[2]], "rank": r[0]}
for r in results
if r["event_id"] in event_map
if r[2] in event_map
],
"highlights": highlights,
"count": count,
Expand Down Expand Up @@ -604,7 +611,7 @@ async def search_rooms(
search_query = search_term
sql = """
SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
origin_server_ts, stream_ordering, room_id, event_id
room_id, event_id, origin_server_ts, stream_ordering
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
FROM event_search
WHERE vector @@ websearch_to_tsquery('english', ?) AND
"""
Expand Down Expand Up @@ -665,16 +672,18 @@ async def search_rooms(
# mypy expects to append only a `str`, not an `int`
args.append(limit)

results = await self.db_pool.execute(
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
# List of tuples of (rank, room_id, event_id, origin_server_ts, stream_ordering).
results = cast(
List[Tuple[int, str, str, int, int]],
await self.db_pool.execute("search_rooms", None, sql, *args),
)

results = list(filter(lambda row: row["room_id"] in room_ids, results))
results = list(filter(lambda row: row[1] in room_ids, results))

# We set redact_behaviour to block here to prevent redacted events being returned in
# search results (which is a data leak)
events = await self.get_events_as_list( # type: ignore[attr-defined]
[r["event_id"] for r in results],
[r[2] for r in results],
redact_behaviour=EventRedactBehaviour.block,
)

Expand All @@ -686,22 +695,25 @@ async def search_rooms(

count_sql += " GROUP BY room_id"

count_results = await self.db_pool.execute(
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
# List of tuples of (room_id, count).
count_results = cast(
List[Tuple[str, int]],
await self.db_pool.execute(
"search_rooms_count", None, count_sql, *count_args
),
)

count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
count = sum(row[1] for row in count_results if row[0] in room_ids)

return {
"results": [
{
"event": event_map[r["event_id"]],
"rank": r["rank"],
"pagination_token": "%s,%s"
% (r["origin_server_ts"], r["stream_ordering"]),
"event": event_map[r[2]],
"rank": r[0],
"pagination_token": "%s,%s" % (r[3], r[4]),
}
for r in results
if r["event_id"] in event_map
if r[2] in event_map
],
"highlights": highlights,
"count": count,
Expand Down