diff --git a/changelog.d/426.misc b/changelog.d/426.misc new file mode 100644 index 00000000..a49b345d --- /dev/null +++ b/changelog.d/426.misc @@ -0,0 +1 @@ +Type annotate the result of reading from the db in `sydent.db`. \ No newline at end of file diff --git a/sydent/db/accounts.py b/sydent/db/accounts.py index 5f6a82d6..1581c5d9 100644 --- a/sydent/db/accounts.py +++ b/sydent/db/accounts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Optional, Tuple, cast from sydent.users.accounts import Account @@ -39,7 +39,7 @@ def getAccountByToken(self, token: str) -> Optional[Account]: (token,), ) - row = res.fetchone() + row: Optional[Tuple[str, int, Optional[str]]] = res.fetchone() if row is None: return None diff --git a/sydent/db/hashing_metadata.py b/sydent/db/hashing_metadata.py index 782d6621..f852c409 100644 --- a/sydent/db/hashing_metadata.py +++ b/sydent/db/hashing_metadata.py @@ -15,7 +15,9 @@ # Actions on the hashing_metadata table which is defined in the migration process in # sqlitedb.py from sqlite3 import Cursor -from typing import TYPE_CHECKING, Callable, Optional, Tuple +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple + +from typing_extensions import Literal if TYPE_CHECKING: from sydent.sydent import Sydent @@ -87,7 +89,7 @@ def _rehash_threepids( cur: Cursor, hashing_function: Callable[[str], str], pepper: str, - table: str, + table: Literal["local_threepid_associations", "global_threepid_associations"], ) -> None: """Rehash 3PIDs of a given table using a given hashing_function and pepper @@ -96,24 +98,17 @@ def _rehash_threepids( the made changes to the database. :param cur: Database cursor - :type cur: - :param hashing_function: A function with single input and output strings - :type hashing_function func(str) -> str - :param pepper: A pepper to append to the end of the 3PID (after a space) before hashing - :type pepper: str - :param table: The database table to perform the rehashing on - :type table: str """ # Get count of all 3PID records # Medium/address combos are marked as UNIQUE in the database sql = "SELECT COUNT(*) FROM %s" % table res = cur.execute(sql) - row_count = res.fetchone() - row_count = row_count[0] + row: Tuple[int] = res.fetchone() + row_count = row[0] # Iterate through each medium, address combo, hash it, # and store in the db @@ -126,7 +121,7 @@ def _rehash_threepids( count, ) res = cur.execute(sql) - rows = res.fetchall() + rows: List[Tuple[str, str]] = res.fetchall() for medium, address in rows: # Skip broken db entry diff --git a/sydent/db/invite_tokens.py b/sydent/db/invite_tokens.py index af028a01..6c304b2d 100644 --- a/sydent/db/invite_tokens.py +++ b/sydent/db/invite_tokens.py @@ -12,12 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, List, Optional, Tuple, cast + +from typing_extensions import TypedDict if TYPE_CHECKING: from sydent.sydent import Sydent +class PendingInviteTokens(TypedDict): + medium: str + address: str + room_id: str + sender: str + token: str + + class JoinTokenStore: def __init__(self, sydent: "Sydent") -> None: self.sydent = sydent @@ -46,9 +56,7 @@ def storeToken( ) self.sydent.db.commit() - def getTokens( - self, medium: str, address: str - ) -> List[Dict[str, Union[str, Dict[str, str]]]]: + def getTokens(self, medium: str, address: str) -> List[PendingInviteTokens]: """ Retrieves the pending invites tokens for this 3PID that haven't been delivered yet. @@ -69,25 +77,12 @@ def getTokens( address, ), ) - rows = res.fetchall() + rows: List[Tuple[str, str, str, str, str]] = res.fetchall() - ret = [] + ret: List[PendingInviteTokens] = [] for row in rows: medium, address, roomId, sender, token = row - - # Ensure we're dealing with unicode. - if isinstance(medium, bytes): - medium = medium.decode("UTF-8") - if isinstance(address, bytes): - address = address.decode("UTF-8") - if isinstance(roomId, bytes): - roomId = roomId.decode("UTF-8") - if isinstance(sender, bytes): - sender = sender.decode("UTF-8") - if isinstance(token, bytes): - token = token.decode("UTF-8") - ret.append( { "medium": medium, diff --git a/sydent/db/threepid_associations.py b/sydent/db/threepid_associations.py index eac36f1c..bd67b19f 100644 --- a/sydent/db/threepid_associations.py +++ b/sydent/db/threepid_associations.py @@ -90,6 +90,16 @@ def getAssociationsAfterId( maxId = None assocs = {} + row: Tuple[ + int, + str, + str, + Optional[str], + Optional[str], + Optional[int], + Optional[int], + Optional[int], + ] for row in res.fetchall(): assoc = ThreepidAssociation( row[1], row[2], row[3], row[4], row[5], row[6], row[7] @@ -147,7 +157,7 @@ def removeAssociation(self, threepid: Dict[str, str], mxid: str) -> None: "WHERE medium = ? AND address = ? AND mxid = ?", (threepid["medium"], threepid["address"], mxid), ) - row = cur.fetchone() + row: Tuple[int] = cur.fetchone() if row[0] > 0: ts = time_msec() cur.execute( @@ -233,7 +243,7 @@ def getMxid(self, medium: str, normalised_address: str) -> Optional[str]: (medium, normalised_address, time_msec(), time_msec()), ) - row: Tuple[str] = res.fetchone() + row: Tuple[Optional[str]] = res.fetchone() if not row: return None @@ -281,6 +291,7 @@ def getMxids( results = [] current = None + row: Tuple[str, str, int, str] for row in res.fetchall(): # only use the most recent entry for each # threepid (they're sorted by ts) @@ -429,6 +440,15 @@ def retrieveMxidsForHashes(self, addresses: List[str]) -> Dict[str, str]: # Place the results from the query into a dictionary # Results are sorted from oldest to newest, so if there are multiple mxid's for # the same lookup hash, only the newest mapping will be returned + + # Type safety: lookup_hash is a nullable string in + # global_threepid_associations. But it must be equal to a lookup_hash + # in the temporary table thanks to the join condition. + # The temporary table gets hashes from the `addresses` argument, + # which is a list of (non-None) strings. + # So lookup_hash really is a str. + lookup_hash: str + mxid: str for lookup_hash, mxid in res.fetchall(): results[lookup_hash] = mxid diff --git a/sydent/threepid/__init__.py b/sydent/threepid/__init__.py index 7f4086f1..fe30b95a 100644 --- a/sydent/threepid/__init__.py +++ b/sydent/threepid/__init__.py @@ -42,10 +42,12 @@ def __init__( medium: str, address: str, lookup_hash: Optional[str], - mxid: str, - ts: int, - not_before: int, - not_after: int, + # Note: the next four fields were made optional in schema version 2. + # See sydent.db.sqlitedb.SqliteDatabase._upgradeSchema + mxid: Optional[str], + ts: Optional[int], + not_before: Optional[int], + not_after: Optional[int], ): """ :param medium: The medium of the 3pid (eg. email) diff --git a/sydent/users/accounts.py b/sydent/users/accounts.py index 18f3fda8..9199bcb7 100644 --- a/sydent/users/accounts.py +++ b/sydent/users/accounts.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional class Account: - def __init__(self, user_id: str, creation_ts: int, consent_version: str) -> None: + def __init__( + self, user_id: str, creation_ts: int, consent_version: Optional[str] + ) -> None: """ :param user_id: The Matrix user ID for the account. :param creation_ts: The timestamp in milliseconds of the account's creation.