Skip to content

Commit

Permalink
Fix tests, simplify logic
Browse files Browse the repository at this point in the history
  • Loading branch information
jdorn committed Apr 9, 2024
1 parent 9c3dff0 commit f5a5d6d
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 38 deletions.
95 changes: 63 additions & 32 deletions growthbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,22 +366,6 @@ class Filter(TypedDict):
hashVersion: int
attribute: str

class StickyBucketServiceInterface:
def get_assignments(self, attributeName: str, attributeValue: str) -> Optional[Dict]:
pass
def save_assignments(self, doc: Dict) -> None:
pass
# By default, just loop through all attributes and call get_assignments
# Override this method in subclasses to perform a multi-query instead
def get_all_assignments(self, attributes: Dict[str, str]) -> Dict[str, Dict]:
docs = {}
for attributeName, attributeValue in attributes.items():
doc = self.get_assignments(attributeName, attributeValue)
if doc:
key = f"{doc['attributeName']}||{doc['attributeValue']}"
docs[key] = doc
return docs

class Experiment(object):
def __init__(
self,
Expand Down Expand Up @@ -777,6 +761,41 @@ def clear(self) -> None:
self.cache.clear()


class AbstractStickyBucketService(ABC):
@abstractmethod
def get_assignments(self, attributeName: str, attributeValue: str) -> Optional[Dict]:
pass

@abstractmethod
def save_assignments(self, doc: Dict) -> None:
pass

def get_key(self, attributeName: str, attributeValue: str) -> str:
return f"{attributeName}||{attributeValue}"

# By default, just loop through all attributes and call get_assignments
# Override this method in subclasses to perform a multi-query instead
def get_all_assignments(self, attributes: Dict[str, str]) -> Dict[str, Dict]:
docs = {}
for attributeName, attributeValue in attributes.items():
doc = self.get_assignments(attributeName, attributeValue)
if doc:
docs[self.get_key(attributeName, attributeValue)] = doc
return docs

class InMemoryStickyBucketService(AbstractStickyBucketService):
def __init__(self) -> None:
self.docs: Dict[str, Dict] = {}

def get_assignments(self, attributeName: str, attributeValue: str) -> Optional[Dict]:
return self.docs.get(self.get_key(attributeName, attributeValue), None)

def save_assignments(self, doc: Dict) -> None:
self.docs[self.get_key(doc["attributeName"], doc["attributeValue"])] = doc

def destroy(self) -> None:
self.docs.clear()

class FeatureRepository(object):
def __init__(self) -> None:
self.cache: AbstractFeatureCache = InMemoryFeatureCache()
Expand Down Expand Up @@ -870,9 +889,8 @@ def __init__(
decryption_key: str = "",
cache_ttl: int = 60,
forced_variations: dict = {},
sticky_bucket_service: StickyBucketServiceInterface = None,
sticky_bucket_service: AbstractStickyBucketService = None,
sticky_bucket_identifier_attributes: List[str] = None,
sticky_bucket_assignment_docs: Dict[str, dict] = {},
# Deprecated args
trackingCallback=None,
qaMode: bool = False,
Expand All @@ -891,10 +909,9 @@ def __init__(
self._cache_ttl = cache_ttl
self.sticky_bucket_identifier_attributes = sticky_bucket_identifier_attributes
self.sticky_bucket_service = sticky_bucket_service
self.sticky_bucket_assignment_docs = sticky_bucket_assignment_docs

if features:
self.setFeatures(features)
self.sticky_bucket_assignment_docs: dict = {}
self._using_derived_sticky_bucket_attributes = not sticky_bucket_identifier_attributes
self._sticky_bucket_attributes: Optional[dict] = None

self._qaMode = qa_mode or qaMode
self._trackingCallback = on_experiment_viewed or trackingCallback
Expand All @@ -909,6 +926,9 @@ def __init__(
self._assigned: Dict[str, Any] = {}
self._subscriptions: Set[Any] = set()

if features:
self.setFeatures(features)

def load_features(self) -> None:
if not self._client_key:
raise ValueError("Must specify `client_key` to refresh features")
Expand All @@ -933,6 +953,7 @@ def set_features(self, features: dict) -> None:
rules=feature.get("rules", []),
defaultValue=feature.get("defaultValue", None),
)
self.refresh_sticky_buckets()

# @deprecated, use get_features
def getFeatures(self) -> Dict[str, Feature]:
Expand All @@ -947,6 +968,7 @@ def setAttributes(self, attributes: dict) -> None:

def set_attributes(self, attributes: dict) -> None:
self._attributes = attributes
self.refresh_sticky_buckets()

# @deprecated, use get_attributes
def getAttributes(self) -> dict:
Expand Down Expand Up @@ -1275,6 +1297,9 @@ def _run(self, experiment: Experiment, featureId: Optional[str] = None) -> Resul
found_sticky_bucket = sticky_bucket.get('variation', 0) >= 0
assigned = sticky_bucket.get('variation', 0)
sticky_bucket_version_is_blocked = sticky_bucket.get('versionIsBlocked', False)

if found_sticky_bucket:
logger.debug("Found sticky bucket for experiment %s, assigning sticky variation %s", experiment.key, assigned)

# Some checks are not needed if we already have a sticky bucket
if not found_sticky_bucket:
Expand Down Expand Up @@ -1495,21 +1520,25 @@ def _getExperimentResult(
stickyBucketUsed=stickyBucketUsed
)

def _derive_sticky_bucket_identifier_attributes(self, data: Dict[str, Feature]) -> List[str]:
def _derive_sticky_bucket_identifier_attributes(self) -> List[str]:
attributes = set()
for key in data:
feature = data[key]
for key in self._features:
feature = self._features[key]
for rule in feature.rules:
if rule.variations:
attributes.add(rule.hashAttribute or "id")
if rule.fallbackAttribute:
attributes.add(rule.fallbackAttribute)
return list(attributes)

def _get_sticky_bucket_attributes(self, data: Dict[str, Feature] = None) -> dict:
def _get_sticky_bucket_attributes(self) -> dict:
attributes: Dict[str, str] = {}
if self._using_derived_sticky_bucket_attributes:
self.sticky_bucket_identifier_attributes = self._derive_sticky_bucket_identifier_attributes()

if not self.sticky_bucket_identifier_attributes:
self.sticky_bucket_identifier_attributes = self._derive_sticky_bucket_identifier_attributes(data or self._features)
return attributes

for attr in self.sticky_bucket_identifier_attributes:
(_, hash_value) = self._getHashValue(attr)
if hash_value:
Expand Down Expand Up @@ -1563,15 +1592,17 @@ def _get_sticky_bucket_variation(self, experiment_key: str, bucket_version: int
def _get_sticky_bucket_experiment_key(self, experiment_key: str, bucket_version: int = 0) -> str:
return experiment_key + "__" + str(bucket_version)

def refresh_sticky_buckets(self, data: Dict[str, Feature] = None) -> None:
def refresh_sticky_buckets(self, force: bool = False) -> None:
if not self.sticky_bucket_service:
return

attributes = self._get_sticky_bucket_attributes(data)
self.sticky_bucket_assignment_docs = self.sticky_bucket_service.get_all_assignments(attributes)
attributes = self._get_sticky_bucket_attributes()
if not force and attributes == self._sticky_bucket_attributes:
logger.debug("Skipping refresh of sticky bucket assignments, no changes")
return

def get_sticky_bucket_assignment_docs(self) -> Dict[str, dict]:
return self.sticky_bucket_assignment_docs
self._sticky_bucket_attributes = attributes
self.sticky_bucket_assignment_docs = self.sticky_bucket_service.get_all_assignments(attributes)

def _generate_sticky_bucket_assignment_doc(self, attribute_name: str, attribute_value: str, assignments: dict):
key = attribute_name + "||" + attribute_value
Expand Down
4 changes: 2 additions & 2 deletions tests/cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -4953,7 +4953,7 @@
{
"attributes": {
"id": "i123",
"anonymousId": "ses123",
"anonymousId": "123",
"foo": "bar",
"country": "USA"
},
Expand All @@ -4965,7 +4965,7 @@
"key": "feature-exp",
"seed": "feature-exp",
"hashAttribute": "id",
"fallbackAttribute": "deviceId",
"fallbackAttribute": "anonymousId",
"hashVersion": 2,
"bucketVersion": 0,
"condition": { "country": "USA" },
Expand Down
104 changes: 100 additions & 4 deletions tests/test_growthbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
GrowthBook,
Experiment,
Feature,
StickyBucketServiceInterface,
InMemoryStickyBucketService,
getBucketRanges,
gbhash,
chooseVariation,
Expand All @@ -22,6 +22,7 @@
)
from time import time
import pytest
from typing import Optional, Dict

logger.setLevel("DEBUG")

Expand Down Expand Up @@ -149,25 +150,31 @@ def test_stickyBucket(stickyBucket_data):
_, ctx, key, expected_result, expected_docs = stickyBucket_data

# Just use the interface directly, which passes and doesn't persist anywhere
ctx['sticky_bucket_service'] = StickyBucketServiceInterface()
service = InMemoryStickyBucketService()
ctx['sticky_bucket_service'] = service

if 'stickyBucketIdentifierAttributes' in ctx:
ctx['sticky_bucket_identifier_attributes'] = ctx['stickyBucketIdentifierAttributes']
ctx.pop('stickyBucketIdentifierAttributes')

if 'stickyBucketAssignmentDocs' in ctx:
ctx['sticky_bucket_assignment_docs'] = ctx['stickyBucketAssignmentDocs']
service.docs = ctx['stickyBucketAssignmentDocs']
ctx.pop('stickyBucketAssignmentDocs')

gb = GrowthBook(**ctx)
res = gb.eval_feature(key)

print(service.docs)

if not res.experimentResult:
assert None == expected_result
else:
assert res.experimentResult.to_dict() == expected_result

assert gb.get_sticky_bucket_assignment_docs() == expected_docs
assert service.docs == expected_docs

service.destroy()
gb.destroy()


def getTrackingMock(gb: GrowthBook):
Expand Down Expand Up @@ -740,3 +747,92 @@ def test_load_features(mocker):

feature_repo.clear_cache()
gb.destroy()

def test_loose_unmarshalling(mocker):
m = mocker.patch.object(feature_repo, "_get")
m.return_value = MockHttpResp(
200, json.dumps({"features": {"feature": {"defaultValue": 5, "rules": [{"force": 3, "unknown": "foo"}], "unknown": "foo"}}, "unknown": "foo"})
)

gb = GrowthBook(api_host="https://cdn.growthbook.io", client_key="sdk-abc123")

assert m.call_count == 0

gb.load_features()
m.assert_called_once_with("https://cdn.growthbook.io/api/features/sdk-abc123")

assert gb.get_features()["feature"].to_dict() == {"defaultValue": 5, "rules": [{"force": 3, "hashVersion": 1}]}

feature_repo.clear_cache()
gb.destroy()

def test_sticky_bucket_service(mocker):
# Start forcing everyone to variation1
features = {
"feature": {
"defaultValue": 5,
"rules": [{
"key": "exp",
"variations": [0, 1],
"weights": [0, 1],
"meta": [
{"key": "control"},
{"key": "variation1"}
]
}]
},
}

service = InMemoryStickyBucketService()
gb = GrowthBook(
sticky_bucket_service=service,
attributes={
"id": "1"
},
features=features
)

assert gb.get_feature_value("feature", -1) == 1
assert service.get_assignments("id", "1") == {
"attributeName": "id",
"attributeValue": "1",
"assignments": {
"exp__0": "variation1"
}
}

logger.debug("Change weights and ensure old user still gets variation")
features["feature"]["rules"][0]["weights"] = [1, 0]
gb.set_features(features)
assert gb.get_feature_value("feature", -1) == 1

logger.debug("New GrowthBook instance should also get variation")
gb2 = GrowthBook(
sticky_bucket_service=service,
attributes={
"id": "1"
},
features=features
)
assert gb2.get_feature_value("feature", -1) == 1
gb2.destroy()

logger.debug("New users should get control")
gb.set_attributes({"id": "2"})
assert gb.get_feature_value("feature", -1) == 0

logger.debug("Bumping bucketVersion, should reset sticky buckets")
gb.set_attributes({"id": "1"})
features["feature"]["rules"][0]["bucketVersion"] = 1
gb.set_features(features)
assert gb.get_feature_value("feature", -1) == 0

assert service.get_assignments("id", "1") == {
"attributeName": "id",
"attributeValue": "1",
"assignments": {
"exp__0": "variation1",
"exp__1": "control"
}
}
gb.destroy()

0 comments on commit f5a5d6d

Please sign in to comment.