Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add some type hints to datastore (#12485)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel committed Apr 27, 2022
1 parent 63ba9ba commit b76f1a4
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 84 deletions.
1 change: 1 addition & 0 deletions changelog.d/12485.misc
@@ -0,0 +1 @@
Add some type hints to datastore.
21 changes: 15 additions & 6 deletions synapse/storage/databases/main/__init__.py
Expand Up @@ -15,12 +15,17 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, cast

from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
IdGenerator,
MultiWriterIdGenerator,
Expand Down Expand Up @@ -266,7 +271,9 @@ async def get_users_paginate(
A tuple of a list of mappings from user to information and a count of total users.
"""

def get_users_paginate_txn(txn):
def get_users_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
filters = []
args = [self.hs.config.server.server_name]

Expand Down Expand Up @@ -301,7 +308,7 @@ def get_users_paginate_txn(txn):
"""
sql = "SELECT COUNT(*) as total_users " + sql_base
txn.execute(sql, args)
count = txn.fetchone()[0]
count = cast(Tuple[int], txn.fetchone())[0]

sql = f"""
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
Expand Down Expand Up @@ -338,7 +345,9 @@ async def search_users(self, term: str) -> Optional[List[JsonDict]]:
)


def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
def check_database_before_upgrade(
cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
) -> None:
"""Called before upgrading an existing database to check that it is broadly sane
compared with the configuration.
"""
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/appservice.py
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast

from synapse.appservice import (
ApplicationService,
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_max_as_txn_id(txn: Cursor) -> int:
txn.execute(
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
)
return txn.fetchone()[0] # type: ignore
return cast(Tuple[int], txn.fetchone())[0]

self._as_txn_seq_gen = build_sequence_generator(
db_conn,
Expand Down
79 changes: 56 additions & 23 deletions synapse/storage/databases/main/deviceinbox.py
Expand Up @@ -14,7 +14,17 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)

from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
Expand Down Expand Up @@ -118,7 +128,13 @@ def __init__(
prefilled_cache=device_outbox_prefill,
)

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
) -> None:
if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
Expand All @@ -134,7 +150,7 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def get_to_device_stream_token(self):
def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()

async def get_messages_for_user_devices(
Expand Down Expand Up @@ -301,7 +317,9 @@ async def _get_device_messages(
if not user_ids_to_query:
return {}, to_stream_id

def get_device_messages_txn(txn: LoggingTransaction):
def get_device_messages_txn(
txn: LoggingTransaction,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
# Build a query to select messages from any of the given devices that
# are between the given stream id bounds.

Expand Down Expand Up @@ -428,7 +446,7 @@ async def delete_messages_for_device(
log_kv({"message": "No changes in cache since last check"})
return 0

def delete_messages_for_device_txn(txn):
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
Expand All @@ -455,15 +473,14 @@ def delete_messages_for_device_txn(txn):

@trace
async def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
) -> Tuple[List[dict], int]:
self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
) -> Tuple[List[JsonDict], int]:
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int|long): The last position of the device message stream
destination: The name of the remote server.
last_stream_id: The last position of the device message stream
that the server sent up to.
current_stream_id(int|long): The current position of the device
message stream.
current_stream_id: The current position of the device message stream.
Returns:
A list of messages for the device and where in the stream the messages got to.
"""
Expand All @@ -485,7 +502,9 @@ async def get_new_device_msgs_for_remote(
return [], last_stream_id

@trace
def get_new_messages_for_remote_destination_txn(txn):
def get_new_messages_for_remote_destination_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
Expand Down Expand Up @@ -527,7 +546,7 @@ async def delete_device_msgs_for_remote(
up_to_stream_id: Where to delete messages up to.
"""

def delete_messages_for_remote_destination_txn(txn):
def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
Expand Down Expand Up @@ -566,7 +585,9 @@ async def get_all_new_device_messages(
if last_id == current_id:
return [], current_id, False

def get_all_new_device_messages_txn(txn):
def get_all_new_device_messages_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
Expand Down Expand Up @@ -607,8 +628,8 @@ def get_all_new_device_messages_txn(txn):
@trace
async def add_messages_to_device_inbox(
self,
local_messages_by_user_then_device: dict,
remote_messages_by_destination: dict,
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
remote_messages_by_destination: Dict[str, JsonDict],
) -> int:
"""Used to send messages from this server.
Expand All @@ -624,7 +645,9 @@ async def add_messages_to_device_inbox(

assert self._can_write_to_device

def add_messages_txn(txn, now_ms, stream_id):
def add_messages_txn(
txn: LoggingTransaction, now_ms: int, stream_id: int
) -> None:
# Add the local messages directly to the local inbox.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
Expand Down Expand Up @@ -677,11 +700,16 @@ def add_messages_txn(txn, now_ms, stream_id):
return self._device_inbox_id_gen.get_current_token()

async def add_messages_from_remote_to_device_inbox(
self, origin: str, message_id: str, local_messages_by_user_then_device: dict
self,
origin: str,
message_id: str,
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> int:
assert self._can_write_to_device

def add_messages_txn(txn, now_ms, stream_id):
def add_messages_txn(
txn: LoggingTransaction, now_ms: int, stream_id: int
) -> None:
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
Expand Down Expand Up @@ -727,8 +755,11 @@ def add_messages_txn(txn, now_ms, stream_id):
return stream_id

def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
self,
txn: LoggingTransaction,
stream_id: int,
messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> None:
assert self._can_write_to_device

local_by_user_then_device = {}
Expand Down Expand Up @@ -840,8 +871,10 @@ def __init__(
self._remove_dead_devices_from_device_inbox,
)

async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
async def _background_drop_index_device_inbox(
self, progress: JsonDict, batch_size: int
) -> int:
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
Expand Down

0 comments on commit b76f1a4

Please sign in to comment.