Skip to content

Commit

Permalink
Additional type hints in sydent.db (#426)
Browse files Browse the repository at this point in the history
We have to manually annotate return values of fetchone(), fetchmany() and fetchall(). (Unless we used something like sqlalchemy, hint hint)

Also remove a bytes -> unicode conversion. sqlite varchar is always `str` in python 3 so this is now redundant.
  • Loading branch information
David Robertson committed Oct 15, 2021
1 parent ba9ec24 commit 1280a2b
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 40 deletions.
1 change: 1 addition & 0 deletions changelog.d/426.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Type annotate the result of reading from the db in `sydent.db`.
4 changes: 2 additions & 2 deletions sydent/db/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Optional, Tuple, cast

from sydent.users.accounts import Account

Expand All @@ -39,7 +39,7 @@ def getAccountByToken(self, token: str) -> Optional[Account]:
(token,),
)

row = res.fetchone()
row: Optional[Tuple[str, int, Optional[str]]] = res.fetchone()
if row is None:
return None

Expand Down
19 changes: 7 additions & 12 deletions sydent/db/hashing_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# Actions on the hashing_metadata table which is defined in the migration process in
# sqlitedb.py
from sqlite3 import Cursor
from typing import TYPE_CHECKING, Callable, Optional, Tuple
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple

from typing_extensions import Literal

if TYPE_CHECKING:
from sydent.sydent import Sydent
Expand Down Expand Up @@ -87,7 +89,7 @@ def _rehash_threepids(
cur: Cursor,
hashing_function: Callable[[str], str],
pepper: str,
table: str,
table: Literal["local_threepid_associations", "global_threepid_associations"],
) -> None:
"""Rehash 3PIDs of a given table using a given hashing_function and pepper
Expand All @@ -96,24 +98,17 @@ def _rehash_threepids(
the made changes to the database.
:param cur: Database cursor
:type cur:
:param hashing_function: A function with single input and output strings
:type hashing_function func(str) -> str
:param pepper: A pepper to append to the end of the 3PID (after a space) before hashing
:type pepper: str
:param table: The database table to perform the rehashing on
:type table: str
"""

# Get count of all 3PID records
# Medium/address combos are marked as UNIQUE in the database
sql = "SELECT COUNT(*) FROM %s" % table
res = cur.execute(sql)
row_count = res.fetchone()
row_count = row_count[0]
row: Tuple[int] = res.fetchone()
row_count = row[0]

# Iterate through each medium, address combo, hash it,
# and store in the db
Expand All @@ -126,7 +121,7 @@ def _rehash_threepids(
count,
)
res = cur.execute(sql)
rows = res.fetchall()
rows: List[Tuple[str, str]] = res.fetchall()

for medium, address in rows:
# Skip broken db entry
Expand Down
33 changes: 14 additions & 19 deletions sydent/db/invite_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, List, Optional, Tuple, cast

from typing_extensions import TypedDict

if TYPE_CHECKING:
from sydent.sydent import Sydent


class PendingInviteTokens(TypedDict):
medium: str
address: str
room_id: str
sender: str
token: str


class JoinTokenStore:
def __init__(self, sydent: "Sydent") -> None:
self.sydent = sydent
Expand Down Expand Up @@ -46,9 +56,7 @@ def storeToken(
)
self.sydent.db.commit()

def getTokens(
self, medium: str, address: str
) -> List[Dict[str, Union[str, Dict[str, str]]]]:
def getTokens(self, medium: str, address: str) -> List[PendingInviteTokens]:
"""
Retrieves the pending invites tokens for this 3PID that haven't been delivered
yet.
Expand All @@ -69,25 +77,12 @@ def getTokens(
address,
),
)
rows = res.fetchall()
rows: List[Tuple[str, str, str, str, str]] = res.fetchall()

ret = []
ret: List[PendingInviteTokens] = []

for row in rows:
medium, address, roomId, sender, token = row

# Ensure we're dealing with unicode.
if isinstance(medium, bytes):
medium = medium.decode("UTF-8")
if isinstance(address, bytes):
address = address.decode("UTF-8")
if isinstance(roomId, bytes):
roomId = roomId.decode("UTF-8")
if isinstance(sender, bytes):
sender = sender.decode("UTF-8")
if isinstance(token, bytes):
token = token.decode("UTF-8")

ret.append(
{
"medium": medium,
Expand Down
24 changes: 22 additions & 2 deletions sydent/db/threepid_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ def getAssociationsAfterId(
maxId = None

assocs = {}
row: Tuple[
int,
str,
str,
Optional[str],
Optional[str],
Optional[int],
Optional[int],
Optional[int],
]
for row in res.fetchall():
assoc = ThreepidAssociation(
row[1], row[2], row[3], row[4], row[5], row[6], row[7]
Expand Down Expand Up @@ -147,7 +157,7 @@ def removeAssociation(self, threepid: Dict[str, str], mxid: str) -> None:
"WHERE medium = ? AND address = ? AND mxid = ?",
(threepid["medium"], threepid["address"], mxid),
)
row = cur.fetchone()
row: Tuple[int] = cur.fetchone()
if row[0] > 0:
ts = time_msec()
cur.execute(
Expand Down Expand Up @@ -233,7 +243,7 @@ def getMxid(self, medium: str, normalised_address: str) -> Optional[str]:
(medium, normalised_address, time_msec(), time_msec()),
)

row: Tuple[str] = res.fetchone()
row: Tuple[Optional[str]] = res.fetchone()

if not row:
return None
Expand Down Expand Up @@ -281,6 +291,7 @@ def getMxids(

results = []
current = None
row: Tuple[str, str, int, str]
for row in res.fetchall():
# only use the most recent entry for each
# threepid (they're sorted by ts)
Expand Down Expand Up @@ -429,6 +440,15 @@ def retrieveMxidsForHashes(self, addresses: List[str]) -> Dict[str, str]:
# Place the results from the query into a dictionary
# Results are sorted from oldest to newest, so if there are multiple mxid's for
# the same lookup hash, only the newest mapping will be returned

# Type safety: lookup_hash is a nullable string in
# global_threepid_associations. But it must be equal to a lookup_hash
# in the temporary table thanks to the join condition.
# The temporary table gets hashes from the `addresses` argument,
# which is a list of (non-None) strings.
# So lookup_hash really is a str.
lookup_hash: str
mxid: str
for lookup_hash, mxid in res.fetchall():
results[lookup_hash] = mxid

Expand Down
10 changes: 6 additions & 4 deletions sydent/threepid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def __init__(
medium: str,
address: str,
lookup_hash: Optional[str],
mxid: str,
ts: int,
not_before: int,
not_after: int,
# Note: the next four fields were made optional in schema version 2.
# See sydent.db.sqlitedb.SqliteDatabase._upgradeSchema
mxid: Optional[str],
ts: Optional[int],
not_before: Optional[int],
not_after: Optional[int],
):
"""
:param medium: The medium of the 3pid (eg. email)
Expand Down
5 changes: 4 additions & 1 deletion sydent/users/accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional


class Account:
def __init__(self, user_id: str, creation_ts: int, consent_version: str) -> None:
def __init__(
self, user_id: str, creation_ts: int, consent_version: Optional[str]
) -> None:
"""
:param user_id: The Matrix user ID for the account.
:param creation_ts: The timestamp in milliseconds of the account's creation.
Expand Down

0 comments on commit 1280a2b

Please sign in to comment.