From 618f0e2f2bef3f1a9d83cf5a7ec4e3a4dc073119 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Tue, 16 Apr 2024 13:07:12 -0400 Subject: [PATCH] Add ops and tests --- pyproject.toml | 3 +- src/awkward_pandas/mixin.py | 117 +++++++++++++++++++++++++++++++++++ src/awkward_pandas/polars.py | 26 +++++++- tests/test_polars.py | 33 ++++++++++ 4 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 src/awkward_pandas/mixin.py create mode 100644 tests/test_polars.py diff --git a/pyproject.toml b/pyproject.toml index eca4a6c..1c56f82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,8 @@ test = [ "awkward-pandas[complete]", "pytest >=6.0", "pytest-cov >=3.0.0", - "distributed" + "distributed", + "polars" ] docs = [ "sphinx", diff --git a/src/awkward_pandas/mixin.py b/src/awkward_pandas/mixin.py new file mode 100644 index 0000000..a2168e2 --- /dev/null +++ b/src/awkward_pandas/mixin.py @@ -0,0 +1,117 @@ +import operator + + +def radd(left, right): + return right + left + + +def rsub(left, right): + return right - left + + +def rmul(left, right): + return right * left + + +def rdiv(left, right): + return right / left + + +def rtruediv(left, right): + return right / left + + +def rfloordiv(left, right): + return right // left + + +def rmod(left, right): + # check if right is a string as % is the string + # formatting operation; this is a TypeError + # otherwise perform the op + if isinstance(right, str): + typ = type(left).__name__ + raise TypeError(f"{typ} cannot perform the operation mod") + + return right % left + + +def rdivmod(left, right): + return divmod(right, left) + + +def rpow(left, right): + return right**left + + +def rand_(left, right): + return operator.and_(right, left) + + +def ror_(left, right): + return operator.or_(right, left) + + +def rxor(left, right): + return operator.xor(right, left) + + +class AbstractMethodError(NotImplementedError): + pass + + +class ArithmeticMixin: + @classmethod + def _create_arithmetic_method(cls, op): + raise AbstractMethodError(cls) + + @classmethod + def _create_comparison_method(cls, op): + raise AbstractMethodError(cls) + + @classmethod + def _create_logical_method(cls, op): + raise AbstractMethodError(cls) + + @classmethod + def _add_arithmetic_ops(cls) -> None: + setattr(cls, "__add__", cls._create_arithmetic_method(operator.add)) + setattr(cls, "__radd__", cls._create_arithmetic_method(radd)) + setattr(cls, "__sub__", cls._create_arithmetic_method(operator.sub)) + setattr(cls, "__rsub__", cls._create_arithmetic_method(rsub)) + setattr(cls, "__mul__", cls._create_arithmetic_method(operator.mul)) + setattr(cls, "__rmul__", cls._create_arithmetic_method(rmul)) + setattr(cls, "__pow__", cls._create_arithmetic_method(operator.pow)) + setattr(cls, "__rpow__", cls._create_arithmetic_method(rpow)) + setattr(cls, "__mod__", cls._create_arithmetic_method(operator.mod)) + setattr(cls, "__rmod__", cls._create_arithmetic_method(rmod)) + setattr(cls, "__floordiv__", cls._create_arithmetic_method(operator.floordiv)) + setattr(cls, "__rfloordiv__", cls._create_arithmetic_method(rfloordiv)) + setattr(cls, "__truediv__", cls._create_arithmetic_method(operator.truediv)) + setattr(cls, "__rtruediv__", cls._create_arithmetic_method(rtruediv)) + setattr(cls, "__divmod__", cls._create_arithmetic_method(divmod)) + setattr(cls, "__rdivmod__", cls._create_arithmetic_method(rdivmod)) + + @classmethod + def _add_comparison_ops(cls) -> None: + setattr(cls, "__eq__", cls._create_comparison_method(operator.eq)) + setattr(cls, "__ne__", cls._create_comparison_method(operator.ne)) + setattr(cls, "__lt__", cls._create_comparison_method(operator.lt)) + setattr(cls, "__gt__", cls._create_comparison_method(operator.gt)) + setattr(cls, "__le__", cls._create_comparison_method(operator.le)) + setattr(cls, "__ge__", cls._create_comparison_method(operator.ge)) + + @classmethod + def _add_logical_ops(cls) -> None: + setattr(cls, "__and__", cls._create_logical_method(operator.and_)) + setattr(cls, "__rand__", cls._create_logical_method(rand_)) + setattr(cls, "__or__", cls._create_logical_method(operator.or_)) + setattr(cls, "__ror__", cls._create_logical_method(ror_)) + setattr(cls, "__xor__", cls._create_logical_method(operator.xor)) + setattr(cls, "__rxor__", cls._create_logical_method(rxor)) + + @classmethod + def _add_all(cls): + cls._add_logical_ops() + cls._add_arithmetic_ops() + cls._add_comparison_ops() diff --git a/src/awkward_pandas/polars.py b/src/awkward_pandas/polars.py index 0d14ae8..c2e90a6 100644 --- a/src/awkward_pandas/polars.py +++ b/src/awkward_pandas/polars.py @@ -4,13 +4,23 @@ import awkward as ak import polars as pl +from awkward_pandas.mixin import ArithmeticMixin + @pl.api.register_series_namespace("ak") @pl.api.register_dataframe_namespace("ak") -class AwkwardOperations: +class AwkwardOperations(ArithmeticMixin): def __init__(self, df: pl.DataFrame): self._df = df + def __array_function__(self, *args, **kwargs): + return self.array.__array_function__(*args, **kwargs) + + def __array_ufunc__(self, *args, **kwargs): + if args[1] == "__call__": + return args[0](self.array, *args[3:], **kwargs) + raise NotImplementedError + def __dir__(self) -> Iterable[str]: return [ _ @@ -62,6 +72,20 @@ def f(*others, **kwargs): raise AttributeError(item) return f + @classmethod + def _create_op(cls, op): + def run(self, *args, **kwargs): + return ak_to_polars(op(self.array, *args, **kwargs)) + + return run + + _create_arithmetic_method = _create_op + _create_comparison_method = _create_op + _create_logical_method = _create_op + + +AwkwardOperations._add_all() + def ak_to_polars(arr: ak.Array) -> pl.DataFrame | pl.Series: return pl.from_arrow(ak.to_arrow(arr, extensionarray=False)) diff --git a/tests/test_polars.py b/tests/test_polars.py new file mode 100644 index 0000000..ff5778a --- /dev/null +++ b/tests/test_polars.py @@ -0,0 +1,33 @@ +import numpy as np +import pytest + +import awkward_pandas.polars # noqa: F401 + +pl = pytest.importorskip("polars") + + +def test_simple(): + s = pl.Series([[1, 2, 3], [], [4, 5]]) + s2 = s.ak[:, -1:] + assert s2.to_list() == [[3], [], [5]] + + +def test_apply(): + s = pl.Series([[1, 2, 3], [], [4, 5]]) + s2 = s.ak.apply(np.negative) + assert s2.to_list() == [[-1, -2, -3], [], [-4, -5]] + + +def test_operator(): + s = pl.Series([[1, 2, 3], [], [4, 5]]) + s2 = s.ak + 1 + assert s2.to_list() == [[2, 3, 4], [], [5, 6]] + + +def test_ufunc(): + s = pl.Series([[1, 2, 3], [], [4, 5]]) + s2 = np.negative(s.ak) + assert s2.to_list() == [[-1, -2, -3], [], [-4, -5]] + + s2 = np.add(s.ak, 1) + assert s2.to_list() == [[2, 3, 4], [], [5, 6]]