Skip to content

Commit

Permalink
Add xarray support
Browse files Browse the repository at this point in the history
Resolves   #816.
  • Loading branch information
evhub committed Dec 22, 2023
1 parent 9d16819 commit 2abb276
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 39 deletions.
15 changes: 8 additions & 7 deletions DOCS.md
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ To allow for better use of [`numpy`](https://numpy.org/) objects in Coconut, all
- `numpy` objects are allowed seamlessly in Coconut's [implicit coefficient syntax](#implicit-function-application-and-coefficients), allowing the use of e.g. `A B**2` shorthand for `A * B**2` when `A` and `B` are `numpy` arrays (note: **not** `A @ B**2`).
- Coconut supports `@` for matrix multiplication of `numpy` arrays on all Python versions, as well as supplying the `(@)` [operator function](#operator-functions).

Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html), including using `pandas`/`jax`-specific methods over `numpy` methods when given `pandas`/`jax` objects.
Additionally, Coconut provides the exact same support for [`pandas`](https://pandas.pydata.org/), [`xarray`](https://docs.xarray.dev/en/stable/), [`pytorch`](https://pytorch.org/), and [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) objects.

#### `xonsh` Support

Expand Down Expand Up @@ -3383,14 +3383,8 @@ In Haskell, `fmap(func, obj)` takes a data type `obj` and returns a new data typ

`fmap` can also be used on the built-in objects `str`, `dict`, `list`, `tuple`, `set`, `frozenset`, and `dict` as a variant of `map` that returns back an object of the same type.

The behavior of `fmap` for a given object can be overridden by defining an `__fmap__(self, func)` magic method that will be called whenever `fmap` is invoked on that object. Note that `__fmap__` implementations should always satisfy the [Functor Laws](https://wiki.haskell.org/Functor).

For `dict`, or any other `collections.abc.Mapping`, `fmap` will map over the mapping's `.items()` instead of the default iteration through its `.keys()`, with the new mapping reconstructed from the mapped over items. _Deprecated: `fmap$(starmap_over_mappings=True)` will `starmap` over the `.items()` instead of `map` over them._

For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result.

For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s).

For asynchronous iterables, `fmap` will map asynchronously, making `fmap` equivalent in that case to
```coconut_python
async def fmap_over_async_iters(func, async_iter):
Expand All @@ -3399,6 +3393,13 @@ async def fmap_over_async_iters(func, async_iter):
```
such that `fmap` can effectively be used as an async map.

Some objects from external libraries are also given special support:
* For [`numpy`](#numpy-integration) objects, `fmap` will use [`np.vectorize`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.vectorize.html) to produce the result.
* For [`pandas`](https://pandas.pydata.org/) objects, `fmap` will use [`.apply`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.apply.html) along the last axis (so row-wise for `DataFrame`'s, element-wise for `Series`'s).
* For [`xarray`](https://docs.xarray.dev/en/stable/) objects, `fmap` will first convert them into `pandas` objects, apply `fmap`, then convert them back.

The behavior of `fmap` for a given object can be overridden by defining an `__fmap__(self, func)` magic method that will be called whenever `fmap` is invoked on that object. Note that `__fmap__` implementations should always satisfy the [Functor Laws](https://wiki.haskell.org/Functor).

_Deprecated: `fmap(func, obj, fallback_to_init=True)` will fall back to `obj.__class__(map(func, obj))` if no `fmap` implementation is available rather than raise `TypeError`._

##### Example
Expand Down
2 changes: 2 additions & 0 deletions __coconut__/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1466,6 +1466,8 @@ def fmap(func: _t.Callable[[_T, _U], _t.Tuple[_V, _W]], obj: _t.Mapping[_T, _U],
"""
...

_coconut_fmap = fmap


def _coconut_handle_cls_kwargs(**kwargs: _t.Dict[_t.Text, _t.Any]) -> _t.Callable[[_T], _T]: ...

Expand Down
4 changes: 3 additions & 1 deletion _coconut/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ npt = _npt # Fake, like typing
zip_longest = _zip_longest

numpy_modules: _t.Any = ...
pandas_numpy_modules: _t.Any = ...
xarray_modules: _t.Any = ...
pandas_modules: _t.Any = ...
jax_numpy_modules: _t.Any = ...

tee_type: _t.Any = ...
reiterables: _t.Any = ...
fmappables: _t.Any = ...
Expand Down
6 changes: 4 additions & 2 deletions coconut/compiler/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@
justify_len,
report_this_text,
numpy_modules,
pandas_numpy_modules,
pandas_modules,
jax_numpy_modules,
xarray_modules,
self_match_types,
is_data_var,
data_defaults_var,
Expand Down Expand Up @@ -291,7 +292,8 @@ def process_header_args(which, use_hash, target, no_tco, strict, no_wrap):
from_None=" from None" if target.startswith("3") else "",
process_="process_" if target_info >= (3, 13) else "",
numpy_modules=tuple_str_of(numpy_modules, add_quotes=True),
pandas_numpy_modules=tuple_str_of(pandas_numpy_modules, add_quotes=True),
xarray_modules=tuple_str_of(xarray_modules, add_quotes=True),
pandas_modules=tuple_str_of(pandas_modules, add_quotes=True),
jax_numpy_modules=tuple_str_of(jax_numpy_modules, add_quotes=True),
self_match_types=tuple_str_of(self_match_types),
comma_bytearray=", bytearray" if not target.startswith("3") else "",
Expand Down
65 changes: 43 additions & 22 deletions coconut/compiler/templates/header.py_template
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE}
else:
abc.Sequence.register(numpy.ndarray)
numpy_modules = {numpy_modules}
pandas_numpy_modules = {pandas_numpy_modules}
xarray_modules = {xarray_modules}
pandas_modules = {pandas_modules}
jax_numpy_modules = {jax_numpy_modules}
tee_type = type(itertools.tee((), 1)[0])
reiterables = abc.Sequence, abc.Mapping, abc.Set
Expand Down Expand Up @@ -121,6 +122,20 @@ class _coconut_Sentinel(_coconut_baseclass):
_coconut_sentinel = _coconut_Sentinel()
def _coconut_get_base_module(obj):
return obj.__class__.__module__.split(".", 1)[0]
def _coconut_xarray_to_pandas(obj):
import xarray
if isinstance(obj, xarray.Dataset):
return obj.to_dataframe()
elif isinstance(obj, xarray.DataArray):
return obj.to_series()
else:
return obj.to_pandas()
def _coconut_xarray_to_numpy(obj):
import xarray
if isinstance(obj, xarray.Dataset):
return obj.to_dataframe().to_numpy()
else:
return obj.to_numpy()
class MatchError(_coconut_baseclass, Exception):
"""Pattern-matching error. Has attributes .pattern, .value, and .message."""{COMMENT.no_slots_to_allow_setattr_below}
max_val_repr_len = 500
Expand Down Expand Up @@ -752,8 +767,10 @@ Additionally supports Cartesian products of numpy arrays."""
if iterables:
it_modules = [_coconut_get_base_module(it) for it in iterables]
if _coconut.all(mod in _coconut.numpy_modules for mod in it_modules):
if _coconut.any(mod in _coconut.pandas_numpy_modules for mod in it_modules):
iterables = tuple((it.to_numpy() if _coconut_get_base_module(it) in _coconut.pandas_numpy_modules else it) for it in iterables)
if _coconut.any(mod in _coconut.xarray_modules for mod in it_modules):
iterables = tuple((_coconut_xarray_to_numpy(it) if mod in _coconut.xarray_modules else it) for it, mod in _coconut.zip(iterables, it_modules))
if _coconut.any(mod in _coconut.pandas_modules for mod in it_modules):
iterables = tuple((it.to_numpy() if mod in _coconut.pandas_modules else it) for it, mod in _coconut.zip(iterables, it_modules))
if _coconut.any(mod in _coconut.jax_numpy_modules for mod in it_modules):
from jax import numpy
else:
Expand Down Expand Up @@ -1605,7 +1622,9 @@ def fmap(func, obj, **kwargs):
if result is not _coconut.NotImplemented:
return result
obj_module = _coconut_get_base_module(obj)
if obj_module in _coconut.pandas_numpy_modules:
if obj_module in _coconut.xarray_modules:
return {_coconut_}fmap(func, _coconut_xarray_to_pandas(obj)).to_xarray()
if obj_module in _coconut.pandas_modules:
if obj.ndim <= 1:
return obj.apply(func)
return obj.apply(func, axis=obj.ndim-1)
Expand Down Expand Up @@ -1941,7 +1960,9 @@ def all_equal(iterable):
"""
iterable_module = _coconut_get_base_module(iterable)
if iterable_module in _coconut.numpy_modules:
if iterable_module in _coconut.pandas_numpy_modules:
if iterable_module in _coconut.xarray_modules:
iterable = _coconut_xarray_to_numpy(iterable)
elif iterable_module in _coconut.pandas_modules:
iterable = iterable.to_numpy()
return not _coconut.len(iterable) or (iterable == iterable[0]).all()
first_item = _coconut_sentinel
Expand Down Expand Up @@ -2014,8 +2035,11 @@ def _coconut_mk_anon_namedtuple(fields, types=None, of_kwargs=None):
return NT
return NT(**of_kwargs)
def _coconut_ndim(arr):
if (_coconut_get_base_module(arr) in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"):
arr_mod = _coconut_get_base_module(arr)
if (arr_mod in _coconut.numpy_modules or _coconut.hasattr(arr.__class__, "__matconcat__")) and _coconut.hasattr(arr, "ndim"):
return arr.ndim
if arr_mod in _coconut.xarray_modules:{COMMENT.if_we_got_here_its_a_Dataset_not_a_DataArray}
return 2
if not _coconut.isinstance(arr, _coconut.abc.Sequence) or _coconut.isinstance(arr, (_coconut.str, _coconut.bytes)):
return 0
if _coconut.len(arr) == 0:
Expand All @@ -2040,23 +2064,20 @@ def _coconut_expand_arr(arr, new_dims):
arr = [arr]
return arr
def _coconut_concatenate(arrs, axis):
matconcat = None
for a in arrs:
if _coconut.hasattr(a.__class__, "__matconcat__"):
matconcat = a.__class__.__matconcat__
break
a_module = _coconut_get_base_module(a)
if a_module in _coconut.pandas_numpy_modules:
from pandas import concat as matconcat
break
if a_module in _coconut.jax_numpy_modules:
from jax.numpy import concatenate as matconcat
break
if a_module in _coconut.numpy_modules:
matconcat = _coconut.numpy.concatenate
break
if matconcat is not None:
return matconcat(arrs, axis=axis)
return a.__class__.__matconcat__(arrs, axis=axis)
arr_modules = [_coconut_get_base_module(a) for a in arrs]
if any(mod in _coconut.xarray_modules for mod in arr_modules):
return _coconut_concatenate([(_coconut_xarray_to_pandas(a) if mod in _coconut.xarray_modules else a) for a, mod in _coconut.zip(arrs, arr_modules)], axis).to_xarray()
if any(mod in _coconut.pandas_modules for mod in arr_modules):
import pandas
return pandas.concat(arrs, axis=axis)
if any(mod in _coconut.jax_numpy_modules for mod in arr_modules):
import jax.numpy
return jax.numpy.concatenate(arrs, axis=axis)
if any(mod in _coconut.numpy_modules for mod in arr_modules):
return _coconut.numpy.concatenate(arrs, axis=axis)
if not axis:
return _coconut.list(_coconut.itertools.chain.from_iterable(arrs))
return [_coconut_concatenate(rows, axis - 1) for rows in _coconut.zip(*arrs)]
Expand Down Expand Up @@ -2209,4 +2230,4 @@ class _coconut_SupportsInv(_coconut.typing.Protocol):
{def_async_map}
{def_aliases}
_coconut_self_match_types = {self_match_types}
_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_mapreduce, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, filter, groupsof, ident, lift, map, mapreduce, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file}
_coconut_Expected, _coconut_MatchError, _coconut_cartesian_product, _coconut_count, _coconut_cycle, _coconut_enumerate, _coconut_flatten, _coconut_fmap, _coconut_filter, _coconut_groupsof, _coconut_ident, _coconut_lift, _coconut_map, _coconut_mapreduce, _coconut_multiset, _coconut_range, _coconut_reiterable, _coconut_reversed, _coconut_scan, _coconut_starmap, _coconut_tee, _coconut_windowsof, _coconut_zip, _coconut_zip_longest, TYPE_CHECKING, reduce, takewhile, dropwhile = Expected, MatchError, cartesian_product, count, cycle, enumerate, flatten, fmap, filter, groupsof, ident, lift, map, mapreduce, multiset, range, reiterable, reversed, scan, starmap, tee, windowsof, zip, zip_longest, False, _coconut.functools.reduce, _coconut.itertools.takewhile, _coconut.itertools.dropwhile{COMMENT.anything_added_here_should_be_copied_to_stub_file}
10 changes: 8 additions & 2 deletions coconut/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def get_path_env_var(env_var, default):
sys.setrecursionlimit(default_recursion_limit)

# modules that numpy-like arrays can live in
pandas_numpy_modules = (
xarray_modules = (
"xarray",
)
pandas_modules = (
"pandas",
)
jax_numpy_modules = (
Expand All @@ -190,7 +193,8 @@ def get_path_env_var(env_var, default):
"numpy",
"torch",
) + (
pandas_numpy_modules
xarray_modules
+ pandas_modules
+ jax_numpy_modules
)

Expand Down Expand Up @@ -999,6 +1003,7 @@ def get_path_env_var(env_var, default):
("numpy", "py34;py<39"),
("numpy", "py39"),
("pandas", "py36"),
("xarray", "py39"),
),
"tests": (
("pytest", "py<36"),
Expand All @@ -1021,6 +1026,7 @@ def get_path_env_var(env_var, default):
("trollius", "py<3;cpy"): (2, 2),
"requests": (2, 31),
("numpy", "py39"): (1, 26),
("xarray", "py39"): (2023,),
("dataclasses", "py==36"): (0, 8),
("aenum", "py<34"): (3, 1, 15),
"pydata-sphinx-theme": (0, 14),
Expand Down
2 changes: 1 addition & 1 deletion coconut/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
VERSION = "3.0.4"
VERSION_NAME = None
# False for release, int >= 1 for develop
DEVELOP = 9
DEVELOP = 10
ALPHA = False # for pre releases rather than post releases

assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
Expand Down
33 changes: 29 additions & 4 deletions coconut/tests/src/extras.coco
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ from coconut.constants import (
PY34,
PY35,
PY36,
PY39,
PYPY,
) # type: ignore
from coconut._pyparsing import USE_COMPUTATION_GRAPH # type: ignore
Expand Down Expand Up @@ -664,22 +665,46 @@ def test_pandas() -> bool:
return True


def test_xarray() -> bool:
import xarray as xr
import numpy as np
def ds1 `dataset_equal` ds2 = (ds1 == ds2).all().values() |> all
da = xr.DataArray([10, 11;; 12, 13], dims=["x", "y"])
ds = xr.Dataset({"a": da, "b": da + 10})
assert ds$[0] == "a"
ds_ = [da; da + 10]
assert ds `dataset_equal` ds_ # type: ignore
ds__ = [da; da |> fmap$(.+10)]
assert ds `dataset_equal` ds__ # type: ignore
assert ds `dataset_equal` (ds |> fmap$(ident))
assert da.to_numpy() `np.array_equal` (da |> fmap$(ident) |> .to_numpy())
assert (ds |> fmap$(r -> r["a"] + r["b"]) |> .to_numpy()) `np.array_equal` np.array([30; 32;; 34; 36])
assert not all_equal(da)
assert not all_equal(ds)
assert multi_enumerate(da) |> list == [((0, 0), 10), ((0, 1), 11), ((1, 0), 12), ((1, 1), 13)]
assert cartesian_product(da.sel(x=0), da.sel(x=1)) `np.array_equal` np.array([10; 12;; 10; 13;; 11; 12;; 11; 13]) # type: ignore
return True


def test_extras() -> bool:
if not PYPY and (PY2 or PY34):
assert test_numpy() is True
print(".", end="")
if not PYPY and PY36:
assert test_pandas() is True # .
print(".", end="")
if not PYPY and PY39:
assert test_xarray() is True # ..
print(".") # newline bc we print stuff after this
assert test_setup_none() is True # ..
assert test_setup_none() is True # ...
print(".") # ditto
assert test_convenience() is True # ...
assert test_convenience() is True # ....
# everything after here uses incremental parsing, so it must come last
print(".", end="")
assert test_incremental() is True # ....
assert test_incremental() is True # .....
if IPY:
print(".", end="")
assert test_kernel() is True # .....
assert test_kernel() is True # ......
return True


Expand Down

0 comments on commit 2abb276

Please sign in to comment.