1
1
import base64
2
2
import hashlib
3
- import json
4
3
import logging
5
4
6
5
import inspect
6
+ import jwt
7
7
import requests
8
8
from django .contrib .auth import get_user_model
9
9
from django .contrib .auth .backends import ModelBackend
10
10
from django .core .exceptions import ImproperlyConfigured , SuspiciousOperation
11
11
from django .urls import reverse
12
- from django .utils .encoding import force_bytes , smart_bytes , smart_str
12
+ from django .utils .encoding import force_bytes , smart_str
13
13
from django .utils .module_loading import import_string
14
- from josepy .b64 import b64decode
15
- from josepy .jwk import JWK
16
- from josepy .jws import JWS , Header
17
14
from requests .auth import HTTPBasicAuth
18
15
from requests .exceptions import HTTPError
19
16
@@ -127,10 +124,10 @@ def update_user(self, user, claims):
127
124
128
125
def _verify_jws (self , payload , key ):
129
126
"""Verify the given JWS payload with the given key and return the payload"""
130
- jws = JWS . from_compact (payload )
127
+ jws = jwt . get_unverified_header (payload )
131
128
132
129
try :
133
- alg = jws . signature . combined . alg . name
130
+ alg = jws [ " alg" ]
134
131
except KeyError :
135
132
msg = "No alg value found in header"
136
133
raise SuspiciousOperation (msg )
@@ -142,21 +139,19 @@ def _verify_jws(self, payload, key):
142
139
)
143
140
raise SuspiciousOperation (msg )
144
141
145
- if isinstance (key , str ):
146
- # Use smart_bytes here since the key string comes from settings.
147
- jwk = JWK .load (smart_bytes (key ))
148
- else :
149
- # The key is a json returned from the IDP JWKS endpoint.
150
- jwk = JWK .from_json (key )
151
-
152
- if not jws .verify (jwk ):
142
+ try :
143
+ # Maybe add a settings to enforce audiance validation
144
+ return jwt .decode (payload , key , algorithms = alg , options = {"verify_aud" : False })
145
+ except jwt .DecodeError :
153
146
msg = "JWS token verification failed."
154
147
raise SuspiciousOperation (msg )
155
148
156
- return jws .payload
157
-
158
149
def retrieve_matching_jwk (self , token ):
159
- """Get the signing key by exploring the JWKS endpoint of the OP."""
150
+ """Get the signing key by exploring the JWKS endpoint of the OP.
151
+
152
+ Don't use jwt.PyJWKClient()get_signing_key_from_jwt() because it doesn't check
153
+ the algorithm in case of multiple jwk with the same kid.
154
+ """
160
155
response_jwks = requests .get (
161
156
self .OIDC_OP_JWKS_ENDPOINT ,
162
157
verify = self .get_settings ("OIDC_VERIFY_SSL" , True ),
@@ -167,32 +162,29 @@ def retrieve_matching_jwk(self, token):
167
162
jwks = response_jwks .json ()
168
163
169
164
# Compute the current header from the given token to find a match
170
- jws = JWS .from_compact (token )
171
- json_header = jws .signature .protected
172
- header = Header .json_loads (json_header )
165
+ jws = jwt .get_unverified_header (token )
173
166
174
167
key = None
175
168
for jwk in jwks ["keys" ]:
176
169
if import_from_settings ("OIDC_VERIFY_KID" , True ) and jwk [
177
170
"kid"
178
- ] != smart_str (header . kid ):
171
+ ] != smart_str (jws [ " kid" ] ):
179
172
continue
180
- if "alg" in jwk and jwk ["alg" ] != smart_str (header . alg ):
173
+ if "alg" in jwk and jwk ["alg" ] != smart_str (jws [ " alg" ] ):
181
174
continue
182
175
key = jwk
183
176
if key is None :
184
177
raise SuspiciousOperation ("Could not find a valid JWKS." )
185
- return key
178
+ return jwt . PyJWK ( key )
186
179
187
180
def get_payload_data (self , token , key ):
188
181
"""Helper method to get the payload of the JWT token."""
189
182
if self .get_settings ("OIDC_ALLOW_UNSECURED_JWT" , False ):
190
- header , payload_data , signature = token .split (b"." )
191
- header = json .loads (smart_str (b64decode (header )))
183
+ header = jwt .get_unverified_header (token )
192
184
193
185
# If config allows unsecured JWTs check the header and return the decoded payload
194
186
if "alg" in header and header ["alg" ] == "none" :
195
- return b64decode ( payload_data )
187
+ return jwt . decode ( token , options = { "verify_signature" : False } )
196
188
197
189
# By default fallback to verify JWT signatures
198
190
return self ._verify_jws (token , key )
@@ -201,7 +193,6 @@ def verify_token(self, token, **kwargs):
201
193
"""Validate the token signature."""
202
194
nonce = kwargs .get ("nonce" )
203
195
204
- token = force_bytes (token )
205
196
if self .OIDC_RP_SIGN_ALGO .startswith ("RS" ) or self .OIDC_RP_SIGN_ALGO .startswith (
206
197
"ES"
207
198
):
@@ -212,16 +203,7 @@ def verify_token(self, token, **kwargs):
212
203
else :
213
204
key = self .OIDC_RP_CLIENT_SECRET
214
205
215
- payload_data = self .get_payload_data (token , key )
216
-
217
- # The 'token' will always be a byte string since it's
218
- # the result of base64.urlsafe_b64decode().
219
- # The payload is always the result of base64.urlsafe_b64decode().
220
- # In Python 3 and 2, that's always a byte string.
221
- # In Python3.6, the json.loads() function can accept a byte string
222
- # as it will automagically decode it to a unicode string before
223
- # deserializing https://bugs.python.org/issue17909
224
- payload = json .loads (payload_data .decode ("utf-8" ))
206
+ payload = self .get_payload_data (token , key )
225
207
token_nonce = payload .get ("nonce" )
226
208
227
209
if self .get_settings ("OIDC_USE_NONCE" , True ) and nonce != token_nonce :
0 commit comments