Skip to content

Commit

Permalink
Modify JWT Serializer Field Names (#501)
Browse files Browse the repository at this point in the history
* Rename JWTserializer fields access_token and refresh_token to access and refresh

* Modify JWT tests to use 'access' and 'refresh' as key names

* Use key names of 'access' and 'refresh'

* Modify JWT tests to ensure all access and refresh fields are returned

* Include refresh_expiration field in JWT refresh view
  • Loading branch information
Dresdn committed May 7, 2023
1 parent 6012859 commit a174adb
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 28 deletions.
3 changes: 2 additions & 1 deletion dj_rest_auth/jwt_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ class RefreshViewWithCookieSupport(TokenRefreshView):
def finalize_response(self, request, response, *args, **kwargs):
if response.status_code == status.HTTP_200_OK and 'access' in response.data:
set_jwt_access_cookie(response, response.data['access'])
response.data['access_token_expiration'] = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME)
response.data['access_expiration'] = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME)
if response.status_code == status.HTTP_200_OK and 'refresh' in response.data:
set_jwt_refresh_cookie(response, response.data['refresh'])
response.data['refresh_expiration'] = (timezone.now() + jwt_settings.REFRESH_TOKEN_LIFETIME)
return super().finalize_response(request, response, *args, **kwargs)
return RefreshViewWithCookieSupport

Expand Down
4 changes: 2 additions & 2 deletions dj_rest_auth/registration/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def get_response_data(self, user):
if api_settings.USE_JWT:
data = {
'user': user,
'access_token': self.access_token,
'refresh_token': self.refresh_token,
'access': self.access_token,
'refresh': self.refresh_token,
}
return api_settings.JWT_SERIALIZER(data, context=self.get_serializer_context()).data
elif api_settings.SESSION_LOGIN:
Expand Down
8 changes: 4 additions & 4 deletions dj_rest_auth/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ class JWTSerializer(serializers.Serializer):
"""
Serializer for JWT authentication.
"""
access_token = serializers.CharField()
refresh_token = serializers.CharField()
access = serializers.CharField()
refresh = serializers.CharField()
user = serializers.SerializerMethodField()

def get_user(self, obj):
Expand All @@ -204,8 +204,8 @@ class JWTSerializerWithExpiration(JWTSerializer):
"""
Serializer for JWT authentication with expiration times.
"""
access_token_expiration = serializers.DateTimeField()
refresh_token_expiration = serializers.DateTimeField()
access_expiration = serializers.DateTimeField()
refresh_expiration = serializers.DateTimeField()


class PasswordResetSerializer(serializers.Serializer):
Expand Down
37 changes: 22 additions & 15 deletions dj_rest_auth/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def test_login_jwt(self):
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)

self.post(self.login_url, data=payload, status_code=200)
self.assertEqual('access_token' in self.response.json.keys(), True)
self.token = self.response.json['access_token']
self.assertEqual('access' in self.response.json.keys(), True)
self.token = self.response.json['access']

@modify_settings(INSTALLED_APPS={'remove': ['allauth', 'allauth.account']})
def test_login_by_email(self):
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_password_change_honors_password_validators(self):
self.token = self.response.json['key']
new_password_payload = {"new_password1": 123, "new_password2": 123}
self.post(self.password_change_url, data=new_password_payload, status_code=400)

@override_api_settings(OLD_PASSWORD_FIELD_ENABLED=True)
def test_password_change_with_old_password(self):
login_payload = {
Expand Down Expand Up @@ -439,7 +439,7 @@ def test_user_details_using_jwt(self):
'password': self.PASS,
}
self.post(self.login_url, data=payload, status_code=200)
self.token = self.response.json['access_token']
self.token = self.response.json['access']
self.get(self.user_url, status_code=200)

self.patch(self.user_url, data=self.BASIC_USER_DATA, status_code=200)
Expand Down Expand Up @@ -522,7 +522,7 @@ def test_registration_with_jwt(self):
self.post(self.register_url, data={}, status_code=400)

result = self.post(self.register_url, data=self.REGISTRATION_DATA, status_code=201)
self.assertIn('access_token', result.data)
self.assertIn('access', result.data)
self.assertEqual(get_user_model().objects.all().count(), user_count + 1)

self._login()
Expand Down Expand Up @@ -727,7 +727,7 @@ def test_blacklisting_not_installed(self):
}
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
resp = self.post(self.login_url, data=payload, status_code=200)
token = resp.data['refresh_token']
token = resp.data['refresh']
resp = self.post(self.logout_url, status=200, data={'refresh': token})
self.assertEqual(resp.status_code, 200)
self.assertEqual(
Expand All @@ -745,7 +745,7 @@ def test_blacklisting(self):
}
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
resp = self.post(self.login_url, data=payload, status_code=200)
token = resp.data['refresh_token']
token = resp.data['refresh']
# test refresh token not included in request data
self.post(self.logout_url, status_code=401)
# test token is invalid or expired
Expand Down Expand Up @@ -776,8 +776,8 @@ def test_custom_jwt_claims(self):
get_user_model().objects.create_user(self.USERNAME, self.EMAIL, self.PASS)

self.post(self.login_url, data=payload, status_code=200)
self.assertEqual('access_token' in self.response.json.keys(), True)
self.token = self.response.json['access_token']
self.assertEqual('access' in self.response.json.keys(), True)
self.token = self.response.json['access']
claims = decode_jwt(self.token, settings.SECRET_KEY, algorithms='HS256')
self.assertEquals(claims['user_id'], 1)
self.assertEquals(claims['name'], 'person')
Expand Down Expand Up @@ -839,7 +839,7 @@ def test_wo_csrf_enforcement(self):

## TEST WITH JWT AUTH HEADER
jwtclient = APIClient(enforce_csrf_checks=True)
token = resp.data['access_token']
token = resp.data['access']
resp = jwtclient.get('/protected-view/', HTTP_AUTHORIZATION='Bearer ' + token)
self.assertEquals(resp.status_code, 200)
resp = jwtclient.post('/protected-view/', {}, HTTP_AUTHORIZATION='Bearer ' + token)
Expand Down Expand Up @@ -885,7 +885,7 @@ def test_csrf_wo_login_csrf_enforcement(self):

## TEST WITH JWT AUTH HEADER
jwtclient = APIClient(enforce_csrf_checks=True)
token = resp.data['access_token']
token = resp.data['access']
resp = jwtclient.get('/protected-view/')
self.assertEquals(resp.status_code, 403)
resp = jwtclient.get('/protected-view/', HTTP_AUTHORIZATION='Bearer ' + token)
Expand Down Expand Up @@ -1017,8 +1017,8 @@ def test_return_expiration(self):
get_user_model().objects.create_user(self.USERNAME, '', self.PASS)

resp = self.post(self.login_url, data=payload, status_code=200)
self.assertIn('access_token_expiration', resp.data.keys())
self.assertIn('refresh_token_expiration', resp.data.keys())
self.assertIn('access_expiration', resp.data.keys())
self.assertIn('refresh_expiration', resp.data.keys())

@override_api_settings(JWT_AUTH_RETURN_EXPIRATION=True)
@override_api_settings(USE_JWT=True)
Expand Down Expand Up @@ -1053,14 +1053,18 @@ def test_custom_token_refresh_view(self):

get_user_model().objects.create_user(self.USERNAME, '', self.PASS)
resp = self.post(self.login_url, data=payload, status_code=200)
refresh = resp.data.get('refresh_token')
refresh = resp.data.get('refresh')
refresh_resp = self.post(
reverse('token_refresh'),
data=dict(refresh=refresh),
status_code=200,
)
self.assertIn('xxx', refresh_resp.cookies)

# Ensure access keys are provided in response
self.assertIn('access', refresh_resp.data)
self.assertIn('access_expiration', refresh_resp.data)

@override_api_settings(USE_JWT=True)
@override_api_settings(JWT_AUTH_HTTPONLY=False)
def test_rotate_token_refresh_view(self):
Expand All @@ -1075,13 +1079,16 @@ def test_rotate_token_refresh_view(self):
resp = self.post(self.login_url, data=payload)
self.assertEqual(resp.status_code, status.HTTP_200_OK)

refresh = resp.data.get('refresh_token', None)
refresh = resp.data.get('refresh', None)
resp = self.post(
reverse('token_refresh'),
data=dict(refresh=refresh),
)
self.assertEqual(resp.status_code, status.HTTP_200_OK)

# Ensure access keys are provided in response
self.assertIn('refresh', resp.data)
self.assertIn('refresh_expiration', resp.data)

@override_api_settings(TOKEN_MODEL=None)
@modify_settings(INSTALLED_APPS={'remove': ['rest_framework.authtoken']})
Expand Down
2 changes: 1 addition & 1 deletion dj_rest_auth/tests/test_social.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def test_jwt(self):
}

self.post(self.fb_login_url, data=payload, status_code=200)
self.assertIn('access_token', self.response.json.keys())
self.assertIn('access', self.response.json.keys())
self.assertIn('user', self.response.json.keys())

self.assertEqual(get_user_model().objects.all().count(), users_count + 1)
Expand Down
10 changes: 5 additions & 5 deletions dj_rest_auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,18 @@ def get_response(self):

data = {
'user': self.user,
'access_token': self.access_token,
'access': self.access_token,
}

if not auth_httponly:
data['refresh_token'] = self.refresh_token
data['refresh'] = self.refresh_token
else:
# Wasnt sure if the serializer needed this
data['refresh_token'] = ""
data['refresh'] = ""

if return_expiration_times:
data['access_token_expiration'] = access_token_expiration
data['refresh_token_expiration'] = refresh_token_expiration
data['access_expiration'] = access_token_expiration
data['refresh_expiration'] = refresh_token_expiration

serializer = serializer_class(
instance=data,
Expand Down

0 comments on commit a174adb

Please sign in to comment.