Skip to content
This repository has been archived by the owner on Jan 29, 2019. It is now read-only.

Commit

Permalink
Simplify how we fetch CSRF tokens.
Browse files Browse the repository at this point in the history
After running into issues with the template magic we use to fetch the
CSRF token, I decided to just keep things simple and add an explicit
special case for checking where session-csrf stores its token.
  • Loading branch information
Michael Kelly committed Apr 30, 2015
1 parent e331d9b commit 443fddc
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 117 deletions.
3 changes: 1 addition & 2 deletions django_browserid/tests/__init__.py
Expand Up @@ -8,7 +8,6 @@
from django.utils.functional import wraps

from mock import patch
from nose.tools import eq_

from django_browserid.auth import BrowserIDBackend
from django_browserid.base import MockVerifier
Expand Down Expand Up @@ -63,7 +62,7 @@ def inner(*args, **kwargs):

class TestCase(DjangoTestCase):
def assert_json_equals(self, json_str, value):
return eq_(json.loads(smart_text(json_str)), value)
return self.assertEqual(json.loads(smart_text(json_str)), value)

def shortDescription(self):
# Stop nose using the test docstring and instead the test method
Expand Down
81 changes: 39 additions & 42 deletions django_browserid/tests/test_base.py
Expand Up @@ -11,8 +11,6 @@

import requests
from mock import Mock, patch
from nose.plugins.skip import SkipTest
from nose.tools import eq_, ok_

from django_browserid import base
from django_browserid.compat import pybrowserid_found
Expand All @@ -30,7 +28,7 @@ def test_debug_true(self):
run the checks.
"""
request = self.factory.get('/')
ok_(base.sanity_checks(request))
self.assertTrue(base.sanity_checks(request))

@override_settings(DEBUG=False)
def test_debug_false(self):
Expand All @@ -39,7 +37,7 @@ def test_debug_false(self):
run the checks.
"""
request = self.factory.get('/')
ok_(not base.sanity_checks(request))
self.assertTrue(not base.sanity_checks(request))

@override_settings(BROWSERID_DISABLE_SANITY_CHECKS=True)
def test_disable_sanity_checks(self):
Expand All @@ -48,7 +46,7 @@ def test_disable_sanity_checks(self):
checks.
"""
request = self.factory.get('/')
ok_(not base.sanity_checks(request))
self.assertTrue(not base.sanity_checks(request))

@override_settings(BROWSERID_DISABLE_SANITY_CHECKS=False, SESSION_COOKIE_SECURE=True)
def test_sanity_session_cookie(self):
Expand All @@ -60,7 +58,7 @@ def test_sanity_session_cookie(self):
request.is_secure = Mock(return_value=False)
with patch('django_browserid.base.logger.warning') as warning:
base.sanity_checks(request)
ok_(warning.called)
self.assertTrue(warning.called)

@override_settings(BROWSERID_DISABLE_SANITY_CHECKS=False,
MIDDLEWARE_CLASSES=['csp.middleware.CSPMiddleware'])
Expand All @@ -77,31 +75,31 @@ def test_sanity_csp(self, warning):
CSP_SCRIPT_SRC=['https://login.persona.org'],
CSP_FRAME_SRC=['https://login.persona.org']):
base.sanity_checks(request)
ok_(not warning.called)
self.assertTrue(not warning.called)
warning.reset_mock()

# Test fallback to default-src.
with self.settings(CSP_DEFAULT_SRC=['https://login.persona.org'],
CSP_SCRIPT_SRC=[],
CSP_FRAME_SRC=[]):
base.sanity_checks(request)
ok_(not warning.called)
self.assertTrue(not warning.called)
warning.reset_mock()

# Test incorrect csp.
with self.settings(CSP_DEFAULT_SRC=[],
CSP_SCRIPT_SRC=[],
CSP_FRAME_SRC=[]):
base.sanity_checks(request)
ok_(warning.called)
self.assertTrue(warning.called)
warning.reset_mock()

# Test partial incorrectness.
with self.settings(CSP_DEFAULT_SRC=[],
CSP_SCRIPT_SRC=['https://login.persona.org'],
CSP_FRAME_SRC=[]):
base.sanity_checks(request)
ok_(warning.called)
self.assertTrue(warning.called)

@override_settings(BROWSERID_DISABLE_SANITY_CHECKS=False,
MIDDLEWARE_CLASSES=['csp.middleware.CSPMiddleware'])
Expand All @@ -122,19 +120,19 @@ def test_unset_csp(self, warning):
with self.settings(**setting_kwargs):
del settings.CSP_DEFAULT_SRC
base.sanity_checks(request)
ok_(not warning.called)
self.assertTrue(not warning.called)
warning.reset_mock()

with self.settings(**setting_kwargs):
del settings.CSP_FRAME_SRC
base.sanity_checks(request)
ok_(not warning.called)
self.assertTrue(not warning.called)
warning.reset_mock()

with self.settings(**setting_kwargs):
del settings.CSP_SCRIPT_SRC
base.sanity_checks(request)
ok_(not warning.called)
self.assertTrue(not warning.called)
warning.reset_mock()


Expand Down Expand Up @@ -165,7 +163,7 @@ def test_same_origin_found(self):

audiences = ['https://example.com', 'http://testserver']
with self.settings(BROWSERID_AUDIENCES=audiences, DEBUG=False):
eq_(base.get_audience(request), 'http://testserver')
self.assertEqual(base.get_audience(request), 'http://testserver')

def test_no_audience(self):
"""
Expand All @@ -189,7 +187,7 @@ def test_missing_setting_but_in_debug(self):
with patch('django_browserid.base.settings') as settings:
del settings.BROWSERID_AUDIENCES
settings.DEBUG = True
eq_(base.get_audience(request), 'http://testserver')
self.assertEqual(base.get_audience(request), 'http://testserver')

def test_no_audience_but_in_debug(self):
"""
Expand All @@ -200,7 +198,7 @@ def test_no_audience_but_in_debug(self):

# Simulate that no BROWSERID_AUDIENCES has been set
with self.settings(BROWSERID_AUDIENCES=[], DEBUG=True):
eq_(base.get_audience(request), 'http://testserver')
self.assertEqual(base.get_audience(request), 'http://testserver')


class VerificationResultTests(TestCase):
Expand All @@ -210,7 +208,7 @@ def test_getattr_attribute_exists(self):
attribute on the result.
"""
result = base.VerificationResult({'myattr': 'foo'})
eq_(result.myattr, 'foo')
self.assertEqual(result.myattr, 'foo')

def test_getattr_attribute_doesnt_exist(self):
"""
Expand All @@ -236,52 +234,52 @@ def test_expires_invalid_timestamp(self):
the raw string instead.
"""
result = base.VerificationResult({'expires': 'foasdfhas'})
eq_(result.expires, 'foasdfhas')
self.assertEqual(result.expires, 'foasdfhas')

def test_expires_valid_timestamp(self):
"""
If expires contains a valid millisecond timestamp, return a
corresponding datetime.
"""
result = base.VerificationResult({'expires': '1379307128000'})
eq_(datetime(2013, 9, 16, 4, 52, 8), result.expires)
self.assertEqual(datetime(2013, 9, 16, 4, 52, 8), result.expires)

def test_nonzero_failure(self):
"""
If the response status is not 'okay', the result should be
falsy.
"""
ok_(not base.VerificationResult({'status': 'failure'}))
self.assertTrue(not base.VerificationResult({'status': 'failure'}))

def test_nonzero_okay(self):
"""
If the response status is 'okay', the result should be truthy.
"""
ok_(base.VerificationResult({'status': 'okay'}))
self.assertTrue(base.VerificationResult({'status': 'okay'}))

def test_str_success(self):
"""
If the result is successful, include 'Success' and the email in
the string.
"""
result = base.VerificationResult({'status': 'okay', 'email': 'a@example.com'})
eq_(six.text_type(result), '<VerificationResult Success email=a@example.com>')
self.assertEqual(six.text_type(result), '<VerificationResult Success email=a@example.com>')

# If the email is missing, don't include it.
result = base.VerificationResult({'status': 'okay'})
eq_(six.text_type(result), '<VerificationResult Success>')
self.assertEqual(six.text_type(result), '<VerificationResult Success>')

def test_str_failure(self):
"""
If the result is a failure, include 'Failure' in the string.
"""
result = base.VerificationResult({'status': 'failure'})
eq_(six.text_type(result), '<VerificationResult Failure>')
self.assertEqual(six.text_type(result), '<VerificationResult Failure>')

def test_str_unicode(self):
"""Ensure that __str__ can handle unicode values."""
result = base.VerificationResult({'status': 'okay', 'email': six.u('\x80@example.com')})
eq_(six.text_type(result), six.u('<VerificationResult Success email=\x80@example.com>'))
self.assertEqual(six.text_type(result), six.u('<VerificationResult Success email=\x80@example.com>'))


class RemoteVerifierTests(TestCase):
Expand All @@ -302,7 +300,7 @@ class MyVerifier(base.RemoteVerifier):
verifier.verify('asdf', 'http://testserver')

# foo parameter passed with 'bar' value.
eq_(post.call_args[1]['foo'], 'bar')
self.assertEqual(post.call_args[1]['foo'], 'bar')

def test_verify_kwargs(self):
"""
Expand All @@ -316,8 +314,8 @@ def test_verify_kwargs(self):
verifier.verify('asdf', 'http://testserver', foo='bar', baz=5)

# foo parameter passed with 'bar' value.
eq_(post.call_args[1]['data']['foo'], 'bar')
eq_(post.call_args[1]['data']['baz'], 5)
self.assertEqual(post.call_args[1]['data']['foo'], 'bar')
self.assertEqual(post.call_args[1]['data']['baz'], 5)

def test_verify_request_exception(self):
"""
Expand All @@ -332,7 +330,7 @@ def test_verify_request_exception(self):
with self.assertRaises(base.BrowserIDException) as cm:
verifier.verify('asdf', 'http://testserver')

eq_(cm.exception.exc, request_exception)
self.assertEqual(cm.exception.exc, request_exception)

def test_verify_invalid_json(self):
"""
Expand All @@ -345,9 +343,8 @@ def test_verify_invalid_json(self):
response.json.side_effect = ValueError("Couldn't parse json")
post.return_value = response
result = verifier.verify('asdf', 'http://testserver')
ok_(not result)
ok_(result.reason.startswith('Could not parse verifier response'))

self.assertTrue(not result)
self.assertTrue(result.reason.startswith('Could not parse verifier response'))

def test_verify_success(self):
"""
Expand All @@ -362,8 +359,8 @@ def test_verify_success(self):
response.json.return_value = {"status": "okay", "email": "foo@example.com"}
post.return_value = response
result = verifier.verify('asdf', 'http://testserver')
ok_(result)
eq_(result.email, 'foo@example.com')
self.assertTrue(result)
self.assertEqual(result.email, 'foo@example.com')


class MockVerifierTests(TestCase):
Expand All @@ -374,8 +371,8 @@ def test_verify_no_email(self):
"""
verifier = base.MockVerifier(None)
result = verifier.verify('asdf', 'http://testserver')
ok_(not result)
eq_(result.reason, 'No email given to MockVerifier.')
self.assertTrue(not result)
self.assertEqual(result.reason, 'No email given to MockVerifier.')

def test_verify_email(self):
"""
Expand All @@ -384,23 +381,23 @@ def test_verify_email(self):
"""
verifier = base.MockVerifier('a@example.com')
result = verifier.verify('asdf', 'http://testserver')
ok_(result)
eq_(result.audience, 'http://testserver')
eq_(result.email, 'a@example.com')
self.assertTrue(result)
self.assertEqual(result.audience, 'http://testserver')
self.assertEqual(result.email, 'a@example.com')

def test_verify_result_attributes(self):
"""Extra kwargs to the constructor are added to the result."""
verifier = base.MockVerifier('a@example.com', foo='bar', baz=5)
result = verifier.verify('asdf', 'http://testserver')
eq_(result.foo, 'bar')
eq_(result.baz, 5)
self.assertEqual(result.foo, 'bar')
self.assertEqual(result.baz, 5)


class LocalVerifierTests(TestCase):
def setUp(self):
# Skip tests if PyBrowserID is not installed.
if not pybrowserid_found:
raise SkipTest
self.skipTest('PyBrowserID required for test but not installed.')

self.verifier = base.LocalVerifier()

Expand Down
5 changes: 2 additions & 3 deletions django_browserid/tests/test_helpers.py
@@ -1,7 +1,6 @@
from django.utils.functional import lazy

from mock import patch
from nose.tools import eq_

from django_browserid import helpers
from django_browserid.tests import TestCase
Expand All @@ -22,7 +21,7 @@ def test_defaults(self):
with self.settings(BROWSERID_REQUEST_ARGS={'foo': 'bar', 'baz': 1}):
output = helpers.browserid_info()

eq_(output, self.render_to_string.return_value)
self.assertEqual(output, self.render_to_string.return_value)
expected_info = {
'loginUrl': '/browserid/login/',
'logoutUrl': '/browserid/logout/',
Expand All @@ -35,7 +34,7 @@ def test_lazy_request_args(self):
with self.settings(BROWSERID_REQUEST_ARGS=lazy_request_args()):
output = helpers.browserid_info()

eq_(output, self.render_to_string.return_value)
self.assertEqual(output, self.render_to_string.return_value)
expected_info = {
'loginUrl': '/browserid/login/',
'logoutUrl': '/browserid/logout/',
Expand Down
8 changes: 3 additions & 5 deletions django_browserid/tests/test_http.py
@@ -1,5 +1,3 @@
from nose.tools import eq_

from django_browserid.http import JSONResponse
from django_browserid.tests import TestCase

Expand All @@ -8,13 +6,13 @@ class JSONResponseTests(TestCase):
def test_basic(self):
response = JSONResponse({'blah': 'foo', 'bar': 7})
self.assert_json_equals(response.content, {'blah': 'foo', 'bar': 7})
eq_(response.status_code, 200)
self.assertEqual(response.status_code, 200)

response = JSONResponse(['baz', {'biff': False}])
self.assert_json_equals(response.content, ['baz', {'biff': False}])
eq_(response.status_code, 200)
self.assertEqual(response.status_code, 200)

def test_status(self):
response = JSONResponse({'blah': 'foo', 'bar': 7}, status=404)
self.assert_json_equals(response.content, {'blah': 'foo', 'bar': 7})
eq_(response.status_code, 404)
self.assertEqual(response.status_code, 404)
4 changes: 1 addition & 3 deletions django_browserid/tests/test_util.py
Expand Up @@ -5,8 +5,6 @@
from django.utils import six
from django.utils.functional import lazy

from nose.tools import eq_

from django_browserid.tests import TestCase
from django_browserid.util import import_from_setting, LazyEncoder

Expand All @@ -20,7 +18,7 @@ class TestLazyEncoder(TestCase):
def test_lazy(self):
thing = ['foo', lazy_string]
thing_json = json.dumps(thing, cls=LazyEncoder)
eq_('["foo", "blah"]', thing_json)
self.assertEqual('["foo", "blah"]', thing_json)


import_value = 1
Expand Down

0 comments on commit 443fddc

Please sign in to comment.