Skip to content

Commit

Permalink
Add blacklist view to log out users (#306)
Browse files Browse the repository at this point in the history
* feat: add blacklist view to log out users

* fix: import order

* test: check if blacklisted tokens cannot be used

Co-authored-by: Hodossy, Szabolcs <szabolcs.hodossy@continental-corporation.com>
Co-authored-by: Andrew Chen Wang <60190294+Andrew-Chen-Wang@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 1, 2021
1 parent c9e989e commit 9b06293
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 1 deletion.
12 changes: 12 additions & 0 deletions rest_framework_simplejwt/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,15 @@ def validate(self, attrs):
raise ValidationError("Token is blacklisted")

return {}


class TokenBlacklistSerializer(serializers.Serializer):
refresh = serializers.CharField()

def validate(self, attrs):
refresh = RefreshToken(attrs['refresh'])
try:
refresh.blacklist()
except AttributeError:
pass
return {}
11 changes: 11 additions & 0 deletions rest_framework_simplejwt/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,14 @@ class TokenVerifyView(TokenViewBase):


token_verify = TokenVerifyView.as_view()


class TokenBlacklistView(TokenViewBase):
"""
Takes a token and blacklists it. Must be used with the
`rest_framework_simplejwt.token_blacklist` app installed.
"""
serializer_class = serializers.TokenBlacklistSerializer


token_blacklist = TokenBlacklistView.as_view()
75 changes: 74 additions & 1 deletion tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.serializers import (
TokenObtainPairSerializer, TokenObtainSerializer,
TokenBlacklistSerializer, TokenObtainPairSerializer, TokenObtainSerializer,
TokenObtainSlidingSerializer, TokenRefreshSerializer,
TokenRefreshSlidingSerializer, TokenVerifySerializer,
)
Expand Down Expand Up @@ -365,3 +365,76 @@ def test_it_should_return_given_token_if_everything_ok(self):
self.assertTrue(s.is_valid())

self.assertEqual(len(s.validated_data), 0)


class TestTokenBlacklistSerializer(TestCase):
def test_it_should_raise_token_error_if_token_invalid(self):
token = RefreshToken()
del token['exp']

s = TokenBlacklistSerializer(data={'refresh': str(token)})

with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("has no 'exp' claim", e.exception.args[0])

token.set_exp(lifetime=-timedelta(days=1))

s = TokenBlacklistSerializer(data={'refresh': str(token)})

with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn('invalid or expired', e.exception.args[0])

def test_it_should_raise_token_error_if_token_has_wrong_type(self):
token = RefreshToken()
token[api_settings.TOKEN_TYPE_CLAIM] = 'wrong_type'

s = TokenBlacklistSerializer(data={'refresh': str(token)})

with self.assertRaises(TokenError) as e:
s.is_valid()

self.assertIn("wrong type", e.exception.args[0])

def test_it_should_return_nothing_if_everything_ok(self):
refresh = RefreshToken()
refresh['test_claim'] = 'arst'

# Serializer validates
s = TokenBlacklistSerializer(data={'refresh': str(refresh)})

now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(s.is_valid())

self.assertDictEqual(s.validated_data, {})

def test_it_should_blacklist_refresh_token_if_everything_ok(self):
self.assertEqual(OutstandingToken.objects.count(), 0)
self.assertEqual(BlacklistedToken.objects.count(), 0)

refresh = RefreshToken()

refresh['test_claim'] = 'arst'

old_jti = refresh['jti']

# Serializer validates
ser = TokenBlacklistSerializer(data={'refresh': str(refresh)})

now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
fake_aware_utcnow.return_value = now
self.assertTrue(ser.is_valid())

self.assertEqual(OutstandingToken.objects.count(), 1)
self.assertEqual(BlacklistedToken.objects.count(), 1)

# Assert old refresh token is blacklisted
self.assertEqual(BlacklistedToken.objects.first().token.jti, old_jti)
69 changes: 69 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,3 +335,72 @@ def test_it_should_ignore_token_type(self):
res = self.view_post(data={'token': str(token)})
self.assertEqual(res.status_code, 200)
self.assertEqual(len(res.data), 0)


class TestTokenBlacklistView(APIViewTestCase):
view_name = 'token_blacklist'

def setUp(self):
self.username = 'test_user'
self.password = 'test_password'

self.user = User.objects.create_user(
username=self.username,
password=self.password,
)

def test_fields_missing(self):
res = self.view_post(data={})
self.assertEqual(res.status_code, 400)
self.assertIn('refresh', res.data)

def test_it_should_return_401_if_token_invalid(self):
token = RefreshToken()
del token['exp']

res = self.view_post(data={'refresh': str(token)})
self.assertEqual(res.status_code, 401)
self.assertEqual(res.data['code'], 'token_not_valid')

token.set_exp(lifetime=-timedelta(seconds=1))

res = self.view_post(data={'refresh': str(token)})
self.assertEqual(res.status_code, 401)
self.assertEqual(res.data['code'], 'token_not_valid')

def test_it_should_return_if_everything_ok(self):
refresh = RefreshToken()
refresh['test_claim'] = 'arst'

# View returns 200
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
fake_aware_utcnow.return_value = now

res = self.view_post(data={'refresh': str(refresh)})

self.assertEqual(res.status_code, 200)

self.assertDictEqual(res.data, {})

def test_it_should_return_401_if_token_is_blacklisted(self):
refresh = RefreshToken()
refresh['test_claim'] = 'arst'

# View returns 200
now = aware_utcnow() - api_settings.ACCESS_TOKEN_LIFETIME / 2

with patch('rest_framework_simplejwt.tokens.aware_utcnow') as fake_aware_utcnow:
fake_aware_utcnow.return_value = now

res = self.view_post(data={'refresh': str(refresh)})

self.assertEqual(res.status_code, 200)

self.view_name = 'token_refresh'
res = self.view_post(data={'refresh': str(refresh)})
# make sure other tests are not affected
del self.view_name

self.assertEqual(res.status_code, 401)
2 changes: 2 additions & 0 deletions tests/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@

re_path(r'^token/verify/$', jwt_views.token_verify, name='token_verify'),

re_path(r'^token/blacklist/$', jwt_views.token_blacklist, name='token_blacklist'),

re_path(r'^test-view/$', views.test_view, name='test_view'),
]

0 comments on commit 9b06293

Please sign in to comment.