Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert simple_select_list and simple_select_list_txn to return lists of tuples #16505

Merged
merged 7 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16505.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Reduce memory allocations.
4 changes: 2 additions & 2 deletions synapse/handlers/deactivate_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ async def deactivate_account(
# Attempt to unbind any known bound threepids to this account from identity
# server(s).
bound_threepids = await self.store.user_get_bound_threepids(user_id)
for threepid in bound_threepids:
for medium, address in bound_threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
user_id, threepid["medium"], threepid["address"], id_server
user_id, medium, address, id_server
)
except Exception:
# Do we want this to be a fatal error or should we carry on?
Expand Down
5 changes: 1 addition & 4 deletions synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,10 +1206,7 @@ async def revoke_sessions_for_provider_session_id(
# We have no guarantee that all the devices of that session are for the same
# `user_id`. Hence, we have to iterate over the list of devices and log them out
# one by one.
for device in devices:
user_id = device["user_id"]
device_id = device["device_id"]

for user_id, device_id in devices:
# If the user_id associated with that device/session is not the one we got
# out of the `sub` claim, skip that device and show log an error.
if expected_user_id is not None and user_id != expected_user_id:
Expand Down
31 changes: 16 additions & 15 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,13 +606,16 @@ async def _check_safe_to_upsert(self) -> None:

If the background updates have not completed, wait 15 sec and check again.
"""
updates = await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
updates = cast(
List[Tuple[str]],
await self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
desc="check_background_updates",
),
)
background_update_names = [x["update_name"] for x in updates]
background_update_names = [x[0] for x in updates]

for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
if update_name not in background_update_names:
Expand Down Expand Up @@ -1804,9 +1807,9 @@ async def simple_select_list(
keyvalues: Optional[Dict[str, Any]],
retcols: Collection[str],
desc: str = "simple_select_list",
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows, returning the result as a list of tuples.

Args:
table: the table name
Expand All @@ -1817,8 +1820,7 @@ async def simple_select_list(
desc: description of the transaction, for logging and metrics

Returns:
A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
A list of tuples, one per result row, each the retcolumn's value for the row.
"""
return await self.runInteraction(
desc,
Expand All @@ -1836,9 +1838,9 @@ def simple_select_list_txn(
table: str,
keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str],
) -> List[Dict[str, Any]]:
) -> List[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
more rows, returning the result as a list of tuples.

Args:
txn: Transaction object
Expand All @@ -1849,8 +1851,7 @@ def simple_select_list_txn(
retcols: the names of the columns to return

Returns:
A list of dictionaries, one per result row, each a mapping between the
column names from `retcols` and that column's value for the row.
A list of tuples, one per result row, each the retcolumn's value for the row.
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (
Expand All @@ -1863,7 +1864,7 @@ def simple_select_list_txn(
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)

return cls.cursor_to_dict(txn)
return txn.fetchall()

async def simple_select_many_batch(
self,
Expand Down
18 changes: 11 additions & 7 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,16 +286,20 @@ async def get_account_data_for_room(

def get_account_data_for_room_txn(
txn: LoggingTransaction,
) -> Dict[str, JsonDict]:
rows = self.db_pool.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
["account_data_type", "content"],
) -> Dict[str, JsonMapping]:
rows = cast(
List[Tuple[str, str]],
self.db_pool.simple_select_list_txn(
txn,
table="room_account_data",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=["account_data_type", "content"],
),
)

return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
account_data_type: db_to_json(content)
for account_data_type, content in rows
}

return await self.db_pool.runInteraction(
Expand Down
13 changes: 9 additions & 4 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,16 +197,21 @@ async def get_appservices_by_state(
Returns:
A list of ApplicationServices, which may be empty.
"""
results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state.value}, ["as_id"]
results = cast(
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="application_services_state",
keyvalues={"state": state.value},
retcols=("as_id",),
),
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
services = []

for res in results:
for (as_id,) in results:
for service in as_list:
if service.id == res["as_id"]:
if service.id == as_id:
services.append(service)
return services

Expand Down
25 changes: 14 additions & 11 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,21 +508,24 @@ async def _get_last_client_ip_by_device_from_database(
if device_id is not None:
keyvalues["device_id"] = device_id

res = await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
res = cast(
List[Tuple[str, Optional[str], Optional[str], str, Optional[int]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
),
)

return {
(d["user_id"], d["device_id"]): DeviceLastConnectionInfo(
user_id=d["user_id"],
device_id=d["device_id"],
ip=d["ip"],
user_agent=d["user_agent"],
last_seen=d["last_seen"],
(user_id, device_id): DeviceLastConnectionInfo(
user_id=user_id,
device_id=device_id,
ip=ip,
user_agent=user_agent,
last_seen=last_seen,
)
for d in res
for user_id, ip, user_agent, device_id, last_seen in res
}

async def _get_user_ip_and_agents_from_database(
Expand Down
70 changes: 40 additions & 30 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,28 +283,36 @@ async def get_device_opt(
allow_none=True,
)

async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
async def get_devices_by_user(
self, user_id: str
) -> Dict[str, Dict[str, Optional[str]]]:
"""Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden.

Args:
user_id:
Returns:
A mapping from device_id to a dict containing "device_id", "user_id"
and "display_name" for each device.
and "display_name" for each device. Display name may be null.
"""
devices = await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
devices = cast(
List[Tuple[str, str, Optional[str]]],
await self.db_pool.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
),
)

return {d["device_id"]: d for d in devices}
return {
d[1]: {"user_id": d[0], "device_id": d[1], "display_name": d[2]}
for d in devices
}

async def get_devices_by_auth_provider_session_id(
self, auth_provider_id: str, auth_provider_session_id: str
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, str]]:
"""Retrieve the list of devices associated with a SSO IdP session ID.

Args:
Expand All @@ -313,14 +321,17 @@ async def get_devices_by_auth_provider_session_id(
Returns:
A list of dicts containing the device_id and the user_id of each device
"""
return await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
return cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_auth_providers",
keyvalues={
"auth_provider_id": auth_provider_id,
"auth_provider_session_id": auth_provider_session_id,
},
retcols=("user_id", "device_id"),
desc="get_devices_by_auth_provider_session_id",
),
)

@trace
Expand Down Expand Up @@ -821,15 +832,16 @@ async def _get_cached_user_device(
async def get_cached_devices_for_user(
self, user_id: str
) -> Mapping[str, JsonMapping]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="get_cached_devices_for_user",
devices = cast(
List[Tuple[str, str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="get_cached_devices_for_user",
),
)
return {
device["device_id"]: db_to_json(device["content"]) for device in devices
}
return {device[0]: db_to_json(device[1]) for device in devices}

def get_cached_device_list_changes(
self,
Expand Down Expand Up @@ -1080,7 +1092,7 @@ async def get_user_ids_requiring_device_list_resync(
The IDs of users whose device lists need resync.
"""
if user_ids:
row_tuples = cast(
rows = cast(
List[Tuple[str]],
await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync",
Expand All @@ -1090,11 +1102,9 @@ async def get_user_ids_requiring_device_list_resync(
desc="get_user_ids_requiring_device_list_resync_with_iterable",
),
)

return {row[0] for row in row_tuples}
else:
rows = cast(
List[Dict[str, str]],
List[Tuple[str]],
await self.db_pool.simple_select_list(
table="device_lists_remote_resync",
keyvalues=None,
Expand All @@ -1103,7 +1113,7 @@ async def get_user_ids_requiring_device_list_resync(
),
)

return {row["user_id"] for row in rows}
return {row[0] for row in rows}

async def mark_remote_users_device_caches_as_stale(
self, user_ids: StrCollection
Expand Down
49 changes: 29 additions & 20 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, cast
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, cast

from typing_extensions import Literal, TypedDict

Expand Down Expand Up @@ -274,32 +274,41 @@ async def get_e2e_room_keys(
if session_id:
keyvalues["session_id"] = session_id

rows = await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
"user_id",
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
rows = cast(
List[Tuple[str, str, int, int, int, str]],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first_message_index and forwarded_count and is_verified is nullable. (The last doesn't matter we cast to a bool below.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #16548.

await self.db_pool.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
"room_id",
"session_id",
"first_message_index",
"forwarded_count",
"is_verified",
"session_data",
),
desc="get_e2e_room_keys",
),
desc="get_e2e_room_keys",
)

sessions: Dict[
Literal["rooms"], Dict[str, Dict[Literal["sessions"], Dict[str, RoomKey]]]
] = {"rooms": {}}
for row in rows:
room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
for (
room_id,
session_id,
first_message_index,
forwarded_count,
is_verified,
session_data,
) in rows:
room_entry = sessions["rooms"].setdefault(room_id, {"sessions": {}})
room_entry["sessions"][session_id] = {
"first_message_index": first_message_index,
"forwarded_count": forwarded_count,
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
"session_data": db_to_json(row["session_data"]),
"is_verified": bool(is_verified),
"session_data": db_to_json(session_data),
}

return sessions
Expand Down
Loading
Loading