diff --git a/docs/api/conditions.md b/docs/api/conditions.md index b7a7f8f..8fabca8 100644 --- a/docs/api/conditions.md +++ b/docs/api/conditions.md @@ -15,7 +15,7 @@ Register a new condition, either as a decorator: ```python from flags import conditions -@conditions.register('path', validator=conditions.validate_path) +@conditions.register('path') def path_condition(path, request=None, **kwargs): return request.path.startswith(path) ``` @@ -26,12 +26,49 @@ Or as a function call: def path_condition(path, request=None, **kwargs): return request.path.startswith(path) -conditions.register('path', fn=path_condition, validator=conditions.validate_path) +conditions.register('path', fn=path_condition) ``` Will raise a `conditions.DuplicateCondition` exception if the condition name is already registered. -A [validator](https://docs.djangoproject.com/en/stable/ref/validators/) can be given to validate the condition's expected value as provided by [the flag sources](../sources/). +A [validator](https://docs.djangoproject.com/en/stable/ref/validators/) can be given to validate the condition's expected value as provided by [the flag sources](../sources/), either as another callable as an argument to the `register` function: + + +```python +from flags import conditions + +def validate_path(value): + if not value.startswith('/'): + raise ValidationError('Enter a valid path') + +@conditions.register('path', validator=validate_path) +def path_condition(path, request=None, **kwargs): + return request.path.startswith(path) +``` + +Or as an attribute on the condition callable: + +```python +from flags import conditions + +class PathCondition: + def __call__(self, path, request=None, **kwargs): + return request.path.startswith(path) + + def validate(self, value): + if not value.startswith('/'): + raise ValidationError('Enter a valid path') + +conditions.register('path', fn=path_condition) +``` + +Validators specified in both ways are available on condition callables as +a `validate` attribute: + +```python +condition = get_condition('path') +condition.validate(value) +``` ## Exceptions diff --git a/flags/checks.py b/flags/checks.py index e68e238..4114da2 100644 --- a/flags/checks.py +++ b/flags/checks.py @@ -25,9 +25,9 @@ def flag_conditions_check(app_configs, **kwargs): id="flags.E001", ) ) - elif condition.validator is not None: + elif condition.fn.validate is not None: try: - condition.validator(condition.value) + condition.fn.validate(condition.value) except ValidationError as e: errors.append( Warning( diff --git a/flags/conditions/__init__.py b/flags/conditions/__init__.py index 6038558..7825296 100644 --- a/flags/conditions/__init__.py +++ b/flags/conditions/__init__.py @@ -13,7 +13,6 @@ from flags.conditions.registry import ( DuplicateCondition, get_condition, - get_condition_validator, get_conditions, register, ) diff --git a/flags/conditions/registry.py b/flags/conditions/registry.py index 87b286d..570ed3d 100644 --- a/flags/conditions/registry.py +++ b/flags/conditions/registry.py @@ -1,7 +1,6 @@ -# These will be maintained by register() as a global dictionary of -# condition_name: function/validator_function +# This will be maintained by register() as the global dictionary of +# condition_name: function _conditions = {} -_validators = {} class DuplicateCondition(ValueError): @@ -9,7 +8,16 @@ class DuplicateCondition(ValueError): def register(condition_name, fn=None, validator=None): - """ Register a condition to test for flag state. Can be decorator. + """ Register a condition to test for flag state. + + This function can be used as a decorator or the condition callable can be + passed as `fn`. + + Validators can be passed as a separate callable, `validator`, or can be an + attribute of the condition callable, fn.validate. If `validator` is + explicitly given, it will override an existing `validate` attribute of the + condition callable. + Conditions can be any callable that takes a value and some number of required arguments (specified in 'requires') that were passed as kwargs when checking the flag state. """ @@ -31,8 +39,13 @@ def decorator(fn): ) ) + # We attach the validator to the callable to allow for both a single source + # of truth for conditions (_conditions) and to allow for validators to be + # defined on a callable class along with their condition. + if validator is not None or not hasattr(fn, "validate"): + fn.validate = validator + _conditions[condition_name] = fn - _validators[condition_name] = validator def get_conditions(): @@ -44,9 +57,3 @@ def get_condition(condition_name): """ Fetch condition checker functions from the registry """ if condition_name in _conditions: return _conditions[condition_name] - - -def get_condition_validator(condition_name): - """ Fetch condition validators from the registry """ - if condition_name in _validators: - return _validators[condition_name] diff --git a/flags/forms.py b/flags/forms.py index 93f1517..d60919c 100644 --- a/flags/forms.py +++ b/flags/forms.py @@ -1,6 +1,6 @@ from django import forms -from flags.conditions import get_condition_validator, get_conditions +from flags.conditions import get_condition, get_conditions from flags.models import FlagState from flags.sources import get_flags @@ -30,13 +30,16 @@ def __init__(self, *args, **kwargs): ] def clean_value(self): - condition = self.cleaned_data.get("condition") + condition_name = self.cleaned_data.get("condition") value = self.cleaned_data.get("value") - - try: - get_condition_validator(condition)(value) - except Exception as e: - raise forms.ValidationError(e) + condition = get_condition(condition_name) + validator = getattr(condition, "validate") + + if validator is not None: + try: + validator(value) + except Exception as e: + raise forms.ValidationError(e) return value diff --git a/flags/sources.py b/flags/sources.py index 4165273..4011f15 100644 --- a/flags/sources.py +++ b/flags/sources.py @@ -5,7 +5,7 @@ from django.conf import settings from django.utils.module_loading import import_string -from flags.conditions import get_condition, get_condition_validator +from flags.conditions import get_condition logger = logging.getLogger(__name__) @@ -18,7 +18,6 @@ def __init__(self, condition, value, required=False): self.condition = condition self.value = value self.fn = get_condition(self.condition) - self.validator = get_condition_validator(self.condition) self.required = required def __eq__(self, other): diff --git a/flags/tests/test_conditions_registry.py b/flags/tests/test_conditions_registry.py index 33deb19..bb32b2e 100644 --- a/flags/tests/test_conditions_registry.py +++ b/flags/tests/test_conditions_registry.py @@ -3,7 +3,6 @@ from flags.conditions.registry import ( DuplicateCondition, _conditions, - _validators, get_condition, register, ) @@ -16,7 +15,7 @@ def test_register_decorator(self): register("decorated", validator=validator)(fn) self.assertIn("decorated", _conditions) self.assertEqual(_conditions["decorated"], fn) - self.assertEqual(_validators["decorated"], validator) + self.assertEqual(_conditions["decorated"].validate, validator) def test_register_fn(self): fn = lambda conditional_value: True @@ -24,7 +23,7 @@ def test_register_fn(self): register("undecorated", fn=fn, validator=validator) self.assertIn("undecorated", _conditions) self.assertEqual(_conditions["undecorated"], fn) - self.assertEqual(_validators["undecorated"], validator) + self.assertEqual(_conditions["undecorated"].validate, validator) def test_register_dup_condition(self): with self.assertRaises(DuplicateCondition):