Skip to content

Commit

Permalink
asynchronous tasks (v0.13)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Mar 3, 2017
1 parent a560333 commit 0c37246
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 0 deletions.
4 changes: 4 additions & 0 deletions flack/__init__.py
Expand Up @@ -32,4 +32,8 @@ def create_app(config_name=None):
from .api import api as api_blueprint
app.register_blueprint(api_blueprint, url_prefix='/api')

# Register async tasks support
from .tasks import tasks_bp as tasks_blueprint
app.register_blueprint(tasks_blueprint, url_prefix='/tasks')

return app
3 changes: 3 additions & 0 deletions flack/api/messages.py
Expand Up @@ -4,11 +4,13 @@
from ..auth import token_auth, token_optional_auth
from ..models import Message
from ..utils import timestamp, url_for
from ..tasks import async
from . import api


@api.route('/messages', methods=['POST'])
@token_auth.login_required
@async
def new_message():
"""
Post a new message.
Expand Down Expand Up @@ -54,6 +56,7 @@ def get_message(id):

@api.route('/messages/<id>', methods=['PUT'])
@token_auth.login_required
@async
def edit_message(id):
"""
Modify an existing message.
Expand Down
101 changes: 101 additions & 0 deletions flack/tasks.py
@@ -0,0 +1,101 @@
from functools import wraps
import threading
import time
import uuid

from flask import Blueprint, abort, current_app, g, request
from werkzeug.exceptions import HTTPException, InternalServerError

from . import db
from .models import User
from .utils import timestamp, url_for

tasks_bp = Blueprint('tasks', __name__)
tasks = {}


@tasks_bp.before_app_first_request
def before_first_request():
"""Start a background thread that cleans up old tasks."""
def clean_old_tasks():
"""
This function cleans up old tasks from our in-memory data structure.
"""
global tasks
while True:
# Only keep tasks that are running or that finished less than 5
# minutes ago.
five_min_ago = timestamp() - 5 * 60
tasks = {id: task for id, task in tasks.items()
if 't' not in task or task['t'] > five_min_ago}
time.sleep(60)

if not current_app.config['TESTING']:
thread = threading.Thread(target=clean_old_tasks)
thread.start()


def async(f):
"""
This decorator transforms a sync route to asynchronous by running it
in a background thread.
"""
@wraps(f)
def wrapped(*args, **kwargs):
def task(app, environ, current_user_nickname):
# Create a request context similar to that of the original request
# so that the task can have access to flask.g, flask.request, etc.
with app.request_context(environ):
# Install the current user in the thread's flask.g
current_user = User.query.filter_by(
nickname=current_user_nickname).first()
if current_user is None:
raise RuntimeError('Invalid user.')
g.current_user = current_user
try:
# Run the route function and record the response
tasks[id]['rv'] = f(*args, **kwargs)
except HTTPException as e:
tasks[id]['rv'] = current_app.handle_http_exception(e)
except Exception as e:
# The function raised an exception, so we set a 500 error
tasks[id]['rv'] = InternalServerError()
if current_app.debug:
# We want to find out if something happened so reraise
raise
finally:
# We record the time of the response, to help in garbage
# collecting old tasks
tasks[id]['t'] = timestamp()

# close the database session
db.session.remove()

# Assign an id to the asynchronous task
id = uuid.uuid4().hex

# Record the task, and then launch it
tasks[id] = {'task': threading.Thread(
target=task, args=(current_app._get_current_object(),
request.environ, g.current_user.nickname))}
tasks[id]['task'].start()

# Return a 202 response, with a link that the client can use to
# obtain task status
return '', 202, {'Location': url_for('tasks.get_status', id=id)}
return wrapped


@tasks_bp.route('/status/<id>', methods=['GET'])
def get_status(id):
"""
Return status about an asynchronous task. If this request returns a 202
status code, it means that task hasn't finished yet. Else, the response
from the task is returned.
"""
task = tasks.get(id)
if task is None:
abort(404)
if 'rv' not in task:
return '', 202, {'Location': url_for('tasks.get_status', id=id)}
return task['rv']
76 changes: 76 additions & 0 deletions tests/tests.py
Expand Up @@ -252,12 +252,28 @@ def test_message(self):
# create a message
r, s, h = self.post('/api/messages', data={'source': 'hello *world*!'},
token_auth=token)
self.assertEqual(s, 202)
url = h['Location']

# wait for asnychronous task to complete
while True:
r, s, h = self.get(url)
if s != 202:
break
self.assertEqual(s, 201)
url = h['Location']

# create incomplete message
r, s, h = self.post('/api/messages', data={'foo': 'hello *world*!'},
token_auth=token)
self.assertEqual(s, 202)
url2 = h['Location']

# wait for asynchronous task to complete
while True:
r, s, h = self.get(url2)
if s != 202:
break
self.assertEqual(s, 400)

# get message
Expand All @@ -269,6 +285,14 @@ def test_message(self):
# modify message
r, s, h = self.put(url, data={'source': '*hello* world!'},
token_auth=token)
self.assertEqual(s, 202)
url2 = h['Location']

# wait for asynchronous task to complete
while True:
r, s, h = self.get(url2)
if s != 202:
break
self.assertEqual(s, 204)

# check modified message
Expand All @@ -283,6 +307,14 @@ def test_message(self):
r, s, h = self.post('/api/messages',
data={'source': 'bye *world*!'},
token_auth=token)
self.assertEqual(s, 202)
url2 = h['Location']

# wait for asynchronous task to complete
while True:
r, s, h = self.get(url2)
if s != 202:
break
self.assertEqual(s, 201)

# get list of messages
Expand Down Expand Up @@ -311,6 +343,14 @@ def test_message(self):
# modify message from first user with second user's token
r, s, h = self.put(url, data={'source': '*hello* world!'},
token_auth=token2)
self.assertEqual(s, 202)
url2 = h['Location']

# wait for asynchronous task to complete
while True:
r, s, h = self.get(url2)
if s != 202:
break
self.assertEqual(s, 403)

def responses():
Expand Down Expand Up @@ -341,7 +381,16 @@ def responses():
'/api/messages',
data={'source': 'hello http://foo.com!'},
token_auth=token)
self.assertEqual(s, 202)
url2 = h['Location']

# wait for asynchronous task to complete
while True:
r, s, h = self.get(url2)
if s != 202:
break
self.assertEqual(s, 201)

self.assertEqual(
r['html'],
'hello <a href="http://foo.com" rel="nofollow">'
Expand All @@ -352,7 +401,16 @@ def responses():
'/api/messages',
data={'source': 'hello http://foo.com!'},
token_auth=token)
self.assertEqual(s, 202)
url2 = h['Location']

# wait for asynchronous task to complete
while True:
r, s, h = self.get(url2)
if s != 202:
break
self.assertEqual(s, 201)

self.assertEqual(
r['html'],
'hello <a href="http://foo.com" rel="nofollow">'
Expand All @@ -363,7 +421,16 @@ def responses():
'/api/messages',
data={'source': 'hello foo.com!'},
token_auth=token)
self.assertEqual(s, 202)
url2 = h['Location']

# wait for asynchronous task to complete
while True:
r, s, h = self.get(url2)
if s != 202:
break
self.assertEqual(s, 201)

self.assertEqual(
r['html'],
'hello <a href="http://foo.com" rel="nofollow">'
Expand All @@ -374,7 +441,16 @@ def responses():
'/api/messages',
data={'source': 'hello foo.com!'},
token_auth=token)
self.assertEqual(s, 202)
url2 = h['Location']

# wait for asynchronous task to complete
while True:
r, s, h = self.get(url2)
if s != 202:
break
self.assertEqual(s, 201)

self.assertEqual(
r['html'],
'hello <a href="http://foo.com" rel="nofollow">'
Expand Down

0 comments on commit 0c37246

Please sign in to comment.