Skip to content

Commit

Permalink
OAuth2 token refresh support. Closes #492
Browse files Browse the repository at this point in the history
  • Loading branch information
omab committed Dec 11, 2012
1 parent 9afedc8 commit c55abfa
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 11 deletions.
20 changes: 20 additions & 0 deletions doc/use_cases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,23 @@ Some particular use cases are listed below.

This view just expect the ``access_token`` as a GET parameter and the
backend name in the URL (check django-social-auth URLs).


4. Token refreshing

OAuth2 defines a mechanism to refresh the ``access_token`` once it expired,
not all the providers support it, and some providers implement it in some
way or another. Usually there's a ``refresh_token`` involved (a token that
identifies the user but it's only used to retrieve a new ``access_token``).

To refresh the token on a given social account just call the
``refresh_token()`` in the ``UserSocialAuth`` instance, like this::

user = User.objects.get(...)
social = user.social_auth.filter(provider='google-oauth2')[0]
social.refresh_token()

At the moment just a few backends were tested against token refreshing
(Google OAuth2, Facebook and Stripe), probably others backends also support
it (if they follow the OAuth2 standard) with the default mechanism. Others
don't support it because the token is not supposed to expire.
38 changes: 33 additions & 5 deletions social_auth/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,13 @@ def get_user_id(self, details, response):
"""OAuth providers return an unique user id in response"""
return response[self.ID_KEY]

def extra_data(self, user, uid, response, details):
@classmethod
def extra_data(cls, user, uid, response, details=None):
"""Return access_token and extra defined names to store in
extra_data field"""
data = {'access_token': response.get('access_token', '')}
name = self.name.replace('-', '_').upper()
names = (self.EXTRA_DATA or []) + setting(name + '_EXTRA_DATA', [])
name = cls.name.replace('-', '_').upper()
names = (cls.EXTRA_DATA or []) + setting(name + '_EXTRA_DATA', [])
for entry in names:
if len(entry) == 2:
(name, alias), discard = entry, False
Expand Down Expand Up @@ -719,6 +720,7 @@ class BaseOAuth2(BaseOAuth):
"""
AUTHORIZATION_URL = None
ACCESS_TOKEN_URL = None
REFRESH_TOKEN_URL = None
RESPONSE_TYPE = 'code'
REDIRECT_STATE = True
STATE_PARAMETER = True
Expand Down Expand Up @@ -802,7 +804,8 @@ def auth_complete_params(self, state=None):
'redirect_uri': self.get_redirect_uri(state)
}

def auth_complete_headers(self):
@classmethod
def auth_headers(cls):
return {'Content-Type': 'application/x-www-form-urlencoded',
'Accept': 'application/json'}

Expand All @@ -811,7 +814,7 @@ def auth_complete(self, *args, **kwargs):
self.process_error(self.data)
params = self.auth_complete_params(self.validate_state())
request = Request(self.ACCESS_TOKEN_URL, data=urlencode(params),
headers=self.auth_complete_headers())
headers=self.auth_headers())

try:
response = simplejson.loads(dsa_urlopen(request).read())
Expand All @@ -827,6 +830,31 @@ def auth_complete(self, *args, **kwargs):
return self.do_auth(response['access_token'], response=response,
*args, **kwargs)

@classmethod
def refresh_token_params(cls, token):
client_id, client_secret = cls.get_key_and_secret()
return {
'refresh_token': token,
'grant_type': 'refresh_token',
'client_id': client_id,
'client_secret': client_secret
}

@classmethod
def process_refresh_token_response(cls, response):
return simplejson.loads(response)

@classmethod
def refresh_token(cls, token):
request = Request(
cls.REFRESH_TOKEN_URL or cls.ACCESS_TOKEN_URL,
data=urlencode(cls.refresh_token_params(token)),
headers=cls.auth_headers()
)
return cls.process_refresh_token_response(
dsa_urlopen(request).read()
)

def do_auth(self, access_token, *args, **kwargs):
"""Finish the auth process once the access_token was retrieved"""
data = self.user_data(access_token, *args, **kwargs)
Expand Down
16 changes: 16 additions & 0 deletions social_auth/backends/facebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class FacebookAuth(BaseOAuth2):
RESPONSE_TYPE = None
SCOPE_SEPARATOR = ','
AUTHORIZATION_URL = 'https://www.facebook.com/dialog/oauth'
ACCESS_TOKEN_URL = ACCESS_TOKEN
SETTINGS_KEY_NAME = 'FACEBOOK_APP_ID'
SETTINGS_SECRET_NAME = 'FACEBOOK_API_SECRET'
SCOPE_VAR_NAME = 'FACEBOOK_EXTENDED_PERMISSIONS'
Expand Down Expand Up @@ -133,6 +134,21 @@ def auth_complete(self, *args, **kwargs):
else:
raise AuthException(self)

@classmethod
def process_refresh_token_response(cls, response):
return dict((key, val[0])
for key, val in cgi.parse_qs(response).iteritems())

@classmethod
def refresh_token_params(cls, token):
client_id, client_secret = cls.get_key_and_secret()
return {
'fb_exchange_token': token,
'grant_type': 'fb_exchange_token',
'client_id': client_id,
'client_secret': client_secret
}

def do_auth(self, access_token, expires=None, *args, **kwargs):
data = self.user_data(access_token)

Expand Down
3 changes: 2 additions & 1 deletion social_auth/backends/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class GoogleOAuth2Backend(GoogleOAuthBackend):
name = 'google-oauth2'
EXTRA_DATA = [
('refresh_token', 'refresh_token', True),
('expires_in', setting('SOCIAL_AUTH_EXPIRATION', 'expires'))
('expires_in', setting('SOCIAL_AUTH_EXPIRATION', 'expires')),
('token_type', 'token_type', True)
]

def get_user_id(self, details, response):
Expand Down
12 changes: 10 additions & 2 deletions social_auth/backends/stripe.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,21 @@ def auth_complete_params(self, state=None):
'code': self.data['code']
}

def auth_complete_headers(self):
client_id, client_secret = self.get_key_and_secret()
@classmethod
def auth_headers(cls):
client_id, client_secret = cls.get_key_and_secret()
return {
'Accept': 'application/json',
'Authorization': 'Bearer %s' % client_secret
}

@classmethod
def refresh_token_params(cls, refresh_token):
return {
'refresh_token': refresh_token,
'grant_type': 'refresh_token'
}


# Backend definition
BACKENDS = {
Expand Down
22 changes: 19 additions & 3 deletions social_auth/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,33 @@ def __unicode__(self):
"""Return associated user unicode representation"""
return u'%s - %s' % (unicode(self.user), self.provider.title())

def get_backend(self):
# Make import here to avoid recursive imports :-/
from social_auth.backends import get_backends
return get_backends().get(self.provider)

@property
def tokens(self):
"""Return access_token stored in extra_data or None"""
# Make import here to avoid recursive imports :-/
from social_auth.backends import get_backends
backend = get_backends().get(self.provider)
backend = self.get_backend()
if backend:
return backend.AUTH_BACKEND.tokens(self)
else:
return {}

def refresh_token(self):
data = self.extra_data
if 'refresh_token' in data or 'access_token' in data:
backend = self.get_backend()
if hasattr(backend, 'refresh_token'):
token = data.get('refresh_token') or data.get('access_token')
response = backend.refresh_token(token)
self.extra_data.update(
backend.AUTH_BACKEND.extra_data(self.user, self.uid,
response)
)
self.save()

def expiration_datetime(self):
"""Return provider session live seconds. Returns a timedelta ready to
use with session.set_expiry().
Expand Down

0 comments on commit c55abfa

Please sign in to comment.