From 9f113c689fcc70955afcedbfc0fb00f2fca276ba Mon Sep 17 00:00:00 2001 From: Randy Topliffe Date: Mon, 10 Jun 2013 16:32:26 -0400 Subject: [PATCH] Implemented refresh token validation and handler. Added tests for refresh token validator and handler. --- flask_oauthlib/provider/oauth2.py | 55 ++++++++++++++++++++++++++++--- tests/oauth2_server.py | 22 ++++++++++--- tests/test_oauth2.py | 26 +++++++++++++++ 3 files changed, 93 insertions(+), 10 deletions(-) diff --git a/flask_oauthlib/provider/oauth2.py b/flask_oauthlib/provider/oauth2.py index d4770165..29e350c7 100644 --- a/flask_oauthlib/provider/oauth2.py +++ b/flask_oauthlib/provider/oauth2.py @@ -299,8 +299,34 @@ def decorated(*args, **kwargs): return response return decorated - def refresh_token_handler(self, func): - pass + def refresh_token_handler(self, f): + """Refresh token handler + + The decorated function should return an dictionary or None as + the extra credentials for creating the token response. + + You can control the access method with standard flask route mechanism. + If you only allow the `POST` method:: + + @app.route('/oauth/refresh_token') + @oauth.refresh_token_handler + def refresh_token(): + return None + """ + @wraps(f) + def decorated(*args, **kwargs): + uri, http_method, body, headers = _extract_params() + credentials = f(*args, **kwargs) or {} + log.debug('Fetched extra credentials, %r.', credentials) + server = self.server + uri, headers, body, status = server.create_token_response( + uri, http_method, body, headers, credentials + ) + response = make_response(body, status) + for k, v in headers.items(): + response.headers[k] = v + return response + return decorated def require_oauth(self, scopes=None): """Protect resource with specified scopes.""" @@ -432,7 +458,13 @@ def confirm_redirect_uri(self, client_id, code, redirect_uri, client, return grant.redirect_uri == redirect_uri def confirm_scopes(self, refresh_token, scopes, request, *args, **kwargs): - #TODO + """Ensures the requested scope matches the scope originally granted + by the resource owner. If the scope is omitted it is treated as equal + to the scope originally granted by the resource owner + """ + if not scopes: + log.debug('Scope omitted for refresh token %r', refresh_token) + return True log.debug('Confirm scopes %r for refresh token %r', scopes, refresh_token) tok = self._tokengetter(refresh_token=refresh_token) @@ -592,8 +624,21 @@ def validate_redirect_uri(self, client_id, redirect_uri, request, def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs): - # TODO - return True + """Ensure the token is valid and belongs to the client + + This method is used by the authorization code grant indirectly by + issuing refresh tokens, resource owner password credentials grant + (also indirectly) and the refresh token grant. + """ + + token = self._tokengetter(refresh_token=refresh_token) + + if token and token.client == client: + # Make sure the request object contains user and client_id + request.client_id = token.client.client_id + request.user = token.user + return True + return False def validate_response_type(self, client_id, response_type, client, request, *args, **kwargs): diff --git a/tests/oauth2_server.py b/tests/oauth2_server.py index 983a7854..4ca14303 100644 --- a/tests/oauth2_server.py +++ b/tests/oauth2_server.py @@ -23,11 +23,10 @@ class User(db.Model): class Client(db.Model): - id = db.Column(db.Integer, primary_key=True) + #id = db.Column(db.Integer, primary_key=True) # human readable name name = db.Column(db.Unicode(40)) - client_id = db.Column(db.Unicode(40), unique=True, index=True, - nullable=False) + client_id = db.Column(db.Unicode(40), primary_key=True) client_secret = db.Column(db.Unicode(55), unique=True, index=True, nullable=False) client_type = db.Column(db.Unicode(20), default=u'public') @@ -62,7 +61,11 @@ class Grant(db.Model): ) user = relationship('User') - client_id = db.Column(db.Unicode(40), nullable=False) + client_id = db.Column( + db.Unicode(40), db.ForeignKey('client.client_id', ondelete='CASCADE'), + nullable=False, + ) + client = relationship('Client') code = db.Column(db.Unicode(255), index=True, nullable=False) redirect_uri = db.Column(db.Unicode(255)) @@ -83,11 +86,15 @@ def scopes(self): class Token(db.Model): id = db.Column(db.Integer, primary_key=True) - client_id = db.Column(db.Unicode(40), nullable=False) + client_id = db.Column( + db.Unicode(40), db.ForeignKey('client.client_id', ondelete='CASCADE'), + nullable=False, + ) user_id = db.Column( db.Integer, db.ForeignKey('user.id', ondelete='CASCADE') ) user = relationship('User') + client = relationship('Client') token_type = db.Column(db.Unicode(40)) access_token = db.Column(db.Unicode(255)) refresh_token = db.Column(db.Unicode(255)) @@ -212,6 +219,11 @@ def authorize(*args, **kwargs): def access_token(): return {} + @app.route('/oauth/refresh_token') + @oauth.refresh_token_handler + def refresh_token(): + return {} + @app.route('/api/email') @oauth.require_oauth(['email']) def email(data): diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index a57ee45d..94f43993 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -3,6 +3,7 @@ import os import tempfile import unittest +import json from urlparse import urlparse from flask import Flask from .oauth2_server import create_server, db @@ -96,6 +97,31 @@ def test_get_access_token(self): assert 'access_token' in rv.data +class TestRefreshToken(BaseSuite): + + def test_refresh_token(self): + auth_code = 'confidential:confidential'.encode('base64').strip() + url = ('/oauth/access_token?grant_type=password' + '&scope=email+address&username=admin&password=admin') + rv = self.client.get(url, headers={ + 'HTTP_AUTHORIZATION': 'Basic %s' % auth_code, + }, data={'confirm': 'yes'}) + assert 'access_token' in rv.data + + data = json.loads(rv.data) + + args = (data.get('scope').replace(' ', '+'), + data.get('refresh_token')) + auth_code_r = 'confidential:confidential'.encode('base64').strip() + url_r = ('/oauth/refresh_token?grant_type=refresh_token' + '&scope={}&refresh_token={}&username=admin') + url_r = url_r.format(*args) + rv_r = self.client.get(url_r, headers={ + 'HTTP_AUTHORIZATION': 'Basic %s' % auth_code_r, + }, data={'confirm': 'yes'}) + assert 'access_token' in rv_r.data + + class TestCredentialAuth(BaseSuite): def test_get_access_token(self): auth_code = 'confidential:confidential'.encode('base64').strip()