diff --git a/django/middleware/csrf.py b/django/middleware/csrf.py index 2602ce88243f0..b5a85795b2452 100644 --- a/django/middleware/csrf.py +++ b/django/middleware/csrf.py @@ -13,6 +13,7 @@ from django.core.urlresolvers import get_callable from django.utils.cache import patch_vary_headers from django.utils.hashcompat import md5_constructor +from django.utils.http import same_origin from django.utils.log import getLogger from django.utils.safestring import mark_safe from django.utils.crypto import constant_time_compare @@ -161,10 +162,9 @@ def process_view(self, request, callback, callback_args, callback_kwargs): ) return self._reject(request, REASON_NO_REFERER) - # The following check ensures that the referer is HTTPS, - # the domains match and the ports match - the same origin policy. + # Note that request.get_host() includes the port good_referer = 'https://%s/' % request.get_host() - if not referer.startswith(good_referer): + if not same_origin(referer, good_referer): reason = REASON_BAD_REFERER % (referer, good_referer) logger.warning('Forbidden (%s): %s' % (reason, request.path), extra={ diff --git a/django/utils/http.py b/django/utils/http.py index ae2dabf7363d2..c93a338c30810 100644 --- a/django/utils/http.py +++ b/django/utils/http.py @@ -3,6 +3,7 @@ import re import sys import urllib +import urlparse from email.Utils import formatdate from django.utils.encoding import smart_str, force_unicode @@ -186,3 +187,20 @@ def quote_etag(etag): """ return '"%s"' % etag.replace('\\', '\\\\').replace('"', '\\"') +if sys.version_info >= (2, 6): + def same_origin(url1, url2): + """ + Checks if two URLs are 'same-origin' + """ + p1, p2 = urlparse.urlparse(url1), urlparse.urlparse(url2) + return (p1.scheme, p1.hostname, p1.port) == (p2.scheme, p2.hostname, p2.port) +else: + # Python 2.4, 2.5 compatibility. This actually works for Python 2.6 and + # above, but the above definition is much more obviously correct and so is + # preferred going forward. + def same_origin(url1, url2): + """ + Checks if two URLs are 'same-origin' + """ + p1, p2 = urlparse.urlparse(url1), urlparse.urlparse(url2) + return p1[0:2] == p2[0:2] diff --git a/tests/regressiontests/csrf_tests/tests.py b/tests/regressiontests/csrf_tests/tests.py index 396f98f419bc1..d75adc917bb02 100644 --- a/tests/regressiontests/csrf_tests/tests.py +++ b/tests/regressiontests/csrf_tests/tests.py @@ -382,3 +382,16 @@ def test_https_good_referer(self): req.META['HTTP_REFERER'] = 'https://www.example.com/somepage' req2 = CsrfViewMiddleware().process_view(req, post_form_view, (), {}) self.assertEqual(None, req2) + + def test_https_good_referer_2(self): + """ + Test that a POST HTTPS request with a good referer is accepted + where the referer contains no trailing slash + """ + # See ticket #15617 + req = self._get_POST_request_with_token() + req._is_secure = True + req.META['HTTP_HOST'] = 'www.example.com' + req.META['HTTP_REFERER'] = 'https://www.example.com' + req2 = CsrfViewMiddleware().process_view(req, post_form_view, (), {}) + self.assertEqual(None, req2) diff --git a/tests/regressiontests/utils/http.py b/tests/regressiontests/utils/http.py new file mode 100644 index 0000000000000..83a4a7f54d7ed --- /dev/null +++ b/tests/regressiontests/utils/http.py @@ -0,0 +1,23 @@ +from django.utils import http +from django.utils import unittest + +class TestUtilsHttp(unittest.TestCase): + + def test_same_origin_true(self): + # Identical + self.assertTrue(http.same_origin('http://foo.com/', 'http://foo.com/')) + # One with trailing slash - see #15617 + self.assertTrue(http.same_origin('http://foo.com', 'http://foo.com/')) + self.assertTrue(http.same_origin('http://foo.com/', 'http://foo.com')) + # With port + self.assertTrue(http.same_origin('https://foo.com:8000', 'https://foo.com:8000/')) + + def test_same_origin_false(self): + # Different scheme + self.assertFalse(http.same_origin('http://foo.com', 'https://foo.com')) + # Different host + self.assertFalse(http.same_origin('http://foo.com', 'http://goo.com')) + # Different host again + self.assertFalse(http.same_origin('http://foo.com', 'http://foo.com.evil.com')) + # Different port + self.assertFalse(http.same_origin('http://foo.com:8000', 'http://foo.com:8001')) diff --git a/tests/regressiontests/utils/tests.py b/tests/regressiontests/utils/tests.py index 6d3bbfa86cc17..5c4c0602e8f12 100644 --- a/tests/regressiontests/utils/tests.py +++ b/tests/regressiontests/utils/tests.py @@ -7,6 +7,7 @@ from module_loading import * from termcolors import * from html import * +from http import * from checksums import * from text import * from simplelazyobject import *