Skip to content

Commit

Permalink
Merge pull request #24 from gmcquillan/expiration
Browse files Browse the repository at this point in the history
Expiration
  • Loading branch information
gmcquillan committed Jan 7, 2016
2 parents 8d584f9 + 2388219 commit 1620090
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 72 deletions.
14 changes: 14 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,20 @@ This is a fork of Django Ratelimit, to support:

The intention is to remain API compliant with Django Ratelimit.

NOTICE - UPGRADES
=================

If you are upgrading from a version prior to 1.4.1, please upgrade to that version first,
then upgrade to the latest version. There has been a serialization change for
cached count values so that expirations are more precise.

To upgrade from any version prior to 1.4.X:

- First upgrade to 1.4.1. It's backwards compatible with all previous versions, but won't cause a service interruption while you're deploying the latest version of django-brake. ``pip install django-brake==1.4.1``

- After this is fully deployed to all your webservers, then you can safely deploy the latest: ``pip install -U django-brake``


Using Django Brake
==================

Expand Down
2 changes: 1 addition & 1 deletion brake/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = (1, 4, 1)
VERSION = (1, 5, 0)
28 changes: 11 additions & 17 deletions brake/backends/cachebe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hashlib
import time

from django.core.cache import cache
from django.core.cache.backends.base import BaseCache
Expand All @@ -7,7 +8,6 @@


CACHE_PREFIX = 'rl:'
BASE_CACHE = BaseCache({})
IP_PREFIX = 'ip:'
KEY_TEMPLATE = 'func:%s:%s%s:%s%s'
PERIOD_PREFIX = 'period:'
Expand Down Expand Up @@ -46,30 +46,23 @@ def _keys(self, func_name, request, ip=True, field=None, period=None):
))

return [
BASE_CACHE.make_key(CACHE_PREFIX + k) for k in keys
CACHE_PREFIX + k for k in keys
]

def count(self, func_name, request, ip=True, field=None, period=60):
"""Increment counters for all relevant cache_keys given a request."""
cache_keys = self._keys(func_name, request, ip, field, period)
counters = dict((key, 1) for key in cache_keys)
counters_read = cache.get_many(cache_keys)
for key, value in counters_read.items():
#Handle old values.
counters = dict((key, (1, time.time() + period)) for key in self._keys(
func_name, request, ip, field, period))
counters.update(cache.get_many(counters.keys()))
for key, value in counters.items():
# Handle old values.
if isinstance(value, tuple):
count, _ = value # Futureproofing for expiration values.
count, expiration = value
else:
count = value

expiration = time.time() + period
count += 1

# These changes come from:
# https://github.com/gmcquillan/django-brake/pull/21
# However, to future proof them, we accept the new values, but
# continue to write the old-style values as part of an upgrade path.
counters[key] = count

cache.set_many(counters, timeout=period)
cache.set(key, (count, expiration), timeout=(expiration - time.time()))

def limit(self, func_name, request,
ip=True, field=None, count=5, period=None):
Expand All @@ -87,6 +80,7 @@ def limit(self, func_name, request,
current_count = counters[counter]
if isinstance(current_count, tuple):
current_count = current_count[0]

if current_count > count:
limits.append({
'ratelimited_by': ratelimited_by,
Expand Down
97 changes: 43 additions & 54 deletions brake/tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time

import unittest
from django.core.cache import cache
from django.http import HttpResponse
Expand Down Expand Up @@ -54,9 +55,9 @@ def setUpClass(cls):

cls.PERIODS = (60, 3600, 86400)
# Setup the keys used for the ip-specific counters.
cls.IP_TEMPLATE = ':1:rl:func:%s:period:%d:ip:127.0.0.1'
cls.IP_TEMPLATE = 'rl:func:%s:period:%d:ip:127.0.0.1'
# Keys using this template are for form field-specific counters.
cls.FIELD_TEMPLATE = ':1:rl:func:%s:period:%s:field:username:%s'
cls.FIELD_TEMPLATE = 'rl:func:%s:period:%s:field:username:%s'
# Sha1 hash of 'user' used in rate limit related tests:
cls.USERNAME_SHA1_DIGEST = 'efe049ccead779e455e93893366c119d44ddd8b5'
cls.KEYS = MockRLKeys()
Expand All @@ -77,6 +78,14 @@ def setUpClass(cls):
'%s_ip_%d' % (function, period),
cls.IP_TEMPLATE % (function, period)
)
cls.FAKE_LOGIN_CACHE_KEYS = [
cls.KEYS.fake_login_field_60,
cls.KEYS.fake_login_field_3600,
cls.KEYS.fake_login_field_86400,
cls.KEYS.fake_login_ip_60,
cls.KEYS.fake_login_ip_3600,
cls.KEYS.fake_login_ip_86400,
]

def _make_rl_key(self, func_name, period, field_hash):
"""Makes a ratelimit-style memcached key."""
Expand Down Expand Up @@ -143,6 +152,7 @@ def fake_login_use_request_path(request):
"""Used to test use_request_path=True"""
return HttpResponse()


class TestRateLimiting(RateLimitTestCase):

def setUp(self):
Expand All @@ -161,27 +171,33 @@ def test_allow_some_failures(self):

def test_fake_keys_work(self):
"""Ensure our ability to artificially set keys is accurate."""
cache.set(self.KEYS.fake_login_ip_60, 4)
cache.set(self.KEYS.fake_login_field_60, 4)
cache.set(self.KEYS.fake_login_ip_3600, 4)
cache.set(self.KEYS.fake_login_field_3600, 4)
cache.set(self.KEYS.fake_login_ip_86400, 4)
cache.set(self.KEYS.fake_login_field_86400, 4)
for initial_key in self.FAKE_LOGIN_CACHE_KEYS:
cache.set(initial_key, (4, time.time() + 120))

self.client.post(fake_login, self.good_payload)

self.assertEqual(cache.get(self.KEYS.fake_login_ip_60), 5)
self.assertEqual(cache.get(self.KEYS.fake_login_field_60), 5)
self.assertEqual(cache.get(self.KEYS.fake_login_ip_3600), 5)
self.assertEqual(cache.get(self.KEYS.fake_login_field_3600), 5)
self.assertEqual(cache.get(self.KEYS.fake_login_ip_86400), 5)
self.assertEqual(cache.get(self.KEYS.fake_login_field_86400), 5)
for test_key in self.FAKE_LOGIN_CACHE_KEYS:
self.assertEqual(cache.get(test_key)[0], 5)

def test_expiration_ttl_set_correctly(self):
"""Ensure our cache TTLs are set correctly."""
cur_time = int(time.time())
self.client.post(fake_login, self.bad_payload)

for key in self.FAKE_LOGIN_CACHE_KEYS:
# We have to use the default prefix that django cache puts on keys
# because we are reaching into the implementation of our LocMemCache
# implementation.
test_ttl = int(cache._expire_info.get(':1:' + key, 0))
expected_ttl = int(key.split(':')[4]) + cur_time
# within a second
self.assertAlmostEqual(test_ttl, expected_ttl, delta=1)

def test_ratelimit_by_ip_one_minute(self):
"""Block requests after 1 minute limit is exceeded."""
# Set our counter as the threshold for our lowest period
# We're only setting the counter for this remote IP
cache.set(self.KEYS.fake_login_ip_60, 5)
cache.set(self.KEYS.fake_login_ip_60, (5, time.time() + 120))
# Ensure that correct logins still go through.
self.assertFalse(self.client.post(fake_login, self.bad_payload))
# Now this most recent login has exceeded the threshold, we should get
Expand All @@ -196,39 +212,39 @@ def test_ratelimit_by_ip_one_minute(self):

def test_ratelimit_by_field_one_minute(self):
"""Block requests after one minute limit is exceeded for a username."""
cache.set(self.KEYS.fake_login_field_60, 5)
cache.set(self.KEYS.fake_login_field_60, (5, time.time() + 120))
self.assertFalse(self.client.post(fake_login, self.bad_payload))
self.assertRaises(
RateLimitError, self.client.post, fake_login, self.bad_payload
)

def test_ratelimit_one_hour(self):
"""Block requests after 1 hour limit is exceeded."""
cache.set(self.KEYS.fake_login_ip_3600, 10)
cache.set(self.KEYS.fake_login_ip_3600, (10, time.time() + 120))
self.assertFalse(self.client.post(fake_login, self.bad_payload))
self.assertRaises(
RateLimitError, self.client.post, fake_login, self.bad_payload
)

def test_ratelimit_by_field_one_hour(self):
"""Block requests after 1 hour limit is exceeded for a username."""
cache.set(self.KEYS.fake_login_field_3600, 10)
cache.set(self.KEYS.fake_login_field_3600, (10, time.time() + 120))
self.assertFalse(self.client.post(fake_login, self.bad_payload))
self.assertRaises(
RateLimitError, self.client.post, fake_login, self.bad_payload
)

def test_ratelimit_one_day(self):
"""Block requests after 1 hour limit is exceeded."""
cache.set(self.KEYS.fake_login_ip_86400, 20)
cache.set(self.KEYS.fake_login_ip_86400, (20, time.time() + 120))
self.assertFalse(self.client.post(fake_login, self.bad_payload))
self.assertRaises(
RateLimitError, self.client.post, fake_login, self.bad_payload
)

def test_ratelimit_by_field_one_day(self):
"""Block requests after 1 hour limit is exceeded for a username."""
cache.set(self.KEYS.fake_login_field_86400, 20)
cache.set(self.KEYS.fake_login_field_86400, (20, time.time() + 120))
self.assertFalse(self.client.post(fake_login, self.bad_payload))
self.assertRaises(
RateLimitError, self.client.post, fake_login, self.bad_payload
Expand All @@ -238,13 +254,13 @@ def test_smaller_periods_unaffected_by_larger_periods(self):
"""Ensure that counts above a smaller period's threshold."""
# Here we set the cache way above the 1 minute threshold, but for the
# hourly period.
cache.set(self.KEYS.fake_login_ip_86400, 15)
cache.set(self.KEYS.fake_login_ip_86400, (15, time.time() + 120))
# We will not be limited because this doesn't put us over any threshold.
self.assertTrue(self.client.post(fake_login, self.good_payload))

def test_overridden_get_ip_works(self):
"""Test that our MyBrake Class defined in test_settings works."""
cache.set(self.KEYS.fake_login_ip_60, 6)
cache.set(self.KEYS.fake_login_ip_60, (6, time.time() + 120))
# Should trigger a ratelimit, but only from the HTTP_TRUE_CLIENT_IP
# REMOTE_ADDR (the default) isn't in our cache at all.
self.assertRaises(
Expand All @@ -260,44 +276,25 @@ def test_overridden_get_ip_works(self):

def test_status_code(self):
"""Test that our custom status code is returned."""
cache.set(self.KEYS.fake_login_no_exception_ip_60, 20)
cache.set(self.KEYS.fake_login_no_exception_ip_60, (20, time.time() + 120))
result = self.client.post(fake_login_no_exception, self.bad_payload)
# The default is 403, if we see 429, then we know our setting worked.
self.assertEqual(result.status_code, 429)

def test_use_request_path(self):
"""Test use_request_path=True = use request.path instead of view function name in cache key"""
cache.set(self.KEYS.fake_login_path_ip_60, 6)
cache.set(self.KEYS.fake_login_path_ip_60, (6, time.time() + 120))
rl = ratelimit(method='POST', use_request_path=True, rate='5/m', block=True)
result = self.client.post(rl(fake_login_use_request_path), self.bad_payload)
self.assertEqual(result.status_code, 429)

def test_dont_use_request_path(self):
"""Test use_request_path=False for the same view function above"""
cache.set(self.KEYS.fake_login_path_ip_60, 6)
cache.set(self.KEYS.fake_login_path_ip_60, (6, time.time() + 120))
rl = ratelimit(method='POST', use_request_path=False, rate='5/m', block=True)
result = self.client.post(rl(fake_login_use_request_path), self.bad_payload)
self.assertEqual(result.status_code, 200)

def test_accept_new_expiration_value_write_legacy_value(self):
"""Make sure we read new version, but write the legacy format.
This is as necessary use-case for a fault-free upgrade path from
legacy to the improved expiration-base values.
"""
# The new value format will be a tuple of the (count, expiration_time).
cache.set(self.KEYS.fake_login_field_60, (5, time.time() + 120))
# Make another incorrect login attempt
self.assertFalse(self.client.post(fake_login, self.bad_payload))
# Check that our count was not only incremented,
# but is in the legacy format.
self.assertEqual(6, cache.get(self.KEYS.fake_login_field_60))

# Next attempt should be Ratelimited.
self.assertRaises(
RateLimitError, self.client.post, fake_login, self.bad_payload
)

def test_new_counters_are_created(self):
"""Makes sure that we create counters for keys/buckets.
Expand All @@ -308,14 +305,6 @@ def test_new_counters_are_created(self):
self.assertFalse(self.client.post(fake_login, self.bad_payload))
# These are the cache keys that are specified by the decorator
# for this view.
fake_login_cache_keys = [
self.KEYS.fake_login_field_60,
self.KEYS.fake_login_field_3600,
self.KEYS.fake_login_field_86400,
self.KEYS.fake_login_ip_60,
self.KEYS.fake_login_ip_3600,
self.KEYS.fake_login_ip_86400,
]
for key in fake_login_cache_keys:
self.assertEquals(1, cache.get(key))
for key in self.FAKE_LOGIN_CACHE_KEYS:
self.assertTrue(cache.get(key) > 1)

0 comments on commit 1620090

Please sign in to comment.