Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert simple_select_one_txn and simple_select_one to return tuples. #16612

Merged
merged 4 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16612.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
3 changes: 1 addition & 2 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
backward_chunk = 0
already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
forward_chunk, backward_chunk = row

if total_to_port is None:
already_ported, total_to_port = await self._get_total_count_to_port(
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ async def _upgrade_room(
self,
requester: Requester,
old_room_id: str,
old_room: Dict[str, Any],
old_room: Tuple[bool, str, bool],
new_room_id: str,
new_version: RoomVersion,
tombstone_event: EventBase,
Expand All @@ -279,7 +279,7 @@ async def _upgrade_room(
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
old_room: a dict containing room information for the room to be replaced,
old_room: a tuple containing room information for the room to be replaced,
as returned by `RoomWorkerStore.get_room`.
new_room_id: the id of the replacement room
new_version: the version to upgrade the room to
Expand All @@ -299,7 +299,7 @@ async def _upgrade_room(
await self.store.store_room(
room_id=new_room_id,
room_creator_user_id=user_id,
is_public=old_room["is_public"],
is_public=old_room[0],
room_version=new_version,
)

Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,8 @@ async def transfer_room_state_on_room_upgrade(
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
if old_room is not None and old_room["is_public"]:
# If the old room exists and is public.
if old_room is not None and old_room[0]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)

Expand Down
3 changes: 2 additions & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,8 @@ async def room_is_in_public_room_list(self, room_id: str) -> bool:
if not room:
return False

return room.get("is_public", False)
# The first item is whether the room is public.
return room[0]

async def add_room_to_public_room_list(self, room_id: str) -> None:
"""Publishes a room to the public room list.
Expand Down
8 changes: 4 additions & 4 deletions synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ async def on_GET(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")

members = await self.store.get_users_in_room(room_id)
Expand Down Expand Up @@ -442,8 +442,8 @@ async def on_GET(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")

event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
if room is None:
raise NotFoundError("Unknown room")

return 200, {"visibility": "public" if room["is_public"] else "private"}
return 200, {"visibility": "public" if room[0] else "private"}

class PutBody(RequestBodyModel):
visibility: Literal["public", "private"] = "public"
Expand Down
10 changes: 5 additions & 5 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ async def simple_select_one(
retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
) -> Tuple[Any, ...]:
...

@overload
Expand All @@ -1608,7 +1608,7 @@ async def simple_select_one(
retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
...

async def simple_select_one(
Expand All @@ -1618,7 +1618,7 @@ async def simple_select_one(
retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.

Expand Down Expand Up @@ -2127,7 +2127,7 @@ def simple_select_one_txn(
keyvalues: Dict[str, Any],
retcols: Collection[str],
allow_none: bool = False,
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
select_sql = "SELECT %s FROM %s" % (", ".join(retcols), table)

if keyvalues:
Expand All @@ -2145,7 +2145,7 @@ def simple_select_one_txn(
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))

return dict(zip(retcols, row))
return row

async def simple_delete_one(
self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
Expand Down
43 changes: 13 additions & 30 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,33 +255,16 @@ async def get_device(
A dict containing the device information, or `None` if the device does not
exist.
"""
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)

async def get_device_opt(
self, user_id: str, device_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.

Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
A dict containing the device information, or None if the device does not exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
if row is None:
return None
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}

async def get_devices_by_user(
self, user_id: str
Expand Down Expand Up @@ -1221,9 +1204,7 @@ async def get_dehydrated_device(
retcols=["device_id", "device_data"],
allow_none=True,
)
return (
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
)
return (row[0], json_decoder.decode(row[1])) if row else None

def _store_dehydrated_device_txn(
self,
Expand Down Expand Up @@ -2326,13 +2307,15 @@ async def get_device_change_last_converted_pos(self) -> Tuple[int, str]:
`FALSE` have not been converted.
"""

row = await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
return cast(
Tuple[int, str],
await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
),
)
return row["stream_id"], row["room_id"]

async def set_device_change_last_converted_pos(
self,
Expand Down
31 changes: 19 additions & 12 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,19 +506,26 @@ def _get_e2e_room_keys_version_info_txn(txn: LoggingTransaction) -> JsonDict:
# it isn't there.
raise StoreError(404, "No backup with that version exists")

result = self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
row = cast(
Tuple[int, str, str, Optional[int]],
self.db_pool.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={
"user_id": user_id,
"version": this_version,
"deleted": 0,
},
retcols=("version", "algorithm", "auth_data", "etag"),
allow_none=False,
),
)
assert result is not None # see comment on `simple_select_one_txn`
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
return result
return {
"auth_data": db_to_json(row[2]),
"version": str(row[0]),
"algorithm": row[1],
"etag": 0 if row[3] is None else row[3],
}

return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
Expand Down
4 changes: 1 addition & 3 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,9 +1266,7 @@ async def _claim_e2e_fallback_keys_simple(
if row is None:
continue

key_id = row["key_id"]
key_json = row["key_json"]
used = row["used"]
key_id, key_json, used = row

# Mark fallback key as used if not already.
if not used and mark_as_used:
Expand Down
24 changes: 10 additions & 14 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ async def get_auth_chain_ids(
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
Expand Down Expand Up @@ -411,7 +412,8 @@ async def get_auth_chain_difference(
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[1]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
Expand Down Expand Up @@ -1437,24 +1439,18 @@ def _get_backfill_events(
)

if event_lookup_result is not None:
event_type, depth, stream_ordering = event_lookup_result
logger.debug(
"_get_backfill_events(room_id=%s): seed_event_id=%s depth=%s stream_ordering=%s type=%s",
room_id,
seed_event_id,
event_lookup_result["depth"],
event_lookup_result["stream_ordering"],
event_lookup_result["type"],
depth,
stream_ordering,
event_type,
)

if event_lookup_result["depth"]:
queue.put(
(
-event_lookup_result["depth"],
-event_lookup_result["stream_ordering"],
seed_event_id,
event_lookup_result["type"],
)
)
if depth:
queue.put((-depth, -stream_ordering, seed_event_id, event_type))

while not queue.empty() and len(event_id_results) < limit:
try:
Expand Down
3 changes: 1 addition & 2 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,8 +1934,7 @@ def _handle_redact_relations(
if row is None:
return

redacted_relates_to = row["relates_to_id"]
rel_type = row["relation_type"]
redacted_relates_to, rel_type = row
self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,7 +1998,7 @@ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))

return int(res["topological_ordering"]), int(res["stream_ordering"])
return int(res[0]), int(res[1])

async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
Expand Down
30 changes: 23 additions & 7 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,17 @@ async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
)
if row is None:
return None
return LocalMedia(media_id=media_id, **row)
return LocalMedia(
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
quarantined_by=row[4],
url_cache=row[5],
last_access_ts=row[6],
safe_from_quarantine=row[7],
)

async def get_local_media_by_user_paginate(
self,
Expand Down Expand Up @@ -541,7 +551,17 @@ async def get_cached_remote_media(
)
if row is None:
return row
return RemoteMedia(media_origin=origin, media_id=media_id, **row)
return RemoteMedia(
media_origin=origin,
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
filesystem_id=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
)

async def store_cached_remote_media(
self,
Expand Down Expand Up @@ -665,11 +685,7 @@ async def get_remote_media_thumbnail(
if row is None:
return None
return ThumbnailInfo(
width=row["thumbnail_width"],
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)

@trace
Expand Down