diff --git a/flask_oauthlib/provider/oauth2.py b/flask_oauthlib/provider/oauth2.py index ea0a0eaa..5bf114e6 100644 --- a/flask_oauthlib/provider/oauth2.py +++ b/flask_oauthlib/provider/oauth2.py @@ -131,11 +131,18 @@ def validate_client_id(self, client_id): if token_generator and not callable(token_generator): token_generator = import_string(token_generator) + refresh_token_generator = self.app.config.get( + 'OAUTH2_PROVIDER_REFRESH_TOKEN_GENERATOR', None + ) + if refresh_token_generator and not callable(refresh_token_generator): + refresh_token_generator = import_string(refresh_token_generator) + if hasattr(self, '_validator'): return Server( self._validator, token_expires_in=expires_in, token_generator=token_generator, + refresh_token_generator=refresh_token_generator, ) if hasattr(self, '_clientgetter') and \ @@ -161,6 +168,7 @@ def validate_client_id(self, client_id): validator, token_expires_in=expires_in, token_generator=token_generator, + refresh_token_generator=refresh_token_generator, ) raise RuntimeError('application not bound to required getters') diff --git a/tests/oauth2/test_oauth2.py b/tests/oauth2/test_oauth2.py index 862cdef7..abd10cc1 100644 --- a/tests/oauth2/test_oauth2.py +++ b/tests/oauth2/test_oauth2.py @@ -389,7 +389,7 @@ class TestTokenGenerator(OAuthSuite): def create_oauth_provider(self, app): - def generator(request, refresh_token=False): + def generator(request): return 'foobar' app.config['OAUTH2_PROVIDER_TOKEN_GENERATOR'] = generator @@ -403,6 +403,28 @@ def test_get_access_token(self): assert data['refresh_token'] == 'foobar' +class TestRefreshTokenGenerator(OAuthSuite): + + def create_oauth_provider(self, app): + + def at_generator(request): + return 'foobar' + + def rt_generator(request): + return 'abracadabra' + + app.config['OAUTH2_PROVIDER_TOKEN_GENERATOR'] = at_generator + app.config['OAUTH2_PROVIDER_REFRESH_TOKEN_GENERATOR'] = rt_generator + return default_provider(app) + + def test_get_access_token(self): + rv = self.client.post(authorize_url, data={'confirm': 'yes'}) + rv = self.client.get(clean_url(rv.location)) + data = json.loads(u(rv.data)) + assert data['access_token'] == 'foobar' + assert data['refresh_token'] == 'abracadabra' + + class TestConfidentialClient(OAuthSuite): def create_oauth_provider(self, app):