Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert simple_select_one to return tuples.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Nov 8, 2023
1 parent 381fc7c commit c4b714e
Show file tree
Hide file tree
Showing 27 changed files with 199 additions and 196 deletions.
3 changes: 1 addition & 2 deletions synapse/_scripts/synapse_port_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
backward_chunk = 0
already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
forward_chunk, backward_chunk = row

if total_to_port is None:
already_ported, total_to_port = await self._get_total_count_to_port(
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ async def _upgrade_room(
self,
requester: Requester,
old_room_id: str,
old_room: Dict[str, Any],
old_room: Tuple[bool, str, bool],
new_room_id: str,
new_version: RoomVersion,
tombstone_event: EventBase,
Expand All @@ -279,7 +279,7 @@ async def _upgrade_room(
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
old_room: a dict containing room information for the room to be replaced,
old_room: a tuple containing room information for the room to be replaced,
as returned by `RoomWorkerStore.get_room`.
new_room_id: the id of the replacement room
new_version: the version to upgrade the room to
Expand All @@ -299,7 +299,7 @@ async def _upgrade_room(
await self.store.store_room(
room_id=new_room_id,
room_creator_user_id=user_id,
is_public=old_room["is_public"],
is_public=old_room[0],
room_version=new_version,
)

Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,8 @@ async def transfer_room_state_on_room_upgrade(
# Add new room to the room directory if the old room was there
# Remove old room from the room directory
old_room = await self.store.get_room(old_room_id)
if old_room is not None and old_room["is_public"]:
# If
if old_room is not None and old_room[0]:
await self.store.set_room_is_public(old_room_id, False)
await self.store.set_room_is_public(room_id, True)

Expand Down
3 changes: 2 additions & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,8 @@ async def room_is_in_public_room_list(self, room_id: str) -> bool:
if not room:
return False

return room.get("is_public", False)
# The first item is whether the room is public.
return room[0]

async def add_room_to_public_room_list(self, room_id: str) -> None:
"""Publishes a room to the public room list.
Expand Down
8 changes: 4 additions & 4 deletions synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,8 @@ async def on_GET(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")

members = await self.store.get_users_in_room(room_id)
Expand Down Expand Up @@ -442,8 +442,8 @@ async def on_GET(
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

ret = await self.store.get_room(room_id)
if not ret:
room = await self.store.get_room(room_id)
if not room:
raise NotFoundError("Room not found")

event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
Expand Down
2 changes: 1 addition & 1 deletion synapse/rest/client/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
if room is None:
raise NotFoundError("Unknown room")

return 200, {"visibility": "public" if room["is_public"] else "private"}
return 200, {"visibility": "public" if room[0] else "private"}

class PutBody(RequestBodyModel):
visibility: Literal["public", "private"] = "public"
Expand Down
9 changes: 4 additions & 5 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ async def simple_select_one(
retcols: Collection[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
) -> Tuple[Any, ...]:
...

@overload
Expand All @@ -1608,7 +1608,7 @@ async def simple_select_one(
retcols: Collection[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
...

async def simple_select_one(
Expand All @@ -1618,7 +1618,7 @@ async def simple_select_one(
retcols: Collection[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
) -> Optional[Tuple[Any, ...]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
Expand All @@ -1630,7 +1630,7 @@ async def simple_select_one(
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
row = await self.runInteraction(
return await self.runInteraction(
desc,
self.simple_select_one_txn,
table,
Expand All @@ -1639,7 +1639,6 @@ async def simple_select_one(
allow_none,
db_autocommit=True,
)
return dict(zip(retcols, row)) if row is not None else row

@overload
async def simple_select_one_onecol(
Expand Down
43 changes: 13 additions & 30 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,33 +255,16 @@ async def get_device(
A dict containing the device information, or `None` if the device does not
exist.
"""
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)

async def get_device_opt(
self, user_id: str, device_id: str
) -> Optional[Dict[str, Any]]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
A dict containing the device information, or None if the device does not exist.
"""
return await self.db_pool.simple_select_one(
row = await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
allow_none=True,
)
if row is None:
return None
return {"user_id": row[0], "device_id": row[1], "display_name": row[2]}

async def get_devices_by_user(
self, user_id: str
Expand Down Expand Up @@ -1221,9 +1204,7 @@ async def get_dehydrated_device(
retcols=["device_id", "device_data"],
allow_none=True,
)
return (
(row["device_id"], json_decoder.decode(row["device_data"])) if row else None
)
return (row[0], json_decoder.decode(row[1])) if row else None

def _store_dehydrated_device_txn(
self,
Expand Down Expand Up @@ -2326,13 +2307,15 @@ async def get_device_change_last_converted_pos(self) -> Tuple[int, str]:
`FALSE` have not been converted.
"""

row = await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
return cast(
Tuple[int, str],
await self.db_pool.simple_select_one(
table="device_lists_changes_converted_stream_position",
keyvalues={},
retcols=["stream_id", "room_id"],
desc="get_device_change_last_converted_pos",
),
)
return row["stream_id"], row["room_id"]

async def set_device_change_last_converted_pos(
self,
Expand Down
4 changes: 1 addition & 3 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,9 +1268,7 @@ async def _claim_e2e_fallback_keys_simple(
if row is None:
continue

key_id = row["key_id"]
key_json = row["key_json"]
used = row["used"]
key_id, key_json, used = row

# Mark fallback key as used if not already.
if not used and mark_as_used:
Expand Down
6 changes: 4 additions & 2 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ async def get_auth_chain_ids(
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[2]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_ids_chains",
Expand Down Expand Up @@ -411,7 +412,8 @@ async def get_auth_chain_difference(
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
if room["has_auth_chain_index"]:
# If the room has an auth chain index.
if room[2]:
try:
return await self.db_pool.runInteraction(
"get_auth_chain_difference_chains",
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,7 +1998,7 @@ async def get_event_ordering(self, event_id: str) -> Tuple[int, int]:
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))

return int(res["topological_ordering"]), int(res["stream_ordering"])
return int(res[0]), int(res[1])

async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
Expand Down
30 changes: 23 additions & 7 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,17 @@ async def get_local_media(self, media_id: str) -> Optional[LocalMedia]:
)
if row is None:
return None
return LocalMedia(media_id=media_id, **row)
return LocalMedia(
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
quarantined_by=row[4],
url_cache=row[5],
last_access_ts=row[6],
safe_from_quarantine=row[7],
)

async def get_local_media_by_user_paginate(
self,
Expand Down Expand Up @@ -541,7 +551,17 @@ async def get_cached_remote_media(
)
if row is None:
return row
return RemoteMedia(media_origin=origin, media_id=media_id, **row)
return RemoteMedia(
media_origin=origin,
media_id=media_id,
media_type=row[0],
media_length=row[1],
upload_name=row[2],
created_ts=row[3],
filesystem_id=row[4],
last_access_ts=row[5],
quarantined_by=row[6],
)

async def store_cached_remote_media(
self,
Expand Down Expand Up @@ -665,11 +685,7 @@ async def get_remote_media_thumbnail(
if row is None:
return None
return ThumbnailInfo(
width=row["thumbnail_width"],
height=row["thumbnail_height"],
method=row["thumbnail_method"],
type=row["thumbnail_type"],
length=row["thumbnail_length"],
width=row[0], height=row[1], method=row[2], type=row[3], length=row[4]
)

@trace
Expand Down
28 changes: 11 additions & 17 deletions synapse/storage/databases/main/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from typing import TYPE_CHECKING, Optional

from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
Expand Down Expand Up @@ -138,23 +137,18 @@ def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
return 50

async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"full_user_id": user_id.to_string()},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
)
except StoreError as e:
if e.code == 404:
# no match
return ProfileInfo(None, None)
else:
raise

return ProfileInfo(
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
profile = await self.db_pool.simple_select_one(
table="profiles",
keyvalues={"full_user_id": user_id.to_string()},
retcols=("displayname", "avatar_url"),
desc="get_profileinfo",
allow_none=True,
)
if profile is None:
# no match
return ProfileInfo(None, None)

return ProfileInfo(avatar_url=profile[1], display_name=profile[0])

async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
Expand Down

0 comments on commit c4b714e

Please sign in to comment.