/
async_app.py
205 lines (171 loc) · 7.66 KB
/
async_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import time
import logging
from authlib.common.urls import urlparse
from authlib.jose import JsonWebToken, JsonWebKey
from authlib.oidc.core import UserInfo, CodeIDToken, ImplicitIDToken
from .base_app import BaseApp
from .errors import (
MissingRequestTokenError,
MissingTokenError,
)
__all__ = ['AsyncRemoteApp']
log = logging.getLogger(__name__)
class AsyncRemoteApp(BaseApp):
async def load_server_metadata(self):
if self._server_metadata_url and '_loaded_at' not in self.server_metadata:
metadata = await self._fetch_server_metadata(self._server_metadata_url)
metadata['_loaded_at'] = time.time()
self.server_metadata.update(metadata)
return self.server_metadata
async def _on_update_token(self, token, refresh_token=None, access_token=None):
if self._update_token:
await self._update_token(
token,
refresh_token=refresh_token,
access_token=access_token,
)
async def _create_oauth1_authorization_url(self, client, authorization_endpoint, **kwargs):
params = {}
if self.request_token_params:
params.update(self.request_token_params)
token = await client.fetch_request_token(
self.request_token_url, **params
)
log.debug('Fetch request token: {!r}'.format(token))
url = client.create_authorization_url(authorization_endpoint, **kwargs)
return {'url': url, 'request_token': token}
async def create_authorization_url(self, redirect_uri=None, **kwargs):
"""Generate the authorization url and state for HTTP redirect.
:param redirect_uri: Callback or redirect URI for authorization.
:param kwargs: Extra parameters to include.
:return: dict
"""
metadata = await self.load_server_metadata()
authorization_endpoint = self.authorize_url
if not authorization_endpoint and not self.request_token_url:
authorization_endpoint = metadata.get('authorization_endpoint')
if not authorization_endpoint:
raise RuntimeError('Missing "authorize_url" value')
if self.authorize_params:
kwargs.update(self.authorize_params)
async with self._get_oauth_client(**metadata) as client:
client.redirect_uri = redirect_uri
if self.request_token_url:
return await self._create_oauth1_authorization_url(
client, authorization_endpoint, **kwargs)
else:
return self._create_oauth2_authorization_url(
client, authorization_endpoint, **kwargs)
async def fetch_access_token(self, redirect_uri=None, request_token=None, **params):
"""Fetch access token in one step.
:param redirect_uri: Callback or Redirect URI that is used in
previous :meth:`authorize_redirect`.
:param request_token: A previous request token for OAuth 1.
:param params: Extra parameters to fetch access token.
:return: A token dict.
"""
metadata = await self.load_server_metadata()
token_endpoint = self.access_token_url
if not token_endpoint and not self.request_token_url:
token_endpoint = metadata.get('token_endpoint')
async with self._get_oauth_client(**metadata) as client:
if self.request_token_url:
client.redirect_uri = redirect_uri
if request_token is None:
raise MissingRequestTokenError()
# merge request token with verifier
token = {}
token.update(request_token)
token.update(params)
client.token = token
kwargs = self.access_token_params or {}
token = await client.fetch_access_token(token_endpoint, **kwargs)
client.redirect_uri = None
else:
client.redirect_uri = redirect_uri
kwargs = {}
if self.access_token_params:
kwargs.update(self.access_token_params)
kwargs.update(params)
token = await client.fetch_token(token_endpoint, **kwargs)
return token
async def request(self, method, url, token=None, **kwargs):
if self.api_base_url and not url.startswith(('https://', 'http://')):
url = urlparse.urljoin(self.api_base_url, url)
withhold_token = kwargs.get('withhold_token')
if not withhold_token:
metadata = await self.load_server_metadata()
else:
metadata = {}
async with self._get_oauth_client(**metadata) as client:
request = kwargs.pop('request', None)
if withhold_token:
return await client.request(method, url, **kwargs)
if token is None and request:
token = await self._fetch_token(request)
if token is None:
raise MissingTokenError()
client.token = token
return await client.request(method, url, **kwargs)
async def userinfo(self, **kwargs):
"""Fetch user info from ``userinfo_endpoint``."""
metadata = await self.load_server_metadata()
resp = await self.get(metadata['userinfo_endpoint'], **kwargs)
data = resp.json()
compliance_fix = metadata.get('userinfo_compliance_fix')
if compliance_fix:
data = await compliance_fix(self, data)
return UserInfo(data)
async def _parse_id_token(self, token, nonce, claims_options=None):
"""Return an instance of UserInfo from token's ``id_token``."""
claims_params = dict(
nonce=nonce,
client_id=self.client_id,
)
if 'access_token' in token:
claims_params['access_token'] = token['access_token']
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken
metadata = await self.load_server_metadata()
if claims_options is None and 'issuer' in metadata:
claims_options = {'iss': {'values': [metadata['issuer']]}}
alg_values = metadata.get('id_token_signing_alg_values_supported')
if not alg_values:
alg_values = ['RS256']
jwt = JsonWebToken(alg_values)
jwk_set = await self._fetch_jwk_set()
try:
claims = jwt.decode(
token['id_token'],
key=JsonWebKey.import_key_set(jwk_set),
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
except ValueError:
jwk_set = await self._fetch_jwk_set(force=True)
claims = jwt.decode(
token['id_token'],
key=JsonWebKey.import_key_set(jwk_set),
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
)
claims.validate(leeway=120)
return UserInfo(claims)
async def _fetch_jwk_set(self, force=False):
metadata = await self.load_server_metadata()
jwk_set = metadata.get('jwks')
if jwk_set and not force:
return jwk_set
uri = metadata.get('jwks_uri')
if not uri:
raise RuntimeError('Missing "jwks_uri" in metadata')
jwk_set = await self._fetch_server_metadata(uri)
self.server_metadata['jwks'] = jwk_set
return jwk_set
async def _fetch_server_metadata(self, url):
async with self._get_oauth_client() as client:
resp = await client.request('GET', url, withhold_token=True)
return resp.json()