-
Notifications
You must be signed in to change notification settings - Fork 27
/
token_service.py
111 lines (98 loc) · 3.9 KB
/
token_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import time
import requests
from jose import jwk, jwt
from jose.exceptions import JOSEError
from jose.utils import base64url_decode
from flask_awscognito.exceptions import FlaskAWSCognitoError, TokenVerifyError
class TokenService:
def __init__(
self,
user_pool_id,
user_pool_client_id,
region,
request_client=None,
_jwk_keys=None,
):
self.region = region
if not self.region:
raise FlaskAWSCognitoError("No AWS region provided")
self.user_pool_id = user_pool_id
self.user_pool_client_id = user_pool_client_id
self.claims = None
if not request_client:
self.request_client = requests.get
else:
self.request_client = request_client
if _jwk_keys:
self.jwk_keys = _jwk_keys
else:
self.jwk_keys = self._load_jwk_keys()
def _load_jwk_keys(self):
keys_url = f"https://cognito-idp.{self.region}.amazonaws.com/{self.user_pool_id}/.well-known/jwks.json"
try:
response = self.request_client(keys_url)
return response.json()["keys"]
except requests.exceptions.RequestException as e:
raise FlaskAWSCognitoError(str(e)) from e
@staticmethod
def _extract_headers(token):
try:
headers = jwt.get_unverified_headers(token)
return headers
except JOSEError as e:
raise TokenVerifyError(str(e)) from e
def _find_pkey(self, headers):
kid = headers["kid"]
# search for the kid in the downloaded public keys
key_index = -1
for i in range(len(self.jwk_keys)):
if kid == self.jwk_keys[i]["kid"]:
key_index = i
break
if key_index == -1:
raise TokenVerifyError("Public key not found in jwks.json")
return self.jwk_keys[key_index]
@staticmethod
def _verify_signature(token, pkey_data):
try:
# construct the public key
public_key = jwk.construct(pkey_data)
except JOSEError as e:
raise TokenVerifyError(str(e)) from e
# get the last two sections of the token,
# message and signature (encoded in base64)
message, encoded_signature = str(token).rsplit(".", 1)
# decode the signature
decoded_signature = base64url_decode(encoded_signature.encode("utf-8"))
# verify the signature
if not public_key.verify(message.encode("utf8"), decoded_signature):
raise TokenVerifyError("Signature verification failed")
@staticmethod
def _extract_claims(token):
try:
claims = jwt.get_unverified_claims(token)
return claims
except JOSEError as e:
raise TokenVerifyError(str(e)) from e
@staticmethod
def _check_expiration(claims, current_time):
if not current_time:
current_time = time.time()
if current_time > claims["exp"]:
raise TokenVerifyError("Token is expired") # probably another exception
def _check_audience(self, claims):
# and the Audience (use claims['client_id'] if verifying an access token)
audience = claims["aud"] if "aud" in claims else claims["client_id"]
if audience != self.user_pool_client_id:
raise TokenVerifyError("Token was not issued for this audience")
def verify(self, token, current_time=None):
""" https://github.com/awslabs/aws-support-tools/blob/master/Cognito/decode-verify-jwt/decode-verify-jwt.py """
if not token:
raise TokenVerifyError("No token provided")
headers = self._extract_headers(token)
pkey_data = self._find_pkey(headers)
self._verify_signature(token, pkey_data)
claims = self._extract_claims(token)
self._check_expiration(claims, current_time)
# self._check_audience(claims)
self.claims = claims