Skip to content

Commit

Permalink
Add set_callback_url function (#239)
Browse files Browse the repository at this point in the history
* Add set_callback_url function

* Add exception handler for set_callback_url
  • Loading branch information
mosi-kha committed Mar 24, 2021
1 parent 553d262 commit 7fb793e
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions dj_rest_auth/registration/serializers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from django.contrib.auth import get_user_model
from django.http import HttpRequest
from django.urls.exceptions import NoReverseMatch
from django.utils.translation import gettext_lazy as _
from requests.exceptions import HTTPError
from rest_framework import serializers
from rest_framework.reverse import reverse

try:
from allauth.account import app_settings as allauth_settings
Expand All @@ -20,6 +22,7 @@ class SocialAccountSerializer(serializers.ModelSerializer):
"""
serialize allauth SocialAccounts for use with a REST API
"""

class Meta:
model = SocialAccount
fields = (
Expand Down Expand Up @@ -57,6 +60,21 @@ def get_social_login(self, adapter, app, token, response):
social_login.token = token
return social_login

def set_callback_url(self, view, adapter_class):
# first set url from view
self.callback_url = getattr(view, 'callback_url', None)
if not self.callback_url:
# auto generate base on adapter and request
try:
self.callback_url = reverse(
viewname=adapter_class.provider_id + '_callback',
request=self._get_request()
)
except NoReverseMatch:
raise serializers.ValidationError(
_("Define callback_url in view")
)

def validate(self, attrs):
view = self.context.get('view')
request = self._get_request()
Expand Down Expand Up @@ -89,13 +107,9 @@ def validate(self, attrs):

# Case 2: We received the authorization code
elif code:
self.callback_url = getattr(view, 'callback_url', None)
self.set_callback_url(view=view, adapter_class=adapter_class)
self.client_class = getattr(view, 'client_class', None)

if not self.callback_url:
raise serializers.ValidationError(
_("Define callback_url in view")
)
if not self.client_class:
raise serializers.ValidationError(
_("Define client_class in view")
Expand Down

0 comments on commit 7fb793e

Please sign in to comment.