Skip to content

Commit

Permalink
Merge pull request #149 from jschlyter/unpack_check_nbf_exp
Browse files Browse the repository at this point in the history
JWT check nbf/exp
  • Loading branch information
rohe committed Jul 31, 2023
2 parents 361ecaa + 9ff84a2 commit 7ec089f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 16 deletions.
41 changes: 31 additions & 10 deletions src/cryptojwt/jwt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Basic JSON Web Token implementation."""
import json
import logging
import time
import uuid
from datetime import datetime
from datetime import timezone
from json import JSONDecodeError

from .exception import HeaderError
Expand All @@ -28,9 +27,7 @@ def utc_time_sans_frac():
:return: A number of seconds
"""

now_timestampt = int(datetime.now(timezone.utc).timestamp())
return now_timestampt
return int(time.time())


def pick_key(keys, use, alg="", key_type="", kid=""):
Expand Down Expand Up @@ -95,6 +92,7 @@ def __init__(
allowed_sign_algs=None,
allowed_enc_algs=None,
allowed_enc_encs=None,
allowed_max_lifetime=None,
zip="",
):
self.key_jar = key_jar # KeyJar instance
Expand All @@ -115,6 +113,7 @@ def __init__(
self.allowed_sign_algs = allowed_sign_algs
self.allowed_enc_algs = allowed_enc_algs
self.allowed_enc_encs = allowed_enc_encs
self.allowed_max_lifetime = allowed_max_lifetime
self.zip = zip

def receiver_keys(self, recv, use):
Expand Down Expand Up @@ -176,13 +175,13 @@ def put_together_aud(recv, aud=None):

return _aud

def pack_init(self, recv, aud):
def pack_init(self, recv, aud, iat=None):
"""
Gather initial information for the payload.
:return: A dictionary with claims and values
"""
argv = {"iss": self.iss, "iat": utc_time_sans_frac()}
argv = {"iss": self.iss, "iat": iat or utc_time_sans_frac()}
if self.lifetime:
argv["exp"] = argv["iat"] + self.lifetime

Expand All @@ -207,7 +206,7 @@ def pack_key(self, issuer_id="", kid=""):

return keys[0] # Might be more then one if kid == ''

def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, **kwargs):
def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, iat=None, **kwargs):
"""
:param payload: Information to be carried as payload in the JWT
Expand All @@ -216,13 +215,14 @@ def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, **kwargs):
:param recv: The intended immediate receiver
:param aud: Intended audience for this JWS/JWE, not expected to
contain the recipient.
:param iat: Override issued at (default current timestamp)
:param kwargs: Extra keyword arguments
:return: A signed or signed and encrypted Json Web Token
"""
_args = {}
if payload is not None:
_args.update(payload)
_args.update(self.pack_init(recv, aud))
_args.update(self.pack_init(recv, aud, iat))

try:
_encrypt = kwargs["encrypt"]
Expand Down Expand Up @@ -304,11 +304,12 @@ def verify_profile(msg_cls, info, **kwargs):
raise VerificationError()
return _msg

def unpack(self, token):
def unpack(self, token, timestamp=None):
"""
Unpack a received signed or signed and encrypted Json Web Token
:param token: The Json Web Token
:param timestamp: Time for evaluation (default now)
:return: If decryption and signature verification work the payload
will be returned as a Message instance if possible.
"""
Expand Down Expand Up @@ -378,6 +379,26 @@ def unpack(self, token):
except KeyError:
_msg_cls = None

timestamp = timestamp or utc_time_sans_frac()

if "nbf" in _info:
nbf = int(_info["nbf"])
if timestamp < nbf - self.skew:
raise VerificationError("Token not yet valid")

if "exp" in _info:
exp = int(_info["exp"])
if timestamp >= exp + self.skew:
raise VerificationError("Token expired")
else:
exp = None

if "iat" in _info:
iat = int(_info["iat"])
if self.allowed_max_lifetime and exp:
if abs(exp - iat) > self.allowed_max_lifetime:
raise VerificationError("Token lifetime exceeded")

if _msg_cls:
vp_args = {"skew": self.skew}
if self.iss:
Expand Down
81 changes: 75 additions & 6 deletions tests/test_09_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from cryptojwt.exception import IssuerNotFound
from cryptojwt.jws.exception import NoSuitableSigningKeys
from cryptojwt.jwt import JWT
from cryptojwt.jwt import VerificationError
from cryptojwt.jwt import pick_key
from cryptojwt.jwt import utc_time_sans_frac
from cryptojwt.key_bundle import KeyBundle
from cryptojwt.key_jar import KeyJar
from cryptojwt.key_jar import init_key_jar
Expand Down Expand Up @@ -81,15 +83,82 @@ def test_jwt_pack_and_unpack():
assert set(info.keys()) == {"iat", "iss", "sub"}


def test_jwt_pack_and_unpack_unknown_issuer():
def test_jwt_pack_and_unpack_valid():
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256")
t = utc_time_sans_frac()
payload = {"sub": "sub", "nbf": t, "exp": t + 3600}
_jwt = alice.pack(payload=payload)

bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"])
info = bob.unpack(_jwt)

assert set(info.keys()) == {"iat", "iss", "sub", "nbf", "exp"}


def test_jwt_pack_and_unpack_not_yet_valid():
lifetime = 3600
skew = 15
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime)
timestamp = utc_time_sans_frac()
payload = {"sub": "sub", "nbf": timestamp}
_jwt = alice.pack(payload=payload)

bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], skew=skew)
_ = bob.unpack(_jwt, timestamp=timestamp - skew)
with pytest.raises(VerificationError):
_ = bob.unpack(_jwt, timestamp=timestamp - skew - 1)


def test_jwt_pack_and_unpack_expired():
lifetime = 3600
skew = 15
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime)
payload = {"sub": "sub"}
_jwt = alice.pack(payload=payload)

kj = KeyJar()
bob = JWT(key_jar=kj, iss=BOB, allowed_sign_algs=["RS256"])
with pytest.raises(IssuerNotFound):
info = bob.unpack(_jwt)
bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], skew=skew)
iat = bob.unpack(_jwt)["iat"]
_ = bob.unpack(_jwt, timestamp=iat + lifetime + skew - 1)
with pytest.raises(VerificationError):
_ = bob.unpack(_jwt, timestamp=iat + lifetime + skew)


def test_jwt_pack_and_unpack_max_lifetime_exceeded():
lifetime = 3600
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime)
payload = {"sub": "sub"}
_jwt = alice.pack(payload=payload)

bob = JWT(
key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], allowed_max_lifetime=lifetime - 1
)
with pytest.raises(VerificationError):
_ = bob.unpack(_jwt)


def test_jwt_pack_and_unpack_max_lifetime_exceeded():
lifetime = 3600
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime)
payload = {"sub": "sub"}
_jwt = alice.pack(payload=payload)

bob = JWT(
key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], allowed_max_lifetime=lifetime - 1
)
with pytest.raises(VerificationError):
_ = bob.unpack(_jwt)


def test_jwt_pack_and_unpack_timestamp():
lifetime = 3600
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime)
payload = {"sub": "sub"}
_jwt = alice.pack(payload=payload, iat=42)

bob = JWT(key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"])
_ = bob.unpack(_jwt, timestamp=42)
with pytest.raises(VerificationError):
_ = bob.unpack(_jwt)


def test_jwt_pack_and_unpack_unknown_key():
Expand Down Expand Up @@ -261,4 +330,4 @@ def test_eddsa_jwt():
kj = KeyJar()
kj.add_kb(ISSUER, KeyBundle(JWKS_DICT))
jwt = JWT(key_jar=kj)
_ = jwt.unpack(JWT_TEST)
_ = jwt.unpack(JWT_TEST, timestamp=1655278809)

0 comments on commit 7ec089f

Please sign in to comment.