Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Remove more usages of cursor_to_dict #16551

Merged
merged 12 commits into from
Oct 26, 2023
18 changes: 9 additions & 9 deletions synapse/handlers/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import urllib.parse
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple

import attr

from synapse.api.errors import (
CodeMessageException,
Codes,
Expand Down Expand Up @@ -357,9 +359,9 @@ async def send_threepid_validation(

# Check to see if a session already exists and that it is not yet
# marked as validated
if session and session.get("validated_at") is None:
session_id = session["session_id"]
last_send_attempt = session["last_send_attempt"]
if session and session.validated_at is None:
session_id = session.session_id
last_send_attempt = session.last_send_attempt

# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
Expand Down Expand Up @@ -480,27 +482,25 @@ async def validate_threepid_session(

# We don't actually know which medium this 3PID is. Thus we first assume it's email,
# and if validation fails we try msisdn
validation_session = None

# Try to validate as email
if self.hs.config.email.can_verify_email:
# Get a validated session matching these details
validation_session = await self.store.get_threepid_validation_session(
"email", client_secret, sid=sid, validated=True
)

if validation_session:
return validation_session
if validation_session:
return attr.asdict(validation_session)

# Try to validate as msisdn
if self.hs.config.registration.account_threepid_delegate_msisdn:
# Ask our delegated msisdn identity server
validation_session = await self.threepid_from_creds(
return await self.threepid_from_creds(
self.hs.config.registration.account_threepid_delegate_msisdn,
threepid_creds,
)

return validation_session
return None

async def proxy_msisdn_submit_token(
self, id_server: str, client_secret: str, sid: str, token: str
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/ui_auth/checkers.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ async def _check_threepid(self, medium: str, authdict: dict) -> dict:

if row:
threepid = {
"medium": row["medium"],
"address": row["address"],
"validated_at": row["validated_at"],
"medium": row.medium,
"address": row.address,
"validated_at": row.validated_at,
}

# Valid threepid returned, delete from the db
Expand Down
5 changes: 1 addition & 4 deletions synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,10 +949,7 @@ async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:

deleted = 0

for media in old_media:
origin = media["media_origin"]
media_id = media["media_id"]
file_id = media["filesystem_id"]
for origin, media_id, file_id in old_media:
key = (origin, media_id)

logger.info("Deleting: %r", key)
Expand Down
12 changes: 11 additions & 1 deletion synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,17 @@ async def on_GET(
room_id, _ = await self.resolve_room_id(room_identifier)

extremities = await self.store.get_forward_extremities_for_room(room_id)
return HTTPStatus.OK, {"count": len(extremities), "results": extremities}
result = [
{
"event_id": ex[0],
"state_group": ex[1],
"depth": ex[2],
"received_ts": ex[3],
}
for ex in extremities
]

return HTTPStatus.OK, {"count": len(extremities), "results": result}


class RoomEventContextServlet(RestServlet):
Expand Down
13 changes: 12 additions & 1 deletion synapse/rest/admin/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,18 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term
)
ret = {"users": users_media, "total": total}
ret = {
"users": [
{
"user_id": r[0],
"displayname": r[1],
"media_count": r[2],
"media_length": r[3],
}
for r in users_media
],
"total": total,
}
if (start + limit) < total:
ret["next_token"] = start + len(users_media)

Expand Down
15 changes: 10 additions & 5 deletions synapse/storage/databases/main/events_forward_extremities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, List
from typing import List, Tuple, cast

from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
Expand Down Expand Up @@ -91,12 +91,17 @@ def delete_forward_extremities_for_room_txn(txn: LoggingTransaction) -> int:

async def get_forward_extremities_for_room(
self, room_id: str
) -> List[Dict[str, Any]]:
"""Get list of forward extremities for a room."""
) -> List[Tuple[str, int, int, int]]:
"""
Get list of forward extremities for a room.

Returns:
A list of tuples of event_id, state_group, depth, and received_ts.
"""

def get_forward_extremities_for_room_txn(
txn: LoggingTransaction,
) -> List[Dict[str, Any]]:
) -> List[Tuple[str, int, int, int]]:
clokep marked this conversation as resolved.
Show resolved Hide resolved
sql = """
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
Expand All @@ -106,7 +111,7 @@ def get_forward_extremities_for_room_txn(
"""

txn.execute(sql, (room_id,))
return self.db_pool.cursor_to_dict(txn)
return cast(List[Tuple[str, int, int, int]], txn.fetchall())

return await self.db_pool.runInteraction(
"get_forward_extremities_for_room",
Expand Down
19 changes: 11 additions & 8 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ async def store_remote_media_thumbnail(

async def get_remote_media_ids(
self, before_ts: int, include_quarantined_media: bool
) -> List[Dict[str, str]]:
) -> List[Tuple[str, str, str]]:
"""
Retrieve a list of server name, media ID tuples from the remote media cache.

Expand All @@ -666,21 +666,24 @@ async def get_remote_media_ids(
A list of tuples containing:
* The server name of homeserver where the media originates from,
* The ID of the media.
* The filesystem ID.
"""

sql = """
SELECT media_origin, media_id, filesystem_id
FROM remote_media_cache
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
WHERE last_access_ts < ?
"""
sql = (
"SELECT media_origin, media_id, filesystem_id"
" FROM remote_media_cache"
" WHERE last_access_ts < ?"
)

if include_quarantined_media is False:
# Only include media that has not been quarantined
sql += """
AND quarantined_by IS NULL
"""

return await self.db_pool.execute(
"get_remote_media_ids", self.db_pool.cursor_to_dict, sql, before_ts
return cast(
List[Tuple[str, str, str]],
await self.db_pool.execute("get_remote_media_ids", None, sql, before_ts),
)

async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
Expand Down
43 changes: 29 additions & 14 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,22 @@ class ThreepidResult:
added_at: int


@attr.s(frozen=True, slots=True, auto_attribs=True)
class ThreepidValidationSession:
address: str
"""address of the 3pid"""
medium: str
"""medium of the 3pid"""
client_secret: str
"""a secret provided by the client for this validation session"""
session_id: str
"""ID of the validation session"""
last_send_attempt: int
"""a number serving to dedupe send attempts for this session"""
validated_at: Optional[int]
"""timestamp of when this session was validated if so"""


class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
Expand Down Expand Up @@ -1156,7 +1172,7 @@ async def get_threepid_validation_session(
address: Optional[str] = None,
sid: Optional[str] = None,
validated: Optional[bool] = True,
) -> Optional[Dict[str, Any]]:
) -> Optional[ThreepidValidationSession]:
"""Gets a session_id and last_send_attempt (if available) for a
combination of validation metadata

Expand All @@ -1171,15 +1187,7 @@ async def get_threepid_validation_session(
perform no filtering

Returns:
A dict containing the following:
* address - address of the 3pid
* medium - medium of the 3pid
* client_secret - a secret provided by the client for this validation session
* session_id - ID of the validation session
* send_attempt - a number serving to dedupe send attempts for this session
* validated_at - timestamp of when this session was validated if so

Otherwise None if a validation session is not found
A ThreepidValidationSession or None if a validation session is not found
"""
if not client_secret:
raise SynapseError(
Expand All @@ -1198,7 +1206,7 @@ async def get_threepid_validation_session(

def get_threepid_validation_session_txn(
txn: LoggingTransaction,
) -> Optional[Dict[str, Any]]:
) -> Optional[ThreepidValidationSession]:
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
Expand All @@ -1213,11 +1221,18 @@ def get_threepid_validation_session_txn(
sql += " LIMIT 1"

txn.execute(sql, list(keyvalues.values()))
rows = self.db_pool.cursor_to_dict(txn)
if not rows:
row = txn.fetchone()
if not row:
return None

return rows[0]
return ThreepidValidationSession(
address=row[0],
session_id=row[1],
medium=row[2],
client_secret=row[3],
last_send_attempt=row[4],
validated_at=row[5],
)

return await self.db_pool.runInteraction(
"get_threepid_validation_session", get_threepid_validation_session_txn
Expand Down
Loading