Skip to content
This repository has been archived by the owner on Nov 5, 2019. It is now read-only.

Commit

Permalink
Merge pull request #452 from jonparrott/flask-util-expired-credentials
Browse files Browse the repository at this point in the history
Fix flask required decorator to redirect on expired credentials.
  • Loading branch information
nathanielmanistaatgoogle committed Mar 11, 2016
2 parents 15c945f + a82146a commit 20dcfe2
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 23 deletions.
9 changes: 8 additions & 1 deletion oauth2client/contrib/flask_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,14 @@ def credentials(self):

def has_credentials(self):
"""Returns True if there are valid credentials for the current user."""
return self.credentials and not self.credentials.invalid
if not self.credentials:
return False
# Is the access token expired? If so, do we have an refresh token?
elif (self.credentials.access_token_expired
and not self.credentials.refresh_token):
return False
else:
return True

@property
def email(self):
Expand Down
92 changes: 70 additions & 22 deletions tests/contrib/test_flask_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

"""Unit tests for the Flask utilities"""

import datetime
import httplib2
import json
import unittest
import unittest2

import flask
import six.moves.http_client as httplib
Expand Down Expand Up @@ -64,7 +65,7 @@ def __call__(self, *args, **kwargs):
return self


class FlaskOAuth2Tests(unittest.TestCase):
class FlaskOAuth2Tests(unittest2.TestCase):

def setUp(self):
self.app = flask.Flask(__name__)
Expand All @@ -81,7 +82,7 @@ def _generate_credentials(self, scopes=None):
'client_idz',
'client_secretz',
'refresh_tokenz',
'3600',
datetime.datetime.utcnow() + datetime.timedelta(seconds=3600),
GOOGLE_TOKEN_URI,
'Test',
id_token={
Expand Down Expand Up @@ -175,13 +176,13 @@ def test_create_flow(self):
with self.app.test_request_context():
flow = self.oauth2._make_flow()
state = json.loads(flow.params['state'])
self.assertTrue('google_oauth2_csrf_token' in flask.session)
self.assertIn('google_oauth2_csrf_token', flask.session)
self.assertEqual(
flask.session['google_oauth2_csrf_token'], state['csrf_token'])
self.assertEqual(flow.client_id, self.oauth2.client_id)
self.assertEqual(flow.client_secret, self.oauth2.client_secret)
self.assertTrue('http' in flow.redirect_uri)
self.assertTrue('oauth2callback' in flow.redirect_uri)
self.assertIn('http', flow.redirect_uri)
self.assertIn('oauth2callback', flow.redirect_uri)

flow = self.oauth2._make_flow(return_url='/return_url')
state = json.loads(flow.params['state'])
Expand All @@ -208,9 +209,9 @@ def test_authorize_view(self):
q = urlparse.parse_qs(location.split('?', 1)[1])
state = json.loads(q['state'][0])

self.assertTrue(GOOGLE_AUTH_URI in location)
self.assertFalse(self.oauth2.client_secret in location)
self.assertTrue(self.oauth2.client_id in q['client_id'])
self.assertIn(GOOGLE_AUTH_URI, location)
self.assertNotIn(self.oauth2.client_secret, location)
self.assertIn(self.oauth2.client_id, q['client_id'])
self.assertEqual(
flask.session['google_oauth2_csrf_token'], state['csrf_token'])
self.assertEqual(state['return_url'], '/')
Expand All @@ -225,7 +226,7 @@ def test_authorize_view(self):
with self.app.test_client() as client:
response = client.get('/oauth2authorize?extra_param=test')
location = response.headers['Location']
self.assertTrue('extra_param=test' in location)
self.assertIn('extra_param=test', location)

def _setup_callback_state(self, client, **kwargs):
with self.app.test_request_context():
Expand Down Expand Up @@ -255,9 +256,9 @@ def test_callback_view(self):
'/oauth2callback?state={0}&code=codez'.format(state))

self.assertEqual(response.status_code, httplib.FOUND)
self.assertTrue('/return_url' in response.headers['Location'])
self.assertTrue(self.oauth2.client_secret in http.body)
self.assertTrue('codez' in http.body)
self.assertIn('/return_url', response.headers['Location'])
self.assertIn(self.oauth2.client_secret, http.body)
self.assertIn('codez', http.body)
self.assertTrue(self.oauth2.storage.put.called)

def test_authorize_callback(self):
Expand All @@ -273,7 +274,7 @@ def test_callback_view_errors(self):

response = client.get('/oauth2callback?state={}&error=something')
self.assertEqual(response.status_code, httplib.BAD_REQUEST)
self.assertTrue('something' in response.data.decode('utf-8'))
self.assertIn('something', response.data.decode('utf-8'))

# CSRF mismatch
with self.app.test_client() as client:
Expand Down Expand Up @@ -352,6 +353,24 @@ def test_with_credentials(self):
self.assertEqual(self.oauth2.email, 'user@example.com')
self.assertTrue(self.oauth2.http())

@mock.patch('oauth2client.client._UTCNOW')
def test_with_expired_credentials(self, utcnow):
utcnow.return_value = datetime.datetime(1990, 5, 29)

credentials = self._generate_credentials()
credentials.token_expiry = datetime.datetime(1990, 5, 28)

# Has a refresh token, so this should be fine.
with self.app.test_request_context():
self.oauth2.storage.put(credentials)
self.assertTrue(self.oauth2.has_credentials())

# Without a refresh token this should return false.
credentials.refresh_token = None
with self.app.test_request_context():
self.oauth2.storage.put(credentials)
self.assertFalse(self.oauth2.has_credentials())

def test_bad_id_token(self):
credentials = self._generate_credentials()
credentials.id_token = {}
Expand All @@ -370,8 +389,8 @@ def index():
with self.app.test_client() as client:
response = client.get('/protected')
self.assertEqual(response.status_code, httplib.FOUND)
self.assertTrue('oauth2authorize' in response.headers['Location'])
self.assertTrue('protected' in response.headers['Location'])
self.assertIn('oauth2authorize', response.headers['Location'])
self.assertIn('protected', response.headers['Location'])

credentials = self._generate_credentials(scopes=self.oauth2.scopes)

Expand All @@ -382,7 +401,36 @@ def index():

response = client.get('/protected')
self.assertEqual(response.status_code, httplib.OK)
self.assertTrue('Hello' in response.data.decode('utf-8'))
self.assertIn('Hello', response.data.decode('utf-8'))

# Expired credentials with refresh token, should allow.
credentials.token_expiry = datetime.datetime(1990, 5, 28)
with mock.patch('oauth2client.client._UTCNOW') as utcnow:
utcnow.return_value = datetime.datetime(1990, 5, 29)

with self.app.test_client() as client:
with client.session_transaction() as session:
session['google_oauth2_credentials'] = (
credentials.to_json())

response = client.get('/protected')
self.assertEqual(response.status_code, httplib.OK)
self.assertIn('Hello', response.data.decode('utf-8'))

# Expired credentials without a refresh token, should redirect.
credentials.refresh_token = None
with mock.patch('oauth2client.client._UTCNOW') as utcnow:
utcnow.return_value = datetime.datetime(1990, 5, 29)

with self.app.test_client() as client:
with client.session_transaction() as session:
session['google_oauth2_credentials'] = (
credentials.to_json())

response = client.get('/protected')
self.assertEqual(response.status_code, httplib.FOUND)
self.assertIn('oauth2authorize', response.headers['Location'])
self.assertIn('protected', response.headers['Location'])

def _create_incremental_auth_app(self):
self.app = flask.Flask(__name__)
Expand Down Expand Up @@ -410,7 +458,7 @@ def test_incremental_auth(self):
# No credentials, should redirect
with self.app.test_client() as client:
response = client.get('/one')
self.assertTrue('one' in response.headers['Location'])
self.assertIn('one', response.headers['Location'])
self.assertEqual(response.status_code, httplib.FOUND)

# Credentials for one. /one should allow, /two should redirect.
Expand All @@ -424,14 +472,14 @@ def test_incremental_auth(self):
self.assertEqual(response.status_code, httplib.OK)

response = client.get('/two')
self.assertTrue('two' in response.headers['Location'])
self.assertIn('two', response.headers['Location'])
self.assertEqual(response.status_code, httplib.FOUND)

# Starting the authorization flow should include the
# include_granted_scopes parameter as well as the scopes.
response = client.get(response.headers['Location'][17:])
q = urlparse.parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertTrue('include_granted_scopes' in q)
self.assertIn('include_granted_scopes', q)
self.assertEqual(
set(q['scope'][0].split(' ')),
set(['one', 'email', 'two', 'three']))
Expand Down Expand Up @@ -483,8 +531,8 @@ def test_delete(self):
self.oauth2.storage.put(self._generate_credentials())
self.oauth2.storage.delete()

self.assertFalse('google_oauth2_credentials' in flask.session)
self.assertNotIn('google_oauth2_credentials', flask.session)


if __name__ == '__main__': # pragma: NO COVER
unittest.main()
unittest2.main()

0 comments on commit 20dcfe2

Please sign in to comment.