Skip to content
This repository has been archived by the owner on Mar 7, 2023. It is now read-only.

Commit

Permalink
Merge 65b67d1 into 12c27d3
Browse files Browse the repository at this point in the history
  • Loading branch information
reinout committed Jul 5, 2016
2 parents 12c27d3 + 65b67d1 commit 3b609f5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 61 deletions.
66 changes: 22 additions & 44 deletions lizard_auth_server/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,38 +30,40 @@ class JWTField(forms.CharField):
the JWT. Therefore the 'secret_key' field must be set manually. Typically
you can do this in the __init__ method of your Form.
"""
def __init__(self, allowed_keys=(), secret_key='', *args, **kwargs):
def __init__(self, required_keys=(), *args, **kwargs):
"""Constructor
Args:
allowed_keys: (iterable) set of expected keys
required_keys: (iterable) set of expected keys
secret_key: the secret for decoding the JWT, generally set not
via this argument but manually when used in conjunction with
forms
"""
super(JWTField, self).__init__(*args, **kwargs)
self.allowed_keys = allowed_keys
self.secret_key = secret_key
self.required_keys = required_keys

def clean(self, value):
# Call the CharField cleaning method with its validators
super(JWTField, self).clean(value)
value = super(JWTField, self).clean(value)
self.original_value = value
# Do our own cleaning
try:
custom_cleaned = jwt.decode(value, self.secret_key,
algorithms=['HS256'])
# custom_cleaned = jwt.decode(value, verify=False)
except jwt.exceptions.DecodeError:
raise ValidationError("Failed to decode JWT.")
custom_cleaned = jwt.decode(value, verify=False)
except Exception as e:
raise ValidationError(
"Unknown exception while decoding JWT: %s" % e)

for key in self.allowed_keys:
for key in self.required_keys:
if key not in custom_cleaned:
raise ValidationError("Missing key in JWT.")
return custom_cleaned

def verify_signature(self, secret_key):
try:
jwt.decode(self.original_value, secret_key, algorithms=['HS256'])
except jwt.exceptions.DecodeError:
raise ValidationError("Signature of the JWT is wrong.")


class DecryptForm(forms.Form):
key = forms.CharField(max_length=1024)
Expand Down Expand Up @@ -89,51 +91,27 @@ class JWTDecryptForm(forms.Form):
"""Form to decrypt and verify JWT requests."""
key = forms.CharField(max_length=1024)
message = JWTField(max_length=8192,
allowed_keys=('key', 'domain', 'force_sso_login'))

def __init__(self, *args, **kwargs):
"""This init override is necessary for setting the secret key of
the JWTField."""
super(JWTDecryptForm, self).__init__(*args, **kwargs)
self.init_errors = 0
self.init_error_msgs = []
# We'll try to get the SSO key. If this somehow fails, we can just
# return and we'll let the validation fail in the clean method
# due to insufficient data. Plus, for extra safety, the failed
# exception msg is saved and 're-raised' in an ad-hoc manner in the
# clean method so that it can be found in Form.errors.
if 'key' not in self.data:
self.init_errors += 1
self.init_error_msgs.append('No SSO key.')
return
try:
self.site = Site.objects.get(sso_key=self.data['key'])
except Site.DoesNotExist:
self.init_errors += 1
self.init_error_msgs.append('Invalid SSO key.')
return
# The key for decoding the JWTField must be set via this way
self.fields['message'].secret_key = self.site.sso_secret
required_keys=('key', 'domain', 'force_sso_login'))

def clean(self):
"""Verifies additional stuff and cleans form data. """
super(JWTDecryptForm, self).clean()
if self.init_errors > 0:
try:
self.site = Site.objects.get(sso_key=self.cleaned_data['key'])
except Site.DoesNotExist:
raise ValidationError(
"There were errors in the __init__ method of the form: %s" %
' '.join(self.init_error_msgs))
if 'key' not in self.cleaned_data:
raise ValidationError('No SSO key')
"No site found matching key %s" % self.cleaned_data['key'])
self.fields['message'].verify_signature(self.site.sso_secret)

# This check is useful because we can check if the key from the GET
# parameter has not been tampered with.
if self.cleaned_data['key'] != self.cleaned_data.get(
'message', {}).get('key'):
if self.cleaned_data['key'] != self.cleaned_data['message']['key']:
raise ValidationError('Public SSO key does not match signed key')


class JWTLogoutDecryptForm(JWTDecryptForm):
key = forms.CharField(max_length=1024)
message = JWTField(max_length=8192, allowed_keys=('key', 'domain'))
message = JWTField(max_length=8192, required_keys=('key', 'domain'))


def validate_password(cleaned_password):
Expand Down
25 changes: 8 additions & 17 deletions lizard_auth_server/tests/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@ def test_smoke(self):
jwtfield = JWTField()
self.assertTrue(jwtfield is not None)

def test_jwt_field_no_secret_key(self):
"""Test that the JWTField gives an exception when the secret key
isn't set"""
jwtfield = JWTField()
with self.assertRaises(ValidationError):
jwtfield.clean(self.message)

def test_jwt_field_validates(self):
"""JWTField validates with the correct secret key."""
jwtfield = JWTField()
Expand All @@ -41,19 +34,17 @@ def test_jwt_field_validates(self):
def test_jwt_field_wrong_key(self):
"""JWTField doesn't validate with the wrong secret key."""
jwtfield = JWTField()
jwtfield.secret_key = "not the right key"
jwtfield.clean(self.message)
with self.assertRaises(ValidationError):
jwtfield.clean(self.message)
jwtfield.verify_signature('pietje')

def test_jwt_field_unallowed_keys(self):
"""JWTField doesn't validate with unknown allowed_keys."""
jwtfield = JWTField(allowed_keys=('unknown_key',))
jwtfield.secret_key = self.secret_key
def test_jwt_field_nonrequired_keys(self):
"""JWTField doesn't validate with unknown required_keys."""
jwtfield = JWTField(required_keys=('unknown_key',))
with self.assertRaises(ValidationError):
jwtfield.clean(self.message)

def test_jwt_field_allowed_keys(self):
"""JWTField validates with correct allowed_keys."""
jwtfield = JWTField(allowed_keys=('foo',))
jwtfield.secret_key = self.secret_key
def test_jwt_field_required_keys(self):
"""JWTField validates with correct required_keys."""
jwtfield = JWTField(required_keys=('foo',))
jwtfield.clean(self.message)

0 comments on commit 3b609f5

Please sign in to comment.