From 1a5ec5d8cc72964175a2d1723407a0f77c77a9d2 Mon Sep 17 00:00:00 2001 From: Nicholas Cilfone Date: Thu, 28 Apr 2022 11:35:05 -0400 Subject: [PATCH 1/2] added in sum of iterable post hook --- spock/utils.py | 51 ++++++++++++++++++++++++++++++--- tests/base/test_post_hooks.py | 54 ++++++++++++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 5 deletions(-) diff --git a/spock/utils.py b/spock/utils.py index 6ecc01ad..4b3e9ef9 100644 --- a/spock/utils.py +++ b/spock/utils.py @@ -16,6 +16,7 @@ from time import localtime, strftime from typing import Any, Dict, List, Tuple, Type, TypeVar, Union from warnings import warn +from math import isclose import attr import git @@ -54,15 +55,15 @@ def _get_callable_type(): _C = TypeVar("_C", bound=type) -def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True): - """Checks that all values passed in the iterable are of the same length +def _filter_optional(val: List, allow_optional: bool = True): + """Filters an iterable for None values if they are allowed Args: - val: iterable to compare lengths + val: iterable of values that might contain None allow_optional: allows the check to succeed if a given val in the iterable is None Returns: - None + filtered list of values with None values removed Raises: _SpockValueError @@ -76,6 +77,48 @@ def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True): ) elif v is not None: filtered_val.append(v) + return filtered_val + + +def sum_vals(val: List[Union[float, int, None]], sum_val: Union[float, int], allow_optional: bool = True, rel_tol: float = 1E-9, abs_tol: float = 0.0): + """Checks if an iterable of values sums within tolerance to a specified value + + Args: + val: iterable of values to sum + sum_val: sum value to compare against + allow_optional: allows the check to succeed if a given val in the iterable is None + rel_tol: relative tolerance – it is the maximum allowed difference between a and b + abs_tol: the minimum absolute tolerance – useful for comparisons near zero + + Returns: + None + + Raises: + _SpockValueError + + """ + filtered_val = _filter_optional(val, allow_optional) + if not isclose(sum(filtered_val), sum_val, rel_tol=rel_tol, abs_tol=abs_tol): + raise _SpockValueError( + f"Sum of iterable is `{sum(filtered_val)}` which is not equal to specified value `{sum_val}` within given tolerances" + ) + + +def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True): + """Checks that all values passed in the iterable are of the same length + + Args: + val: iterable to compare lengths + allow_optional: allows the check to succeed if a given val in the iterable is None + + Returns: + None + + Raises: + _SpockValueError + + """ + filtered_val = _filter_optional(val, allow_optional) # just do a set comprehension -- iterables shouldn't be that long so pay the O(n) price lens = {len(v) for v in filtered_val} if len(lens) != 1: diff --git a/tests/base/test_post_hooks.py b/tests/base/test_post_hooks.py index bdf55d3d..c5782eb7 100644 --- a/tests/base/test_post_hooks.py +++ b/tests/base/test_post_hooks.py @@ -8,7 +8,7 @@ from spock import spock from spock import SpockBuilder -from spock.utils import within, gt, ge, lt, le, eq_len +from spock.utils import within, gt, ge, lt, le, eq_len, sum_vals from spock.exceptions import _SpockInstantiationError @@ -130,8 +130,60 @@ def __post_hook__(self): eq_len([self.val_1, self.val_2, self.val_3], allow_optional=True) +@spock +class SumNoneFailConfig: + val_1: float = 0.5 + val_2: float = 0.5 + val_3: Optional[float] = None + + def __post_hook__(self): + sum_vals([self.val_1, self.val_2, self.val_3], sum_val=1.0, allow_optional=False) + + +@spock +class SumNoneNotEqualConfig: + val_1: float = 0.5 + val_2: float = 0.5 + val_3: Optional[float] = None + + def __post_hook__(self): + sum_vals([self.val_1, self.val_2, self.val_3], sum_val=0.75) + + class TestPostHooks: + def test_sum_none_fail_config(self, monkeypatch, tmp_path): + """Test serialization/de-serialization""" + with monkeypatch.context() as m: + m.setattr( + sys, + "argv", + [""], + ) + with pytest.raises(_SpockInstantiationError): + config = SpockBuilder( + SumNoneFailConfig, + desc="Test Builder", + ) + config.generate() + + def test_sum_not_equal_config(self, monkeypatch, tmp_path): + """Test serialization/de-serialization""" + with monkeypatch.context() as m: + m.setattr( + sys, + "argv", + [""], + ) + with pytest.raises(_SpockInstantiationError): + config = SpockBuilder( + SumNoneNotEqualConfig, + desc="Test Builder", + ) + config.generate() + + + def test_eq_len_two_len_fail(self, monkeypatch, tmp_path): """Test serialization/de-serialization""" with monkeypatch.context() as m: From 74d8ae679e66945bc8a50756716646982a7d1ee7 Mon Sep 17 00:00:00 2001 From: Nicholas Cilfone Date: Thu, 28 Apr 2022 11:36:01 -0400 Subject: [PATCH 2/2] linted --- spock/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/spock/utils.py b/spock/utils.py index 4b3e9ef9..f61e96d5 100644 --- a/spock/utils.py +++ b/spock/utils.py @@ -12,11 +12,11 @@ import sys from argparse import _ArgumentGroup from enum import EnumMeta +from math import isclose from pathlib import Path from time import localtime, strftime from typing import Any, Dict, List, Tuple, Type, TypeVar, Union from warnings import warn -from math import isclose import attr import git @@ -80,7 +80,13 @@ def _filter_optional(val: List, allow_optional: bool = True): return filtered_val -def sum_vals(val: List[Union[float, int, None]], sum_val: Union[float, int], allow_optional: bool = True, rel_tol: float = 1E-9, abs_tol: float = 0.0): +def sum_vals( + val: List[Union[float, int, None]], + sum_val: Union[float, int], + allow_optional: bool = True, + rel_tol: float = 1e-9, + abs_tol: float = 0.0, +): """Checks if an iterable of values sums within tolerance to a specified value Args: