Skip to content

Commit

Permalink
Add teardown_request decorator. Fixes issue #174
Browse files Browse the repository at this point in the history
  • Loading branch information
glyphobet authored and mitsuhiko committed Mar 14, 2011
1 parent 3deae1b commit 04e70bd
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGES
Expand Up @@ -37,6 +37,8 @@ Release date to be announced, codename to be selected
was incorrectly introduced in 0.6.
- Added `create_jinja_loader` to override the loader creation process.
- Implemented a silent flag for `config.from_pyfile`.
- Added `teardown_request` decorator, for functions that should run at the end
of a request regardless of whether an exception occurred.

Version 0.6.1
-------------
Expand Down
25 changes: 22 additions & 3 deletions docs/tutorial/dbcon.rst
Expand Up @@ -8,8 +8,11 @@ but how can we elegantly do that for requests? We will need the database
connection in all our functions so it makes sense to initialize them
before each request and shut them down afterwards.

Flask allows us to do that with the :meth:`~flask.Flask.before_request` and
:meth:`~flask.Flask.after_request` decorators::
Flask allows us to do that with the :meth:`~flask.Flask.before_request`,
:meth:`~flask.Flask.after_request` and :meth:`~flask.Flask.teardown_request`
decorators. In debug mode, if an error is raised,
:meth:`~flask.Flask.after_request` won't be run, and you'll have access to the
db connection in the interactive debugger::

@app.before_request
def before_request():
Expand All @@ -20,13 +23,29 @@ Flask allows us to do that with the :meth:`~flask.Flask.before_request` and
g.db.close()
return response

If you want to guarantee that the connection is always closed in debug mode, you
can close it in a function decorated with :meth:`~flask.Flask.teardown_request`:

@app.before_request
def before_request():
g.db = connect_db()

@app.teardown_request
def teardown_request(exception):
g.db.close()

Functions marked with :meth:`~flask.Flask.before_request` are called before
a request and passed no arguments, functions marked with
a request and passed no arguments. Functions marked with
:meth:`~flask.Flask.after_request` are called after a request and
passed the response that will be sent to the client. They have to return
that response object or a different one. In this case we just return it
unchanged.

Functions marked with :meth:`~flask.Flask.teardown_request` get called after the
response has been constructed. They are not allowed to modify the request, and
their return values are ignored. If an exception occurred while the request was
being processed, it is passed to each function; otherwise, None is passed in.

We store our current database connection on the special :data:`~flask.g`
object that flask provides for us. This object stores information for one
request only and is available from within each function. Never store such
Expand Down
39 changes: 39 additions & 0 deletions flask/app.py
Expand Up @@ -11,6 +11,7 @@

from __future__ import with_statement

import sys
from threading import Lock
from datetime import timedelta, datetime
from itertools import chain
Expand Down Expand Up @@ -247,6 +248,18 @@ def __init__(self, import_name, static_path=None):
#: :meth:`after_request` decorator.
self.after_request_funcs = {}

#: A dictionary with lists of functions that are called after
#: each request, even if an exception has occurred. The key of the
#: dictionary is the name of the module this function is active for,
#: `None` for all requests. These functions are not allowed to modify
#: the request, and their return values are ignored. If an exception
#: occurred while processing the request, it gets passed to each
#: teardown_request function. To register a function here, use the
#: :meth:`teardown_request` decorator.
#:
#: .. versionadded:: 0.7
self.teardown_request_funcs = {}

#: A dictionary with list of functions that are called without argument
#: to populate the template context. The key of the dictionary is the
#: name of the module this function is active for, `None` for all
Expand Down Expand Up @@ -704,6 +717,11 @@ def after_request(self, f):
self.after_request_funcs.setdefault(None, []).append(f)
return f

def teardown_request(self, f):
"""Register a function to be run at the end of each request, regardless of whether there was an exception or not."""
self.teardown_request_funcs.setdefault(None, []).append(f)
return f

def context_processor(self, f):
"""Registers a template context processor function."""
self.template_context_processors[None].append(f)
Expand Down Expand Up @@ -869,6 +887,20 @@ def process_response(self, response):
response = handler(response)
return response

def do_teardown_request(self):
"""Called after the actual request dispatching and will
call every as :meth:`teardown_request` decorated function.
"""
funcs = reversed(self.teardown_request_funcs.get(None, ()))
mod = request.module
if mod and mod in self.teardown_request_funcs:
funcs = chain(funcs, reversed(self.teardown_request_funcs[mod]))
exc = sys.exc_info()[1]
for func in funcs:
rv = func(exc)
if rv is not None:
return rv

def request_context(self, environ):
"""Creates a request context from the given environment and binds
it to the current context. This must be used in combination with
Expand Down Expand Up @@ -947,6 +979,11 @@ def wsgi_app(self, environ, start_response):
even if an exception happens database have the chance to
properly close the connection.
.. versionchanged:: 0.7
The :meth:`teardown_request` functions get called at the very end of
processing the request. If an exception was thrown, it gets passed to
each teardown_request function.
:param environ: a WSGI environment
:param start_response: a callable accepting a status code,
a list of headers and an optional
Expand All @@ -965,6 +1002,8 @@ def wsgi_app(self, environ, start_response):
response = self.process_response(response)
except Exception, e:
response = self.make_response(self.handle_exception(e))
finally:
self.do_teardown_request()
request_finished.send(self, response=response)
return response(environ, start_response)

Expand Down
74 changes: 73 additions & 1 deletion tests/flask_tests.py
Expand Up @@ -413,6 +413,72 @@ def fails():
assert 'Internal Server Error' in rv.data
assert len(called) == 1

def test_teardown_request_handler(self):
called = []
app = flask.Flask(__name__)
@app.teardown_request
def teardown_request(exc):
called.append(True)
return "Ignored"
@app.route('/')
def root():
return "Response"
rv = app.test_client().get('/')
assert rv.status_code == 200
assert 'Response' in rv.data
assert len(called) == 1

def test_teardown_request_handler_debug_mode(self):
called = []
app = flask.Flask(__name__)
app.debug = True
@app.teardown_request
def teardown_request(exc):
called.append(True)
return "Ignored"
@app.route('/')
def root():
return "Response"
rv = app.test_client().get('/')
assert rv.status_code == 200
assert 'Response' in rv.data
assert len(called) == 1


def test_teardown_request_handler_error(self):
called = []
app = flask.Flask(__name__)
@app.teardown_request
def teardown_request1(exc):
assert type(exc) == ZeroDivisionError
called.append(True)
# This raises a new error and blows away sys.exc_info(), so we can
# test that all teardown_requests get passed the same original
# exception.
try:
raise TypeError
except:
pass
@app.teardown_request
def teardown_request2(exc):
assert type(exc) == ZeroDivisionError
called.append(True)
# This raises a new error and blows away sys.exc_info(), so we can
# test that all teardown_requests get passed the same original
# exception.
try:
raise TypeError
except:
pass
@app.route('/')
def fails():
1/0
rv = app.test_client().get('/')
assert rv.status_code == 500
assert 'Internal Server Error' in rv.data
assert len(called) == 2


def test_before_after_request_order(self):
called = []
app = flask.Flask(__name__)
Expand All @@ -430,12 +496,18 @@ def after1(response):
def after2(response):
called.append(3)
return response
@app.teardown_request
def finish1(exc):
called.append(6)
@app.teardown_request
def finish2(exc):
called.append(5)
@app.route('/')
def index():
return '42'
rv = app.test_client().get('/')
assert rv.data == '42'
assert called == [1, 2, 3, 4]
assert called == [1, 2, 3, 4, 5, 6]

def test_error_handling(self):
app = flask.Flask(__name__)
Expand Down

0 comments on commit 04e70bd

Please sign in to comment.