Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert simple_select_list and simple_select_list_txn to return lists…
Browse files Browse the repository at this point in the history
… of tuples (#16505)

This should use fewer allocations and improves type hints.
  • Loading branch information
clokep committed Oct 26, 2023
1 parent c14a7de commit 9407d5b
Show file tree
Hide file tree
Showing 31 changed files with 609 additions and 509 deletions.
1 change: 1 addition & 0 deletions changelog.d/16505.misc
@@ -0,0 +1 @@
Reduce memory allocations.
4 changes: 2 additions & 2 deletions synapse/handlers/deactivate_account.py
Expand Up @@ -103,10 +103,10 @@ async def deactivate_account(
# Attempt to unbind any known bound threepids to this account from identity
# server(s).
bound_threepids = await self.store.user_get_bound_threepids(user_id)
for threepid in bound_threepids:
for medium, address in bound_threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
user_id, threepid["medium"], threepid["address"], id_server
user_id, medium, address, id_server
)
except Exception:
# Do we want this to be a fatal error or should we carry on?
Expand Down
5 changes: 1 addition & 4 deletions synapse/handlers/sso.py
Expand Up @@ -1206,10 +1206,7 @@ async def revoke_sessions_for_provider_session_id(
# We have no guarantee that all the devices of that session are for the same
# `user_id`. Hence, we have to iterate over the list of devices and log them out
# one by one.
for device in devices:
user_id = device["user_id"]
device_id = device["device_id"]

for user_id, device_id in devices:
# If the user_id associated with that device/session is not the one we got
# out of the `sub` claim, skip that device and show log an error.
if expected_user_id is not None and user_id != expected_user_id:
Expand Down
31 changes: 16 additions & 15 deletions synapse/storage/database.py
Expand Up @@ -606,13 +606,16 @@ async def _check_safe_to_upsert(self) -> None:
If the background updates have not completed, wait 15 sec and check again.
"""
updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
updates = cast(
List[Tuple[str]],
await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
),
)
background_update_names = [x["update_name"] for x in updates]
background_update_names = [x[0] for x in updates]

for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
if update_name not in background_update_names:
Expand Down Expand Up @@ -1804,9 +1807,9 @@ async def simple_select_list(
keyvalues: Optional[Dict[str, Any]],
retcols: Collection[str],
desc: str = "simple_select_list",
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows, returning the result as a list of tuples.
Args:
table: the table name
Expand All @@ -1817,8 +1820,7 @@ async def simple_select_list(
desc: description of the transaction, for logging and metrics
Returns:
A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
A list of tuples, one per result row, each the retcolumn's value for the row.
"""
return await self.runInteraction(
desc,
Expand All @@ -1836,9 +1838,9 @@ def simple_select_list_txn(
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows, returning the result as a list of tuples.
Args:
txn: Transaction object
Expand All @@ -1849,8 +1851,7 @@ def simple_select_list_txn(
retcols: the names of the columns to return
Returns:
A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
A list of tuples, one per result row, each the retcolumn's value for the row.
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
Expand All @@ -1863,7 +1864,7 @@ def simple_select_list_txn(
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)

return cls.cursor_to_dict(txn)
return txn.fetchall()

async def simple_select_many_batch(
self,
Expand Down
18 changes: 11 additions & 7 deletions synapse/storage/databases/main/account_data.py
Expand Up @@ -286,16 +286,20 @@ async def get_account_data_for_room(

def get_account_data_for_room_txn(
txn: LoggingTransaction,
) -> Dict[str, JsonDict]:
rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
["account_data_type", "content"],
) -> Dict[str, JsonMapping]:
rows = cast(
List[Tuple[str, str]],
self.db_pool.simple_select_list_txn(
txn,
table="room_account_data",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=["account_data_type", "content"],
),
)

return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
account_data_type: db_to_json(content)
for account_data_type, content in rows
}

return await self.db_pool.runInteraction(
Expand Down
13 changes: 9 additions & 4 deletions synapse/storage/databases/main/appservice.py
Expand Up @@ -197,16 +197,21 @@ async def get_appservices_by_state(
Returns:
A list of ApplicationServices, which may be empty.
"""
results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state.value}, ["as_id"]
results = cast(
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="application_services_state",
keyvalues={"state": state.value},
retcols=("as_id",),
),
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
services = []

for res in results:
for (as_id,) in results:
for service in as_list:
if service.id == res["as_id"]:
if service.id == as_id:
services.append(service)
return services

Expand Down
25 changes: 14 additions & 11 deletions synapse/storage/databases/main/client_ips.py
Expand Up @@ -508,21 +508,24 @@ async def _get_last_client_ip_by_device_from_database(
if device_id is not None:
keyvalues["device_id"] = device_id

res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
res = cast(
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
)

return {
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
user_id=d["user_id"],
device_id=d["device_id"],
ip=d["ip"],
user_agent=d["user_agent"],
last_seen=d["last_seen"],
(user_id, device_id): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip=ip,
user_agent=user_agent,
last_seen=last_seen,
)
for d in res
for user_id, ip, user_agent, device_id, last_seen in res
}

async def _get_user_ip_and_agents_from_database(
Expand Down
70 changes: 40 additions & 30 deletions synapse/storage/databases/main/devices.py
Expand Up @@ -283,28 +283,36 @@ async def get_device_opt(
allow_none=True,
)

async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
async def get_devices_by_user(
self, user_id: str
) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.
Args:
user_id:
Returns:
A mapping from device_id to a dict containing "device_id", "user_id"
and "display_name" for each device.
and "display_name" for each device. Display name may be null.
"""
devices = await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
devices = cast(
List[Tuple[str, str, Optional[str]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
),
)

return {d["device_id"]: d for d in devices}
return {
d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
for d in devices
}

async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.
Args:
Expand All @@ -313,14 +321,17 @@ async def get_devices_by_auth_provider_session_id(
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
return await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
),
)

@trace
Expand Down Expand Up @@ -821,15 +832,16 @@ async def _get_cached_user_device(
async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="get_cached_devices_for_user",
devices = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="get_cached_devices_for_user",
),
)
return {
device["device_id"]: db_to_json(device["content"]) for device in devices
}
return {device[0]: db_to_json(device[1]) for device in devices}

def get_cached_device_list_changes(
self,
Expand Down Expand Up @@ -1080,7 +1092,7 @@ async def get_user_ids_requiring_device_list_resync(
The IDs of users whose device lists need resync.
"""
if user_ids:
row_tuples = cast(
rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
Expand All @@ -1090,11 +1102,9 @@ async def get_user_ids_requiring_device_list_resync(
desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
)

return {row[0] for row in row_tuples}
else:
rows = cast(
List[Dict[str, str]],
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
Expand All @@ -1103,7 +1113,7 @@ async def get_user_ids_requiring_device_list_resync(
),
)

return {row["user_id"] for row in rows}
return {row[0] for row in rows}

async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
Expand Down
49 changes: 29 additions & 20 deletions synapse/storage/databases/main/e2e_room_keys.py
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast

from typing_extensions import Literal, TypedDict

Expand Down Expand Up @@ -274,32 +274,41 @@ async def get_e2e_room_keys(
if session_id:
keyvalues["session_id"] = session_id

rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
"user_id",
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
rows = cast(
List[Tuple[str, str, int, int, int, str]],
await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
),
desc="get_e2e_room_keys",
),
desc="get_e2e_room_keys",
)

sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
for row in rows:
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
for (
room_id,
session_id,
first_message_index,
forwarded_count,
is_verified,
session_data,
) in rows:
room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
room_entry["sessions"][session_id] = {
"first_message_index": first_message_index,
"forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
"session_data": db_to_json(row["session_data"]),
"is_verified": bool(is_verified),
"session_data": db_to_json(session_data),
}

return sessions
Expand Down

0 comments on commit 9407d5b

Please sign in to comment.