Skip to content

Commit 774b140

Browse files
authored
Merge pull request #543 from tonial/tonial/use_pyjwt
Replace josepy with PyJWT
2 parents 2c2334f + c0fb88d commit 774b140

File tree

4 files changed

+202
-257
lines changed

4 files changed

+202
-257
lines changed

mozilla_django_oidc/auth.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
import base64
22
import hashlib
3-
import json
43
import logging
54

65
import inspect
6+
import jwt
77
import requests
88
from django.contrib.auth import get_user_model
99
from django.contrib.auth.backends import ModelBackend
1010
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation
1111
from django.urls import reverse
12-
from django.utils.encoding import force_bytes, smart_bytes, smart_str
12+
from django.utils.encoding import force_bytes, smart_str
1313
from django.utils.module_loading import import_string
14-
from josepy.b64 import b64decode
15-
from josepy.jwk import JWK
16-
from josepy.jws import JWS, Header
1714
from requests.auth import HTTPBasicAuth
1815
from requests.exceptions import HTTPError
1916

@@ -127,10 +124,10 @@ def update_user(self, user, claims):
127124

128125
def _verify_jws(self, payload, key):
129126
"""Verify the given JWS payload with the given key and return the payload"""
130-
jws = JWS.from_compact(payload)
127+
jws = jwt.get_unverified_header(payload)
131128

132129
try:
133-
alg = jws.signature.combined.alg.name
130+
alg = jws["alg"]
134131
except KeyError:
135132
msg = "No alg value found in header"
136133
raise SuspiciousOperation(msg)
@@ -142,21 +139,19 @@ def _verify_jws(self, payload, key):
142139
)
143140
raise SuspiciousOperation(msg)
144141

145-
if isinstance(key, str):
146-
# Use smart_bytes here since the key string comes from settings.
147-
jwk = JWK.load(smart_bytes(key))
148-
else:
149-
# The key is a json returned from the IDP JWKS endpoint.
150-
jwk = JWK.from_json(key)
151-
152-
if not jws.verify(jwk):
142+
try:
143+
# Maybe add a settings to enforce audiance validation
144+
return jwt.decode(payload, key, algorithms=alg, options={"verify_aud": False})
145+
except jwt.DecodeError:
153146
msg = "JWS token verification failed."
154147
raise SuspiciousOperation(msg)
155148

156-
return jws.payload
157-
158149
def retrieve_matching_jwk(self, token):
159-
"""Get the signing key by exploring the JWKS endpoint of the OP."""
150+
"""Get the signing key by exploring the JWKS endpoint of the OP.
151+
152+
Don't use jwt.PyJWKClient()get_signing_key_from_jwt() because it doesn't check
153+
the algorithm in case of multiple jwk with the same kid.
154+
"""
160155
response_jwks = requests.get(
161156
self.OIDC_OP_JWKS_ENDPOINT,
162157
verify=self.get_settings("OIDC_VERIFY_SSL", True),
@@ -167,32 +162,29 @@ def retrieve_matching_jwk(self, token):
167162
jwks = response_jwks.json()
168163

169164
# Compute the current header from the given token to find a match
170-
jws = JWS.from_compact(token)
171-
json_header = jws.signature.protected
172-
header = Header.json_loads(json_header)
165+
jws = jwt.get_unverified_header(token)
173166

174167
key = None
175168
for jwk in jwks["keys"]:
176169
if import_from_settings("OIDC_VERIFY_KID", True) and jwk[
177170
"kid"
178-
] != smart_str(header.kid):
171+
] != smart_str(jws["kid"]):
179172
continue
180-
if "alg" in jwk and jwk["alg"] != smart_str(header.alg):
173+
if "alg" in jwk and jwk["alg"] != smart_str(jws["alg"]):
181174
continue
182175
key = jwk
183176
if key is None:
184177
raise SuspiciousOperation("Could not find a valid JWKS.")
185-
return key
178+
return jwt.PyJWK(key)
186179

187180
def get_payload_data(self, token, key):
188181
"""Helper method to get the payload of the JWT token."""
189182
if self.get_settings("OIDC_ALLOW_UNSECURED_JWT", False):
190-
header, payload_data, signature = token.split(b".")
191-
header = json.loads(smart_str(b64decode(header)))
183+
header = jwt.get_unverified_header(token)
192184

193185
# If config allows unsecured JWTs check the header and return the decoded payload
194186
if "alg" in header and header["alg"] == "none":
195-
return b64decode(payload_data)
187+
return jwt.decode(token, options={"verify_signature": False})
196188

197189
# By default fallback to verify JWT signatures
198190
return self._verify_jws(token, key)
@@ -201,7 +193,6 @@ def verify_token(self, token, **kwargs):
201193
"""Validate the token signature."""
202194
nonce = kwargs.get("nonce")
203195

204-
token = force_bytes(token)
205196
if self.OIDC_RP_SIGN_ALGO.startswith("RS") or self.OIDC_RP_SIGN_ALGO.startswith(
206197
"ES"
207198
):
@@ -212,16 +203,7 @@ def verify_token(self, token, **kwargs):
212203
else:
213204
key = self.OIDC_RP_CLIENT_SECRET
214205

215-
payload_data = self.get_payload_data(token, key)
216-
217-
# The 'token' will always be a byte string since it's
218-
# the result of base64.urlsafe_b64decode().
219-
# The payload is always the result of base64.urlsafe_b64decode().
220-
# In Python 3 and 2, that's always a byte string.
221-
# In Python3.6, the json.loads() function can accept a byte string
222-
# as it will automagically decode it to a unicode string before
223-
# deserializing https://bugs.python.org/issue17909
224-
payload = json.loads(payload_data.decode("utf-8"))
206+
payload = self.get_payload_data(token, key)
225207
token_nonce = payload.get("nonce")
226208

227209
if self.get_settings("OIDC_USE_NONCE", True) and nonce != token_nonce:

mozilla_django_oidc/utils.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import logging
22
import time
33
import warnings
4+
from base64 import urlsafe_b64decode, urlsafe_b64encode
45
from hashlib import sha256
56
from urllib.request import parse_http_list, parse_keqv_list
67

7-
# Make it obvious that these aren't the usual base64 functions
8-
import josepy.b64
98
from django.conf import settings
109
from django.core.exceptions import ImproperlyConfigured
10+
from django.utils.encoding import force_bytes
1111

1212
LOGGER = logging.getLogger(__name__)
1313

@@ -57,16 +57,12 @@ def is_authenticated(user):
5757

5858
def base64_url_encode(bytes_like_obj):
5959
"""Return a URL-Safe, base64 encoded version of bytes_like_obj
60-
6160
Implements base64urlencode as described in
6261
https://datatracker.ietf.org/doc/html/rfc7636#appendix-A
62+
This function is not used by the OpenID client; it's just for testing PKCE related functions.
6363
"""
64-
65-
s = josepy.b64.b64encode(bytes_like_obj).decode("ascii") # base64 encode
66-
# the josepy base64 encoder (strips '='s padding) automatically
67-
s = s.replace("+", "-") # 62nd char of encoding
68-
s = s.replace("/", "_") # 63rd char of encoding
69-
64+
s = urlsafe_b64encode(force_bytes(bytes_like_obj)).decode('utf-8')
65+
s = s.rstrip("=")
7066
return s
7167

7268

@@ -78,11 +74,14 @@ def base64_url_decode(string_like_obj):
7874
"""
7975
s = string_like_obj
8076

81-
s = s.replace("_", "/") # 63rd char of encoding
82-
s = s.replace("-", "+") # 62nd char of encoding
83-
b = josepy.b64.b64decode(s) # josepy base64 encoder (decodes without '='s padding)
84-
85-
return b
77+
size = len(s) % 4
78+
if size == 2:
79+
s += '=='
80+
elif size == 3:
81+
s += '='
82+
elif size != 0:
83+
raise ValueError('Invalid base64 string')
84+
return urlsafe_b64decode(s.encode('utf-8'))
8685

8786

8887
def generate_code_challenge(code_verifier, method):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
install_requirements = [
3535
"Django >= 3.2",
36-
"josepy",
36+
"pyjwt",
3737
"requests",
3838
"cryptography",
3939
]

0 commit comments

Comments
 (0)