diff --git a/.travis.yml b/.travis.yml index e3af11da..ba96f5d2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,7 @@ script: after_success: - pip install coveralls - - coverage run --source=flask_oauthlib setup.py -q nosetests + - DEBUG=1 coverage run --source=flask_oauthlib setup.py -q nosetests - coveralls notifications: diff --git a/README.rst b/README.rst index d17b5318..1a3133bd 100644 --- a/README.rst +++ b/README.rst @@ -1,6 +1,11 @@ Flask-OAuthlib ============== +.. image:: https://travis-ci.org/lepture/flask-oauthlib.png?branch=master + :target: https://travis-ci.org/lepture/flask-oauthlib +.. image:: https://coveralls.io/repos/lepture/flask-oauthlib/badge.png?branch=master + :target: https://coveralls.io/r/lepture/flask-oauthlib + Flask-OAuthlib is an extension to Flask that allows you to interact with remote OAuth enabled applications. It is a replacement for Flask-OAuth. diff --git a/docs/api.rst b/docs/api.rst index ca325bc5..3f293c53 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -22,3 +22,15 @@ Client Reference .. autoclass:: OAuthException :members: + + +Provider Reference +------------------ + +.. module:: flask_oauthlib.provider + +.. autoclass:: OAuth2Provider + :members: + +.. autoclass:: OAuth2RequestValidator + :members: diff --git a/flask_oauthlib/client.py b/flask_oauthlib/client.py index a05a4439..e844ed15 100644 --- a/flask_oauthlib/client.py +++ b/flask_oauthlib/client.py @@ -84,7 +84,7 @@ def __getattr__(self, key): _etree = None -def get_etree(): +def get_etree(): # pragma: no cover global _etree if _etree is not None: return _etree diff --git a/flask_oauthlib/provider/oauth2.py b/flask_oauthlib/provider/oauth2.py index 418b2cb7..d4770165 100644 --- a/flask_oauthlib/provider/oauth2.py +++ b/flask_oauthlib/provider/oauth2.py @@ -19,6 +19,7 @@ from werkzeug import cached_property from oauthlib import oauth2 from oauthlib.oauth2 import RequestValidator, Server +from oauthlib.common import to_unicode __all__ = ('OAuth2Provider', 'OAuth2RequestValidator') @@ -84,11 +85,17 @@ def server(self): hasattr(self, '_tokensetter') and \ hasattr(self, '_grantgetter') and \ hasattr(self, '_grantsetter'): + + usergetter = None + if hasattr(self, '_usergetter'): + usergetter = self._usergetter + validator = OAuth2RequestValidator( clientgetter=self._clientgetter, tokengetter=self._tokengetter, - tokensetter=self._tokensetter, grantgetter=self._grantgetter, + usergetter=usergetter, + tokensetter=self._tokensetter, grantsetter=self._grantsetter, ) return Server(validator, token_expires_in=expires_in) @@ -144,6 +151,19 @@ def bearer_token(access_token=None, refresh_token=None): """ self._tokengetter = f + def usergetter(self, f): + """Register a function as the user getter. + + This decorator is only required for password credential + authorization:: + + @oauth.usergetter + def get_user(username=username, password=password, + *args, **kwargs): + return get_user_by_username(username, password) + """ + self._usergetter = f + def tokensetter(self, f): """Register a function to save the bearer token. """ @@ -210,7 +230,7 @@ def decorated(*args, **kwargs): kwargs.update(credentials) return f(*args, **kwargs) except oauth2.FatalClientError as e: - log.debug('Fatal client error') + log.debug('Fatal client error %r', e) return redirect(e.in_uri(self.error_uri)) if request.method == 'POST': @@ -308,10 +328,11 @@ class OAuth2RequestValidator(RequestValidator): :param grantgetter: a function to get grant token :param grantsetter: a function to save grant token """ - def __init__(self, clientgetter, tokengetter, tokensetter, - grantgetter, grantsetter): + def __init__(self, clientgetter, tokengetter, grantgetter, + usergetter=None, tokensetter=None, grantsetter=None): self._clientgetter = clientgetter self._tokengetter = tokengetter + self._usergetter = usergetter self._tokensetter = tokensetter self._grantgetter = grantgetter self._grantsetter = grantsetter @@ -323,11 +344,14 @@ def authenticate_client(self, request, *args, **kwargs): .. _`Section 3.2.1`: http://tools.ietf.org/html/rfc6749#section-3.2.1 """ - auth = request.headers.get('HTTP_AUTHORIZATION', None) + auth = request.headers.get('Http-Authorization', None) + log.debug('Authenticate client %r', auth) if auth: try: _, base64 = auth.split(' ') client_id, client_secret = base64.decode('base64').split(':') + client_id = to_unicode(client_id, 'utf-8') + client_secret = to_unicode(client_secret, 'utf-8') except Exception as e: log.debug('Authenticate client failed with exception: %r', e) return False @@ -409,18 +433,24 @@ def confirm_redirect_uri(self, client_id, code, redirect_uri, client, def confirm_scopes(self, refresh_token, scopes, request, *args, **kwargs): #TODO + log.debug('Confirm scopes %r for refresh token %r', + scopes, refresh_token) tok = self._tokengetter(refresh_token=refresh_token) return set(tok.scopes) == set(scopes) def get_default_redirect_uri(self, client_id, request, *args, **kwargs): """Default redirect_uri for the given client.""" request.client = request.client or self._clientgetter(client_id) - return request.client.default_redirect_uri + redirect_uri = request.client.default_redirect_uri + log.debug('Found default redirect uri %r', redirect_uri) + return redirect_uri def get_default_scopes(self, client_id, request, *args, **kwargs): """Default scopes for the given client.""" request.client = request.client or self._clientgetter(client_id) - return request.client.default_scopes + scopes = request.client.default_scopes + log.debug('Found default scopes %r', scopes) + return scopes def invalidate_authorization_code(self, client_id, code, request, *args, **kwargs): @@ -429,6 +459,7 @@ def invalidate_authorization_code(self, client_id, code, request, We keep the temporary code in a grant, which has a `delete` function to destroy itself. """ + log.debug('Destroy grant token for client %r, %r', client_id, code) grant = self._grantgetter(client_id=client_id, code=code) if grant: grant.delete() @@ -523,12 +554,25 @@ def validate_grant_type(self, client_id, grant_type, client, request, It is suggested that `allowed_grant_types` should contain at least `authorization_code` and `refresh_token`. """ + if self._usergetter is None and grant_type == 'password': + log.debug('Password credential authorization is disabled.') + return False + if grant_type not in ('authorization_code', 'password', 'client_credentials', 'refresh_token'): return False if hasattr(client, 'allowed_grant_types'): return grant_type in client.allowed_grant_types + + if grant_type == 'client_credentials': + # TODO: other means + if hasattr(client, 'user'): + request.user = client.user + return True + log.debug('Client should has a user property') + return False + return True def validate_redirect_uri(self, client_id, redirect_uri, request, @@ -580,8 +624,22 @@ def validate_scopes(self, client_id, scopes, client, request, def validate_user(self, username, password, client, request, *args, **kwargs): - # TODO - pass + """Ensure the username and password is valid. + + Attach user object on request for later using. + """ + log.debug('Validating username %r and password %r', + username, password) + if self._usergetter is not None: + user = self._usergetter( + username, password, client, request, *args, **kwargs + ) + if user: + request.user = user + return True + return False + log.debug('Password credential authorization is disabled.') + return False def _extract_params(): @@ -594,8 +652,8 @@ def _extract_params(): del headers['wsgi.input'] if 'wsgi.errors' in headers: del headers['wsgi.errors'] - if 'HTTP_AUTHORIZATION' in headers: - headers['Authorization'] = headers['HTTP_AUTHORIZATION'] + if 'Http-Authorization' in headers: + headers['Authorization'] = headers['Http-Authorization'] body = request.form.to_dict() return uri, http_method, body, headers diff --git a/tests/oauth2_server.py b/tests/oauth2_server.py index b3e3a121..983a7854 100644 --- a/tests/oauth2_server.py +++ b/tests/oauth2_server.py @@ -9,6 +9,13 @@ db = SQLAlchemy() +def enable_log(name='flask_oauthlib'): + import logging + logger = logging.getLogger(name) + logger.addHandler(logging.StreamHandler()) + logger.setLevel(logging.DEBUG) + + class User(db.Model): id = db.Column(db.Integer, primary_key=True) username = db.Column(db.Unicode(40), unique=True, index=True, @@ -27,6 +34,10 @@ class Client(db.Model): _redirect_uris = db.Column(db.UnicodeText) default_scope = db.Column(db.UnicodeText) + @property + def user(self): + return User.query.get(1) + @property def redirect_uris(self): if self._redirect_uris: @@ -102,17 +113,26 @@ def prepare_app(app): db.app = app db.create_all() - client = Client( + client1 = Client( name=u'dev', client_id=u'dev', client_secret=u'dev', _redirect_uris=u'http://localhost:8000/authorized' ) + + client2 = Client( + name=u'confidential', client_id=u'confidential', + client_secret=u'confidential', client_type=u'confidential', + _redirect_uris=u'http://localhost:8000/authorized' + ) + user = User(username=u'admin') + try: + db.session.add(client1) + db.session.add(client2) db.session.add(user) - db.session.add(client) db.session.commit() except: - pass + db.session.rollback() return app @@ -161,6 +181,12 @@ def set_token(token, request, *args, **kwargs): db.session.add(tok) db.session.commit() + @oauth.usergetter + def get_user(username, password, *args, **kwargs): + # This is optional, if you don't need password credential + # there is no need to implement this method + return User.query.get(1) + @app.before_request def load_current_user(): user = User.query.get(1) diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 00000000..358b65a5 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,20 @@ +from flask_oauthlib.client import encode_request_data, add_query + + +def test_encode_request_data(): + data, _ = encode_request_data('foo', None) + assert data == 'foo' + + data, f = encode_request_data(None, 'json') + assert data == '{}' + assert f == 'application/json' + + data, f = encode_request_data(None, 'urlencoded') + assert data == '' + assert f == 'application/x-www-form-urlencoded' + + +def test_add_query(): + assert 'path' == add_query('path', None) + assert 'path?foo=foo' == add_query('path', {'foo': 'foo'}) + assert '?path&foo=foo' == add_query('?path', {'foo': 'foo'}) diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index d8d25f7b..a57ee45d 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -2,13 +2,14 @@ import os import tempfile -from urlparse import urlparse +import unittest +from urlparse import urlparse from flask import Flask -from .oauth2_server import create_server +from .oauth2_server import create_server, db from .oauth2_client import create_client -class BaseSuite(object): +class BaseSuite(unittest.TestCase): def setUp(self): app = Flask(__name__) app.debug = True @@ -19,6 +20,7 @@ def setUp(self): config = { 'SQLALCHEMY_DATABASE_URI': 'sqlite:///%s' % self.db_file } + app.config.update(config) app = create_server(app) app = create_client(app) @@ -28,6 +30,9 @@ def setUp(self): return app def tearDown(self): + db.session.remove() + db.drop_all() + os.close(self.db_fd) os.unlink(self.db_file) @@ -38,7 +43,7 @@ def tearDown(self): ) -class TestAuth(BaseSuite): +class TestWebAuth(BaseSuite): def test_login(self): rv = self.client.get('/login') assert 'response_type=code' in rv.location @@ -72,11 +77,36 @@ def test_get_access_token(self): assert 'access_token' in rv.data def test_full_flow(self): - self.test_get_access_token() + rv = self.client.post(authorize_url, data={'confirm': 'yes'}) + rv = self.client.get(clean_url(rv.location)) + assert 'access_token' in rv.data + rv = self.client.get('/') assert 'username' in rv.data +class TestPasswordAuth(BaseSuite): + def test_get_access_token(self): + auth_code = 'confidential:confidential'.encode('base64').strip() + url = ('/oauth/access_token?grant_type=password' + '&scope=email+address&username=admin&password=admin') + rv = self.client.get(url, headers={ + 'HTTP_AUTHORIZATION': 'Basic %s' % auth_code, + }, data={'confirm': 'yes'}) + assert 'access_token' in rv.data + + +class TestCredentialAuth(BaseSuite): + def test_get_access_token(self): + auth_code = 'confidential:confidential'.encode('base64').strip() + url = ('/oauth/access_token?grant_type=client_credentials' + '&scope=email+address&username=admin&password=admin') + rv = self.client.get(url, headers={ + 'HTTP_AUTHORIZATION': 'Basic %s' % auth_code, + }, data={'confirm': 'yes'}) + assert 'access_token' in rv.data + + def clean_url(location): ret = urlparse(location) return '%s?%s' % (ret.path, ret.query)