Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

- [#59](https://github.com/castle/castle-python/pull/59) drop requests min version in ci
- [#56](https://github.com/castle/castle-python/pull/56) drop special ip header behavior
- [#58](https://github.com/castle/castle-python/pull/58) Adds `ip_header` configuration option

### Breaking Changes:

Expand Down
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ import and configure the library with your Castle API secret.
# some headers are always scrubbed, for security reasons.
configuration.blacklisted = ['HTTP-X-header']

# Castle needs the original IP of the client, not the IP of your proxy or load balancer.
# If that IP is sent as a header you can configure the SDK to extract it automatically.
# Note that format, it should be prefixed with `HTTP`, capitalized and separated by underscores.
configuration.ip_headers = ["HTTP_X_FORWARDED_FOR"]

Tracking
--------

Expand Down
12 changes: 12 additions & 0 deletions castle/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self):
self.blacklisted = []
self.request_timeout = REQUEST_TIMEOUT
self.failover_strategy = 'allow'
self.ip_headers = []

@property
def api_secret(self):
Expand Down Expand Up @@ -111,6 +112,17 @@ def failover_strategy(self, value):
else:
raise ConfigurationError

@property
def ip_headers(self):
return self.__ip_headers

@ip_headers.setter
def ip_headers(self, value):
if isinstance(value, list):
self.__ip_headers = value
else:
raise ConfigurationError


# pylint: disable=invalid-name
configuration = Configuration()
14 changes: 14 additions & 0 deletions castle/extractors/ip.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
from castle.configuration import configuration


class ExtractorsIp(object):
def __init__(self, request):
self.request = request

def call(self):
ip_address = self.get_ip_from_headers()
if ip_address:
return ip_address

if hasattr(self.request, 'ip'):
return self.request.ip

return self.request.environ.get('REMOTE_ADDR')

def get_ip_from_headers(self):
for header in configuration.ip_headers:
value = self.request.environ.get(header)
if value:
return value
return None
12 changes: 12 additions & 0 deletions castle/test/configuration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_default_values(self):
self.assertEqual(config.blacklisted, [])
self.assertEqual(config.request_timeout, 500)
self.assertEqual(config.failover_strategy, 'allow')
self.assertEqual(config.ip_headers, [])

def test_api_secret_setter(self):
config = Configuration()
Expand Down Expand Up @@ -80,3 +81,14 @@ def test_failover_strategy_setter_invalid(self):
config = Configuration()
with self.assertRaises(ConfigurationError):
config.failover_strategy = 'invalid'

def test_ip_headers_setter_valid(self):
config = Configuration()
ip_headers = ['HTTP_X_FORWARDED_FOR']
config.ip_headers = ip_headers
self.assertEqual(config.ip_headers, ip_headers)

def test_ip_headers_setter_invalid(self):
config = Configuration()
with self.assertRaises(ConfigurationError):
config.ip_headers = 'invalid'
33 changes: 33 additions & 0 deletions castle/test/extractors/ip_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from castle.test import unittest, mock
from castle.extractors.ip import ExtractorsIp
from castle.configuration import configuration


def request_ip():
Expand All @@ -22,7 +23,23 @@ def request_with_ip_remote_addr():
return req


def request_with_ip_x_forwarded_for():
req = mock.Mock(spec=['environ'])
req.environ = {'HTTP_X_FORWARDED_FOR': request_ip()}
return req


def request_with_ip_cf_connecting_ip():
req = mock.Mock(spec=['environ'])
req.environ = {'HTTP_CF_CONNECTING_IP': request_ip_next()}
return req


class ExtractorsIpTestCase(unittest.TestCase):
@classmethod
def tearDownClass(cls):
configuration.ip_headers = []

def test_extract_ip(self):
self.assertEqual(ExtractorsIp(request()).call(), request_ip())

Expand All @@ -31,3 +48,19 @@ def test_extract_ip_from_wsgi_request_remote_addr(self):
ExtractorsIp(request_with_ip_remote_addr()).call(),
request_ip()
)

def test_extract_ip_from_wsgi_request_configured_ip_header_first(self):
configuration.ip_headers = ["HTTP_CF_CONNECTING_IP"]
self.assertEqual(
ExtractorsIp(request_with_ip_cf_connecting_ip()).call(),
request_ip_next()
)
configuration.ip_headers = []

def test_extract_ip_from_wsgi_request_configured_ip_header_second(self):
configuration.ip_headers = ["HTTP_CF_CONNECTING_IP", "HTTP_X_FORWARDED_FOR"]
self.assertEqual(
ExtractorsIp(request_with_ip_x_forwarded_for()).call(),
request_ip()
)
configuration.ip_headers = []