Skip to content

Commit

Permalink
OAuth2 login support
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed May 28, 2023
1 parent 759d64d commit e657d6b
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 2 deletions.
5 changes: 5 additions & 0 deletions api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@
| `MAIL_USERNAME` | not defined | The username to use for sending emails. |
| `MAIL_PASSWORD` | not defined | The password to use for sending emails. |
| `MAIL_DEFAULT_SENDER` | `donotreply@microblog.example.com` | The default sender to use for emails. |
| `GITHUB_CLIENT_ID` | not defined | The client ID for the GitHub OAuth2 application, used for logging in with a GitHub account. |
| `GITHUB_CLIENT_SECRET` | not defined | The client secret for the GitHub OAuth2 application, used for logging in with a GitHub account. |
| `GOOGLE_CLIENT_ID` | not defined | The client ID for the Google OAuth2 application, used for logging in with a Google account. |
| `GOOGLE_CLIENT_SECRET` | not defined | The client secret for the Google OAuth2 application, used for logging in with a Google account. |
| `OAUTH2_REDIRECT_URI` | `http://localhost:3000/oauth2/{provider}/callback` | The redirect URI to use for OAuth2 logins. A `{provider}` placeholder can be used to have the provider name inserted dynamically. |
## Authentication
Expand Down
4 changes: 4 additions & 0 deletions api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ def __repr__(self): # pragma: no cover
def url(self):
return url_for('users.get', id=self.id)

@property
def has_password(self):
return self.password_hash is not None

@property
def avatar_url(self):
digest = md5(self.email.lower().encode('utf-8')).hexdigest()
Expand Down
6 changes: 6 additions & 0 deletions api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Meta:
validate.Email()])
password = ma.String(required=True, load_only=True,
validate=validate.Length(min=3))
has_password = ma.Boolean(dump_only=True)
avatar_url = ma.String(dump_only=True)
about_me = ma.auto_field()
first_seen = ma.auto_field(dump_only=True)
Expand Down Expand Up @@ -154,3 +155,8 @@ class Meta:

token = ma.String(required=True)
new_password = ma.String(required=True, validate=validate.Length(min=3))


class OAuth2Schema(ma.Schema):
code = ma.String(required=True)
state = ma.String(required=True)
79 changes: 77 additions & 2 deletions api/tokens.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from flask import Blueprint, request, abort, current_app, url_for
import secrets
from urllib.parse import urlencode

from flask import Blueprint, request, abort, current_app, url_for, session
from werkzeug.http import dump_cookie
from apifairy import authenticate, body, response, other_responses
import requests

from api.app import db
from api.auth import basic_auth
from api.email import send_email
from api.models import User, Token
from api.schemas import TokenSchema, PasswordResetRequestSchema, \
PasswordResetSchema, EmptySchema
PasswordResetSchema, OAuth2Schema, EmptySchema

tokens = Blueprint('tokens', __name__)
token_schema = TokenSchema()
oauth2_schema = OAuth2Schema()


def token_response(token):
Expand Down Expand Up @@ -119,3 +124,73 @@ def password_reset(args):
user.password = args['new_password']
db.session.commit()
return {}


@tokens.route('/tokens/oauth2/<provider>', methods=['GET'])
@response(EmptySchema, status_code=302,
description="Redirect to OAuth2 provider's authentication page")
@other_responses({404: 'Unknown OAuth2 provider'})
def oauth2_authorize(provider):
"""Initiate OAuth2 authentication with a third-party provider"""
provider_data = current_app.config['OAUTH2_PROVIDERS'].get(provider)
if provider_data is None:
abort(404)
session['oauth2_state'] = secrets.token_urlsafe(16)
qs = urlencode({
'client_id': provider_data['client_id'],
'redirect_uri': current_app.config['OAUTH2_REDIRECT_URI'].format(
provider=provider),
'response_type': 'code',
'scope': ' '.join(provider_data['scopes']),
'state': session['oauth2_state'],
})
return {}, 302, {'Location': provider_data['authorize_url'] + '?' + qs}


@tokens.route('/tokens/oauth2/<provider>', methods=['POST'])
@body(oauth2_schema)
@response(token_schema)
@other_responses({401: 'Invalid code or state',
404: 'Unknown OAuth2 provider'})
def oauth2_new(args, provider):
"""Create new access and refresh tokens with OAuth2 authentication
The refresh token is returned in the body of the request or as a hardened
cookie, depending on configuration. A cookie should be used when the
client is running in an insecure environment such as a web browser, and
cannot adequately protect the refresh token against unauthorized access.
"""
provider_data = current_app.config['OAUTH2_PROVIDERS'].get(provider)
if provider_data is None:
abort(404)
if args['state'] != session.get('oauth2_state'):
abort(401)
response = requests.post(provider_data['access_token_url'], data={
'client_id': provider_data['client_id'],
'client_secret': provider_data['client_secret'],
'code': args['code'],
'grant_type': 'authorization_code',
'redirect_uri': current_app.config['OAUTH2_REDIRECT_URI'].format(
provider=provider),
}, headers={'Accept': 'application/json'})
if response.status_code != 200:
abort(401)
oauth2_token = response.json().get('access_token')
if not oauth2_token:
abort(401)
response = requests.get(provider_data['get_user']['url'], headers={
'Authorization': 'Bearer ' + oauth2_token,
'Accept': 'application/json',
})
if response.status_code != 200:
abort(401)
email = provider_data['get_user']['email'](response.json())
user = db.session.scalar(User.select().where(User.email == email))
if user is None:
user = User(email=email, username=email.split('@')[0])
db.session.add(user)
token = user.generate_auth_token()
db.session.add(token)
Token.clean() # keep token table clean of old tokens
db.session.commit()
return token_response(token)
29 changes: 29 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,35 @@ class Config:
'http://localhost:3000/reset'
USE_CORS = as_bool(os.environ.get('USE_CORS') or 'yes')
CORS_SUPPORTS_CREDENTIALS = True
OAUTH2_PROVIDERS = {
# https://developers.google.com/identity/protocols/oauth2/web-server
# #httprest
'google': {
'client_id': os.environ.get('GOOGLE_CLIENT_ID'),
'client_secret': os.environ.get('GOOGLE_CLIENT_SECRET'),
'authorize_url': 'https://accounts.google.com/o/oauth2/auth',
'access_token_url': 'https://accounts.google.com/o/oauth2/token',
'get_user': {
'url': 'https://www.googleapis.com/oauth2/v3/userinfo',
'email': lambda json: json['email'],
},
'scopes': ['https://www.googleapis.com/auth/userinfo.email'],
},
# https://docs.github.com/en/apps/oauth-apps/building-oauth-apps
# /authorizing-oauth-apps
'github': {
'client_id': os.environ.get('GITHUB_CLIENT_ID'),
'client_secret': os.environ.get('GITHUB_CLIENT_SECRET'),
'authorize_url': 'https://github.com/login/oauth/authorize',
'access_token_url': 'https://github.com/login/oauth/access_token',
'get_user': {
'url': 'https://api.github.com/user/emails',
'email': lambda json: json[0]['email'],
},
'scopes': ['user:email'],
},
}
OAUTH2_REDIRECT_URI = 'http://localhost:3000/oauth2/{provider}/callback'

# API documentation
APIFAIRY_TITLE = 'Microblog API'
Expand Down
1 change: 1 addition & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ gunicorn
marshmallow-sqlalchemy
pyjwt
python-dotenv
requests
10 changes: 10 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ apispec==6.0.2
# via apifairy
blinker==1.5
# via flask-mail
certifi==2023.5.7
# via requests
charset-normalizer==3.1.0
# via requests
click==8.1.3
# via flask
faker==16.8.1
Expand Down Expand Up @@ -48,6 +52,8 @@ greenlet==2.0.2
# via sqlalchemy
gunicorn==20.1.0
# via -r requirements.in
idna==3.4
# via requests
itsdangerous==2.1.2
# via flask
jinja2==3.1.2
Expand Down Expand Up @@ -78,6 +84,8 @@ python-dateutil==2.8.2
# via faker
python-dotenv==0.21.1
# via -r requirements.in
requests==2.31.0
# via -r requirements.in
six==1.16.0
# via
# flask-cors
Expand All @@ -91,6 +99,8 @@ sqlalchemy==2.0.3
# marshmallow-sqlalchemy
typing-extensions==4.4.0
# via sqlalchemy
urllib3==2.0.2
# via requests
webargs==8.2.0
# via apifairy
werkzeug==2.2.2
Expand Down
14 changes: 14 additions & 0 deletions tests/base_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ class TestConfig(Config):
class TestConfigWithAuth(TestConfig):
DISABLE_AUTH = False
REFRESH_TOKEN_IN_BODY = True
OAUTH2_PROVIDERS = {
'foo': {
'client_id': 'foo-id',
'client_secret': 'foo-secret',
'authorize_url': 'https://foo.com/login',
'access_token_url': 'https://foo.com/token',
'get_user': {
'url': 'https://foo.com/me',
'email': lambda json: json['email'],
},
'scopes': ['user', 'email'],
},
}
OAUTH2_REDIRECT_URI = 'http://localhost/oauth2/{provider}/callback'


class BaseTestCase(unittest.TestCase):
Expand Down
151 changes: 151 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,154 @@ def test_reset_password(self):

rv = self.client.post('/api/tokens', auth=('test', 'bar'))
assert rv.status_code == 200

def test_oauth(self):
rv = self.client.get('/api/tokens/oauth2/bar')
assert rv.status_code == 404
rv = self.client.get('/api/tokens/oauth2/foo')
assert rv.status_code == 302
assert rv.headers['Location'].startswith('https://foo.com/login?')
args = rv.headers['Location'].split('?')[1].split('&')
assert 'client_id=foo-id' in args
assert ('redirect_uri='
'http%3A%2F%2Flocalhost%2Foauth2%2Ffoo%2Fcallback') in args
assert 'response_type=code' in args
assert 'scope=user+email' in args
state = None
for arg in args:
if arg.startswith('state='):
state = arg.split('=')[1]
assert state is not None

# redirect to auth provider
rv = self.client.post('/api/tokens/oauth2/bar',
json={'code': '123', 'state': state})
assert rv.status_code == 404
rv = self.client.post('/api/tokens/oauth2/foo',
json={'code': '123', 'state': 'not-the-state'})
assert rv.status_code == 401
with mock.patch('api.tokens.requests.post') as requests_post:
requests_post.return_value.status_code = 401
rv = self.client.post('/api/tokens/oauth2/foo',
json={'code': '123', 'state': state})
assert rv.status_code == 401
requests_post.assert_called_with(
'https://foo.com/token', data={
'client_id': 'foo-id',
'client_secret': 'foo-secret',
'code': '123',
'grant_type': 'authorization_code',
'redirect_uri': 'http://localhost/oauth2/foo/callback',
}, headers={'Accept': 'application/json'})

# auth with authorization code (failure case)
with mock.patch('api.tokens.requests.post') as requests_post:
with mock.patch('api.tokens.requests.get') as requests_get:
requests_post.return_value.status_code = 200
requests_post.return_value.json.return_value = {
'access_token': 'foo-token',
}
requests_get.return_value.status_code = 401
rv = self.client.post('/api/tokens/oauth2/foo',
json={'code': '123', 'state': state})
assert rv.status_code == 401
requests_post.assert_called_with(
'https://foo.com/token', data={
'client_id': 'foo-id',
'client_secret': 'foo-secret',
'code': '123',
'grant_type': 'authorization_code',
'redirect_uri': 'http://localhost/oauth2/foo/callback',
}, headers={'Accept': 'application/json'})

# auth with authorization code (failure case)
with mock.patch('api.tokens.requests.post') as requests_post:
with mock.patch('api.tokens.requests.get') as requests_get:
requests_post.return_value.status_code = 200
requests_post.return_value.json.return_value = {
'not_access_token': 'foo-token',
}
requests_get.return_value.status_code = 200
rv = self.client.post('/api/tokens/oauth2/foo',
json={'code': '123', 'state': state})
assert rv.status_code == 401
requests_post.assert_called_with(
'https://foo.com/token', data={
'client_id': 'foo-id',
'client_secret': 'foo-secret',
'code': '123',
'grant_type': 'authorization_code',
'redirect_uri': 'http://localhost/oauth2/foo/callback',
}, headers={'Accept': 'application/json'})

# auth with authorization code (success case with new user)
with mock.patch('api.tokens.requests.post') as requests_post:
with mock.patch('api.tokens.requests.get') as requests_get:
requests_post.return_value.status_code = 200
requests_post.return_value.json.return_value = {
'access_token': 'foo-token',
}
requests_get.return_value.status_code = 200
requests_get.return_value.json.return_value = {
'id': 'user-id',
'email': 'foo@foo.com',
}
rv = self.client.post('/api/tokens/oauth2/foo',
json={'code': '123', 'state': state})
assert rv.status_code == 200
requests_post.assert_called_with(
'https://foo.com/token', data={
'client_id': 'foo-id',
'client_secret': 'foo-secret',
'code': '123',
'grant_type': 'authorization_code',
'redirect_uri': 'http://localhost/oauth2/foo/callback',
}, headers={'Accept': 'application/json'})
requests_get.assert_called_with(
'https://foo.com/me', headers={
'Authorization': 'Bearer foo-token',
'Accept': 'application/json',
})
access_token = rv.json['access_token']

# test the access token
rv = self.client.get('/api/me', headers={
'Authorization': f'Bearer {access_token}'})
assert rv.status_code == 200
assert rv.json['username'] == 'foo'

# auth with authorization code (success case with existing user)
with mock.patch('api.tokens.requests.post') as requests_post:
with mock.patch('api.tokens.requests.get') as requests_get:
requests_post.return_value.status_code = 200
requests_post.return_value.json.return_value = {
'access_token': 'foo-token',
}
requests_get.return_value.status_code = 200
requests_get.return_value.json.return_value = {
'id': 'user-id',
'email': 'test@example.com',
}
rv = self.client.post('/api/tokens/oauth2/foo',
json={'code': '123', 'state': state})
assert rv.status_code == 200
requests_post.assert_called_with(
'https://foo.com/token', data={
'client_id': 'foo-id',
'client_secret': 'foo-secret',
'code': '123',
'grant_type': 'authorization_code',
'redirect_uri': 'http://localhost/oauth2/foo/callback',
}, headers={'Accept': 'application/json'})
requests_get.assert_called_with(
'https://foo.com/me', headers={
'Authorization': 'Bearer foo-token',
'Accept': 'application/json',
})
access_token = rv.json['access_token']

# test the access token
rv = self.client.get('/api/me', headers={
'Authorization': f'Bearer {access_token}'})
assert rv.status_code == 200
assert rv.json['username'] == 'test'

0 comments on commit e657d6b

Please sign in to comment.