Skip to content

Commit

Permalink
Add to arg to all_equal
Browse files Browse the repository at this point in the history
Resolves   #817.
  • Loading branch information
evhub committed Dec 22, 2023
1 parent 2abb276 commit 32ca306
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 7 deletions.
10 changes: 8 additions & 2 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4210,9 +4210,15 @@ _Can't be done without the definition of `windowsof`; see the compiled header fo

#### `all_equal`

**all\_equal**(_iterable_)
**all\_equal**(_iterable_, _to_=`...`)

Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other. `all_equal` assumes transitivity of equality and that `!=` is the negation of `==`. Special support is provided for [`numpy`](#numpy-integration) objects.
Coconut's `all_equal` built-in takes in an iterable and determines whether all of its elements are equal to each other.

If _to_ is passed, `all_equal` will check that all the elements are specifically equal to that value, rather than just equal to each other.

Note that `all_equal` assumes transitivity of equality, that `!=` is the negation of `==`, and that empty arrays always have all their elements equal.

Special support is provided for [`numpy`](#numpy-integration) objects.

##### Example

Expand Down
2 changes: 1 addition & 1 deletion __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1680,7 +1680,7 @@ def lift_apart(func: _t.Callable[..., _W]) -> _t.Callable[..., _t.Callable[...,
...


def all_equal(iterable: _Iterable) -> bool:
def all_equal(iterable: _t.Iterable[_T], to: _T = ...) -> bool:
"""For a given iterable, check whether all elements in that iterable are equal to each other.
Supports numpy arrays. Assumes transitivity and 'x != y' being equivalent to 'not (x == y)'.
Expand Down
7 changes: 4 additions & 3 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -1953,8 +1953,9 @@ class lift_apart(lift):
lift_apart also supports a shortcut form such that lift_apart(f, *func_args, **func_kwargs) is equivalent to lift_apart(f)(*func_args, **func_kwargs).
"""
_apart = True
def all_equal(iterable):
def all_equal(iterable, to=_coconut_sentinel):
"""For a given iterable, check whether all elements in that iterable are equal to each other.
If 'to' is passed, check that all the elements are equal to that value.

Supports numpy arrays. Assumes transitivity and 'x != y' being equivalent to 'not (x == y)'.
"""
Expand All @@ -1964,8 +1965,8 @@ def all_equal(iterable):
iterable = _coconut_xarray_to_numpy(iterable)
elif iterable_module in _coconut.pandas_modules:
iterable = iterable.to_numpy()
return not _coconut.len(iterable) or (iterable == iterable[0]).all()
first_item = _coconut_sentinel
return not _coconut.len(iterable) or (iterable == (iterable[0] if to is _coconut_sentinel else to)).all()
first_item = to
for item in iterable:
if first_item is _coconut_sentinel:
first_item = item
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.0.4"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 10
DEVELOP = 11
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down
3 changes: 3 additions & 0 deletions coconut/tests/src/cocotest/agnostic/primary_2.coco
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,9 @@ def primary_test_2() -> bool:
arr |>= [[3; 4] ;; .]
assert arr == [3; 4;; 1; 2] == [[3; 4] ;; .] |> call$(?, [. ; 2] |> call$(?, 1))
assert (if)(10, 20, 30) == 20 == (if)(0, 10, 20)
assert all_equal([], to=10)
assert all_equal([10; 10; 10; 10], to=10)
assert not all_equal([1, 1], to=10)

with process_map.multiple_sequential_calls(): # type: ignore
assert map((+), range(3), range(4)$[:-1], strict=True) |> list == [0, 2, 4] == process_map((+), range(3), range(4)$[:-1], strict=True) |> list # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions coconut/tests/src/extras.coco
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,9 @@ def test_numpy() -> bool:
assert all_equal(np.array([1, 1]))
assert all_equal(np.array([1, 1;; 1, 1]))
assert not all_equal(np.array([1, 1;; 1, 2]))
assert all_equal(np.array([]), to=10)
assert all_equal(np.array([10; 10;; 10; 10]), to=10)
assert not all_equal(np.array([1, 1]), to=10)
assert (
cartesian_product(np.array([1, 2]), np.array([3, 4]))
`np.array_equal`
Expand Down

0 comments on commit 32ca306

Please sign in to comment.