diff --git a/social_auth/db/base.py b/social_auth/db/base.py index 95dfe7720..4a2e6fa46 100644 --- a/social_auth/db/base.py +++ b/social_auth/db/base.py @@ -46,16 +46,21 @@ def revoke_token(self, drop_token=True): self.save() def refresh_token(self): - data = self.extra_data - if 'refresh_token' in data or 'access_token' in data: + refresh_token = self.extra_data.get('refresh_token') + if refresh_token: backend = self.get_backend() if hasattr(backend, 'refresh_token'): - token = data.get('refresh_token') or data.get('access_token') - response = backend.refresh_token(token) - self.extra_data.update( - backend.AUTH_BACKEND.extra_data(self.user, self.uid, - response) - ) + response = backend.refresh_token(refresh_token) + new_access_token = response.get('access_token') + # We have not got a new access token, so don't lose the + # existing one. + if not new_access_token: + return + self.extra_data['access_token'] = new_access_token + # New refresh token might be given. + new_refresh_token = response.get('refresh_token') + if new_refresh_token: + self.extra_data['refresh_token'] = new_refresh_token self.save() def expiration_datetime(self):