Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pkce implementation #2510

Merged
merged 26 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/install.pip
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
authlib==0.10
authlib==0.14.3
awesome-slugify==1.6.5
Babel==2.8.0
bcrypt==3.1.7
Expand Down
24 changes: 24 additions & 0 deletions udata/api/commands.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import logging
import os
import time

import click

from werkzeug.security import gen_salt
from flask import json, current_app
from flask_restplus import schemas

from udata.api import api
from udata.commands import cli, success, exit_with_error
from udata.models import User
from udata.api.oauth2 import OAuth2Client

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,3 +63,23 @@ def validate():
success('API specifications are valid')
except schemas.SchemaValidationError as e:
exit_with_error('API specifications are not valid', e)


@grp.command()
@click.option('-u', '--user-email', help='User\'s email')
@click.option('--uri', default='http://localhost:8080/login', help='Client\'s redirect uri')
quaxsze marked this conversation as resolved.
Show resolved Hide resolved
def create_oauth_client(user_email, uri):
'''Creates an OAuth2Client instance in DB'''
user = User.objects(email=user_email).first()
if user is None:
exit_with_error('No matching user to email')

client = OAuth2Client.objects.create(
name='test-client',
quaxsze marked this conversation as resolved.
Show resolved Hide resolved
owner=user,
redirect_uris=[uri]
quaxsze marked this conversation as resolved.
Show resolved Hide resolved
)

click.echo(f'New OAuth client')
click.echo(f'Client ID {client.id}')
click.echo(f'Client secret {client.secret}')
quaxsze marked this conversation as resolved.
Show resolved Hide resolved
175 changes: 101 additions & 74 deletions udata/api/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

from datetime import datetime, timedelta

from authlib.common.security import generate_token
from authlib.flask.error import _HTTPException as AuthlibFlaskException
from authlib.flask.oauth2 import AuthorizationServer, ResourceProtector
from authlib.specs.rfc6749 import grants, ClientMixin
from authlib.specs.rfc6750 import BearerTokenValidator
from authlib.specs.rfc7009 import RevocationEndpoint
from flask import abort, request
from authlib.integrations.flask_oauth2.errors import _HTTPException as AuthlibFlaskException
from authlib.integrations.flask_oauth2 import AuthorizationServer, ResourceProtector
from authlib.oauth2.rfc6749 import grants, ClientMixin
from authlib.oauth2.rfc6750 import BearerTokenValidator
from authlib.oauth2.rfc7009 import RevocationEndpoint
from authlib.oauth2.rfc7636 import CodeChallenge
from authlib.oauth2.rfc6749.util import scope_to_list, list_to_scope
from authlib.oauth2 import OAuth2Error
from flask import request
from flask_security.utils import verify_password
from werkzeug.exceptions import Unauthorized
from werkzeug.security import gen_salt
Expand Down Expand Up @@ -70,7 +72,7 @@ class OAuth2Client(ClientMixin, db.Datetimed, db.Document):
thumbnails=[150, 25])

redirect_uris = db.ListField(db.StringField())
scopes = db.ListField(db.StringField(), default=['default'])
scope = db.StringField(default='')
abulte marked this conversation as resolved.
Show resolved Hide resolved

confidential = db.BooleanField(default=False)
internal = db.BooleanField(default=False)
Expand All @@ -79,6 +81,9 @@ class OAuth2Client(ClientMixin, db.Datetimed, db.Document):
'collection': 'oauth2_client'
}

def get_client_id(self):
quaxsze marked this conversation as resolved.
Show resolved Hide resolved
return str(self.id)

@property
def client_id(self):
return str(self.id)
Expand All @@ -95,6 +100,12 @@ def get_default_redirect_uri(self):
'''Implement required ClientMixin method'''
return self.default_redirect_uri

def get_allowed_scope(self, scope):
if not scope:
return ''
allowed = set(scope_to_list(self.scope))
return list_to_scope([s for s in scope.split() if s in allowed])

def check_redirect_uri(self, redirect_uri):
'''Implement required ClientMixin method'''
return redirect_uri in self.redirect_uris
Expand All @@ -113,42 +124,14 @@ def check_response_type(self, response_type):
def check_grant_type(self, grant_type):
return True

def check_requested_scopes(self, scopes):
allowed = set(self.scopes)
return allowed.issuperset(set(scopes))
def check_requested_scope(self, scope):
allowed = set(self.scope)
return allowed.issuperset(set(scope))

def has_client_secret(self):
return bool(self.secret)


class OAuth2Grant(db.Document):
user = db.ReferenceField('User', required=True)
client = db.ReferenceField('OAuth2Client', required=True)

code = db.StringField(required=True)

redirect_uri = db.StringField()
expires = db.DateTimeField()

scopes = db.ListField(db.StringField())

meta = {
'collection': 'oauth2_grant'
}

def __str__(self):
return '<OAuth2Grant({0.client.name}, {0.user.fullname})>'.format(self)

def is_expired(self):
return self.expires < datetime.utcnow()

def get_redirect_uri(self):
return self.redirect_uri

def get_scope(self):
return ' '.join(self.scopes)


class OAuth2Token(db.Document):
client = db.ReferenceField('OAuth2Client', required=True)
user = db.ReferenceField('User')
Expand All @@ -160,7 +143,8 @@ class OAuth2Token(db.Document):
refresh_token = db.StringField(unique=True, sparse=True)
created_at = db.DateTimeField(default=datetime.utcnow, required=True)
expires_in = db.IntField(required=True, default=TOKEN_EXPIRATION)
scopes = db.ListField(db.StringField())
scope = db.StringField(default='')
revoked = db.BooleanField(default=False)

meta = {
'collection': 'oauth2_token'
Expand All @@ -170,44 +154,81 @@ def __str__(self):
return '<OAuth2Token({0.client.name})>'.format(self)

def get_scope(self):
return ' '.join(self.scopes)
return self.scope

def get_expires_in(self):
return self.expires_in

def get_expires_at(self):
return (self.created_at - EPOCH).total_seconds() + self.expires_in

def is_refresh_token_expired(self):
def get_client_id(self):
abulte marked this conversation as resolved.
Show resolved Hide resolved
return str(self.client.id)

def is_refresh_token_valid(self):
if self.revoked:
return False
expired_at = datetime.fromtimestamp(self.get_expires_at())
expired_at += timedelta(days=REFRESH_EXPIRATION)
return expired_at < datetime.utcnow()
return expired_at > datetime.utcnow()


class OAuth2Code(db.Document):
user = db.ReferenceField('User', required=True)
client = db.ReferenceField('OAuth2Client', required=True)

code = db.StringField(required=True)

redirect_uri = db.StringField()
expires = db.DateTimeField()

scope = db.StringField(default='')
code_challenge = db.StringField()
code_challenge_method = db.StringField()

meta = {
'collection': 'oauth2_code'
}

def __str__(self):
return '<OAuth2Code({0.client.name}, {0.user.fullname})>'.format(self)

def is_expired(self):
return self.expires < datetime.utcnow()

def get_redirect_uri(self):
return self.redirect_uri

def get_scope(self):
return self.scope


class AuthorizationCodeGrant(grants.AuthorizationCodeGrant):
TOKEN_ENDPOINT_AUTH_METHODS = [
'client_secret_basic',
'client_secret_post',
'client_secret_post'
]

def create_authorization_code(self, client, grant_user, request):
code = generate_token(48)
def save_authorization_code(self, code, request):
code_challenge = request.data.get('code_challenge')
code_challenge_method = request.data.get('code_challenge_method')
expires = datetime.utcnow() + timedelta(seconds=GRANT_EXPIRATION)
scopes = request.scope.split(' ') if request.scope else client.scopes
OAuth2Grant.objects.create(
auth_code = OAuth2Code.objects.create(
code=code,
client=client,
client=ObjectId(request.client.client_id),
redirect_uri=request.redirect_uri,
scopes=scopes,
user=ObjectId(grant_user.id),
scope=request.scope,
user=ObjectId(request.user.id),
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
expires=expires,
)
return code
return auth_code

def parse_authorization_code(self, code, client):
item = OAuth2Grant.objects(code=code, client=client).first()
if item and not item.is_expired():
return item
def query_authorization_code(self, code, client):
auth_code = OAuth2Code.objects(code=code, client=client).first()
if auth_code and not auth_code.is_expired():
return auth_code

def delete_authorization_code(self, authorization_code):
authorization_code.delete()
Expand All @@ -225,26 +246,32 @@ def authenticate_user(self, username, password):

class RefreshTokenGrant(grants.RefreshTokenGrant):
def authenticate_refresh_token(self, refresh_token):
token = OAuth2Token.objects(refresh_token=refresh_token).first()
if token and not token.is_refresh_token_expired():
return token
item = OAuth2Token.objects(refresh_token=refresh_token).first()
if item and item.is_refresh_token_valid():
return item

def authenticate_user(self, credential):
return credential.user

def revoke_old_credential(self, credential):
credential.revoked = True
credential.save()


class RevokeToken(RevocationEndpoint):
def query_token(self, token, token_type_hint, client):
qs = OAuth2Token.objects(client=client)
if token_type_hint:
qs = qs(**{token_type_hint: token})
if token_type_hint == 'access_token':
return qs.filter(access_token=token).first()
elif token_type_hint == 'refresh_token':
return qs.filter(refresh_token=token).first()
else:
qs = qs(db.Q(access_token=token) | db.Q(refresh_token=token))
return qs.first()
return qs.first()

def revoke_token(self, token):
# TODO: mark token as revoked
token.delete()
token.revoked = True
token.save()


class BearerToken(BearerTokenValidator):
Expand All @@ -255,8 +282,7 @@ def request_invalid(self, request):
return False

def token_revoked(self, token):
# TODO: return token.revoked
return False
return token.revoked


@blueprint.route('/token', methods=['POST'], localize=False, endpoint='token')
Expand All @@ -266,6 +292,7 @@ def access_token():


@blueprint.route('/revoke', methods=['POST'], localize=False)
@csrf.exempt
def revoke_token():
return oauth.create_endpoint_response(RevokeToken.ENDPOINT_NAME)

Expand All @@ -274,7 +301,10 @@ def revoke_token():
@login_required
def authorize(*args, **kwargs):
if request.method == 'GET':
grant = oauth.validate_consent_request(end_user=current_user)
try:
grant = oauth.validate_consent_request(end_user=current_user)
except OAuth2Error as error:
return error.error
# Bypass authorization screen for internal clients
if grant.client.internal:
return oauth.create_authorization_response(grant_user=current_user)
Expand All @@ -287,8 +317,6 @@ def authorize(*args, **kwargs):
else:
grant_user = None
return oauth.create_authorization_response(grant_user=grant_user)
else:
abort(405)
abulte marked this conversation as resolved.
Show resolved Hide resolved


@blueprint.route('/error')
Expand All @@ -302,17 +330,17 @@ def query_client(client_id):


def save_token(token, request):
scopes = token.pop('scope', '').split(' ')
scope = token.pop('scope', '')
if request.grant_type == 'refresh_token':
credential = request.credential
credential.update(scopes=scopes, **token)
credential.update(scope=scope, **token)
else:
client = request.client
user = request.user or client.owner
OAuth2Token.objects.create(
client=client,
user=user.id,
scopes=scopes,
scope=scope,
**token
)

Expand All @@ -330,11 +358,10 @@ def init_app(app):
oauth.init_app(app, query_client=query_client, save_token=save_token)

# support all grants
oauth.register_grant(AuthorizationCodeGrant)
oauth.register_grant(AuthorizationCodeGrant, [CodeChallenge(required=True)])
oauth.register_grant(PasswordGrant)
oauth.register_grant(RefreshTokenGrant)
oauth.register_grant(grants.ClientCredentialsGrant)
oauth.register_grant(grants.ImplicitGrant)

# support revocation endpoint
oauth.register_endpoint(RevokeToken)
Expand Down