Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sum of iterable post hook #251

Merged
merged 2 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions spock/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import sys
from argparse import _ArgumentGroup
from enum import EnumMeta
from math import isclose
from pathlib import Path
from time import localtime, strftime
from typing import Any, Dict, List, Tuple, Type, TypeVar, Union
Expand Down Expand Up @@ -54,15 +55,15 @@ def _get_callable_type():
_C = TypeVar("_C", bound=type)


def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True):
"""Checks that all values passed in the iterable are of the same length
def _filter_optional(val: List, allow_optional: bool = True):
"""Filters an iterable for None values if they are allowed

Args:
val: iterable to compare lengths
val: iterable of values that might contain None
allow_optional: allows the check to succeed if a given val in the iterable is None

Returns:
None
filtered list of values with None values removed

Raises:
_SpockValueError
Expand All @@ -76,6 +77,54 @@ def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True):
)
elif v is not None:
filtered_val.append(v)
return filtered_val


def sum_vals(
val: List[Union[float, int, None]],
sum_val: Union[float, int],
allow_optional: bool = True,
rel_tol: float = 1e-9,
abs_tol: float = 0.0,
):
"""Checks if an iterable of values sums within tolerance to a specified value

Args:
val: iterable of values to sum
sum_val: sum value to compare against
allow_optional: allows the check to succeed if a given val in the iterable is None
rel_tol: relative tolerance – it is the maximum allowed difference between a and b
abs_tol: the minimum absolute tolerance – useful for comparisons near zero

Returns:
None

Raises:
_SpockValueError

"""
filtered_val = _filter_optional(val, allow_optional)
if not isclose(sum(filtered_val), sum_val, rel_tol=rel_tol, abs_tol=abs_tol):
raise _SpockValueError(
f"Sum of iterable is `{sum(filtered_val)}` which is not equal to specified value `{sum_val}` within given tolerances"
)


def eq_len(val: List[Union[Tuple, List, None]], allow_optional: bool = True):
"""Checks that all values passed in the iterable are of the same length

Args:
val: iterable to compare lengths
allow_optional: allows the check to succeed if a given val in the iterable is None

Returns:
None

Raises:
_SpockValueError

"""
filtered_val = _filter_optional(val, allow_optional)
# just do a set comprehension -- iterables shouldn't be that long so pay the O(n) price
lens = {len(v) for v in filtered_val}
if len(lens) != 1:
Expand Down
54 changes: 53 additions & 1 deletion tests/base/test_post_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from spock import spock
from spock import SpockBuilder
from spock.utils import within, gt, ge, lt, le, eq_len
from spock.utils import within, gt, ge, lt, le, eq_len, sum_vals
from spock.exceptions import _SpockInstantiationError


Expand Down Expand Up @@ -130,8 +130,60 @@ def __post_hook__(self):
eq_len([self.val_1, self.val_2, self.val_3], allow_optional=True)


@spock
class SumNoneFailConfig:
val_1: float = 0.5
val_2: float = 0.5
val_3: Optional[float] = None

def __post_hook__(self):
sum_vals([self.val_1, self.val_2, self.val_3], sum_val=1.0, allow_optional=False)


@spock
class SumNoneNotEqualConfig:
val_1: float = 0.5
val_2: float = 0.5
val_3: Optional[float] = None

def __post_hook__(self):
sum_vals([self.val_1, self.val_2, self.val_3], sum_val=0.75)


class TestPostHooks:

def test_sum_none_fail_config(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
m.setattr(
sys,
"argv",
[""],
)
with pytest.raises(_SpockInstantiationError):
config = SpockBuilder(
SumNoneFailConfig,
desc="Test Builder",
)
config.generate()

def test_sum_not_equal_config(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
m.setattr(
sys,
"argv",
[""],
)
with pytest.raises(_SpockInstantiationError):
config = SpockBuilder(
SumNoneNotEqualConfig,
desc="Test Builder",
)
config.generate()



def test_eq_len_two_len_fail(self, monkeypatch, tmp_path):
"""Test serialization/de-serialization"""
with monkeypatch.context() as m:
Expand Down