diff --git a/pyproject.toml b/pyproject.toml index 88369b1..dd09088 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ test = [ "awkward-pandas[complete]", "pytest >=6.0", "pytest-cov >=3.0.0", - "distributed" + "distributed", + "polars" ] docs = [ "sphinx", diff --git a/src/awkward_pandas/dask_connect.py b/src/awkward_pandas/dask_connect.py index 4b39bc9..9fc561d 100644 --- a/src/awkward_pandas/dask_connect.py +++ b/src/awkward_pandas/dask_connect.py @@ -16,9 +16,12 @@ def _(x): try: - from dask.dataframe.accessor import Accessor + from dask_expr._accessor import Accessor except (ImportError, ModuleNotFoundError): - Accessor = object + try: + from dask.dataframe.accessor import Accessor + except (ImportError, ModuleNotFoundError): + Accessor = object class DaskAwkwardAccessor(Accessor): diff --git a/src/awkward_pandas/mixin.py b/src/awkward_pandas/mixin.py new file mode 100644 index 0000000..a2168e2 --- /dev/null +++ b/src/awkward_pandas/mixin.py @@ -0,0 +1,117 @@ +import operator + + +def radd(left, right): + return right + left + + +def rsub(left, right): + return right - left + + +def rmul(left, right): + return right * left + + +def rdiv(left, right): + return right / left + + +def rtruediv(left, right): + return right / left + + +def rfloordiv(left, right): + return right // left + + +def rmod(left, right): + # check if right is a string as % is the string + # formatting operation; this is a TypeError + # otherwise perform the op + if isinstance(right, str): + typ = type(left).__name__ + raise TypeError(f"{typ} cannot perform the operation mod") + + return right % left + + +def rdivmod(left, right): + return divmod(right, left) + + +def rpow(left, right): + return right**left + + +def rand_(left, right): + return operator.and_(right, left) + + +def ror_(left, right): + return operator.or_(right, left) + + +def rxor(left, right): + return operator.xor(right, left) + + +class AbstractMethodError(NotImplementedError): + pass + + +class ArithmeticMixin: + @classmethod + def _create_arithmetic_method(cls, op): + raise AbstractMethodError(cls) + + @classmethod + def _create_comparison_method(cls, op): + raise AbstractMethodError(cls) + + @classmethod + def _create_logical_method(cls, op): + raise AbstractMethodError(cls) + + @classmethod + def _add_arithmetic_ops(cls) -> None: + setattr(cls, "__add__", cls._create_arithmetic_method(operator.add)) + setattr(cls, "__radd__", cls._create_arithmetic_method(radd)) + setattr(cls, "__sub__", cls._create_arithmetic_method(operator.sub)) + setattr(cls, "__rsub__", cls._create_arithmetic_method(rsub)) + setattr(cls, "__mul__", cls._create_arithmetic_method(operator.mul)) + setattr(cls, "__rmul__", cls._create_arithmetic_method(rmul)) + setattr(cls, "__pow__", cls._create_arithmetic_method(operator.pow)) + setattr(cls, "__rpow__", cls._create_arithmetic_method(rpow)) + setattr(cls, "__mod__", cls._create_arithmetic_method(operator.mod)) + setattr(cls, "__rmod__", cls._create_arithmetic_method(rmod)) + setattr(cls, "__floordiv__", cls._create_arithmetic_method(operator.floordiv)) + setattr(cls, "__rfloordiv__", cls._create_arithmetic_method(rfloordiv)) + setattr(cls, "__truediv__", cls._create_arithmetic_method(operator.truediv)) + setattr(cls, "__rtruediv__", cls._create_arithmetic_method(rtruediv)) + setattr(cls, "__divmod__", cls._create_arithmetic_method(divmod)) + setattr(cls, "__rdivmod__", cls._create_arithmetic_method(rdivmod)) + + @classmethod + def _add_comparison_ops(cls) -> None: + setattr(cls, "__eq__", cls._create_comparison_method(operator.eq)) + setattr(cls, "__ne__", cls._create_comparison_method(operator.ne)) + setattr(cls, "__lt__", cls._create_comparison_method(operator.lt)) + setattr(cls, "__gt__", cls._create_comparison_method(operator.gt)) + setattr(cls, "__le__", cls._create_comparison_method(operator.le)) + setattr(cls, "__ge__", cls._create_comparison_method(operator.ge)) + + @classmethod + def _add_logical_ops(cls) -> None: + setattr(cls, "__and__", cls._create_logical_method(operator.and_)) + setattr(cls, "__rand__", cls._create_logical_method(rand_)) + setattr(cls, "__or__", cls._create_logical_method(operator.or_)) + setattr(cls, "__ror__", cls._create_logical_method(ror_)) + setattr(cls, "__xor__", cls._create_logical_method(operator.xor)) + setattr(cls, "__rxor__", cls._create_logical_method(rxor)) + + @classmethod + def _add_all(cls): + cls._add_logical_ops() + cls._add_arithmetic_ops() + cls._add_comparison_ops() diff --git a/src/awkward_pandas/polars.py b/src/awkward_pandas/polars.py new file mode 100644 index 0000000..9a67d3d --- /dev/null +++ b/src/awkward_pandas/polars.py @@ -0,0 +1,91 @@ +import functools +from typing import Callable, Iterable, Union + +import awkward as ak +import polars as pl + +from awkward_pandas.mixin import ArithmeticMixin + + +@pl.api.register_series_namespace("ak") +@pl.api.register_dataframe_namespace("ak") +class AwkwardOperations(ArithmeticMixin): + def __init__(self, df: pl.DataFrame): + self._df = df + + def __array_function__(self, *args, **kwargs): + return self.array.__array_function__(*args, **kwargs) + + def __array_ufunc__(self, *args, **kwargs): + if args[1] == "__call__": + return args[0](self.array, *args[3:], **kwargs) + raise NotImplementedError + + def __dir__(self) -> Iterable[str]: + return [ + _ + for _ in (dir(ak)) + if not _.startswith(("_", "ak_")) and not _[0].isupper() + ] + ["apply", "array"] + + def apply(self, fn: Callable) -> pl.DataFrame: + """Perform function on all the values of the series""" + out = fn(self.array) + return ak_to_polars(out) + + def __getitem__(self, item): + # scalars? + out = self.array.__getitem__(item) + result = ak_to_polars(out) + return result + + @property + def array(self): + return ak.from_arrow(self._df.to_arrow()) + + def __getattr__(self, item): + if item not in dir(self): + raise AttributeError + func = getattr(ak, item, None) + + if func: + + @functools.wraps(func) + def f(*others, **kwargs): + others = [ + other.ak.array + if isinstance(other, (pl.DataFrame, pl.Series)) + else other + for other in others + ] + kwargs = { + k: v.ak.array if isinstance(v, (pl.DataFrame, pl.Series)) else v + for k, v in kwargs.items() + } + + ak_arr = func(self.array, *others, **kwargs) + if isinstance(ak_arr, ak.Array): + return ak_to_polars(ak_arr) + return ak_arr + + else: + raise AttributeError(item) + return f + + @classmethod + def _create_op(cls, op): + def run(self, *args, **kwargs): + return ak_to_polars(op(self.array, *args, **kwargs)) + + return run + + _create_arithmetic_method = _create_op + _create_comparison_method = _create_op + _create_logical_method = _create_op + + +AwkwardOperations._add_all() + + +def ak_to_polars(arr: ak.Array) -> Union[pl.DataFrame, pl.Series]: + return pl.from_arrow(ak.to_arrow(arr, extensionarray=False)) diff --git a/tests/test_polars.py b/tests/test_polars.py new file mode 100644 index 0000000..ff5778a --- /dev/null +++ b/tests/test_polars.py @@ -0,0 +1,33 @@ +import numpy as np +import pytest + +import awkward_pandas.polars # noqa: F401 + +pl = pytest.importorskip("polars") + + +def test_simple(): + s = pl.Series([[1, 2, 3], [], [4, 5]]) + s2 = s.ak[:, -1:] + assert s2.to_list() == [[3], [], [5]] + + +def test_apply(): + s = pl.Series([[1, 2, 3], [], [4, 5]]) + s2 = s.ak.apply(np.negative) + assert s2.to_list() == [[-1, -2, -3], [], [-4, -5]] + + +def test_operator(): + s = pl.Series([[1, 2, 3], [], [4, 5]]) + s2 = s.ak + 1 + assert s2.to_list() == [[2, 3, 4], [], [5, 6]] + + +def test_ufunc(): + s = pl.Series([[1, 2, 3], [], [4, 5]]) + s2 = np.negative(s.ak) + assert s2.to_list() == [[-1, -2, -3], [], [-4, -5]] + + s2 = np.add(s.ak, 1) + assert s2.to_list() == [[2, 3, 4], [], [5, 6]]