Skip to content
This repository has been archived by the owner on Mar 7, 2023. It is now read-only.

Commit

Permalink
Rollback
Browse files Browse the repository at this point in the history
  • Loading branch information
jackieleng committed Jul 6, 2016
1 parent 12c27d3 commit a8f1170
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 141 deletions.
110 changes: 27 additions & 83 deletions lizard_auth_server/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,47 +22,6 @@
MIN_LENGTH = 8


class JWTField(forms.CharField):
"""This Field verifies the JWT signature and also verifies if the contents
of the JWT payload contains what we expect.
Note: the JWTField needs the secret key of the JWT to be able to decode
the JWT. Therefore the 'secret_key' field must be set manually. Typically
you can do this in the __init__ method of your Form.
"""
def __init__(self, allowed_keys=(), secret_key='', *args, **kwargs):
"""Constructor
Args:
allowed_keys: (iterable) set of expected keys
secret_key: the secret for decoding the JWT, generally set not
via this argument but manually when used in conjunction with
forms
"""
super(JWTField, self).__init__(*args, **kwargs)
self.allowed_keys = allowed_keys
self.secret_key = secret_key

def clean(self, value):
# Call the CharField cleaning method with its validators
super(JWTField, self).clean(value)
# Do our own cleaning
try:
custom_cleaned = jwt.decode(value, self.secret_key,
algorithms=['HS256'])
# custom_cleaned = jwt.decode(value, verify=False)
except jwt.exceptions.DecodeError:
raise ValidationError("Failed to decode JWT.")
except Exception as e:
raise ValidationError(
"Unknown exception while decoding JWT: %s" % e)

for key in self.allowed_keys:
if key not in custom_cleaned:
raise ValidationError("Missing key in JWT.")
return custom_cleaned


class DecryptForm(forms.Form):
key = forms.CharField(max_length=1024)
message = forms.CharField(max_length=8192)
Expand All @@ -86,54 +45,39 @@ def clean(self):


class JWTDecryptForm(forms.Form):
"""Form to decrypt and verify JWT requests."""
key = forms.CharField(max_length=1024)
message = JWTField(max_length=8192,
allowed_keys=('key', 'domain', 'force_sso_login'))
message = forms.CharField(max_length=8192)

def __init__(self, *args, **kwargs):
"""This init override is necessary for setting the secret key of
the JWTField."""
super(JWTDecryptForm, self).__init__(*args, **kwargs)
self.init_errors = 0
self.init_error_msgs = []
# We'll try to get the SSO key. If this somehow fails, we can just
# return and we'll let the validation fail in the clean method
# due to insufficient data. Plus, for extra safety, the failed
# exception msg is saved and 're-raised' in an ad-hoc manner in the
# clean method so that it can be found in Form.errors.
if 'key' not in self.data:
self.init_errors += 1
self.init_error_msgs.append('No SSO key.')
return
def clean(self):
"""Verifies the JWT from the site and returns the JWT payload.
Note: replaces the original form data with JWT payload, which should
contain a dictionary with the following keys:
['key',
'domain',
'force_sso_login', (this is optional)
]
"""
cleaned_data = super(JWTDecryptForm, self).clean()
if 'key' not in cleaned_data:
raise ValidationError('No SSO key')
try:
self.site = Site.objects.get(sso_key=self.data['key'])
self.site = Site.objects.get(sso_key=cleaned_data['key'])
except Site.DoesNotExist:
self.init_errors += 1
self.init_error_msgs.append('Invalid SSO key.')
return
# The key for decoding the JWTField must be set via this way
self.fields['message'].secret_key = self.site.sso_secret
raise ValidationError('Invalid SSO key')
try:
new_data = jwt.decode(cleaned_data['message'],
self.site.sso_secret,
algorithms=['HS256'])
except jwt.exceptions.DecodeError:
raise ValidationError("Failed to decode JWT.")

def clean(self):
"""Verifies additional stuff and cleans form data. """
super(JWTDecryptForm, self).clean()
if self.init_errors > 0:
raise ValidationError(
"There were errors in the __init__ method of the form: %s" %
' '.join(self.init_error_msgs))
if 'key' not in self.cleaned_data:
raise ValidationError('No SSO key')
# This check is useful because we can check if the key from the GET
# parameter has not been tampered with.
if self.cleaned_data['key'] != self.cleaned_data.get(
'message', {}).get('key'):
# This is useful for verifying if the key of the GET parameter (which
# could be tampered with) is same as the key in the payload.
if cleaned_data['key'] != new_data['key']:
raise ValidationError('Public SSO key does not match signed key')


class JWTLogoutDecryptForm(JWTDecryptForm):
key = forms.CharField(max_length=1024)
message = JWTField(max_length=8192, allowed_keys=('key', 'domain'))
return new_data


def validate_password(cleaned_password):
Expand Down
53 changes: 3 additions & 50 deletions lizard_auth_server/tests/test_forms.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,12 @@
from django.forms import ValidationError
from django.test import TestCase
import jwt

from lizard_auth_server.forms import (
JWTField,
JWTDecryptForm,
)

ALGORITHM = 'HS256'


class TestFormField(TestCase):

def setUp(self):
self.sso_key = "some sso key"
self.secret_key = "a secret"
self.payload = {
"foo": "bar"
}
self.message = jwt.encode(self.payload, self.secret_key,
algorithm=ALGORITHM)
class TestForm(TestCase):

def test_smoke(self):
jwtfield = JWTField()
self.assertTrue(jwtfield is not None)

def test_jwt_field_no_secret_key(self):
"""Test that the JWTField gives an exception when the secret key
isn't set"""
jwtfield = JWTField()
with self.assertRaises(ValidationError):
jwtfield.clean(self.message)

def test_jwt_field_validates(self):
"""JWTField validates with the correct secret key."""
jwtfield = JWTField()
jwtfield.secret_key = self.secret_key
jwtfield.clean(self.message)

def test_jwt_field_wrong_key(self):
"""JWTField doesn't validate with the wrong secret key."""
jwtfield = JWTField()
jwtfield.secret_key = "not the right key"
with self.assertRaises(ValidationError):
jwtfield.clean(self.message)

def test_jwt_field_unallowed_keys(self):
"""JWTField doesn't validate with unknown allowed_keys."""
jwtfield = JWTField(allowed_keys=('unknown_key',))
jwtfield.secret_key = self.secret_key
with self.assertRaises(ValidationError):
jwtfield.clean(self.message)

def test_jwt_field_allowed_keys(self):
"""JWTField validates with correct allowed_keys."""
jwtfield = JWTField(allowed_keys=('foo',))
jwtfield.secret_key = self.secret_key
jwtfield.clean(self.message)
jwtform = JWTDecryptForm()
self.assertTrue(jwtform is not None)
13 changes: 5 additions & 8 deletions lizard_auth_server/views_api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
from urllib.parse import urljoin, urlparse, urlencode

from django.core.urlresolvers import reverse
from django.http import HttpResponse
from django.http import HttpResponseRedirect
from django.template.context import RequestContext
from django.template.response import TemplateResponse
from django.utils.translation import ugettext as _
import jwt

from lizard_auth_server import forms
Expand All @@ -23,7 +21,6 @@
domain_match,
FormInvalidMixin,
)
from lizard_auth_server.views import ErrorMessageResponse
from lizard_auth_server.models import Profile


Expand All @@ -42,12 +39,12 @@ def get_domain(form):
"""
portal_redirect = form.site.redirect_url
domain = form.cleaned_data.get('message', {}).get('domain', None)
domain = form.cleaned_data.get('domain', None)

# BBB, previously the "next" parameter was used, but django itself also
# uses it, leading to conflicts. IF "next" starts with "http", we use it
# and otherwise we omit it.
next = form.cleaned_data.get('message', {}).get('next', None)
next = form.cleaned_data.get('next', None)
if next:
if next.startswith('http'): # Includes https :-)
domain = next
Expand Down Expand Up @@ -105,7 +102,7 @@ def form_valid(self, form):
if self.request.user.is_authenticated():
return self.form_valid_authenticated()
return self.form_valid_unauthenticated(
form.cleaned_data.get('message', {}).get('force_sso_login', True))
form.cleaned_data.get('force_sso_login', True))

def form_valid_authenticated(self):
"""
Expand Down Expand Up @@ -196,7 +193,7 @@ class LogoutView(FormInvalidMixin, ProcessGetFormView):
"""
View for logging out.
"""
form_class = forms.JWTLogoutDecryptForm
form_class = forms.JWTDecryptForm

def form_valid(self, form):
next_url = reverse('lizard_auth_server.api_v2.logout_redirect')
Expand All @@ -220,7 +217,7 @@ class LogoutRedirectView(FormInvalidMixin, ProcessGetFormView):
"""
View that redirects the user to the logout page of the portal.
"""
form_class = forms.JWTLogoutDecryptForm
form_class = forms.JWTDecryptForm

def form_valid(self, form):
url = urljoin(get_domain(form), 'sso/local_logout/')
Expand Down

0 comments on commit a8f1170

Please sign in to comment.