Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented the oauth2 integration between FastAPI and Authlib #278

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
d9c2d03
Implementing the oauth2 integration between FastAPI and Authlib
gmachado-nextreason Oct 7, 2020
4e32028
Remove the deprecated code and code linter
gmachado-nextreason Oct 15, 2020
3e30635
Implemented the authorization code grant pytests
gmachado-nextreason Oct 16, 2020
4729d67
Merge remote-tracking branch 'upstream/master' into master
gmachado-nextreason Oct 16, 2020
cbef5fb
Implemented the client credentials grant and registration pytests for…
gmachado-nextreason Oct 19, 2020
30dad2f
Implemented the implicit grant pytest for FastAPI integration
gmachado-nextreason Oct 20, 2020
340cabd
Implemented the introspection and jwt bearer pytests for FastAPI inte…
gmachado-nextreason Oct 20, 2020
a167172
Implemented the oauth2 server pytests for FastAPI integration
gmachado-nextreason Oct 20, 2020
b826d29
Implemented the openid code grant pytests for FastAPI integration
gmachado-nextreason Oct 21, 2020
fc34de8
Implemented the openid hybrid grant pytests for FastAPI integration
gmachado-nextreason Oct 21, 2020
786259e
Implemented the openid implict grant pytests for FastAPI integration
gmachado-nextreason Oct 21, 2020
3414d82
Implemented the password grant pytests for FastAPI integration
gmachado-nextreason Oct 21, 2020
a84ca83
Implemented the refresh token pytests for FastAPI integration
gmachado-nextreason Oct 21, 2020
ad11669
Implemented the token revocation pytests for FastAPI integration
gmachado-nextreason Oct 21, 2020
1fdb53f
Merged the database into models file
gmachado-nextreason Oct 22, 2020
1a5f4c1
Merge remote-tracking branch 'upstream/master'
gmachado-nextreason Nov 13, 2020
8988b40
Merge remote-tracking branch 'upstream/master'
gmachado-nextreason Dec 11, 2020
cf1d6d3
Adding fastapi tests run to `make tox` and to github pipeline
Alexey-Unosquare Dec 15, 2020
8666fe4
Integrating refactored authlib into fastapi integration
Alexey-Unosquare Dec 16, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:

- name: Test with tox ${{ matrix.python.toxenv }}
env:
TOXENV: py,flask,django,starlette
TOXENV: py,flask,django,fastapi,starlette
run: tox

- name: Report coverage
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
*.pyo
*.egg-info
*.swp
*.db
__pycache__
build
develop-eggs
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
clean: clean-build clean-pyc clean-docs clean-tox

tests:
@TOXENV=py,flask,django,coverage tox
@TOXENV=py,flask,django,fastapi,coverage tox

clean-build:
@rm -fr build/
Expand Down
4 changes: 4 additions & 0 deletions authlib/integrations/fastapi_oauth2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""FastAPI package implementation."""

from .authorization_server import AuthorizationServer
from .resource_protector import ResourceProtector
145 changes: 145 additions & 0 deletions authlib/integrations/fastapi_oauth2/authorization_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Implementation of authlib.oauth2.rfc6749.AuthorizationServer class for FastAPI."""

import json

from authlib.common.security import generate_token
from authlib.oauth2 import AuthorizationServer as _AuthorizationServer
from authlib.oauth2 import HttpRequest, OAuth2Request
from authlib.oauth2.rfc6750 import BearerToken
from authlib.oauth2.rfc8414 import AuthorizationServerMetadata
from fastapi.responses import JSONResponse
from werkzeug.utils import import_string


class AuthorizationServer(_AuthorizationServer):
"""AuthorizationServer class."""

def __init__(self, app=None, query_client=None, save_token=None):
super(AuthorizationServer, self).__init__()
self._query_client = query_client
self._save_token = save_token
self.config = {}
if app:
self.init_app(app)

def init_app(self, app, query_client=None, save_token=None):
"""Initialize the FastAPI app."""
if query_client:
self.query_client = query_client
if save_token:
self.save_token = save_token

self.generate_token = create_bearer_token_generator(app.config)

metadata_class = AuthorizationServerMetadata

metadata_file = app.config.get("OAUTH2_METADATA_FILE")
if metadata_file:
with open(metadata_file) as metadata_file_content:
metadata = metadata_class(json.loads(metadata_file_content))
metadata.validate()
self.metadata = metadata

self.scopes_supported = app.config.get("OAUTH2_SCOPES_SUPPORTED")
self._error_uris = app.config.get("OAUTH2_ERROR_URIS")

def query_client(self, client_id):
return self._query_client(client_id)

def save_token(self, token, request):
return self._save_token(token, request)

def get_error_uri(self, request, error):
if self._error_uris:
uris = dict(self._error_uris)
return uris.get(error.error)

def create_oauth2_request(self, request):
return OAuth2Request(
request.method, str(request.url), request.body, request.headers
)

def create_json_request(self, request):
return HttpRequest(
request.method, str(request.url), request.body, request.headers
)

def send_signal(self, name, *args, **kwargs):
pass

def handle_response(self, status, body, headers):
return JSONResponse(content=body, status_code=status, headers=dict(headers))

def validate_consent_request(self, request=None, end_user=None):
"""Validate current HTTP request for authorization page. This page
is designed for resource owner to grant or deny the authorization"""
req = self.create_oauth2_request(request)
req.user = end_user

grant = self.get_authorization_grant(req)
grant.validate_consent_request()
if not hasattr(grant, "prompt"):
grant.prompt = None
return grant


def create_bearer_token_generator(config):
"""Create a generator function for generating ``token`` value. This
method will create a Bearer Token generator with
:class:`authlib.oauth2.rfc6750.BearerToken`. By default, it will not
generate ``refresh_token``, which can be turn on by configuration
``OAUTH2_REFRESH_TOKEN_GENERATOR=True``.
"""
conf = config.get("OAUTH2_ACCESS_TOKEN_GENERATOR", True)
access_token_generator = create_token_generator(conf, 42)

conf = config.get("OAUTH2_REFRESH_TOKEN_GENERATOR", False)
refresh_token_generator = create_token_generator(conf, 48)

expires_generator = create_token_expires_in_generator(config)

return BearerToken(
access_token_generator, refresh_token_generator, expires_generator
)


def create_token_expires_in_generator(config):
"""Create a generator function for generating ``expires_in`` value.
Developers can re-implement this method with a subclass if other means
required. The default expires_in value is defined by ``grant_type``,
different ``grant_type`` has different value. It can be configured
with::

OAUTH2_TOKEN_EXPIRES_IN = {
'authorization_code': 864000
}
"""
data = {}
data.update(BearerToken.GRANT_TYPES_EXPIRES_IN)

expires_in_conf = config.get("OAUTH2_TOKEN_EXPIRES_IN")
if expires_in_conf:
data.update(expires_in_conf)

def expires_in(client, grant_type): # pylint: disable=W0613
return data.get(grant_type, BearerToken.DEFAULT_EXPIRES_IN)

return expires_in


def create_token_generator(token_generator_conf, length=42):
"""Create a token generator function."""
if callable(token_generator_conf):
return token_generator_conf

if isinstance(token_generator_conf, str):
return import_string(token_generator_conf)

if token_generator_conf is True:

def token_generator(*args, **kwargs): # pylint: disable=W0613
return generate_token(length)

return token_generator

return None
60 changes: 60 additions & 0 deletions authlib/integrations/fastapi_oauth2/resource_protector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Implementation of authlib.oauth2.rfc6749.ResourceProtector class for FastAPI."""

import functools
from contextlib import contextmanager

from authlib.oauth2 import OAuth2Error
from authlib.oauth2 import ResourceProtector as _ResourceProtector
from authlib.oauth2.rfc6749 import HttpRequest, MissingAuthorizationError
from fastapi import HTTPException


class ResourceProtector(_ResourceProtector):
"""ResourceProtector class."""

def acquire_token(self, request=None, scope=None):
"""A method to acquire current valid token with the given scope.

:param request: request object
:param scope: string or list of scope values
:return: token object
"""
http_request = HttpRequest(request.method, request.url, {}, request.headers)
token = self.validate_request(scope, http_request)
request.state.token = token
return token

@contextmanager
def acquire(self, request=None, scope=None):
"""The with statement of ``require_oauth``. Instead of using a
decorator, you can use a with statement instead."""
try:
yield self.acquire_token(request, scope)
except OAuth2Error as error:
raise_error_response(error)

def __call__(self, scope=None, optional=False):
def wrapper(func):
@functools.wraps(func)
def decorated(request, *args, **kwargs):
try:
self.acquire_token(request, scope)
except MissingAuthorizationError as error:
if optional:
return func(request, *args, **kwargs)
raise_error_response(error)
except OAuth2Error as error:
raise_error_response(error)
return func(request, *args, **kwargs)

return decorated

return wrapper


def raise_error_response(error):
"""Raise the FastAPI HTTPException method."""
status = error.status_code
body = dict(error.get_body())
headers = error.get_headers()
raise HTTPException(status_code=status, detail=body, headers=dict(headers))
Empty file added tests/fastapi/__init__.py
Empty file.
Empty file.
115 changes: 115 additions & 0 deletions tests/fastapi/test_oauth2/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import time

from authlib.integrations.sqla_oauth2 import (OAuth2AuthorizationCodeMixin,
OAuth2ClientMixin,
OAuth2TokenMixin)
from authlib.oidc.core import UserInfo
from sqlalchemy import (Boolean, Column, ForeignKey, Integer, String,
create_engine)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker

engine = create_engine(
"sqlite:///fastapi_auth2_sql.db", connect_args={"check_same_thread": False}
)

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()

db = SessionLocal()


class User(Base):
__tablename__ = "user"

id = Column(Integer, primary_key=True)
username = Column(String(40), unique=True, nullable=False)

def get_user_id(self):
return self.id

def check_password(self, password):
return password != "wrong"

def generate_user_info(self, scopes):
profile = {"sub": str(self.id), "name": self.username}
return UserInfo(profile)


class Client(Base, OAuth2ClientMixin):
__tablename__ = "oauth2_client"

id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"))
user = relationship("User")


class AuthorizationCode(Base, OAuth2AuthorizationCodeMixin):
__tablename__ = "oauth2_code"

id = Column(Integer, primary_key=True)
user_id = Column(Integer, nullable=False)

@property
def user(self):
return db.query(User).filter(User.id == self.user_id).first()


class Token(Base, OAuth2TokenMixin):
__tablename__ = "oauth2_token"

id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("user.id", ondelete="CASCADE"))
user = relationship("User")
revoked = Column(Boolean)

def is_refresh_token_expired(self):
expired_at = self.issued_at + self.expires_in * 2
return expired_at < time.time()


class CodeGrantMixin(object):
def query_authorization_code(self, code, client):
item = (
db.query(AuthorizationCode)
.filter(
AuthorizationCode.code == code, Client.client_id == client.client_id
)
.first()
)
if item and not item.is_expired():
return item

def delete_authorization_code(self, authorization_code):
db.delete(authorization_code)
db.commit()

def authenticate_user(self, authorization_code):
return db.query(User).filter(User.id == authorization_code.user_id).first()


def save_authorization_code(code, request):
client = request.client
auth_code = AuthorizationCode(
code=code,
client_id=client.client_id,
redirect_uri=request.redirect_uri,
scope=request.scope,
nonce=request.data.get("nonce"),
user_id=request.user.id,
code_challenge=request.data.get("code_challenge"),
code_challenge_method=request.data.get("code_challenge_method"),
)
db.add(auth_code)
db.commit()
return auth_code


def exists_nonce(nonce, req):
exists = (
db.query(AuthorizationCode)
.filter(Client.client_id == req.client_id, AuthorizationCode.nonce == nonce)
.first()
)
return bool(exists)
Loading