Skip to content

Commit

Permalink
Make sydent.validators pass mypy --strict (#425)
Browse files Browse the repository at this point in the history
* Bump phonenumbers so we can use its type stubs
* Use snake_case instead of lowerCamelCase
* Don't int(...) an int
  • Loading branch information
David Robertson committed Oct 15, 2021
1 parent 1280a2b commit 57ba780
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 84 deletions.
1 change: 1 addition & 0 deletions changelog.d/425.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints so `sydent.validators` passes `mypy --strict`.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ files = [
"sydent/db",
"sydent/users",
"sydent/util",
"sydent/validators",
# TODO the rest of CI checks these---mypy ought to too.
# "tests",
# "matrix_is_test",
Expand All @@ -66,7 +67,6 @@ module = [
"nacl.*",
"netaddr",
"prometheus_client",
"phonenumbers",
"sentry_sdk",
"signedjson.*",
"sortedcontainers",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def read(fname):
"Twisted>=18.4.0",
# twisted warns about about the absence of this
"service_identity>=1.0.0",
"phonenumbers",
"phonenumbers>=8.12.32",
"pyopenssl",
"attrs>=19.1.0",
"netaddr>=0.7.0",
Expand Down
59 changes: 34 additions & 25 deletions sydent/db/valsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
# limitations under the License.

from random import SystemRandom
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Tuple

import sydent.util.tokenutils
from sydent.util import time_msec
from sydent.validators import (
THREEPID_SESSION_VALID_LIFETIME_MS,
IncorrectClientSecretException,
InvalidSessionIdException,
SessionExpiredException,
SessionNotValidatedException,
TokenInfo,
ValidationSession,
)

Expand All @@ -36,7 +38,7 @@ def __init__(self, syd: "Sydent") -> None:

def getOrCreateTokenSession(
self, medium: str, address: str, clientSecret: str
) -> ValidationSession:
) -> Tuple[ValidationSession, TokenInfo]:
"""
Retrieves the validation session for a given medium, address and client secret,
or creates one if none was found.
Expand All @@ -56,13 +58,16 @@ def getOrCreateTokenSession(
"where s.medium = ? and s.address = ? and s.clientSecret = ? and t.validationSession = s.id",
(medium, address, clientSecret),
)
row = cur.fetchone()
row: Optional[
Tuple[int, str, str, str, Optional[int], int, str, int]
] = cur.fetchone()

if row:
s = ValidationSession(
row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7]
session = ValidationSession(
row[0], row[1], row[2], row[3], bool(row[4]), row[5]
)
return s
token_info = TokenInfo(row[6], row[7])
return session, token_info

sid = self.addValSession(
medium, address, clientSecret, time_msec(), commit=False
Expand All @@ -76,10 +81,16 @@ def getOrCreateTokenSession(
)
self.sydent.db.commit()

s = ValidationSession(
sid, medium, address, clientSecret, False, time_msec(), tokenString, -1
session = ValidationSession(
sid,
medium,
address,
clientSecret,
False,
time_msec(),
)
return s
token_info = TokenInfo(tokenString, -1)
return session, token_info

def addValSession(
self,
Expand Down Expand Up @@ -178,16 +189,16 @@ def getSessionById(self, sid: int) -> Optional[ValidationSession]:
+ "threepid_validation_sessions where id = ?",
(sid,),
)
row = cur.fetchone()
row: Optional[Tuple[int, str, str, str, Optional[int], int]] = cur.fetchone()

if not row:
return None

return ValidationSession(
row[0], row[1], row[2], row[3], row[4], row[5], None, None
)
return ValidationSession(row[0], row[1], row[2], row[3], bool(row[4]), row[5])

def getTokenSessionById(self, sid: int) -> Optional[ValidationSession]:
def getTokenSessionById(
self, sid: int
) -> Optional[Tuple[ValidationSession, TokenInfo]]:
"""
Retrieves a validation session using the session's ID.
Expand All @@ -203,23 +214,23 @@ def getTokenSessionById(self, sid: int) -> Optional[ValidationSession]:
"where s.id = ? and t.validationSession = s.id",
(sid,),
)
row: Optional[Tuple[int, str, str, str, Optional[int], int, str, int]]
row = cur.fetchone()

if row:
s = ValidationSession(
row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7]
)
return s
s = ValidationSession(row[0], row[1], row[2], row[3], bool(row[4]), row[5])
t = TokenInfo(row[6], row[7])
return s, t

return None

def getValidatedSession(self, sid: int, clientSecret: str) -> ValidationSession:
def getValidatedSession(self, sid: int, client_secret: str) -> ValidationSession:
"""
Retrieve a validated and still-valid session whose client secret matches the
one passed in.
:param sid: The ID of the session to retrieve.
:param clientSecret: A client secret to check against the one retrieved from
:param client_secret: A client secret to check against the one retrieved from
the database.
:return: The retrieved session.
Expand All @@ -236,10 +247,10 @@ def getValidatedSession(self, sid: int, clientSecret: str) -> ValidationSession:
if not s:
raise InvalidSessionIdException()

if not s.clientSecret == clientSecret:
if not s.client_secret == client_secret:
raise IncorrectClientSecretException()

if s.mtime + ValidationSession.THREEPID_SESSION_VALID_LIFETIME_MS < time_msec():
if s.mtime + THREEPID_SESSION_VALID_LIFETIME_MS < time_msec():
raise SessionExpiredException()

if not s.validated:
Expand All @@ -252,9 +263,7 @@ def deleteOldSessions(self) -> None:

cur = self.sydent.db.cursor()

delete_before_ts = (
time_msec() - 5 * ValidationSession.THREEPID_SESSION_VALID_LIFETIME_MS
)
delete_before_ts = time_msec() - 5 * THREEPID_SESSION_VALID_LIFETIME_MS

sql = """
DELETE FROM threepid_validation_sessions
Expand Down
47 changes: 21 additions & 26 deletions sydent/validators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import attr

# how long a user can wait before validating a session after starting it
THREEPID_SESSION_VALIDATION_TIMEOUT_MS = 24 * 60 * 60 * 1000

# how long we keep sessions for after they've been validated
THREEPID_SESSION_VALID_LIFETIME_MS = 24 * 60 * 60 * 1000


@attr.s(frozen=True, slots=True, auto_attribs=True)
class ValidationSession:
# how long a user can wait before validating a session after starting it
THREEPID_SESSION_VALIDATION_TIMEOUT_MS = 24 * 60 * 60 * 1000

# how long we keep sessions for after they've been validated
THREEPID_SESSION_VALID_LIFETIME_MS = 24 * 60 * 60 * 1000

def __init__(
self,
_id: int,
_medium: str,
_address: str,
_clientSecret: str,
_validated: int, # bool, but sqlite has no bool type
_mtime: int,
_token: Optional[str],
_sendAttemptNumber: Optional[int],
):
self.id = _id
self.medium = _medium
self.address = _address
self.clientSecret = _clientSecret
self.validated = _validated
self.mtime = _mtime
self.token = _token
self.sendAttemptNumber = _sendAttemptNumber
id: int
medium: str
address: str
client_secret: str
validated: bool
mtime: int


@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenInfo:
token: str
send_attempt_number: int


class IncorrectClientSecretException(Exception):
Expand Down
18 changes: 10 additions & 8 deletions sydent/validators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from sydent.db.valsession import ThreePidValSessionStore
from sydent.util import time_msec
from sydent.validators import (
THREEPID_SESSION_VALIDATION_TIMEOUT_MS,
IncorrectClientSecretException,
IncorrectSessionTokenException,
InvalidSessionIdException,
SessionExpiredException,
ValidationSession,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -39,26 +39,28 @@ def validateSessionWithToken(
:raise IncorrectSessionTokenException: The provided token is incorrect
"""
valSessionStore = ThreePidValSessionStore(sydent)
s = valSessionStore.getTokenSessionById(sid)
if not s:
result = valSessionStore.getTokenSessionById(sid)
if not result:
logger.info("Session ID %s not found", sid)
raise InvalidSessionIdException()

if not clientSecret == s.clientSecret:
session, token_info = result

if not clientSecret == session.client_secret:
logger.info("Incorrect client secret", sid)
raise IncorrectClientSecretException()

if s.mtime + ValidationSession.THREEPID_SESSION_VALIDATION_TIMEOUT_MS < time_msec():
if session.mtime + THREEPID_SESSION_VALIDATION_TIMEOUT_MS < time_msec():
logger.info("Session expired")
raise SessionExpiredException()

# TODO once we can validate the token oob
# if tokenObj.validated and clientSecret == tokenObj.clientSecret:
# return True

if s.token == token:
logger.info("Setting session %s as validated", s.id)
valSessionStore.setValidated(s.id, True)
if token_info.token == token:
logger.info("Setting session %s as validated", session.id)
valSessionStore.setValidated(session.id, True)

return {"success": True}
else:
Expand Down
32 changes: 19 additions & 13 deletions sydent/validators/emailvalidator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

if TYPE_CHECKING:
from sydent.sydent import Sydent
from sydent.validators import ValidationSession

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,7 +56,7 @@ def requestToken(
"""
valSessionStore = ThreePidValSessionStore(self.sydent)

valSession = valSessionStore.getOrCreateTokenSession(
valSession, token_info = valSessionStore.getOrCreateTokenSession(
medium="email", address=emailAddress, clientSecret=clientSecret
)

Expand All @@ -72,24 +71,26 @@ def requestToken(
else:
templateFile = self.sydent.config.email.template

if int(valSession.sendAttemptNumber) >= int(sendAttempt):
if token_info.send_attempt_number >= sendAttempt:
logger.info(
"Not mailing code because current send attempt (%d) is not less than given send attempt (%s)",
int(sendAttempt),
int(valSession.sendAttemptNumber),
sendAttempt,
token_info.send_attempt_number,
)
return valSession.id

ipstring = ipaddress if ipaddress else "an unknown location"

substitutions = {
"ipaddress": ipstring,
"link": self.makeValidateLink(valSession, clientSecret, nextLink),
"token": valSession.token,
"link": self.makeValidateLink(
valSession.id, token_info.token, clientSecret, nextLink
),
"token": token_info.token,
}
logger.info(
"Attempting to mail code %s (nextLink: %s) to %s",
valSession.token,
token_info.token,
nextLink,
emailAddress,
)
Expand All @@ -100,12 +101,17 @@ def requestToken(
return valSession.id

def makeValidateLink(
self, valSession: "ValidationSession", clientSecret: str, nextLink: str
self,
session_id: int,
token: str,
clientSecret: str,
nextLink: str,
) -> str:
"""
Creates a validation link that can be sent via email to the user.
:param valSession: The current validation session.
:param session_id: The current validation session's ID.
:param token: The token to make a link for.
:param clientSecret: The client secret to include in the link.
:param nextLink: The link to redirect the user to once they have completed the
validation.
Expand All @@ -115,9 +121,9 @@ def makeValidateLink(
base = self.sydent.config.http.server_http_url_base
link = "%s/_matrix/identity/api/v1/validate/email/submitToken?token=%s&client_secret=%s&sid=%d" % (
base,
urllib.parse.quote(valSession.token),
urllib.parse.quote(token),
urllib.parse.quote(clientSecret),
valSession.id,
session_id,
)
if nextLink:
# manipulate the nextLink to add the sid, because
Expand All @@ -127,7 +133,7 @@ def makeValidateLink(
nextLink += "&"
else:
nextLink += "?"
nextLink += "sid=" + urllib.parse.quote(str(valSession.id))
nextLink += "sid=" + urllib.parse.quote(str(session_id))

link += "&nextLink=%s" % (urllib.parse.quote(nextLink))
return link
Expand Down

0 comments on commit 57ba780

Please sign in to comment.