Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Remove more usages of cursor_to_dict #16551

Merged
merged 12 commits into from
Oct 26, 2023
1 change: 1 addition & 0 deletions changelog.d/16551.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
18 changes: 9 additions & 9 deletions synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple

import attr

from synapse.api.errors import (
CodeMessageException,
Codes,
Expand Down Expand Up @@ -357,9 +359,9 @@ async def send_threepid_validation(

# Check to see if a session already exists and that it is not yet
# marked as validated
if session and session.get("validated_at") is None:
session_id = session["session_id"]
last_send_attempt = session["last_send_attempt"]
if session and session.validated_at is None:
session_id = session.session_id
last_send_attempt = session.last_send_attempt

# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
Expand Down Expand Up @@ -480,27 +482,25 @@ async def validate_threepid_session(

# We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn
validation_session = None

# Try to validate as email
if self.hs.config.email.can_verify_email:
# Get a validated session matching these details
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)

if validation_session:
return validation_session
if validation_session:
return attr.asdict(validation_session)

# Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds(
return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)

return validation_session
return None

async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/ui_auth/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ async def _check_threepid(self, medium: str, authdict: dict) -> dict:

if row:
threepid = {
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
"medium": row.medium,
"address": row.address,
"validated_at": row.validated_at,
}

# Valid threepid returned, delete from the db
Expand Down
5 changes: 1 addition & 4 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,10 +949,7 @@ async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:

deleted = 0

for media in old_media:
origin = media["media_origin"]
media_id = media["media_id"]
file_id = media["filesystem_id"]
for origin, media_id, file_id in old_media:
key = (origin, media_id)

logger.info("Deleting: %r", key)
Expand Down
14 changes: 13 additions & 1 deletion synapse/rest/admin/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,19 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction
)
response = {"destinations": destinations, "total": total}
response = {
"destinations": [
{
"destination": r[0],
"retry_last_ts": r[1],
"retry_interval": r[2],
"failure_ts": r[3],
"last_successful_stream_ordering": r[4],
}
for r in destinations
],
"total": total,
}
if (start + limit) < total:
response["next_token"] = str(start + len(destinations))

Expand Down
12 changes: 11 additions & 1 deletion synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,17 @@ async def on_GET(
room_id, _ = await self.resolve_room_id(room_identifier)

extremities = await self.store.get_forward_extremities_for_room(room_id)
return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
result = [
{
"event_id": ex[0],
"state_group": ex[1],
"depth": ex[2],
"received_ts": ex[3],
}
for ex in extremities
]

return HTTPStatus.OK, {"count": len(extremities), "results": result}


class RoomEventContextServlet(RestServlet):
Expand Down
13 changes: 12 additions & 1 deletion synapse/rest/admin/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,18 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
)
ret = {"users": users_media, "total": total}
ret = {
"users": [
{
"user_id": r[0],
"displayname": r[1],
"media_count": r[2],
"media_length": r[3],
}
for r in users_media
],
"total": total,
}
if (start + limit) < total:
ret["next_token"] = start + len(users_media)

Expand Down
30 changes: 3 additions & 27 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
Expand Down Expand Up @@ -1047,43 +1046,20 @@ def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
results = [dict(zip(col_headers, row)) for row in cursor]
return results

@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...

@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...

async def execute(
self,
desc: str,
decoder: Optional[Callable[[Cursor], R]],
query: str,
*args: Any,
) -> Union[List[Tuple[Any, ...]], R]:
async def execute(self, desc: str, query: str, *args: Any) -> List[Tuple[Any, ...]]:
"""Runs a single query for a result set.

Args:
desc: description of the transaction, for logging and metrics
decoder - The function which can resolve the cursor results to
something meaningful.
query - The query string to execute
*args - Query args.
Returns:
The result of decoder(results)
"""

def interaction(txn: LoggingTransaction) -> Union[List[Tuple[Any, ...]], R]:
def interaction(txn: LoggingTransaction) -> List[Tuple[Any, ...]]:
txn.execute(query, args)
if decoder:
return decoder(txn)
else:
return txn.fetchall()
return txn.fetchall()

return await self.runInteraction(desc, interaction)

Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/censor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def _censor_redactions(self) -> None:
"""

rows = await self.db_pool.execute(
"_censor_redactions_fetch", None, sql, before_ts, 100
"_censor_redactions_fetch", sql, before_ts, 100
)

updates = []
Expand Down
3 changes: 1 addition & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,6 @@ async def get_all_devices_changed(

rows = await self.db_pool.execute(
"get_all_devices_changed",
None,
sql,
from_key,
to_key,
Expand Down Expand Up @@ -978,7 +977,7 @@ async def get_users_whose_signatures_changed(
WHERE from_user_id = ? AND stream_id > ?
"""
rows = await self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
"get_users_whose_signatures_changed", sql, user_id, from_key
)
return {user for row in rows for user in db_to_json(row[0])}
else:
Expand Down
1 change: 0 additions & 1 deletion synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ async def get_e2e_device_keys_for_federation_query(
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
None,
sql,
now_stream_id,
user_id,
Expand Down
7 changes: 2 additions & 5 deletions synapse/storage/databases/main/events_bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,12 +1310,9 @@ def process(txn: Cursor) -> None:

# ANALYZE the new column to build stats on it, to encourage PostgreSQL to use the
# indexes on it.
# We need to pass execute a dummy function to handle the txn's result otherwise
# it tries to call fetchall() on it and fails because there's no result to fetch.
await self.db_pool.execute(
await self.db_pool.runInteraction(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the point that we should use runInteraction when we don't expect---or don't care---about the result set?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the only place where we used a decoder before was where we were trying to ignore the result set!?!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the point that we should use runInteraction when we don't expect---or don't care---about the result set?

execute(...) automatically calls fetchall() (or iterates the txn if you use cursor_to_dict. You could do hacky things like this to not iterate the txn and just leave it, but...I think this is as clear.

and the only place where we used a decoder before was where we were trying to ignore the result set!?!

Or we passed cursor_to_dict in many many places, which are all gone.

"background_analyze_new_stream_ordering_column",
lambda txn: None,
"ANALYZE events(stream_ordering2)",
lambda txn: txn.execute("ANALYZE events(stream_ordering2)"),
)

await self.db_pool.runInteraction(
Expand Down
15 changes: 10 additions & 5 deletions synapse/storage/databases/main/events_forward_extremities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, List
from typing import List, Optional, Tuple, cast

from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
Expand Down Expand Up @@ -91,12 +91,17 @@ def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int:

async def get_forward_extremities_for_room(
self, room_id: str
) -> List[Dict[str, Any]]:
"""Get list of forward extremities for a room."""
) -> List[Tuple[str, int, int, Optional[int]]]:
"""
Get list of forward extremities for a room.

Returns:
A list of tuples of event_id, state_group, depth, and received_ts.
"""

def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, int, int, Optional[int]]]:
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
Expand All @@ -106,7 +111,7 @@ def get_forward_extremities_for_room_txn(
"""

txn.execute(sql, (room_id,))
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, int, int, Optional[int]]], txn.fetchall())

return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
Expand Down
19 changes: 11 additions & 8 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ async def store_remote_media_thumbnail(

async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
) -> List[Dict[str, str]]:
) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.

Expand All @@ -664,21 +664,24 @@ async def get_remote_media_ids(
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
* The filesystem ID.
"""

sql = """
SELECT media_origin, media_id, filesystem_id
FROM remote_media_cache
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
WHERE last_access_ts < ?
"""
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)

if include_quarantined_media is False:
# Only include media that has not been quarantined
sql += """
AND quarantined_by IS NULL
"""

return await self.db_pool.execute(
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
return cast(
List[Tuple[str, str, str]],
await self.db_pool.execute("get_remote_media_ids", sql, before_ts),
)

async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
Expand Down
Loading
Loading