Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions flask_oauthlib/provider/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,34 @@ def decorated(*args, **kwargs):
return response
return decorated

def refresh_token_handler(self, func):
pass
def refresh_token_handler(self, f):
"""Refresh token handler

The decorated function should return an dictionary or None as
the extra credentials for creating the token response.

You can control the access method with standard flask route mechanism.
If you only allow the `POST` method::

@app.route('/oauth/refresh_token')
@oauth.refresh_token_handler
def refresh_token():
return None
"""
@wraps(f)
def decorated(*args, **kwargs):
uri, http_method, body, headers = _extract_params()
credentials = f(*args, **kwargs) or {}
log.debug('Fetched extra credentials, %r.', credentials)
server = self.server
uri, headers, body, status = server.create_token_response(
uri, http_method, body, headers, credentials
)
response = make_response(body, status)
for k, v in headers.items():
response.headers[k] = v
return response
return decorated

def require_oauth(self, scopes=None):
"""Protect resource with specified scopes."""
Expand Down Expand Up @@ -432,7 +458,13 @@ def confirm_redirect_uri(self, client_id, code, redirect_uri, client,
return grant.redirect_uri == redirect_uri

def confirm_scopes(self, refresh_token, scopes, request, *args, **kwargs):
#TODO
"""Ensures the requested scope matches the scope originally granted
by the resource owner. If the scope is omitted it is treated as equal
to the scope originally granted by the resource owner
"""
if not scopes:
log.debug('Scope omitted for refresh token %r', refresh_token)
return True
log.debug('Confirm scopes %r for refresh token %r',
scopes, refresh_token)
tok = self._tokengetter(refresh_token=refresh_token)
Expand Down Expand Up @@ -592,8 +624,21 @@ def validate_redirect_uri(self, client_id, redirect_uri, request,

def validate_refresh_token(self, refresh_token, client, request,
*args, **kwargs):
# TODO
return True
"""Ensure the token is valid and belongs to the client

This method is used by the authorization code grant indirectly by
issuing refresh tokens, resource owner password credentials grant
(also indirectly) and the refresh token grant.
"""

token = self._tokengetter(refresh_token=refresh_token)

if token and token.client == client:
# Make sure the request object contains user and client_id
request.client_id = token.client.client_id
request.user = token.user
return True
return False

def validate_response_type(self, client_id, response_type, client, request,
*args, **kwargs):
Expand Down
22 changes: 17 additions & 5 deletions tests/oauth2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ class User(db.Model):


class Client(db.Model):
id = db.Column(db.Integer, primary_key=True)
#id = db.Column(db.Integer, primary_key=True)
# human readable name
name = db.Column(db.Unicode(40))
client_id = db.Column(db.Unicode(40), unique=True, index=True,
nullable=False)
client_id = db.Column(db.Unicode(40), primary_key=True)
client_secret = db.Column(db.Unicode(55), unique=True, index=True,
nullable=False)
client_type = db.Column(db.Unicode(20), default=u'public')
Expand Down Expand Up @@ -62,7 +61,11 @@ class Grant(db.Model):
)
user = relationship('User')

client_id = db.Column(db.Unicode(40), nullable=False)
client_id = db.Column(
db.Unicode(40), db.ForeignKey('client.client_id', ondelete='CASCADE'),
nullable=False,
)
client = relationship('Client')
code = db.Column(db.Unicode(255), index=True, nullable=False)

redirect_uri = db.Column(db.Unicode(255))
Expand All @@ -83,11 +86,15 @@ def scopes(self):

class Token(db.Model):
id = db.Column(db.Integer, primary_key=True)
client_id = db.Column(db.Unicode(40), nullable=False)
client_id = db.Column(
db.Unicode(40), db.ForeignKey('client.client_id', ondelete='CASCADE'),
nullable=False,
)
user_id = db.Column(
db.Integer, db.ForeignKey('user.id', ondelete='CASCADE')
)
user = relationship('User')
client = relationship('Client')
token_type = db.Column(db.Unicode(40))
access_token = db.Column(db.Unicode(255))
refresh_token = db.Column(db.Unicode(255))
Expand Down Expand Up @@ -212,6 +219,11 @@ def authorize(*args, **kwargs):
def access_token():
return {}

@app.route('/oauth/refresh_token')
@oauth.refresh_token_handler
def refresh_token():
return {}

@app.route('/api/email')
@oauth.require_oauth(['email'])
def email(data):
Expand Down
26 changes: 26 additions & 0 deletions tests/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import tempfile
import unittest
import json
from urlparse import urlparse
from flask import Flask
from .oauth2_server import create_server, db
Expand Down Expand Up @@ -96,6 +97,31 @@ def test_get_access_token(self):
assert 'access_token' in rv.data


class TestRefreshToken(BaseSuite):

def test_refresh_token(self):
auth_code = 'confidential:confidential'.encode('base64').strip()
url = ('/oauth/access_token?grant_type=password'
'&scope=email+address&username=admin&password=admin')
rv = self.client.get(url, headers={
'HTTP_AUTHORIZATION': 'Basic %s' % auth_code,
}, data={'confirm': 'yes'})
assert 'access_token' in rv.data

data = json.loads(rv.data)

args = (data.get('scope').replace(' ', '+'),
data.get('refresh_token'))
auth_code_r = 'confidential:confidential'.encode('base64').strip()
url_r = ('/oauth/refresh_token?grant_type=refresh_token'
'&scope={}&refresh_token={}&username=admin')
url_r = url_r.format(*args)
rv_r = self.client.get(url_r, headers={
'HTTP_AUTHORIZATION': 'Basic %s' % auth_code_r,
}, data={'confirm': 'yes'})
assert 'access_token' in rv_r.data


class TestCredentialAuth(BaseSuite):
def test_get_access_token(self):
auth_code = 'confidential:confidential'.encode('base64').strip()
Expand Down