Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
branch: master
436 lines (383 sloc) 17.752 kb
# -*- coding: utf-8 -*-
"""
flask_oauth
~~~~~~~~~~~
Implements basic OAuth support for Flask.
:copyright: (c) 2010 by Armin Ronacher.
:license: BSD, see LICENSE for more details.
"""
import httplib2
from functools import wraps
from urlparse import urljoin
from flask import request, session, json, redirect, Response
from werkzeug import url_decode, url_encode, url_quote, \
parse_options_header, Headers
import oauth2
_etree = None
def get_etree():
"""Return an elementtree implementation. Prefers lxml"""
global _etree
if _etree is None:
try:
from lxml import etree as _etree
except ImportError:
try:
from xml.etree import cElementTree as _etree
except ImportError:
try:
from xml.etree import ElementTree as _etree
except ImportError:
raise TypeError('lxml or etree not found')
return _etree
def parse_response(resp, content, strict=False):
ct, options = parse_options_header(resp['content-type'])
if ct in ('application/json', 'text/javascript'):
return json.loads(content)
elif ct in ('application/xml', 'text/xml'):
# technically, text/xml is ascii based but because many
# implementations get that wrong and utf-8 is a superset
# of utf-8 anyways, so there is not much harm in assuming
# utf-8 here
charset = options.get('charset', 'utf-8')
return get_etree().fromstring(content.decode(charset))
elif ct != 'application/x-www-form-urlencoded':
if strict:
return content
charset = options.get('charset', 'utf-8')
return url_decode(content, charset=charset).to_dict()
def add_query(url, args):
if not args:
return url
return url + ('?' in url and '&' or '?') + url_encode(args)
def encode_request_data(data, format):
if format is None:
return data, None
elif format == 'json':
return json.dumps(data or {}), 'application/json'
elif format == 'urlencoded':
return url_encode(data or {}), 'application/x-www-form-urlencoded'
raise TypeError('Unknown format %r' % format)
class OAuthResponse(object):
"""Contains the response sent back from an OAuth protected remote
application.
"""
def __init__(self, resp, content):
#: a :class:`~werkzeug.Headers` object with the response headers
#: the application sent.
self.headers = Headers(resp)
#: the raw, unencoded content from the server
self.raw_data = content
#: the parsed content from the server
self.data = parse_response(resp, content, strict=True)
@property
def status(self):
"""The status code of the response."""
return self.headers.get('status', type=int)
class OAuthClient(oauth2.Client):
def request_new_token(self, uri, callback=None, params={}):
if callback is not None:
params['oauth_callback'] = callback
req = oauth2.Request.from_consumer_and_token(
self.consumer, token=self.token,
http_method='POST', http_url=uri, parameters=params,
is_form_encoded=True)
req.sign_request(self.method, self.consumer, self.token)
body = req.to_postdata()
headers = {
'Content-Type': 'application/x-www-form-urlencoded',
'Content-Length': str(len(body))
}
return httplib2.Http.request(self, uri, method='POST',
body=body, headers=headers)
class OAuthException(RuntimeError):
"""Raised if authorization fails for some reason."""
message = None
type = None
def __init__(self, message, type=None, data=None):
#: A helpful error message for debugging
self.message = message
#: A unique type for this exception if available.
self.type = type
#: If available, the parsed data from the remote API that can be
#: used to pointpoint the error.
self.data = data
def __str__(self):
return self.message.encode('utf-8')
def __unicode__(self):
return self.message
class OAuth(object):
"""Registry for remote applications. In the future this will also
be the central class for OAuth provider functionality.
"""
def __init__(self):
self.remote_apps = {}
def remote_app(self, name, register=True, **kwargs):
"""Registers a new remote applicaton. If `param` register is
set to `False` the application is not registered in the
:attr:`remote_apps` dictionary. The keyword arguments are
forwarded to the :class:`OAuthRemoteApp` consturctor.
"""
app = OAuthRemoteApp(self, name, **kwargs)
if register:
assert name not in self.remote_apps, \
'application already registered'
self.remote_apps[name] = app
return app
class OAuthRemoteApp(object):
"""Represents a remote application.
:param oauth: the associated :class:`OAuth` object.
:param name: then name of the remote application
:param request_token_url: the URL for requesting new tokens
:param access_token_url: the URL for token exchange
:param authorize_url: the URL for authorization
:param consumer_key: the application specific consumer key
:param consumer_secret: the application specific consumer secret
:param request_token_params: an optional dictionary of parameters
to forward to the request token URL
or authorize URL depending on oauth
version.
:param access_token_params: an option diction of parameters to forward to
the access token URL
:param access_token_method: the HTTP method that should be used
for the access_token_url. Defaults
to ``'GET'``.
"""
def __init__(self, oauth, name, base_url,
request_token_url,
access_token_url, authorize_url,
consumer_key, consumer_secret,
request_token_params=None,
access_token_params=None,
access_token_method='GET'):
self.oauth = oauth
#: the `base_url` all URLs are joined with.
self.base_url = base_url
self.name = name
self.request_token_url = request_token_url
self.access_token_url = access_token_url
self.authorize_url = authorize_url
self.consumer_key = consumer_key
self.consumer_secret = consumer_secret
self.tokengetter_func = None
self.request_token_params = request_token_params or {}
self.access_token_params = access_token_params or {}
self.access_token_method = access_token_method
self._consumer = oauth2.Consumer(self.consumer_key,
self.consumer_secret)
self._client = OAuthClient(self._consumer)
def status_okay(self, resp):
"""Given request data, checks if the status is okay."""
try:
return int(resp['status']) in (200, 201)
except ValueError:
return False
def get(self, *args, **kwargs):
"""Sends a ``GET`` request. Accepts the same parameters as
:meth:`request`.
"""
kwargs['method'] = 'GET'
return self.request(*args, **kwargs)
def post(self, *args, **kwargs):
"""Sends a ``POST`` request. Accepts the same parameters as
:meth:`request`.
"""
kwargs['method'] = 'POST'
return self.request(*args, **kwargs)
def put(self, *args, **kwargs):
"""Sends a ``PUT`` request. Accepts the same parameters as
:meth:`request`.
"""
kwargs['method'] = 'PUT'
return self.request(*args, **kwargs)
def delete(self, *args, **kwargs):
"""Sends a ``DELETE`` request. Accepts the same parameters as
:meth:`request`.
"""
kwargs['method'] = 'DELETE'
return self.request(*args, **kwargs)
def make_client(self, token=None):
"""Creates a new `oauth2` Client object with the token attached.
Usually you don't have to do that but use the :meth:`request`
method instead.
"""
return oauth2.Client(self._consumer, self.get_request_token(token))
def request(self, url, data="", headers=None, format='urlencoded',
method='GET', content_type=None, token=None):
"""Sends a request to the remote server with OAuth tokens attached.
The `url` is joined with :attr:`base_url` if the URL is relative.
.. versionadded:: 0.12
added the `token` parameter.
:param url: where to send the request to
:param data: the data to be sent to the server. If the request method
is ``GET`` the data is appended to the URL as query
parameters, otherwise encoded to `format` if the format
is given. If a `content_type` is provided instead, the
data must be a string encoded for the given content
type and used as request body.
:param headers: an optional dictionary of headers.
:param format: the format for the `data`. Can be `urlencoded` for
URL encoded data or `json` for JSON.
:param method: the HTTP request method to use.
:param content_type: an optional content type. If a content type is
provided, the data is passed as it and the
`format` parameter is ignored.
:param token: an optional token to pass to tokengetter. Use this if you
want to support sending requests using multiple tokens.
If you set this to anything not None, `tokengetter_func`
will receive the given token as an argument, in which case
the tokengetter should return the `(token, secret)` tuple
for the given token.
:return: an :class:`OAuthResponse` object.
"""
headers = dict(headers or {})
client = self.make_client(token)
url = self.expand_url(url)
if method == 'GET':
assert format == 'urlencoded'
if data:
url = add_query(url, data)
data = ""
else:
if content_type is None:
data, content_type = encode_request_data(data, format)
if content_type is not None:
headers['Content-Type'] = content_type
return OAuthResponse(*client.request(url, method=method,
body=data or '',
headers=headers))
def expand_url(self, url):
return urljoin(self.base_url, url)
def generate_request_token(self, callback=None):
if callback is not None:
callback = urljoin(request.url, callback)
resp, content = self._client.request_new_token(
self.expand_url(self.request_token_url), callback,
self.request_token_params)
if not self.status_okay(resp):
raise OAuthException('Failed to generate request token',
type='token_generation_failed')
data = parse_response(resp, content)
if data is None:
raise OAuthException('Invalid token response from ' + self.name,
type='token_generation_failed')
tup = (data['oauth_token'], data['oauth_token_secret'])
session[self.name + '_oauthtok'] = tup
return tup
def get_request_token(self, token=None):
assert self.tokengetter_func is not None, 'missing tokengetter function'
# Don't pass the token if the token is None to support old
# tokengetter functions.
rv = self.tokengetter_func(*(token and (token,) or ()))
if rv is None:
rv = session.get(self.name + '_oauthtok')
if rv is None:
raise OAuthException('No token available', type='token_missing')
return oauth2.Token(*rv)
def free_request_token(self):
session.pop(self.name + '_oauthtok', None)
session.pop(self.name + '_oauthredir', None)
def authorize(self, callback=None):
"""Returns a redirect response to the remote authorization URL with
the signed callback given. The callback must be `None` in which
case the application will most likely switch to PIN based authentication
or use a remotely stored callback URL. Alternatively it's an URL
on the system that has to be decorated as :meth:`authorized_handler`.
"""
if self.request_token_url:
token = self.generate_request_token(callback)[0]
url = '%s?oauth_token=%s' % (self.expand_url(self.authorize_url),
url_quote(token))
else:
assert callback is not None, 'Callback is required OAuth2'
# This is for things like facebook's oauth. Since we need the
# callback for the access_token_url we need to keep it in the
# session.
params = dict(self.request_token_params)
params['redirect_uri'] = callback
params['client_id'] = self.consumer_key
params['response_type'] = 'code'
session[self.name + '_oauthredir'] = callback
url = add_query(self.expand_url(self.authorize_url), params)
return redirect(url)
def tokengetter(self, f):
"""Registers a function as tokengetter. The tokengetter has to return
a tuple of ``(token, secret)`` with the user's token and token secret.
If the data is unavailable, the function must return `None`.
If the `token` parameter is passed to the request function it's
forwarded to the tokengetter function::
@oauth.tokengetter
def get_token(token='user'):
if token == 'user':
return find_the_user_token()
elif token == 'app':
return find_the_app_token()
raise RuntimeError('invalid token')
"""
self.tokengetter_func = f
return f
def handle_oauth1_response(self):
"""Handles an oauth1 authorization response. The return value of
this method is forwarded as first argument to the handling view
function.
"""
client = self.make_client()
resp, content = client.request('%s?oauth_verifier=%s' % (
self.expand_url(self.access_token_url),
request.args['oauth_verifier']
), self.access_token_method)
data = parse_response(resp, content)
if not self.status_okay(resp):
raise OAuthException('Invalid response from ' + self.name,
type='invalid_response', data=data)
return data
def handle_oauth2_response(self):
"""Handles an oauth2 authorization response. The return value of
this method is forwarded as first argument to the handling view
function.
"""
remote_args = {
'code': request.args.get('code'),
'client_id': self.consumer_key,
'client_secret': self.consumer_secret,
'redirect_uri': session.get(self.name + '_oauthredir')
}
remote_args.update(self.access_token_params)
if self.access_token_method == 'POST':
resp, content = self._client.request(self.expand_url(self.access_token_url),
self.access_token_method,
url_encode(remote_args))
elif self.access_token_method == 'GET':
url = add_query(self.expand_url(self.access_token_url), remote_args)
resp, content = self._client.request(url, self.access_token_method)
else:
raise OAuthException('Unsupported access_token_method: ' +
self.access_token_method)
data = parse_response(resp, content)
if not self.status_okay(resp):
raise OAuthException('Invalid response from ' + self.name,
type='invalid_response', data=data)
return data
def handle_unknown_response(self):
"""Called if an unknown response came back from the server. This
usually indicates a denied response. The default implementation
just returns `None`.
"""
return None
def authorized_handler(self, f):
"""Injects additional authorization functionality into the function.
The function will be passed the response object as first argument
if the request was allowed, or `None` if access was denied. When the
authorized handler is called, the temporary issued tokens are already
destroyed.
"""
@wraps(f)
def decorated(*args, **kwargs):
if 'oauth_verifier' in request.args:
data = self.handle_oauth1_response()
elif 'code' in request.args:
data = self.handle_oauth2_response()
else:
data = self.handle_unknown_response()
self.free_request_token()
return f(*((data,) + args), **kwargs)
return decorated
Jump to Line
Something went wrong with that request. Please try again.