Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ script:

after_success:
- pip install coveralls
- coverage run --source=flask_oauthlib setup.py -q nosetests
- DEBUG=1 coverage run --source=flask_oauthlib setup.py -q nosetests
- coveralls

notifications:
Expand Down
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Flask-OAuthlib
==============

.. image:: https://travis-ci.org/lepture/flask-oauthlib.png?branch=master
:target: https://travis-ci.org/lepture/flask-oauthlib
.. image:: https://coveralls.io/repos/lepture/flask-oauthlib/badge.png?branch=master
:target: https://coveralls.io/r/lepture/flask-oauthlib

Flask-OAuthlib is an extension to Flask that allows you to interact with
remote OAuth enabled applications. It is a replacement for Flask-OAuth.

Expand Down
12 changes: 12 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,15 @@ Client Reference

.. autoclass:: OAuthException
:members:


Provider Reference
------------------

.. module:: flask_oauthlib.provider

.. autoclass:: OAuth2Provider
:members:

.. autoclass:: OAuth2RequestValidator
:members:
2 changes: 1 addition & 1 deletion flask_oauthlib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __getattr__(self, key):
_etree = None


def get_etree():
def get_etree(): # pragma: no cover
global _etree
if _etree is not None:
return _etree
Expand Down
80 changes: 69 additions & 11 deletions flask_oauthlib/provider/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from werkzeug import cached_property
from oauthlib import oauth2
from oauthlib.oauth2 import RequestValidator, Server
from oauthlib.common import to_unicode

__all__ = ('OAuth2Provider', 'OAuth2RequestValidator')

Expand Down Expand Up @@ -84,11 +85,17 @@ def server(self):
hasattr(self, '_tokensetter') and \
hasattr(self, '_grantgetter') and \
hasattr(self, '_grantsetter'):

usergetter = None
if hasattr(self, '_usergetter'):
usergetter = self._usergetter

validator = OAuth2RequestValidator(
clientgetter=self._clientgetter,
tokengetter=self._tokengetter,
tokensetter=self._tokensetter,
grantgetter=self._grantgetter,
usergetter=usergetter,
tokensetter=self._tokensetter,
grantsetter=self._grantsetter,
)
return Server(validator, token_expires_in=expires_in)
Expand Down Expand Up @@ -144,6 +151,19 @@ def bearer_token(access_token=None, refresh_token=None):
"""
self._tokengetter = f

def usergetter(self, f):
"""Register a function as the user getter.

This decorator is only required for password credential
authorization::

@oauth.usergetter
def get_user(username=username, password=password,
*args, **kwargs):
return get_user_by_username(username, password)
"""
self._usergetter = f

def tokensetter(self, f):
"""Register a function to save the bearer token.
"""
Expand Down Expand Up @@ -210,7 +230,7 @@ def decorated(*args, **kwargs):
kwargs.update(credentials)
return f(*args, **kwargs)
except oauth2.FatalClientError as e:
log.debug('Fatal client error')
log.debug('Fatal client error %r', e)
return redirect(e.in_uri(self.error_uri))

if request.method == 'POST':
Expand Down Expand Up @@ -308,10 +328,11 @@ class OAuth2RequestValidator(RequestValidator):
:param grantgetter: a function to get grant token
:param grantsetter: a function to save grant token
"""
def __init__(self, clientgetter, tokengetter, tokensetter,
grantgetter, grantsetter):
def __init__(self, clientgetter, tokengetter, grantgetter,
usergetter=None, tokensetter=None, grantsetter=None):
self._clientgetter = clientgetter
self._tokengetter = tokengetter
self._usergetter = usergetter
self._tokensetter = tokensetter
self._grantgetter = grantgetter
self._grantsetter = grantsetter
Expand All @@ -323,11 +344,14 @@ def authenticate_client(self, request, *args, **kwargs):

.. _`Section 3.2.1`: http://tools.ietf.org/html/rfc6749#section-3.2.1
"""
auth = request.headers.get('HTTP_AUTHORIZATION', None)
auth = request.headers.get('Http-Authorization', None)
log.debug('Authenticate client %r', auth)
if auth:
try:
_, base64 = auth.split(' ')
client_id, client_secret = base64.decode('base64').split(':')
client_id = to_unicode(client_id, 'utf-8')
client_secret = to_unicode(client_secret, 'utf-8')
except Exception as e:
log.debug('Authenticate client failed with exception: %r', e)
return False
Expand Down Expand Up @@ -409,18 +433,24 @@ def confirm_redirect_uri(self, client_id, code, redirect_uri, client,

def confirm_scopes(self, refresh_token, scopes, request, *args, **kwargs):
#TODO
log.debug('Confirm scopes %r for refresh token %r',
scopes, refresh_token)
tok = self._tokengetter(refresh_token=refresh_token)
return set(tok.scopes) == set(scopes)

def get_default_redirect_uri(self, client_id, request, *args, **kwargs):
"""Default redirect_uri for the given client."""
request.client = request.client or self._clientgetter(client_id)
return request.client.default_redirect_uri
redirect_uri = request.client.default_redirect_uri
log.debug('Found default redirect uri %r', redirect_uri)
return redirect_uri

def get_default_scopes(self, client_id, request, *args, **kwargs):
"""Default scopes for the given client."""
request.client = request.client or self._clientgetter(client_id)
return request.client.default_scopes
scopes = request.client.default_scopes
log.debug('Found default scopes %r', scopes)
return scopes

def invalidate_authorization_code(self, client_id, code, request,
*args, **kwargs):
Expand All @@ -429,6 +459,7 @@ def invalidate_authorization_code(self, client_id, code, request,
We keep the temporary code in a grant, which has a `delete`
function to destroy itself.
"""
log.debug('Destroy grant token for client %r, %r', client_id, code)
grant = self._grantgetter(client_id=client_id, code=code)
if grant:
grant.delete()
Expand Down Expand Up @@ -523,12 +554,25 @@ def validate_grant_type(self, client_id, grant_type, client, request,
It is suggested that `allowed_grant_types` should contain at least
`authorization_code` and `refresh_token`.
"""
if self._usergetter is None and grant_type == 'password':
log.debug('Password credential authorization is disabled.')
return False

if grant_type not in ('authorization_code', 'password',
'client_credentials', 'refresh_token'):
return False

if hasattr(client, 'allowed_grant_types'):
return grant_type in client.allowed_grant_types

if grant_type == 'client_credentials':
# TODO: other means
if hasattr(client, 'user'):
request.user = client.user
return True
log.debug('Client should has a user property')
return False

return True

def validate_redirect_uri(self, client_id, redirect_uri, request,
Expand Down Expand Up @@ -580,8 +624,22 @@ def validate_scopes(self, client_id, scopes, client, request,

def validate_user(self, username, password, client, request,
*args, **kwargs):
# TODO
pass
"""Ensure the username and password is valid.

Attach user object on request for later using.
"""
log.debug('Validating username %r and password %r',
username, password)
if self._usergetter is not None:
user = self._usergetter(
username, password, client, request, *args, **kwargs
)
if user:
request.user = user
return True
return False
log.debug('Password credential authorization is disabled.')
return False


def _extract_params():
Expand All @@ -594,8 +652,8 @@ def _extract_params():
del headers['wsgi.input']
if 'wsgi.errors' in headers:
del headers['wsgi.errors']
if 'HTTP_AUTHORIZATION' in headers:
headers['Authorization'] = headers['HTTP_AUTHORIZATION']
if 'Http-Authorization' in headers:
headers['Authorization'] = headers['Http-Authorization']

body = request.form.to_dict()
return uri, http_method, body, headers
32 changes: 29 additions & 3 deletions tests/oauth2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
db = SQLAlchemy()


def enable_log(name='flask_oauthlib'):
import logging
logger = logging.getLogger(name)
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.DEBUG)


class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.Unicode(40), unique=True, index=True,
Expand All @@ -27,6 +34,10 @@ class Client(db.Model):
_redirect_uris = db.Column(db.UnicodeText)
default_scope = db.Column(db.UnicodeText)

@property
def user(self):
return User.query.get(1)

@property
def redirect_uris(self):
if self._redirect_uris:
Expand Down Expand Up @@ -102,17 +113,26 @@ def prepare_app(app):
db.app = app
db.create_all()

client = Client(
client1 = Client(
name=u'dev', client_id=u'dev', client_secret=u'dev',
_redirect_uris=u'http://localhost:8000/authorized'
)

client2 = Client(
name=u'confidential', client_id=u'confidential',
client_secret=u'confidential', client_type=u'confidential',
_redirect_uris=u'http://localhost:8000/authorized'
)

user = User(username=u'admin')

try:
db.session.add(client1)
db.session.add(client2)
db.session.add(user)
db.session.add(client)
db.session.commit()
except:
pass
db.session.rollback()
return app


Expand Down Expand Up @@ -161,6 +181,12 @@ def set_token(token, request, *args, **kwargs):
db.session.add(tok)
db.session.commit()

@oauth.usergetter
def get_user(username, password, *args, **kwargs):
# This is optional, if you don't need password credential
# there is no need to implement this method
return User.query.get(1)

@app.before_request
def load_current_user():
user = User.query.get(1)
Expand Down
20 changes: 20 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from flask_oauthlib.client import encode_request_data, add_query


def test_encode_request_data():
data, _ = encode_request_data('foo', None)
assert data == 'foo'

data, f = encode_request_data(None, 'json')
assert data == '{}'
assert f == 'application/json'

data, f = encode_request_data(None, 'urlencoded')
assert data == ''
assert f == 'application/x-www-form-urlencoded'


def test_add_query():
assert 'path' == add_query('path', None)
assert 'path?foo=foo' == add_query('path', {'foo': 'foo'})
assert '?path&foo=foo' == add_query('?path', {'foo': 'foo'})
Loading