Skip to content

Commit

Permalink
Start adding cartesian_product
Browse files Browse the repository at this point in the history
  • Loading branch information
evhub committed Dec 30, 2022
1 parent 0af8a6b commit b70b8a7
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 50 deletions.
2 changes: 1 addition & 1 deletion DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2488,7 +2488,7 @@ Coconut's `map`, `zip`, `filter`, `reversed`, and `enumerate` objects are enhanc
- `reversed`,
- `repr`,
- optimized normal (and iterator) slicing (all but `filter`),
- `len` (all but `filter`),
- `len` (all but `filter`) (though `bool` will still always yield `True`),
- the ability to be iterated over multiple times if the underlying iterators are iterables,
- [PEP 618](https://www.python.org/dev/peps/pep-0618) `zip(..., strict=True)` support on all Python versions, and
- have added attributes which subclasses can make use of to get at the original arguments to the object:
Expand Down
1 change: 1 addition & 0 deletions __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ takewhile = _coconut.itertools.takewhile
dropwhile = _coconut.itertools.dropwhile
tee = _coconut.itertools.tee
starmap = _coconut.itertools.starmap
cartesian_product = _coconut.itertools.product


_coconut_tee = tee
Expand Down
10 changes: 0 additions & 10 deletions coconut/compiler/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,16 +324,6 @@ def pattern_prepender(func):
if set_name is not None:
set_name(cls, k)'''
),
pattern_func_slots=pycondition(
(3, 7),
if_lt=r'''
__slots__ = ("FunctionMatchError", "patterns", "__doc__", "__name__")
''',
if_ge=r'''
__slots__ = ("FunctionMatchError", "patterns", "__doc__", "__name__", "__qualname__")
''',
indent=1,
),
set_qualname_none=pycondition(
(3, 7),
if_ge=r'''
Expand Down
114 changes: 76 additions & 38 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class _coconut_base_hashable{object}:
return self.__class__ is other.__class__ and self.__reduce__() == other.__reduce__()
def __hash__(self):
return _coconut.hash(self.__reduce__())
def __bool__(self):
return True{COMMENT.avoids_expensive_len_calls}
class MatchError(_coconut_base_hashable, Exception):
"""Pattern-matching error. Has attributes .pattern, .value, and .message."""
__slots__ = ("pattern", "value", "_message")
Expand Down Expand Up @@ -369,6 +371,10 @@ class scan(_coconut_base_hashable):
self.func = function
self.iter = iterable
self.initial = initial
def __repr__(self):
return "scan(%r, %s%s)" % (self.func, _coconut.repr(self.iter), "" if self.initial is _coconut_sentinel else ", " + _coconut.repr(self.initial))
def __reduce__(self):
return (self.__class__, (self.func, self.iter, self.initial))
def __iter__(self):
acc = self.initial
if acc is not _coconut_sentinel:
Expand All @@ -381,16 +387,11 @@ class scan(_coconut_base_hashable):
yield acc
def __len__(self):
return _coconut.len(self.iter)
def __repr__(self):
return "scan(%r, %s%s)" % (self.func, _coconut.repr(self.iter), "" if self.initial is _coconut_sentinel else ", " + _coconut.repr(self.initial))
def __reduce__(self):
return (self.__class__, (self.func, self.iter, self.initial))
def __fmap__(self, func):
return _coconut_map(func, self)
class reversed(_coconut_base_hashable):
__slots__ = ("iter",)
if hasattr(_coconut.map, "__doc__"):
__doc__ = _coconut.reversed.__doc__
__doc__ = getattr(_coconut.reversed, "__doc__", "<see help(py_reversed)>")
def __new__(cls, iterable):
if _coconut.isinstance(iterable, _coconut.range):
return iterable[::-1]
Expand All @@ -399,6 +400,10 @@ class reversed(_coconut_base_hashable):
return _coconut.reversed(iterable)
def __init__(self, iterable):
self.iter = iterable
def __repr__(self):
return "reversed(%s)" % (_coconut.repr(self.iter),)
def __reduce__(self):
return (self.__class__, (self.iter,))
def __iter__(self):
return _coconut.iter(_coconut.reversed(self.iter))
def __getitem__(self, index):
Expand All @@ -409,10 +414,6 @@ class reversed(_coconut_base_hashable):
return self.iter
def __len__(self):
return _coconut.len(self.iter)
def __repr__(self):
return "reversed(%s)" % (_coconut.repr(self.iter),)
def __reduce__(self):
return (self.__class__, (self.iter,))
def __contains__(self, elem):
return elem in self.iter
def count(self, elem):
Expand Down Expand Up @@ -444,6 +445,7 @@ class flatten(_coconut_base_hashable):
self.iter, new_iter = _coconut_tee(self.iter)
return _coconut.sum(it.count(elem) for it in new_iter)
def index(self, elem):
"""Find the index of elem in the flattened iterable."""
self.iter, new_iter = _coconut_tee(self.iter)
ind = 0
for it in new_iter:
Expand All @@ -454,10 +456,61 @@ class flatten(_coconut_base_hashable):
raise ValueError("%r not in %r" % (elem, self))
def __fmap__(self, func):
return self.__class__(_coconut_map(_coconut.functools.partial(_coconut_map, func), self.iter))
class cartesian_product(_coconut_base_hashable):
__slots__ = ("iters", "repeat")
__doc__ = getattr(_coconut.itertools.product, "__doc__", "Cartesian product of input iterables.") + """

Additionally supports Cartesian products of numpy arrays."""
def __new__(cls, *iterables, **kwargs):
repeat = kwargs.pop("repeat", 1)
if kwargs:
raise _coconut.TypeError("cartesian_product() got unexpected keyword arguments " + _coconut.repr(kwargs))
if iterables and _coconut.all(it.__class__.__module__ in _coconut.numpy_modules for it in iterables):
iterables *= repeat
la = _coconut.len(iterables)
dtype = _coconut.numpy.result_type(*iterables)
arr = _coconut.numpy.empty([_coconut.len(a) for a in iterables] + [la], dtype=dtype)
for i, a in _coconut.enumerate(_coconut.numpy.ix_(*iterables)):
arr[...,i] = a
return arr.reshape(-1, la)
self = _coconut.object.__new__(cls)
self.iters = iterables
self.repeat = repeat
return self
def __iter__(self):
return _coconut.itertools.product(*self.iters, repeat=self.repeat)
def __repr__(self):
return "cartesian_product(" + ", ".join(_coconut.repr(it) for it in self.iters) + (", repeat=" + _coconut.repr(self.repeat) if self.repeat != 1 else "") + ")"
def __reduce__(self):
return (self.__class__, self.iters, {lbrace}"repeat": self.repeat{rbrace})
@property
def all_iters(self):
return _coconut.itertools.chain.from_iterable(_coconut.itertools.repeat(self.iters, self.repeat))
def __len__(self):
total_len = 1
for it in self.iters:
total_len *= _coconut.len(it)
return total_len ** self.repeat
def __contains__(self, elem):
for e, it in _coconut.zip_longest(elem, self.all_iters, fillvalue=_coconut_sentinel):
if e is _coconut_sentinel or it is _coconut_sentinel or e not in it:
return False
return True
def count(self, elem):
"""Count the number of times elem appears in the product."""
total_count = 1
for e, it in _coconut.zip_longest(elem, self.all_iters, fillvalue=_coconut_sentinel):
if e is _coconut_sentinel or it is _coconut_sentinel:
return 0
total_count *= it.count(e)
if not total_count:
return total_count
return total_count
def __fmap__(self, func):
return _coconut_map(func, self)
class map(_coconut_base_hashable, _coconut.map):
__slots__ = ("func", "iters")
if hasattr(_coconut.map, "__doc__"):
__doc__ = _coconut.map.__doc__
__doc__ = getattr(_coconut.map, "__doc__", "<see help(py_map)>")
def __new__(cls, function, *iterables):
new_map = _coconut.map.__new__(cls, function, *iterables)
new_map.func = function
Expand Down Expand Up @@ -568,8 +621,7 @@ class concurrent_map(_coconut_base_parallel_concurrent_map):
return "concurrent_" + _coconut_map.__repr__(self)
class filter(_coconut_base_hashable, _coconut.filter):
__slots__ = ("func", "iter")
if hasattr(_coconut.filter, "__doc__"):
__doc__ = _coconut.filter.__doc__
__doc__ = getattr(_coconut.filter, "__doc__", "<see help(py_filter)>")
def __new__(cls, function, iterable):
new_filter = _coconut.filter.__new__(cls, function, iterable)
new_filter.func = function
Expand All @@ -587,8 +639,7 @@ class filter(_coconut_base_hashable, _coconut.filter):
return _coconut_map(func, self)
class zip(_coconut_base_hashable, _coconut.zip):
__slots__ = ("iters", "strict")
if hasattr(_coconut.zip, "__doc__"):
__doc__ = _coconut.zip.__doc__
__doc__ = getattr(_coconut.zip, "__doc__", "<see help(py_zip)>")
def __new__(cls, *iterables, **kwargs):
new_zip = _coconut.zip.__new__(cls, *iterables)
new_zip.iters = iterables
Expand All @@ -607,17 +658,14 @@ class zip(_coconut_base_hashable, _coconut.zip):
def __repr__(self):
return "zip(%s%s)" % (", ".join((_coconut.repr(i) for i in self.iters)), ", strict=True" if self.strict else "")
def __reduce__(self):
return (self.__class__, self.iters, self.strict)
def __setstate__(self, strict):
self.strict = strict
return (self.__class__, self.iters, {lbrace}"strict": self.strict{rbrace})
def __iter__(self):
{zip_iter}
def __fmap__(self, func):
return _coconut_map(func, self)
class zip_longest(zip):
__slots__ = ("fillvalue",)
if hasattr(_coconut.zip_longest, "__doc__"):
__doc__ = (_coconut.zip_longest).__doc__
__doc__ = getattr(_coconut.zip_longest, "__doc__", "Version of zip that fills in missing values with fillvalue.")
def __new__(cls, *iterables, **kwargs):
self = _coconut_zip.__new__(cls, *iterables, strict=False)
self.fillvalue = kwargs.pop("fillvalue", None)
Expand Down Expand Up @@ -647,15 +695,12 @@ class zip_longest(zip):
def __repr__(self):
return "zip_longest(%s, fillvalue=%s)" % (", ".join((_coconut.repr(i) for i in self.iters)), _coconut.repr(self.fillvalue))
def __reduce__(self):
return (self.__class__, self.iters, self.fillvalue)
def __setstate__(self, fillvalue):
self.fillvalue = fillvalue
return (self.__class__, self.iters, {lbrace}"fillvalue": self.fillvalue{rbrace})
def __iter__(self):
return _coconut.iter(_coconut.zip_longest(*self.iters, fillvalue=self.fillvalue))
class enumerate(_coconut_base_hashable, _coconut.enumerate):
__slots__ = ("iter", "start")
if hasattr(_coconut.enumerate, "__doc__"):
__doc__ = _coconut.enumerate.__doc__
__doc__ = getattr(_coconut.enumerate, "__doc__", "<see help(py_enumerate)>")
def __new__(cls, iterable, start=0):
new_enumerate = _coconut.enumerate.__new__(cls, iterable, start)
new_enumerate.iter = iterable
Expand Down Expand Up @@ -894,8 +939,7 @@ def _coconut_get_function_match_error():
return _coconut_MatchError
ctx.taken = True
return ctx.exc_class
class _coconut_base_pattern_func(_coconut_base_hashable):
{pattern_func_slots}
class _coconut_base_pattern_func(_coconut_base_hashable):{COMMENT.no_slots_to_allow_func_attrs}
_coconut_is_match = True
def __init__(self, *funcs):
self.FunctionMatchError = _coconut.type(_coconut_py_str("MatchError"), (_coconut_MatchError,), {empty_dict})
Expand Down Expand Up @@ -958,8 +1002,7 @@ _coconut_addpattern = addpattern
{def_prepattern}
class _coconut_partial(_coconut_base_hashable):
__slots__ = ("func", "_argdict", "_arglen", "_pos_kwargs", "_stargs", "keywords")
if hasattr(_coconut.functools.partial, "__doc__"):
__doc__ = _coconut.functools.partial.__doc__
__doc__ = getattr(_coconut.functools.partial, "__doc__", "Partial application of a function.")
def __init__(self, _coconut_func, _coconut_argdict, _coconut_arglen, _coconut_pos_kwargs, *args, **kwargs):
self.func = _coconut_func
self._argdict = _coconut_argdict
Expand All @@ -968,9 +1011,7 @@ class _coconut_partial(_coconut_base_hashable):
self._stargs = args
self.keywords = kwargs
def __reduce__(self):
return (self.__class__, (self.func, self._argdict, self._arglen, self._pos_kwargs) + self._stargs, self.keywords)
def __setstate__(self, keywords):
self.keywords = keywords
return (self.__class__, (self.func, self._argdict, self._arglen, self._pos_kwargs) + self._stargs, {lbrace}"keywords": self.keywords{rbrace})
@property
def args(self):
return _coconut.tuple(self._argdict.get(i) for i in _coconut.range(self._arglen)) + self._stargs
Expand Down Expand Up @@ -1020,8 +1061,7 @@ def consume(iterable, keep_last=0):
return _coconut.collections.deque(iterable, maxlen=keep_last)
class starmap(_coconut_base_hashable, _coconut.itertools.starmap):
__slots__ = ("func", "iter")
if hasattr(_coconut.itertools.starmap, "__doc__"):
__doc__ = _coconut.itertools.starmap.__doc__
__doc__ = getattr(_coconut.itertools.starmap, "__doc__", "starmap(func, iterable) = (func(*args) for args in iterable)")
def __new__(cls, function, iterable):
new_map = _coconut.itertools.starmap.__new__(cls, function, iterable)
new_map.func = function
Expand Down Expand Up @@ -1203,9 +1243,7 @@ class _coconut_lifted(_coconut_base_hashable):
self.func_args = func_args
self.func_kwargs = func_kwargs
def __reduce__(self):
return (self.__class__, (self.func,) + self.func_args, self.func_kwargs)
def __setstate__(self, func_kwargs):
self.func_kwargs = func_kwargs
return (self.__class__, (self.func,) + self.func_args, {lbrace}"func_kwargs": self.func_kwargs{rbrace})
def __call__(self, *args, **kwargs):
return self.func(*(g(*args, **kwargs) for g in self.func_args), **_coconut.dict((k, h(*args, **kwargs)) for k, h in self.func_kwargs.items()))
def __repr__(self):
Expand Down
1 change: 1 addition & 0 deletions coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def get_bool_env_var(env_var, default=False):
"all_equal",
"collectby",
"multi_enumerate",
"cartesian_product",
"py_chr",
"py_hex",
"py_input",
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "2.1.1"
VERSION_NAME = "The Spanish Inquisition"
# False for release, int >= 1 for develop
DEVELOP = 11
DEVELOP = 12
ALPHA = False # for pre releases rather than post releases

# -----------------------------------------------------------------------------------------------------------------------
Expand Down
19 changes: 19 additions & 0 deletions coconut/tests/src/cocotest/agnostic/main.coco
Original file line number Diff line number Diff line change
Expand Up @@ -1235,6 +1235,25 @@ def main_test() -> bool:
\list = [1, 2, 3]
return \list
assert test_list() == list((1, 2, 3))
match def only_one(1) = 1
only_one.one = 1
assert only_one.one == 1
assert cartesian_product() |> list == [] == cartesian_product(repeat=10) |> list
assert cartesian_product() |> len == 1 == cartesian_product(repeat=10) |> len
assert () in cartesian_product()
assert () in cartesian_product(repeat=10)
assert (1,) not in cartesian_product()
assert (1,) not in cartesian_product(repeat=10)
assert cartesian_product().count(()) == 1 == cartesian_product(repeat=10).count(())
v = [1, 2]
assert cartesian_product(v, v) |> list == [(1, 1), (1, 2), (2, 1), (2, 2)] == cartesian_product(v, repeat=2) |> list
assert cartesian_product(v, v) |> len == 4 == cartesian_product(v, repeat=2) |> len
assert (2, 2) in cartesian_product(v, v)
assert (2, 2) in cartesian_product(v, repeat=2)
assert (2, 3) not in cartesian_product(v, v)
assert (2, 3) not in cartesian_product(v, repeat=2)
assert cartesian_product(v, v).count((2, 1)) == 1 == cartesian_product(v, repeat=2).count((2, 1))
assert cartesian_product(v, v).count((2, 0)) == 0 == cartesian_product(v, repeat=2).count((2, 0))
return True

def test_asyncio() -> bool:
Expand Down
10 changes: 10 additions & 0 deletions coconut/tests/src/extras.coco
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,16 @@ def test_numpy() -> bool:
assert all_equal(np.array([1, 1]))
assert all_equal(np.array([1, 1;; 1, 1]))
assert not all_equal(np.array([1, 1;; 1, 2]))
assert (
cartesian_product(np.array([1, 2]), np.array([3, 4]))
`np.array_equal`
np.array([1, 3;; 1, 4;; 2, 3;; 2, 4])
) # type: ignore
assert (
cartesian_product(np.array([1, 2]), repeat=2)
`np.array_equal`
np.array([1, 1;; 1, 2;; 2, 1;; 2, 2])
) # type: ignore
return True


Expand Down

0 comments on commit b70b8a7

Please sign in to comment.