Skip to content

Commit

Permalink
added in sum of iterable post hook
Browse files Browse the repository at this point in the history
  • Loading branch information
ncilfone committed Apr 28, 2022
1 parent 733f987 commit 1a5ec5d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 5 deletions.
51 changes: 47 additions & 4 deletions spock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from time import localtime, strftime
from typing import Any, Dict, List, Tuple, Type, TypeVar, Union
from warnings import warn
from math import isclose

import attr
import git
Expand Down Expand Up @@ -54,15 +55,15 @@ def _get_callable_type():
_C = TypeVar("_C", bound=type)


def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True):
"""Checks that all values passed in the iterable are of the same length
def _filter_optional(val: List, allow_optional: bool = True):
"""Filters an iterable for None values if they are allowed
Args:
val: iterable to compare lengths
val: iterable of values that might contain None
allow_optional: allows the check to succeed if a given val in the iterable is None
Returns:
None
filtered list of values with None values removed
Raises:
_SpockValueError
Expand All @@ -76,6 +77,48 @@ def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True):
)
elif v is not None:
filtered_val.append(v)
return filtered_val


def sum_vals(val: List[Union[float, int, None]], sum_val: Union[float, int], allow_optional: bool = True, rel_tol: float = 1E-9, abs_tol: float = 0.0):
"""Checks if an iterable of values sums within tolerance to a specified value
Args:
val: iterable of values to sum
sum_val: sum value to compare against
allow_optional: allows the check to succeed if a given val in the iterable is None
rel_tol: relative tolerance – it is the maximum allowed difference between a and b
abs_tol: the minimum absolute tolerance – useful for comparisons near zero
Returns:
None
Raises:
_SpockValueError
"""
filtered_val = _filter_optional(val, allow_optional)
if not isclose(sum(filtered_val), sum_val, rel_tol=rel_tol, abs_tol=abs_tol):
raise _SpockValueError(
f"Sum of iterable is `{sum(filtered_val)}` which is not equal to specified value `{sum_val}` within given tolerances"
)


def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True):
"""Checks that all values passed in the iterable are of the same length
Args:
val: iterable to compare lengths
allow_optional: allows the check to succeed if a given val in the iterable is None
Returns:
None
Raises:
_SpockValueError
"""
filtered_val = _filter_optional(val, allow_optional)
# just do a set comprehension -- iterables shouldn't be that long so pay the O(n) price
lens = {len(v) for v in filtered_val}
if len(lens) != 1:
Expand Down
54 changes: 53 additions & 1 deletion tests/base/test_post_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from spock import spock
from spock import SpockBuilder
from spock.utils import within, gt, ge, lt, le, eq_len
from spock.utils import within, gt, ge, lt, le, eq_len, sum_vals
from spock.exceptions import _SpockInstantiationError


Expand Down Expand Up @@ -130,8 +130,60 @@ def __post_hook__(self):
eq_len([self.val_1, self.val_2, self.val_3], allow_optional=True)


@spock
class SumNoneFailConfig:
val_1: float = 0.5
val_2: float = 0.5
val_3: Optional[float] = None

def __post_hook__(self):
sum_vals([self.val_1, self.val_2, self.val_3], sum_val=1.0, allow_optional=False)


@spock
class SumNoneNotEqualConfig:
val_1: float = 0.5
val_2: float = 0.5
val_3: Optional[float] = None

def __post_hook__(self):
sum_vals([self.val_1, self.val_2, self.val_3], sum_val=0.75)


class TestPostHooks:

def test_sum_none_fail_config(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
m.setattr(
sys,
"argv",
[""],
)
with pytest.raises(_SpockInstantiationError):
config = SpockBuilder(
SumNoneFailConfig,
desc="Test Builder",
)
config.generate()

def test_sum_not_equal_config(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
m.setattr(
sys,
"argv",
[""],
)
with pytest.raises(_SpockInstantiationError):
config = SpockBuilder(
SumNoneNotEqualConfig,
desc="Test Builder",
)
config.generate()



def test_eq_len_two_len_fail(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
Expand Down

0 comments on commit 1a5ec5d

Please sign in to comment.