Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert simple_select_many_batch, simple_select_many_txn to tuples. #16444

Merged
merged 10 commits into from Oct 11, 2023
1 change: 1 addition & 0 deletions changelog.d/16444.misc
@@ -0,0 +1 @@
Reduce memory allocations.
18 changes: 12 additions & 6 deletions synapse/storage/database.py
Expand Up @@ -1874,9 +1874,9 @@ async def simple_select_many_batch(
keyvalues: Optional[Dict[str, Any]] = None,
desc: str = "simple_select_many_batch",
batch_size: int = 100,
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows.

Filters rows by whether the value of `column` is in `iterable`.

Expand All @@ -1888,10 +1888,13 @@ async def simple_select_many_batch(
keyvalues: dict of column names and values to select the rows with
desc: description of the transaction, for logging and metrics
batch_size: the number of rows for each select query

Returns:
The results as a list of tuples.
"""
keyvalues = keyvalues or {}

results: List[Dict[str, Any]] = []
results: List[Tuple[Any, ...]] = []

for chunk in batch_iter(iterable, batch_size):
rows = await self.runInteraction(
Expand All @@ -1918,9 +1921,9 @@ def simple_select_many_txn(
iterable: Collection[Any],
keyvalues: Dict[str, Any],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows.

Filters rows by whether the value of `column` is in `iterable`.

Expand All @@ -1931,6 +1934,9 @@ def simple_select_many_txn(
iterable: list
keyvalues: dict of column names and values to select the rows with
retcols: list of strings giving the names of the columns to return

Returns:
The results as a list of tuples.
"""
if not iterable:
return []
Expand All @@ -1949,7 +1955,7 @@ def simple_select_many_txn(
)

txn.execute(sql, values)
return cls.cursor_to_dict(txn)
return txn.fetchall()

async def simple_update(
self,
Expand Down
42 changes: 22 additions & 20 deletions synapse/storage/databases/main/deviceinbox.py
Expand Up @@ -344,18 +344,19 @@ def get_device_messages_txn(
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query:
user_device_dicts = self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"hidden": False},
retcols=("device_id",),
user_device_dicts = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"hidden": False},
retcols=("device_id",),
),
)

device_ids_to_query.update(
{row["device_id"] for row in user_device_dicts}
)
device_ids_to_query.update({row[0] for row in user_device_dicts})

if not device_ids_to_query:
# We've ended up with no devices to query.
Expand Down Expand Up @@ -845,20 +846,21 @@ def _add_messages_to_local_device_inbox_txn(

# We exclude hidden devices (such as cross-signing keys) here as they are
# not expected to receive to-device messages.
rows = self.db_pool.simple_select_many_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
column="device_id",
iterable=devices,
retcols=("device_id",),
rows = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
column="device_id",
iterable=devices,
retcols=("device_id",),
),
)

for row in rows:
for (device_id,) in rows:
# Only insert into the local inbox if the device exists on
# this server
device_id = row["device_id"]

with start_active_span("serialise_to_device_message"):
msg = messages_by_device[device_id]
set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
Expand Down
49 changes: 30 additions & 19 deletions synapse/storage/databases/main/devices.py
Expand Up @@ -1052,16 +1052,19 @@ async def get_device_list_last_stream_id_for_remote(
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
) -> Mapping[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id"),
desc="get_device_list_last_stream_id_for_remotes",
rows = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id"),
desc="get_device_list_last_stream_id_for_remotes",
),
)

results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})
results.update(rows)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

return results

Expand All @@ -1077,22 +1080,30 @@ async def get_user_ids_requiring_device_list_resync(
The IDs of users whose device lists need resync.
"""
if user_ids:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
column="user_id",
iterable=user_ids,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync_with_iterable",
row_tuples = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
column="user_id",
iterable=user_ids,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
)

return {row[0] for row in row_tuples}
else:
rows = await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync",
rows = cast(
List[Dict[str, str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
retcols=("user_id",),
desc="get_user_ids_requiring_device_list_resync",
),
)

return {row["user_id"] for row in rows}
return {row["user_id"] for row in rows}
clokep marked this conversation as resolved.
Show resolved Hide resolved

async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
Expand Down
19 changes: 11 additions & 8 deletions synapse/storage/databases/main/end_to_end_keys.py
Expand Up @@ -493,15 +493,18 @@ async def get_e2e_one_time_keys(
A map from (algorithm, key_id) to json string for key
"""

rows = await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
retcols=("algorithm", "key_id", "key_json"),
keyvalues={"user_id": user_id, "device_id": device_id},
desc="add_e2e_one_time_keys_check",
rows = cast(
List[Tuple[str, str, str]],
await self.db_pool.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
retcols=("algorithm", "key_id", "key_json"),
keyvalues={"user_id": user_id, "device_id": device_id},
desc="add_e2e_one_time_keys_check",
),
)
result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
result = {(algorithm, key_id): key_json for algorithm, key_id, key_json in rows}
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result

Expand Down
107 changes: 57 additions & 50 deletions synapse/storage/databases/main/event_federation.py
Expand Up @@ -1049,26 +1049,29 @@ async def get_max_depth_of(
Args:
event_ids: The event IDs to calculate the max depth of.
"""
rows = await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=(
"event_id",
"depth",
rows = cast(
List[Tuple[str, int]],
await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=(
"event_id",
"depth",
),
desc="get_max_depth_of",
),
desc="get_max_depth_of",
)

if not rows:
return None, 0
else:
max_depth_event_id = ""
current_max_depth = 0
for row in rows:
if row["depth"] > current_max_depth:
max_depth_event_id = row["event_id"]
current_max_depth = row["depth"]
for event_id, depth in rows:
if depth > current_max_depth:
max_depth_event_id = event_id
current_max_depth = depth

return max_depth_event_id, current_max_depth

Expand All @@ -1078,26 +1081,29 @@ async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], i
Args:
event_ids: The event IDs to calculate the max depth of.
"""
rows = await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=(
"event_id",
"depth",
rows = cast(
List[Tuple[str, int]],
await self.db_pool.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
retcols=(
"event_id",
"depth",
),
desc="get_min_depth_of",
),
desc="get_min_depth_of",
)

if not rows:
return None, 0
else:
min_depth_event_id = ""
current_min_depth = MAX_DEPTH
for row in rows:
if row["depth"] < current_min_depth:
min_depth_event_id = row["event_id"]
current_min_depth = row["depth"]
for event_id, depth in rows:
if depth < current_min_depth:
min_depth_event_id = event_id
current_min_depth = depth

return min_depth_event_id, current_min_depth

Expand Down Expand Up @@ -1553,19 +1559,18 @@ async def get_event_ids_with_failed_pull_attempts(
A filtered down list of `event_ids` that have previous failed pull attempts.
"""

rows = await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=("event_id",),
desc="get_event_ids_with_failed_pull_attempts",
rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=("event_id",),
desc="get_event_ids_with_failed_pull_attempts",
),
)
event_ids_with_failed_pull_attempts: Set[str] = {
row["event_id"] for row in rows
}

return event_ids_with_failed_pull_attempts
return {row[0] for row in rows}

@trace
async def get_event_ids_to_not_pull_from_backoff(
Expand All @@ -1585,32 +1590,34 @@ async def get_event_ids_to_not_pull_from_backoff(
A dictionary of event_ids that should not be attempted to be pulled and the
next timestamp at which we may try pulling them again.
"""
event_failed_pull_attempts = await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=(
"event_id",
"last_attempt_ts",
"num_attempts",
event_failed_pull_attempts = cast(
List[Tuple[str, int, int]],
await self.db_pool.simple_select_many_batch(
table="event_failed_pull_attempts",
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=(
"event_id",
"last_attempt_ts",
"num_attempts",
),
desc="get_event_ids_to_not_pull_from_backoff",
),
desc="get_event_ids_to_not_pull_from_backoff",
)

current_time = self._clock.time_msec()

event_ids_with_backoff = {}
for event_failed_pull_attempt in event_failed_pull_attempts:
event_id = event_failed_pull_attempt["event_id"]
for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts:
# Exponential back-off (up to the upper bound) so we don't try to
# pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc.
backoff_end_time = (
event_failed_pull_attempt["last_attempt_ts"]
last_attempt_ts
+ (
2
** min(
event_failed_pull_attempt["num_attempts"],
num_attempts,
BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS,
)
)
Expand Down