From 57ba780bbd27fdffbb6fdf6c4044393f35cb4e16 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 15 Oct 2021 17:31:51 +0100 Subject: [PATCH] Make sydent.validators pass `mypy --strict` (#425) * Bump phonenumbers so we can use its type stubs * Use snake_case instead of lowerCamelCase * Don't int(...) an int --- changelog.d/425.misc | 1 + pyproject.toml | 2 +- setup.py | 2 +- sydent/db/valsession.py | 59 ++++++++++++++++------------ sydent/validators/__init__.py | 47 ++++++++++------------ sydent/validators/common.py | 18 +++++---- sydent/validators/emailvalidator.py | 32 +++++++++------ sydent/validators/msisdnvalidator.py | 20 +++++----- 8 files changed, 97 insertions(+), 84 deletions(-) create mode 100644 changelog.d/425.misc diff --git a/changelog.d/425.misc b/changelog.d/425.misc new file mode 100644 index 00000000..b54bdfde --- /dev/null +++ b/changelog.d/425.misc @@ -0,0 +1 @@ +Add type hints so `sydent.validators` passes `mypy --strict`. \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 53b96d40..d63901ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ files = [ "sydent/db", "sydent/users", "sydent/util", + "sydent/validators", # TODO the rest of CI checks these---mypy ought to too. # "tests", # "matrix_is_test", @@ -66,7 +67,6 @@ module = [ "nacl.*", "netaddr", "prometheus_client", - "phonenumbers", "sentry_sdk", "signedjson.*", "sortedcontainers", diff --git a/setup.py b/setup.py index 2d1a7b28..0e292e35 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def read(fname): "Twisted>=18.4.0", # twisted warns about about the absence of this "service_identity>=1.0.0", - "phonenumbers", + "phonenumbers>=8.12.32", "pyopenssl", "attrs>=19.1.0", "netaddr>=0.7.0", diff --git a/sydent/db/valsession.py b/sydent/db/valsession.py index 87d0d2cb..96ed0f00 100644 --- a/sydent/db/valsession.py +++ b/sydent/db/valsession.py @@ -13,15 +13,17 @@ # limitations under the License. from random import SystemRandom -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Tuple import sydent.util.tokenutils from sydent.util import time_msec from sydent.validators import ( + THREEPID_SESSION_VALID_LIFETIME_MS, IncorrectClientSecretException, InvalidSessionIdException, SessionExpiredException, SessionNotValidatedException, + TokenInfo, ValidationSession, ) @@ -36,7 +38,7 @@ def __init__(self, syd: "Sydent") -> None: def getOrCreateTokenSession( self, medium: str, address: str, clientSecret: str - ) -> ValidationSession: + ) -> Tuple[ValidationSession, TokenInfo]: """ Retrieves the validation session for a given medium, address and client secret, or creates one if none was found. @@ -56,13 +58,16 @@ def getOrCreateTokenSession( "where s.medium = ? and s.address = ? and s.clientSecret = ? and t.validationSession = s.id", (medium, address, clientSecret), ) - row = cur.fetchone() + row: Optional[ + Tuple[int, str, str, str, Optional[int], int, str, int] + ] = cur.fetchone() if row: - s = ValidationSession( - row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7] + session = ValidationSession( + row[0], row[1], row[2], row[3], bool(row[4]), row[5] ) - return s + token_info = TokenInfo(row[6], row[7]) + return session, token_info sid = self.addValSession( medium, address, clientSecret, time_msec(), commit=False @@ -76,10 +81,16 @@ def getOrCreateTokenSession( ) self.sydent.db.commit() - s = ValidationSession( - sid, medium, address, clientSecret, False, time_msec(), tokenString, -1 + session = ValidationSession( + sid, + medium, + address, + clientSecret, + False, + time_msec(), ) - return s + token_info = TokenInfo(tokenString, -1) + return session, token_info def addValSession( self, @@ -178,16 +189,16 @@ def getSessionById(self, sid: int) -> Optional[ValidationSession]: + "threepid_validation_sessions where id = ?", (sid,), ) - row = cur.fetchone() + row: Optional[Tuple[int, str, str, str, Optional[int], int]] = cur.fetchone() if not row: return None - return ValidationSession( - row[0], row[1], row[2], row[3], row[4], row[5], None, None - ) + return ValidationSession(row[0], row[1], row[2], row[3], bool(row[4]), row[5]) - def getTokenSessionById(self, sid: int) -> Optional[ValidationSession]: + def getTokenSessionById( + self, sid: int + ) -> Optional[Tuple[ValidationSession, TokenInfo]]: """ Retrieves a validation session using the session's ID. @@ -203,23 +214,23 @@ def getTokenSessionById(self, sid: int) -> Optional[ValidationSession]: "where s.id = ? and t.validationSession = s.id", (sid,), ) + row: Optional[Tuple[int, str, str, str, Optional[int], int, str, int]] row = cur.fetchone() if row: - s = ValidationSession( - row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7] - ) - return s + s = ValidationSession(row[0], row[1], row[2], row[3], bool(row[4]), row[5]) + t = TokenInfo(row[6], row[7]) + return s, t return None - def getValidatedSession(self, sid: int, clientSecret: str) -> ValidationSession: + def getValidatedSession(self, sid: int, client_secret: str) -> ValidationSession: """ Retrieve a validated and still-valid session whose client secret matches the one passed in. :param sid: The ID of the session to retrieve. - :param clientSecret: A client secret to check against the one retrieved from + :param client_secret: A client secret to check against the one retrieved from the database. :return: The retrieved session. @@ -236,10 +247,10 @@ def getValidatedSession(self, sid: int, clientSecret: str) -> ValidationSession: if not s: raise InvalidSessionIdException() - if not s.clientSecret == clientSecret: + if not s.client_secret == client_secret: raise IncorrectClientSecretException() - if s.mtime + ValidationSession.THREEPID_SESSION_VALID_LIFETIME_MS < time_msec(): + if s.mtime + THREEPID_SESSION_VALID_LIFETIME_MS < time_msec(): raise SessionExpiredException() if not s.validated: @@ -252,9 +263,7 @@ def deleteOldSessions(self) -> None: cur = self.sydent.db.cursor() - delete_before_ts = ( - time_msec() - 5 * ValidationSession.THREEPID_SESSION_VALID_LIFETIME_MS - ) + delete_before_ts = time_msec() - 5 * THREEPID_SESSION_VALID_LIFETIME_MS sql = """ DELETE FROM threepid_validation_sessions diff --git a/sydent/validators/__init__.py b/sydent/validators/__init__.py index 9ecace07..15c434f0 100644 --- a/sydent/validators/__init__.py +++ b/sydent/validators/__init__.py @@ -11,35 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +import attr +# how long a user can wait before validating a session after starting it +THREEPID_SESSION_VALIDATION_TIMEOUT_MS = 24 * 60 * 60 * 1000 + +# how long we keep sessions for after they've been validated +THREEPID_SESSION_VALID_LIFETIME_MS = 24 * 60 * 60 * 1000 + + +@attr.s(frozen=True, slots=True, auto_attribs=True) class ValidationSession: - # how long a user can wait before validating a session after starting it - THREEPID_SESSION_VALIDATION_TIMEOUT_MS = 24 * 60 * 60 * 1000 - - # how long we keep sessions for after they've been validated - THREEPID_SESSION_VALID_LIFETIME_MS = 24 * 60 * 60 * 1000 - - def __init__( - self, - _id: int, - _medium: str, - _address: str, - _clientSecret: str, - _validated: int, # bool, but sqlite has no bool type - _mtime: int, - _token: Optional[str], - _sendAttemptNumber: Optional[int], - ): - self.id = _id - self.medium = _medium - self.address = _address - self.clientSecret = _clientSecret - self.validated = _validated - self.mtime = _mtime - self.token = _token - self.sendAttemptNumber = _sendAttemptNumber + id: int + medium: str + address: str + client_secret: str + validated: bool + mtime: int + + +@attr.s(frozen=True, slots=True, auto_attribs=True) +class TokenInfo: + token: str + send_attempt_number: int class IncorrectClientSecretException(Exception): diff --git a/sydent/validators/common.py b/sydent/validators/common.py index 8a049555..34c73c2d 100644 --- a/sydent/validators/common.py +++ b/sydent/validators/common.py @@ -4,11 +4,11 @@ from sydent.db.valsession import ThreePidValSessionStore from sydent.util import time_msec from sydent.validators import ( + THREEPID_SESSION_VALIDATION_TIMEOUT_MS, IncorrectClientSecretException, IncorrectSessionTokenException, InvalidSessionIdException, SessionExpiredException, - ValidationSession, ) if TYPE_CHECKING: @@ -39,16 +39,18 @@ def validateSessionWithToken( :raise IncorrectSessionTokenException: The provided token is incorrect """ valSessionStore = ThreePidValSessionStore(sydent) - s = valSessionStore.getTokenSessionById(sid) - if not s: + result = valSessionStore.getTokenSessionById(sid) + if not result: logger.info("Session ID %s not found", sid) raise InvalidSessionIdException() - if not clientSecret == s.clientSecret: + session, token_info = result + + if not clientSecret == session.client_secret: logger.info("Incorrect client secret", sid) raise IncorrectClientSecretException() - if s.mtime + ValidationSession.THREEPID_SESSION_VALIDATION_TIMEOUT_MS < time_msec(): + if session.mtime + THREEPID_SESSION_VALIDATION_TIMEOUT_MS < time_msec(): logger.info("Session expired") raise SessionExpiredException() @@ -56,9 +58,9 @@ def validateSessionWithToken( # if tokenObj.validated and clientSecret == tokenObj.clientSecret: # return True - if s.token == token: - logger.info("Setting session %s as validated", s.id) - valSessionStore.setValidated(s.id, True) + if token_info.token == token: + logger.info("Setting session %s as validated", session.id) + valSessionStore.setValidated(session.id, True) return {"success": True} else: diff --git a/sydent/validators/emailvalidator.py b/sydent/validators/emailvalidator.py index 5b057a70..8d29e95a 100644 --- a/sydent/validators/emailvalidator.py +++ b/sydent/validators/emailvalidator.py @@ -23,7 +23,6 @@ if TYPE_CHECKING: from sydent.sydent import Sydent - from sydent.validators import ValidationSession logger = logging.getLogger(__name__) @@ -57,7 +56,7 @@ def requestToken( """ valSessionStore = ThreePidValSessionStore(self.sydent) - valSession = valSessionStore.getOrCreateTokenSession( + valSession, token_info = valSessionStore.getOrCreateTokenSession( medium="email", address=emailAddress, clientSecret=clientSecret ) @@ -72,11 +71,11 @@ def requestToken( else: templateFile = self.sydent.config.email.template - if int(valSession.sendAttemptNumber) >= int(sendAttempt): + if token_info.send_attempt_number >= sendAttempt: logger.info( "Not mailing code because current send attempt (%d) is not less than given send attempt (%s)", - int(sendAttempt), - int(valSession.sendAttemptNumber), + sendAttempt, + token_info.send_attempt_number, ) return valSession.id @@ -84,12 +83,14 @@ def requestToken( substitutions = { "ipaddress": ipstring, - "link": self.makeValidateLink(valSession, clientSecret, nextLink), - "token": valSession.token, + "link": self.makeValidateLink( + valSession.id, token_info.token, clientSecret, nextLink + ), + "token": token_info.token, } logger.info( "Attempting to mail code %s (nextLink: %s) to %s", - valSession.token, + token_info.token, nextLink, emailAddress, ) @@ -100,12 +101,17 @@ def requestToken( return valSession.id def makeValidateLink( - self, valSession: "ValidationSession", clientSecret: str, nextLink: str + self, + session_id: int, + token: str, + clientSecret: str, + nextLink: str, ) -> str: """ Creates a validation link that can be sent via email to the user. - :param valSession: The current validation session. + :param session_id: The current validation session's ID. + :param token: The token to make a link for. :param clientSecret: The client secret to include in the link. :param nextLink: The link to redirect the user to once they have completed the validation. @@ -115,9 +121,9 @@ def makeValidateLink( base = self.sydent.config.http.server_http_url_base link = "%s/_matrix/identity/api/v1/validate/email/submitToken?token=%s&client_secret=%s&sid=%d" % ( base, - urllib.parse.quote(valSession.token), + urllib.parse.quote(token), urllib.parse.quote(clientSecret), - valSession.id, + session_id, ) if nextLink: # manipulate the nextLink to add the sid, because @@ -127,7 +133,7 @@ def makeValidateLink( nextLink += "&" else: nextLink += "?" - nextLink += "sid=" + urllib.parse.quote(str(valSession.id)) + nextLink += "sid=" + urllib.parse.quote(str(session_id)) link += "&nextLink=%s" % (urllib.parse.quote(nextLink)) return link diff --git a/sydent/validators/msisdnvalidator.py b/sydent/validators/msisdnvalidator.py index 541d2c31..7dde5ffe 100644 --- a/sydent/validators/msisdnvalidator.py +++ b/sydent/validators/msisdnvalidator.py @@ -16,7 +16,7 @@ import logging from typing import TYPE_CHECKING, Dict, Optional -import phonenumbers # type: ignore +import phonenumbers from sydent.db.valsession import ThreePidValSessionStore from sydent.sms.openmarket import OpenMarketSMS @@ -42,7 +42,7 @@ def requestToken( self, phoneNumber: phonenumbers.PhoneNumber, clientSecret: str, - sendAttempt: int, + send_attempt: int, brand: Optional[str] = None, ) -> int: """ @@ -51,7 +51,7 @@ def requestToken( :param phoneNumber: The phone number to send the email to. :param clientSecret: The client secret to use. - :param sendAttempt: The current send attempt. + :param send_attempt: The current send attempt. :param brand: A hint at a brand from the request. :return: The ID of the session created (or of the existing one if any) @@ -67,17 +67,17 @@ def requestToken( phoneNumber, phonenumbers.PhoneNumberFormat.E164 )[1:] - valSession = valSessionStore.getOrCreateTokenSession( + valSession, token_info = valSessionStore.getOrCreateTokenSession( medium="msisdn", address=msisdn, clientSecret=clientSecret ) valSessionStore.setMtime(valSession.id, time_msec()) - if int(valSession.sendAttemptNumber) >= int(sendAttempt): + if token_info.send_attempt_number >= send_attempt: logger.info( "Not texting code because current send attempt (%d) is not less than given send attempt (%s)", - int(sendAttempt), - int(valSession.sendAttemptNumber), + send_attempt, + token_info.send_attempt_number, ) return valSession.id @@ -86,17 +86,17 @@ def requestToken( logger.info( "Attempting to text code %s to %s (country %d) with originator %s", - valSession.token, + token_info.token, msisdn, phoneNumber.country_code, originator, ) - smsBody = smsBodyTemplate.format(token=valSession.token) + smsBody = smsBodyTemplate.format(token=token_info.token) self.omSms.sendTextSMS(smsBody, msisdn, originator) - valSessionStore.setSendAttemptNumber(valSession.id, sendAttempt) + valSessionStore.setSendAttemptNumber(valSession.id, send_attempt) return valSession.id