diff --git a/flask_oauthlib/provider/oauth2.py b/flask_oauthlib/provider/oauth2.py index d4770165..4c29a7ad 100644 --- a/flask_oauthlib/provider/oauth2.py +++ b/flask_oauthlib/provider/oauth2.py @@ -84,17 +84,19 @@ def server(self): hasattr(self, '_tokengetter') and \ hasattr(self, '_tokensetter') and \ hasattr(self, '_grantgetter') and \ - hasattr(self, '_grantsetter'): + hasattr(self, '_grantsetter') and \ + hasattr(self, '_usergetter'): - usergetter = None - if hasattr(self, '_usergetter'): - usergetter = self._usergetter + usernamegetter = None + if hasattr(self, '_usernamegetter'): + usernamegetter = self._usernamegetter validator = OAuth2RequestValidator( clientgetter=self._clientgetter, tokengetter=self._tokengetter, grantgetter=self._grantgetter, - usergetter=usergetter, + usergetter=self._usergetter, + usernamegetter=usernamegetter, tokensetter=self._tokensetter, grantsetter=self._grantsetter, ) @@ -151,21 +153,53 @@ def bearer_token(access_token=None, refresh_token=None): """ self._tokengetter = f - def usergetter(self, f): - """Register a function as the user getter. + def usernamegetter(self, f): + """Register a function as the username getter. This decorator is only required for password credential authorization:: - @oauth.usergetter - def get_user(username=username, password=password, + @oauth.usernamegetter + def get_user(username=None, password=None, *args, **kwargs): return get_user_by_username(username, password) """ + self._usernamegetter = f + + def usergetter(self, f): + """Register a function as the user getter. + + The function is used to retrieve the currently logged + in user:: + + @oauth.usergetter + def get_user(): + return current_user + """ self._usergetter = f def tokensetter(self, f): """Register a function to save the bearer token. + + The function accepts the folllowing parameters: + + - access_token: A string token + - token_type: A string describing the type of token provided + - scopes: A list of scopes for the token + - client: The client object + - user: The user object + - expires_in: Number of seconds until token should expire + - refresh_token: Either a string token or None + + Implement the token setter:: + + @oauth.tokensetter + def set_token(access_token, scopes, client, user, + expires_in, refresh_token): + expires_at = datetime.now() + timedelta(seconds=expires_in) + save_token(access_token, scopes, client, user, + expires_at, refresh_token) + """ self._tokensetter = f @@ -180,6 +214,8 @@ def grant(client_id, code): It returns a grant object with at least these information: + - user: The user for the grant + - scopes: A list of scopes provided by the grant - delete: A function to delete itself """ self._grantgetter = f @@ -187,11 +223,20 @@ def grant(client_id, code): def grantsetter(self, f): """Register a function to save the grant code. - The function accepts `client_id`, `code`, `request` and more:: + The function accepts the following arguments: + + - code: A string authorization code + - redirect_uri: A string representing the redirect uri + - user: The user object + - client: The client object + - scopes: A list of scopes for the code + - state: A string provided to prevent CSS + + Here is an exmample implementation of the setter:: @oauth.grantsetter - def set_grant(client_id, code, request, *args, **kwargs): - save_grant(client_id, code, request.user, request.scopes) + def set_grant(code, redirect_uri, user, client, scopes, state): + save_grant(code, redirect_uri, user, client, scopes, state) """ self._grantsetter = f @@ -329,10 +374,12 @@ class OAuth2RequestValidator(RequestValidator): :param grantsetter: a function to save grant token """ def __init__(self, clientgetter, tokengetter, grantgetter, - usergetter=None, tokensetter=None, grantsetter=None): + usergetter=None, usernamegetter=None, tokensetter=None, + grantsetter=None): self._clientgetter = clientgetter self._tokengetter = tokengetter self._usergetter = usergetter + self._usernamegetter = usernamegetter self._tokensetter = tokensetter self._grantgetter = grantgetter self._grantsetter = grantsetter @@ -472,13 +519,29 @@ def save_authorization_code(self, client_id, code, request, code, client_id ) request.client = request.client or self._clientgetter(client_id) - self._grantsetter(client_id, code, request, *args, **kwargs) + request.user = request.user or self._usergetter() + self._grantsetter( + code=code['code'], + redirect_uri=request.redirect_uri, + user=request.user, + client=request.client, + scopes=request.scopes, + state=request.state + ) return request.client.default_redirect_uri def save_bearer_token(self, token, request, *args, **kwargs): """Persist the Bearer token.""" log.debug('Save bearer token %r', token) - self._tokensetter(token, request, *args, **kwargs) + self._tokensetter( + access_token=token['access_token'], + token_type=token['token_type'], + scopes=request.scopes, + client=request.client, + user=request.user, + expires_in=token['expires_in'], + refresh_token=token.get('refresh_token', None) + ) return request.client.default_redirect_uri def validate_bearer_token(self, token, scopes, request): @@ -537,7 +600,7 @@ def validate_code(self, client_id, code, client, request, *args, **kwargs): datetime.datetime.utcnow() > grant.expires: log.debug('Grant is expired.') return False - request.state = kwargs.get('state') + request.state = grant.state request.user = grant.user request.scopes = grant.scopes return True @@ -554,7 +617,7 @@ def validate_grant_type(self, client_id, grant_type, client, request, It is suggested that `allowed_grant_types` should contain at least `authorization_code` and `refresh_token`. """ - if self._usergetter is None and grant_type == 'password': + if self._usernamegetter is None and grant_type == 'password': log.debug('Password credential authorization is disabled.') return False @@ -630,8 +693,8 @@ def validate_user(self, username, password, client, request, """ log.debug('Validating username %r and password %r', username, password) - if self._usergetter is not None: - user = self._usergetter( + if self._usernamegetter is not None: + user = self._usernamegetter( username, password, client, request, *args, **kwargs ) if user: diff --git a/tests/oauth2_server.py b/tests/oauth2_server.py index 983a7854..3345a2f3 100644 --- a/tests/oauth2_server.py +++ b/tests/oauth2_server.py @@ -66,6 +66,7 @@ class Grant(db.Model): code = db.Column(db.Unicode(255), index=True, nullable=False) redirect_uri = db.Column(db.Unicode(255)) + state = db.Column(db.Unicode(255)) scope = db.Column(db.UnicodeText) expires = db.Column(db.DateTime) @@ -158,35 +159,49 @@ def get_token(access_token=None, refresh_token=None): return None @oauth.grantsetter - def set_grant(client_id, code, request, *args, **kwargs): + def set_grant(code, redirect_uri, user, client, scopes, state): expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=100) grant = Grant( - client_id=client_id, - code=code['code'], - redirect_uri=request.redirect_uri, - scope=' '.join(request.scopes), - user_id=g.user.id, + client_id=client.client_id, + code=code, + redirect_uri=redirect_uri, + scope=' '.join(scopes), + user_id=user.id, expires=expires, + state=state, ) db.session.add(grant) db.session.commit() @oauth.tokensetter - def set_token(token, request, *args, **kwargs): + def set_token(access_token, token_type, scopes, client, user, + expires_in, refresh_token): # In real project, a token is unique bound to user and client. # Which means, you don't need to create a token every time. - tok = Token(**token) - tok.user_id = request.user.id - tok.client_id = request.client.client_id + tok = Token( + access_token=access_token, + token_type=token_type, + expires_in=expires_in, + refresh_token=refresh_token, + scope=' '.join(scopes) + ) + tok.user_id = user.id + tok.client_id = client.client_id db.session.add(tok) db.session.commit() - @oauth.usergetter + @oauth.usernamegetter def get_user(username, password, *args, **kwargs): # This is optional, if you don't need password credential # there is no need to implement this method return User.query.get(1) + @oauth.usergetter + def get_active_user(): + # This is used to retrieve the user object on the currently + # logged in user. + return User.query.get(1) + @app.before_request def load_current_user(): user = User.query.get(1)