Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 82 additions & 19 deletions flask_oauthlib/provider/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,19 @@ def server(self):
hasattr(self, '_tokengetter') and \
hasattr(self, '_tokensetter') and \
hasattr(self, '_grantgetter') and \
hasattr(self, '_grantsetter'):
hasattr(self, '_grantsetter') and \
hasattr(self, '_usergetter'):

usergetter = None
if hasattr(self, '_usergetter'):
usergetter = self._usergetter
usernamegetter = None
if hasattr(self, '_usernamegetter'):
usernamegetter = self._usernamegetter

validator = OAuth2RequestValidator(
clientgetter=self._clientgetter,
tokengetter=self._tokengetter,
grantgetter=self._grantgetter,
usergetter=usergetter,
usergetter=self._usergetter,
usernamegetter=usernamegetter,
tokensetter=self._tokensetter,
grantsetter=self._grantsetter,
)
Expand Down Expand Up @@ -151,21 +153,53 @@ def bearer_token(access_token=None, refresh_token=None):
"""
self._tokengetter = f

def usergetter(self, f):
"""Register a function as the user getter.
def usernamegetter(self, f):
"""Register a function as the username getter.

This decorator is only required for password credential
authorization::

@oauth.usergetter
def get_user(username=username, password=password,
@oauth.usernamegetter
def get_user(username=None, password=None,
*args, **kwargs):
return get_user_by_username(username, password)
"""
self._usernamegetter = f

def usergetter(self, f):
"""Register a function as the user getter.

The function is used to retrieve the currently logged
in user::

@oauth.usergetter
def get_user():
return current_user
"""
self._usergetter = f

def tokensetter(self, f):
"""Register a function to save the bearer token.

The function accepts the folllowing parameters:

- access_token: A string token
- token_type: A string describing the type of token provided
- scopes: A list of scopes for the token
- client: The client object
- user: The user object
- expires_in: Number of seconds until token should expire
- refresh_token: Either a string token or None

Implement the token setter::

@oauth.tokensetter
def set_token(access_token, scopes, client, user,
expires_in, refresh_token):
expires_at = datetime.now() + timedelta(seconds=expires_in)
save_token(access_token, scopes, client, user,
expires_at, refresh_token)

"""
self._tokensetter = f

Expand All @@ -180,18 +214,29 @@ def grant(client_id, code):

It returns a grant object with at least these information:

- user: The user for the grant
- scopes: A list of scopes provided by the grant
- delete: A function to delete itself
"""
self._grantgetter = f

def grantsetter(self, f):
"""Register a function to save the grant code.

The function accepts `client_id`, `code`, `request` and more::
The function accepts the following arguments:

- code: A string authorization code
- redirect_uri: A string representing the redirect uri
- user: The user object
- client: The client object
- scopes: A list of scopes for the code
- state: A string provided to prevent CSS

Here is an exmample implementation of the setter::

@oauth.grantsetter
def set_grant(client_id, code, request, *args, **kwargs):
save_grant(client_id, code, request.user, request.scopes)
def set_grant(code, redirect_uri, user, client, scopes, state):
save_grant(code, redirect_uri, user, client, scopes, state)
"""
self._grantsetter = f

Expand Down Expand Up @@ -329,10 +374,12 @@ class OAuth2RequestValidator(RequestValidator):
:param grantsetter: a function to save grant token
"""
def __init__(self, clientgetter, tokengetter, grantgetter,
usergetter=None, tokensetter=None, grantsetter=None):
usergetter=None, usernamegetter=None, tokensetter=None,
grantsetter=None):
self._clientgetter = clientgetter
self._tokengetter = tokengetter
self._usergetter = usergetter
self._usernamegetter = usernamegetter
self._tokensetter = tokensetter
self._grantgetter = grantgetter
self._grantsetter = grantsetter
Expand Down Expand Up @@ -472,13 +519,29 @@ def save_authorization_code(self, client_id, code, request,
code, client_id
)
request.client = request.client or self._clientgetter(client_id)
self._grantsetter(client_id, code, request, *args, **kwargs)
request.user = request.user or self._usergetter()
self._grantsetter(
code=code['code'],
redirect_uri=request.redirect_uri,
user=request.user,
client=request.client,
scopes=request.scopes,
state=request.state
)
return request.client.default_redirect_uri

def save_bearer_token(self, token, request, *args, **kwargs):
"""Persist the Bearer token."""
log.debug('Save bearer token %r', token)
self._tokensetter(token, request, *args, **kwargs)
self._tokensetter(
access_token=token['access_token'],
token_type=token['token_type'],
scopes=request.scopes,
client=request.client,
user=request.user,
expires_in=token['expires_in'],
refresh_token=token.get('refresh_token', None)
)
return request.client.default_redirect_uri

def validate_bearer_token(self, token, scopes, request):
Expand Down Expand Up @@ -537,7 +600,7 @@ def validate_code(self, client_id, code, client, request, *args, **kwargs):
datetime.datetime.utcnow() > grant.expires:
log.debug('Grant is expired.')
return False
request.state = kwargs.get('state')
request.state = grant.state
request.user = grant.user
request.scopes = grant.scopes
return True
Expand All @@ -554,7 +617,7 @@ def validate_grant_type(self, client_id, grant_type, client, request,
It is suggested that `allowed_grant_types` should contain at least
`authorization_code` and `refresh_token`.
"""
if self._usergetter is None and grant_type == 'password':
if self._usernamegetter is None and grant_type == 'password':
log.debug('Password credential authorization is disabled.')
return False

Expand Down Expand Up @@ -630,8 +693,8 @@ def validate_user(self, username, password, client, request,
"""
log.debug('Validating username %r and password %r',
username, password)
if self._usergetter is not None:
user = self._usergetter(
if self._usernamegetter is not None:
user = self._usernamegetter(
username, password, client, request, *args, **kwargs
)
if user:
Expand Down
37 changes: 26 additions & 11 deletions tests/oauth2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Grant(db.Model):
code = db.Column(db.Unicode(255), index=True, nullable=False)

redirect_uri = db.Column(db.Unicode(255))
state = db.Column(db.Unicode(255))
scope = db.Column(db.UnicodeText)
expires = db.Column(db.DateTime)

Expand Down Expand Up @@ -158,35 +159,49 @@ def get_token(access_token=None, refresh_token=None):
return None

@oauth.grantsetter
def set_grant(client_id, code, request, *args, **kwargs):
def set_grant(code, redirect_uri, user, client, scopes, state):
expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=100)
grant = Grant(
client_id=client_id,
code=code['code'],
redirect_uri=request.redirect_uri,
scope=' '.join(request.scopes),
user_id=g.user.id,
client_id=client.client_id,
code=code,
redirect_uri=redirect_uri,
scope=' '.join(scopes),
user_id=user.id,
expires=expires,
state=state,
)
db.session.add(grant)
db.session.commit()

@oauth.tokensetter
def set_token(token, request, *args, **kwargs):
def set_token(access_token, token_type, scopes, client, user,
expires_in, refresh_token):
# In real project, a token is unique bound to user and client.
# Which means, you don't need to create a token every time.
tok = Token(**token)
tok.user_id = request.user.id
tok.client_id = request.client.client_id
tok = Token(
access_token=access_token,
token_type=token_type,
expires_in=expires_in,
refresh_token=refresh_token,
scope=' '.join(scopes)
)
tok.user_id = user.id
tok.client_id = client.client_id
db.session.add(tok)
db.session.commit()

@oauth.usergetter
@oauth.usernamegetter
def get_user(username, password, *args, **kwargs):
# This is optional, if you don't need password credential
# there is no need to implement this method
return User.query.get(1)

@oauth.usergetter
def get_active_user():
# This is used to retrieve the user object on the currently
# logged in user.
return User.query.get(1)

@app.before_request
def load_current_user():
user = User.query.get(1)
Expand Down