-
Notifications
You must be signed in to change notification settings - Fork 18
/
base.py
403 lines (344 loc) · 14.1 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
import time
import uuid
from contextlib import contextmanager
from json.decoder import JSONDecodeError
# The Python Social Auth base handler gives us:
# user_id, get_current_user, login_user
#
# `get_current_user` is needed by tornado.authentication,
# and provides a cached version, `current_user`, that should
# be used to look up the logged in user.
import sqlalchemy
import tornado.escape
from tornado.log import app_log
from tornado.web import RequestHandler
from ...log import make_log
# Initialize PSA tornado models
from .. import psa
from ..custom_exceptions import AccessError
from ..env import load_env
from ..flow import Flow
from ..json_util import to_json
from ..models import DBSession, User, VerifiedSession, bulk_verify, session_context_id
env, cfg = load_env()
log = make_log("basehandler")
# Python Social Auth documentation:
# https://python-social-auth.readthedocs.io/en/latest/backends/implementation.html#auth-apis
class NoValue:
pass
class PSABaseHandler(RequestHandler):
"""
Mixin used by Python Social Auth
"""
def user_id(self):
return self.get_secure_cookie("user_id")
def get_current_user(self):
if self.user_id() is None:
return
user_id = int(self.user_id())
oauth_uid = self.get_secure_cookie("user_oauth_uid")
if user_id and oauth_uid:
with DBSession() as session:
try:
user = session.scalars(
sqlalchemy.select(User).where(User.id == user_id)
).first()
if user is None:
return None
sa = session.scalars(
sqlalchemy.select(psa.TornadoStorage.user).where(
psa.TornadoStorage.user.user_id == user_id
)
).first()
if sa is None:
# No SocialAuth entry; probably machine generated user
return user
if sa.uid.encode("utf-8") == oauth_uid:
return user
except Exception as e:
session.rollback()
log(f"Could not get current user: {e}")
return None
else:
return None
def login_user(self, user):
with DBSession() as session:
try:
self.set_secure_cookie("user_id", str(user.id))
user = session.scalars(
sqlalchemy.select(User).where(User.id == user.id)
).first()
if user is None:
return
sa = session.scalars(
sqlalchemy.select(psa.TornadoStorage.user).where(
psa.TornadoStorage.user.user_id == user.id
)
).first()
if sa is not None:
self.set_secure_cookie("user_oauth_uid", sa.uid)
except Exception as e:
session.rollback()
log(f"Could not login user: {e}")
def write_error(self, status_code, exc_info=None):
if exc_info is not None:
err_cls, err, traceback = exc_info
else:
err = "An unknown error occurred"
self.render("loginerror.html", app=cfg["app"], error_message=str(err))
def log_exception(self, typ=None, value=None, tb=None):
expected_exceptions = [
"Authentication Error:",
"User account expired",
"Credentials malformed",
"Method Not Allowed",
"Unauthorized",
]
v_str = str(value)
if any(exception in v_str for exception in expected_exceptions):
log(f"Error response returned by [{self.request.path}]: [{v_str}]")
else:
app_log.error(
"Uncaught exception %s\n%r",
self._request_summary(),
self.request,
exc_info=(typ, value, tb),
)
def on_finish(self):
DBSession.remove()
class BaseHandler(PSABaseHandler):
@contextmanager
def Session(self):
"""
Generate a scoped session that also has knowledge
of the current user, so when commit() is called on it
it will also verify that all rows being committed
are accessible to the user.
The current user is taken from the handler's `current_user`.
This is a shortcut method to `models.Session`
that saves the need to manually input the user object.
Parameters
----------
verify : boolean
if True (default), will call the functions
`verify()` and whenever `commit()` is called.
Returns
-------
A scoped session object that can be used in a context
manager to access the database. If auto verify is enabled,
will use the current user given to apply verification
before every commit.
"""
with VerifiedSession(self.current_user) as session:
# must merge the user object with the current session
# ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#adding-new-or-existing-items
session.add(self.current_user)
session.bind = DBSession.session_factory.kw["bind"]
yield session
def verify_permissions(self):
"""Check that the current user has permission to create, read,
update, or delete rows that are present in the session. If not,
raise an AccessError (causing the transaction to fail and the API to
respond with 401).
"""
# get items to be inserted
new_rows = [row for row in DBSession().new]
# get items to be updated
updated_rows = [
row for row in DBSession().dirty if DBSession().is_modified(row)
]
# get items to be deleted
deleted_rows = [row for row in DBSession().deleted]
# get items that were read
read_rows = [
row
for row in set(DBSession().identity_map.values())
- (set(updated_rows) | set(new_rows) | set(deleted_rows))
]
# need to check delete permissions before flushing, as deleted records
# are not present in the transaction after flush (thus can't be used in
# joins). Read permissions can be checked here or below as they do not
# change on flush.
for mode, collection in zip(
["read", "update", "delete"],
[read_rows, updated_rows, deleted_rows],
):
bulk_verify(mode, collection, self.current_user)
# update transaction state in DB, but don't commit yet. this updates
# or adds rows in the database and uses their new state in joins,
# for permissions checking purposes.
DBSession().flush()
bulk_verify("create", new_rows, self.current_user)
def verify_and_commit(self):
"""Verify permissions on the current database session and commit if
successful, otherwise raise an AccessError.
"""
self.verify_permissions()
DBSession().commit()
def prepare(self):
self.cfg = self.application.cfg
self.flow = Flow()
session_context_id.set(uuid.uuid4().hex)
# Remove slash prefixes from arguments
if self.path_args:
self.path_args = [
arg.lstrip("/") if arg is not None else None for arg in self.path_args
]
self.path_args = [arg if (arg != "") else None for arg in self.path_args]
# If there are no arguments, make it explicit, otherwise
# get / post / put / delete all have to accept an optional kwd argument
if len(self.path_args) == 1 and self.path_args[0] is None:
self.path_args = []
# TODO Refactor to be a context manager or utility function
N = 5
for i in range(1, N + 1):
try:
assert DBSession.session_factory.kw["bind"] is not None
except Exception as e:
if i == N:
raise e
else:
log("Error connecting to database, sleeping for a while")
time.sleep(5)
return super().prepare()
def push(self, action, payload={}):
"""Broadcast a message to current frontend user.
Parameters
----------
action : str
Name of frontend action to perform after API success. This action
is sent to the frontend over WebSocket.
payload : dict, optional
Action payload. This data accompanies the action string
to the frontend.
"""
# Don't push messages if current user is a token
if hasattr(self.current_user, "username"):
self.flow.push(self.current_user.id, action, payload)
def push_all(self, action, payload={}):
"""Broadcast a message to all frontend users.
Use this functionality with care for two reasons:
- It emits many messages, and if those messages trigger a response from
frontends, it can result in many incoming API requests
- Any information included in the message will be seen by everyone; and
everyone will know it was sent. Do not, e.g., send around a message
saying "secret object XYZ was updated; fetch the latest version".
Even though the user won't be able to fetch the object, they'll
know that it exists, and that it was modified.
Parameters
----------
action : str
Name of frontend action to perform after API success. This action
is sent to the frontend over WebSocket.
payload : dict, optional
Action payload. This data accompanies the action string
to the frontend.
"""
self.flow.push("*", action, payload=payload)
def get_json(self):
if len(self.request.body) == 0:
return {}
try:
json = tornado.escape.json_decode(self.request.body)
if not isinstance(json, dict):
raise Exception("Please ensure posted data is of type application/json")
return json
except JSONDecodeError:
raise Exception(
f"JSON decode of request body failed on {self.request.uri}."
" Please ensure all requests are of type application/json."
)
def error(self, message, data={}, status=400, extra={}):
"""Push an error message to the frontend via WebSocket connection.
The return JSON has the following format::
{
"status": "error",
"data": ...,
...extra...
}
Parameters
----------
message : str
Description of the error.
data : dict, optional
Any data to be included with error message.
status : int, optional
HTTP status code. Defaults to 400 (bad request).
See https://www.restapitutorial.com/httpstatuscodes.html for a full
list.
extra : dict
Extra fields to be included in the response.
"""
self.set_header("Content-Type", "application/json")
self.set_status(status)
self.write({"status": "error", "message": message, "data": data, **extra})
def action(self, action, payload={}):
"""Push an action to the frontend via WebSocket connection.
Parameters
----------
action : str
Name of frontend action to perform after API success. This action
is sent to the frontend over WebSocket.
payload : dict, optional
Action payload. This data accompanies the action string
to the frontend.
"""
self.push(action, payload)
def success(self, data={}, action=None, payload={}, status=200, extra={}):
"""Write data and send actions on API success.
The return JSON has the following format::
{
"status": "success",
"data": ...,
...extra...
}
Parameters
----------
data : dict, optional
The JSON returned by the API call in the `data` field.
action : str, optional
Name of frontend action to perform after API success. This action
is sent to the frontend over WebSocket.
payload : dict, optional
Action payload. This data accompanies the action string
to the frontend.
status : int, optional
HTTP status code. Defaults to 200 (OK).
See https://www.restapitutorial.com/httpstatuscodes.html for a full
list.
extra : dict
Extra fields to be included in the response.
"""
if action is not None:
self.action(action, payload)
self.set_header("Content-Type", "application/json")
self.set_status(status)
self.write(to_json({"status": "success", "data": data, **extra}))
def write_error(self, status_code, exc_info=None):
if exc_info is not None:
err_cls, err, traceback = exc_info
if isinstance(err_cls, AccessError):
status_code = 401
else:
err = "An unknown error occurred"
self.error(str(err), status=status_code)
async def _get_client(self, timeout=5):
IP = "127.0.0.1"
PORT_SCHEDULER = self.cfg["ports.dask"]
from distributed import Client
client = await Client(
f"{IP}:{PORT_SCHEDULER}", asynchronous=True, timeout=timeout
)
return client
def push_notification(self, note, notification_type="info"):
self.push(
action="baselayer/SHOW_NOTIFICATION",
payload={"note": note, "type": notification_type},
)
def get_query_argument(self, value, default=NoValue, **kwargs):
if default != NoValue:
kwargs["default"] = default
arg = super().get_query_argument(value, **kwargs)
if type(kwargs.get("default", None)) == bool:
arg = str(arg).lower() in ["true", "yes", "t", "1"]
return arg