From 626c1f4f81e3f6efe8cc7f6808b8b3227b9dd957 Mon Sep 17 00:00:00 2001 From: Evan Hubinger Date: Fri, 2 Dec 2022 14:16:00 -0800 Subject: [PATCH] Fix cartesian_product Resolves #688. --- DOCS.md | 59 +++++++++++++++++-- _coconut/__init__.pyi | 1 + coconut/compiler/templates/header.py_template | 22 +++++-- coconut/tests/src/cocotest/agnostic/main.coco | 11 ++-- 4 files changed, 79 insertions(+), 14 deletions(-) diff --git a/DOCS.md b/DOCS.md index 4b4f4d6f8..9502cbcec 100644 --- a/DOCS.md +++ b/DOCS.md @@ -423,9 +423,11 @@ To distribute your code with checkable type annotations, you'll need to include To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all compiled Coconut code will do a number of special things to better integrate with `numpy` (if `numpy` is available to import when the code is run). Specifically: - Coconut's [multidimensional array literal and array concatenation syntax](#multidimensional-array-literalconcatenation-syntax) supports `numpy` objects, including using fast `numpy` concatenation methods if given `numpy` arrays rather than Coconut's default much slower implementation built for Python lists of lists. -- Coconut's [`multi_enumerate`](#multi_enumerate) built-in allows for easily looping over all the multi-dimensional indices in a `numpy` array. -- Coconut's [`all_equal`](#all_equal) built-in allows for easily checking if all the elements in a `numpy` array are the same. -- When a `numpy` object is passed to [`fmap`](#fmap), [`numpy.vectorize`](https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) is used instead of the default `fmap` implementation. +- Many of Coconut's built-ins include special `numpy` support, specifically: + * [`fmap`](#fmap) will use [`numpy.vectorize`](https://numpy.org/doc/stable/reference/generated/numpy.vectorize.html) to map over `numpy` arrays. + * [`multi_enumerate`](#multi_enumerate) allows for easily looping over all the multi-dimensional indices in a `numpy` array. + * [`cartesian_product`](#cartesian_product) can compute the Cartesian product of given `numpy` arrays as a `numpy` array. + * [`all_equal`](#all_equal) allows for easily checking if all the elements in a `numpy` array are the same. - [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html) is registered as a [`collections.abc.Sequence`](https://docs.python.org/3/library/collections.abc.html#collections.abc.Sequence), enabling it to be used in [sequence patterns](#semantics-specification). - Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions). @@ -3101,7 +3103,7 @@ for x in input_data: ### `flatten` -Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `repr`, `in`, `.count()`, `.index()`, and `fmap`. +Coconut provides an enhanced version of `itertools.chain.from_iterable` as a built-in under the name `flatten` with added support for `reversed`, `len`, `repr`, `in`, `.count()`, `.index()`, and `fmap`. ##### Python Docs @@ -3132,6 +3134,55 @@ iter_of_iters = [[1, 2], [3, 4]] flat_it = iter_of_iters |> chain.from_iterable |> list ``` +### `cartesian_product` + +Coconut provides an enhanced version of `itertools.product` as a built-in under the name `cartesian_product` with added support for `len`, `repr`, `in`, `.count()`, and `fmap`. + +Additionally, `cartesian_product` includes special support for [`numpy`](http://www.numpy.org/)/[`pandas`](https://pandas.pydata.org/)/[`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects, in which case a multidimensional array is returned instead of an iterator. + +##### Python Docs + +itertools.**product**(_\*iterables, repeat=1_) + +Cartesian product of input iterables. + +Roughly equivalent to nested for-loops in a generator expression. For example, `product(A, B)` returns the same as `((x,y) for x in A for y in B)`. + +The nested loops cycle like an odometer with the rightmost element advancing on every iteration. This pattern creates a lexicographic ordering so that if the input’s iterables are sorted, the product tuples are emitted in sorted order. + +To compute the product of an iterable with itself, specify the number of repetitions with the optional repeat keyword argument. For example, `product(A, repeat=4)` means the same as `product(A, A, A, A)`. + +This function is roughly equivalent to the following code, except that the actual implementation does not build up intermediate results in memory: + +```coconut_python +def product(*args, repeat=1): + # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy + # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111 + pools = [tuple(pool) for pool in args] * repeat + result = [[]] + for pool in pools: + result = [x+[y] for x in result for y in pool] + for prod in result: + yield tuple(prod) +``` + +Before `product()` runs, it completely consumes the input iterables, keeping pools of values in memory to generate the products. Accordingly, it is only useful with finite inputs. + +##### Example + +**Coconut:** +```coconut +v = [1, 2] +assert cartesian_product(v, v) |> list == [(1, 1), (1, 2), (2, 1), (2, 2)] +``` + +**Python:** +```coconut_python +from itertools import product +v = [1, 2] +assert list(product(v, v)) == [(1, 1), (1, 2), (2, 1), (2, 2)] +``` + ### `multi_enumerate` Coconut's `multi_enumerate` enumerates through an iterable of iterables. `multi_enumerate` works like enumerate, but indexes through inner iterables and produces a tuple index representing the index in each inner iterable. Supports indexing. diff --git a/_coconut/__init__.pyi b/_coconut/__init__.pyi index 9c66413ea..12273869d 100644 --- a/_coconut/__init__.pyi +++ b/_coconut/__init__.pyi @@ -153,6 +153,7 @@ property = property range = range reversed = reversed set = set +setattr = setattr slice = slice str = str sum = sum diff --git a/coconut/compiler/templates/header.py_template b/coconut/compiler/templates/header.py_template index 307b43346..4b500e679 100644 --- a/coconut/compiler/templates/header.py_template +++ b/coconut/compiler/templates/header.py_template @@ -32,7 +32,7 @@ def _coconut_super(type=None, object_or_type=None): numpy_modules = {numpy_modules} jax_numpy_modules = {jax_numpy_modules} abc.Sequence.register(collections.deque) - Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} + Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bytes, classmethod, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, map, min, max, next, object, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray} class _coconut_sentinel{object}: __slots__ = () class _coconut_base_hashable{object}: @@ -43,8 +43,11 @@ class _coconut_base_hashable{object}: return self.__class__ is other.__class__ and self.__reduce__() == other.__reduce__() def __hash__(self): return _coconut.hash(self.__reduce__()) - def __bool__(self): - return True{COMMENT.avoids_expensive_len_calls} + def __bool__(self):{COMMENT.avoids_expensive_len_calls} + return True + def __setstate__(self, setvars):{COMMENT.fixes_unpickling_with_slots} + for k, v in setvars.items(): + _coconut.setattr(self, k, v) class MatchError(_coconut_base_hashable, Exception): """Pattern-matching error. Has attributes .pattern, .value, and .message.""" __slots__ = ("pattern", "value", "_message") @@ -440,6 +443,9 @@ class flatten(_coconut_base_hashable): def __contains__(self, elem): self.iter, new_iter = _coconut_tee(self.iter) return _coconut.any(elem in it for it in new_iter) + def __len__(self): + self.iter, new_iter = _coconut_tee(self.iter) + return _coconut.sum(_coconut.len(it) for it in new_iter) def count(self, elem): """Count the number of times elem appears in the flattened iterable.""" self.iter, new_iter = _coconut_tee(self.iter) @@ -466,11 +472,15 @@ Additionally supports Cartesian products of numpy arrays.""" if kwargs: raise _coconut.TypeError("cartesian_product() got unexpected keyword arguments " + _coconut.repr(kwargs)) if iterables and _coconut.all(it.__class__.__module__ in _coconut.numpy_modules for it in iterables): + if _coconut.any(it.__class__.__module__ in _coconut.jax_numpy_modules for it in iterables): + from jax import numpy + else: + numpy = _coconut.numpy iterables *= repeat la = _coconut.len(iterables) - dtype = _coconut.numpy.result_type(*iterables) - arr = _coconut.numpy.empty([_coconut.len(a) for a in iterables] + [la], dtype=dtype) - for i, a in _coconut.enumerate(_coconut.numpy.ix_(*iterables)): + dtype = numpy.result_type(*iterables) + arr = numpy.empty([_coconut.len(a) for a in iterables] + [la], dtype=dtype) + for i, a in _coconut.enumerate(numpy.ix_(*iterables)): arr[...,i] = a return arr.reshape(-1, la) self = _coconut.object.__new__(cls) diff --git a/coconut/tests/src/cocotest/agnostic/main.coco b/coconut/tests/src/cocotest/agnostic/main.coco index 6ff303379..af74e879f 100644 --- a/coconut/tests/src/cocotest/agnostic/main.coco +++ b/coconut/tests/src/cocotest/agnostic/main.coco @@ -1235,10 +1235,13 @@ def main_test() -> bool: \list = [1, 2, 3] return \list assert test_list() == list((1, 2, 3)) - match def only_one(1) = 1 - only_one.one = 1 - assert only_one.one == 1 - assert cartesian_product() |> list == [] == cartesian_product(repeat=10) |> list + match def one_or_two(1) = one_or_two.one + addpattern def one_or_two(2) = one_or_two.two # type: ignore + one_or_two.one = 10 + one_or_two.two = 20 + assert one_or_two(1) == 10 + assert one_or_two(2) == 20 + assert cartesian_product() |> list == [()] == cartesian_product(repeat=10) |> list assert cartesian_product() |> len == 1 == cartesian_product(repeat=10) |> len assert () in cartesian_product() assert () in cartesian_product(repeat=10)