Skip to content

Commit

Permalink
Merge pull request #3115 from hypothesis/refactor-feature-caching
Browse files Browse the repository at this point in the history
Refactor feature caching
  • Loading branch information
nickstenning committed Mar 21, 2016
2 parents f143363 + bc85c23 commit be5d793
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 43 deletions.
53 changes: 42 additions & 11 deletions h/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,30 +101,64 @@ def __repr__(self):


class Client(object):
"""
Determine if the named feature is enabled for the current request.
If the feature has no override in the database, it will default to
False. Features must be documented, and an UnknownFeatureError will be
thrown if an undocumented feature is interrogated.
"""

def __init__(self, request):
self.request = request

all_ = request.db.query(Feature).filter(
Feature.name.in_(FEATURES.keys())).all()
self._cache = {f.name: f for f in all_}
self._cache = {}

def __call__(self, name):
return self.enabled(name)

def load(self):
"""Loads the feature flag states into the internal cache."""
all_ = self._fetch_features()
features = {f.name: f for f in all_}
self._cache = {n: self._state(features.get(n))
for n in FEATURES.keys()}

def enabled(self, name):
"""
Determine if the named feature is enabled for the current request.
If the feature has no override in the database, it will default to
False. Features must be documented, and an UnknownFeatureError will be
thrown if an undocumented feature is interrogated.
When the internal cache is empty, it will automatically load the
feature flags from the database first.
"""
if name not in FEATURES:
raise UnknownFeatureError(
'{0} is not a valid feature name'.format(name))

feature = self._cache.get(name)
if not self._cache:
self.load()

return self._cache[name]

def all(self):
"""
Returns a dict mapping feature flag names to enabled states
for the user associated with a given request.
When the internal cache is empty, it will automatically load the
feature flags from the database first.
"""
if not self._cache:
self.load()

return self._cache

def clear(self):
self._cache = {}

def _state(self, feature):
# Features that don't exist in the database are off.
if feature is None:
return False
Expand All @@ -141,12 +175,9 @@ def enabled(self, name):
return True
return False

def all(self):
"""
Returns a dict mapping feature flag names to enabled states
for the user associated with a given request.
"""
return {name: self.enabled(name) for name in FEATURES.keys()}
def _fetch_features(self):
return self.request.db.query(Feature).filter(
Feature.name.in_(FEATURES.keys())).all()


def remove_old_flags():
Expand Down
2 changes: 2 additions & 0 deletions h/mailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def worker(request):
"""
def handle_message(_, message):
"""Receive a message from nsq and send it as an email."""
request.feature.clear()

body = json.loads(message.body)
email = pyramid_mailer.message.Message(
subject=body["subject"], recipients=body["recipients"],
Expand Down
2 changes: 2 additions & 0 deletions h/nipsa/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def worker(request):
"""
def handle_message(_, message):
"""Handle a message on the "nipsa_users_annotations" channel."""
request.feature.clear()

add_or_remove_nipsa(
client=request.es.conn,
index=request.es.index,
Expand Down
2 changes: 2 additions & 0 deletions h/notification/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def run(request):
round-robin fashion.
"""
def handle_message(reader, message=None):
request.feature.clear()

if message is None:
return
with request.tm:
Expand Down
9 changes: 9 additions & 0 deletions h/streamer/test/websocket_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ def test_socket_enqueues_incoming_messages():
assert result.payload == 'client data'


def test_handle_message_clears_feature_cache():
socket = mock.Mock()
message = websocket.Message(socket=socket, payload=json.dumps({
'messageType': 'foo'}))
websocket.handle_message(message)

socket.request.feature.clear.assert_called_with()


def test_handle_message_sets_socket_client_id_for_client_id_messages():
socket = mock.Mock()
socket.client_id = None
Expand Down
1 change: 1 addition & 0 deletions h/streamer/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def closed(self, code, reason=None):

def handle_message(message):
socket = message.socket
socket.request.feature.clear()

data = json.loads(message.payload)

Expand Down
120 changes: 88 additions & 32 deletions h/test/features_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,45 @@ def features_pending_removal_override(request):


class TestClient(object):
def test_init_loads_features(self):
def test_init_stores_the_request(self):
request = DummyRequest()
client = features.Client(request)
assert client.request == request

def test_init_initializes_an_empty_cache(self):
client = features.Client(DummyRequest())
assert client._cache == {}

def test_load_loads_features(self, client):
db.Session.add(features.Feature(name='notification'))
db.Session.flush()

client = self.client()
client.load()
assert client._cache.keys() == ['notification']

def test_init_skips_database_features_missing_from_dict(self):
def test_load_includes_features_not_in_db(self, client):
client.load()
assert client._cache.keys() == ['notification']

def test_load_skips_database_features_missing_from_dict(self, client):
"""
Test that init does not load features that are still in the database
Test that load does not load features that are still in the database
but not in the FEATURES dict anymore
"""
db.Session.add(features.Feature(name='notification'))
db.Session.add(features.Feature(name='new_homepage'))
db.Session.flush()

client = self.client()
assert len(client._cache) == 0
client.load()
assert client._cache.keys() == ['notification']

def test_init_skips_pending_removal_features(self):
def test_load_skips_pending_removal_features(self, client):
db.Session.add(features.Feature(name='notification'))
db.Session.add(features.Feature(name='abouttoberemoved'))
db.Session.flush()

client = self.client()
assert len(client._cache) == 0
client.load()
assert client._cache.keys() == ['notification']

def test_enabled_raises_for_undocumented_feature(self, client):
with pytest.raises(features.UnknownFeatureError):
Expand All @@ -62,53 +77,94 @@ def test_enabled_raises_for_feature_pending_removal(self, client):
with pytest.raises(features.UnknownFeatureError):
client.enabled('abouttoberemoved')

def test_enabled_loads_cache_when_empty(self,
client,
client_load):

def test_load():
client._cache = {'notification': True}
client_load.side_effect = test_load

client._cache = {}
client.enabled('notification')
client_load.assert_called_with()

def test_enabled_false_if_not_in_database(self, client):
assert client.enabled('notification') == False
assert client.enabled('notification') is False

def test_enabled_false_if_everyone_false(self, client):
client._cache['notification'] = features.Feature(everyone=False)
assert client.enabled('notification') == False
def test_enabled_false_if_everyone_false(self, client, fetcher):
fetcher.return_value = [
features.Feature(name='notification', everyone=False)]
assert client.enabled('notification') is False

def test_enabled_true_if_everyone_true(self, client):
client._cache['notification'] = features.Feature(everyone=True)
assert client.enabled('notification') == True
def test_enabled_true_if_everyone_true(self, client, fetcher):
fetcher.return_value = [
features.Feature(name='notification', everyone=True)]
assert client.enabled('notification') is True

def test_enabled_false_when_admins_true_normal_request(self, client):
client._cache['notification'] = features.Feature(admins=True)
assert client.enabled('notification') == False
def test_enabled_false_when_admins_true_normal_request(self,
client,
fetcher):
fetcher.return_value = [
features.Feature(name='notification', admins=True)]
assert client.enabled('notification') is False

def test_enabled_true_when_admins_true_admin_request(self,
client,
fetcher,
authn_policy):
client._cache['notification'] = features.Feature(admins=True)
authn_policy.effective_principals.return_value = [role.Admin]
assert client.enabled('notification') == True
fetcher.return_value = [
features.Feature(name='notification', admins=True)]
assert client.enabled('notification') is True

def test_enabled_false_when_staff_true_normal_request(self,
client,
fetcher):
fetcher.return_value = [
features.Feature(name='notification', staff=True)]

def test_enabled_false_when_staff_true_normal_request(self, client):
client._cache['notification'] = features.Feature(staff=True)
assert client.enabled('notification') == False
assert client.enabled('notification') is False

def test_enabled_true_when_staff_true_staff_request(self,
client,
fetcher,
authn_policy):
client._cache['notification'] = features.Feature(staff=True)
authn_policy.effective_principals.return_value = [role.Staff]
assert client.enabled('notification') == True
fetcher.return_value = [
features.Feature(name='notification', staff=True)]

assert client.enabled('notification') is True

def test_all_checks_enabled(self, client, enabled):
def test_all_loads_cache_when_empty(self, client, client_load):
client._cache = {}
client.all()
enabled.assert_called_with('notification')
client_load.assert_called_with()

def test_all_omits_features_pending_removal(self, client):
assert client.all() == {'notification': False}
def test_all_returns_cache(self, client):
cache = mock.Mock()
client._cache = cache
assert client.all() == cache

def test_clear(self, client):
client._cache = mock.Mock()
client.clear()
assert client._cache == {}

@pytest.fixture
def client(self):
return features.Client(DummyRequest(db=db.Session))

@pytest.fixture
def enabled(self, request, client):
patcher = mock.patch('h.features.Client.enabled')
def client_load(self, request, client):
patcher = mock.patch('h.features.Client.load')
method = patcher.start()
request.addfinalizer(patcher.stop)
return method

@pytest.fixture
def fetcher(self, request, client):
patcher = mock.patch('h.features.Client._fetch_features')
method = patcher.start()
request.addfinalizer(patcher.stop)
return method
Expand Down

0 comments on commit be5d793

Please sign in to comment.