Skip to content

Commit

Permalink
Add ops and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Apr 16, 2024
1 parent c6b0677 commit 618f0e2
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ test = [
"awkward-pandas[complete]",
"pytest >=6.0",
"pytest-cov >=3.0.0",
"distributed"
"distributed",
"polars"
]
docs = [
"sphinx",
Expand Down
117 changes: 117 additions & 0 deletions src/awkward_pandas/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import operator


def radd(left, right):
return right + left


def rsub(left, right):
return right - left


def rmul(left, right):
return right * left


def rdiv(left, right):
return right / left


def rtruediv(left, right):
return right / left


def rfloordiv(left, right):
return right // left


def rmod(left, right):
# check if right is a string as % is the string
# formatting operation; this is a TypeError
# otherwise perform the op
if isinstance(right, str):
typ = type(left).__name__
raise TypeError(f"{typ} cannot perform the operation mod")

return right % left


def rdivmod(left, right):
return divmod(right, left)


def rpow(left, right):
return right**left


def rand_(left, right):
return operator.and_(right, left)


def ror_(left, right):
return operator.or_(right, left)


def rxor(left, right):
return operator.xor(right, left)


class AbstractMethodError(NotImplementedError):
pass


class ArithmeticMixin:
@classmethod
def _create_arithmetic_method(cls, op):
raise AbstractMethodError(cls)

@classmethod
def _create_comparison_method(cls, op):
raise AbstractMethodError(cls)

@classmethod
def _create_logical_method(cls, op):
raise AbstractMethodError(cls)

@classmethod
def _add_arithmetic_ops(cls) -> None:
setattr(cls, "__add__", cls._create_arithmetic_method(operator.add))
setattr(cls, "__radd__", cls._create_arithmetic_method(radd))
setattr(cls, "__sub__", cls._create_arithmetic_method(operator.sub))
setattr(cls, "__rsub__", cls._create_arithmetic_method(rsub))
setattr(cls, "__mul__", cls._create_arithmetic_method(operator.mul))
setattr(cls, "__rmul__", cls._create_arithmetic_method(rmul))
setattr(cls, "__pow__", cls._create_arithmetic_method(operator.pow))
setattr(cls, "__rpow__", cls._create_arithmetic_method(rpow))
setattr(cls, "__mod__", cls._create_arithmetic_method(operator.mod))
setattr(cls, "__rmod__", cls._create_arithmetic_method(rmod))
setattr(cls, "__floordiv__", cls._create_arithmetic_method(operator.floordiv))
setattr(cls, "__rfloordiv__", cls._create_arithmetic_method(rfloordiv))
setattr(cls, "__truediv__", cls._create_arithmetic_method(operator.truediv))
setattr(cls, "__rtruediv__", cls._create_arithmetic_method(rtruediv))
setattr(cls, "__divmod__", cls._create_arithmetic_method(divmod))
setattr(cls, "__rdivmod__", cls._create_arithmetic_method(rdivmod))

@classmethod
def _add_comparison_ops(cls) -> None:
setattr(cls, "__eq__", cls._create_comparison_method(operator.eq))
setattr(cls, "__ne__", cls._create_comparison_method(operator.ne))
setattr(cls, "__lt__", cls._create_comparison_method(operator.lt))
setattr(cls, "__gt__", cls._create_comparison_method(operator.gt))
setattr(cls, "__le__", cls._create_comparison_method(operator.le))
setattr(cls, "__ge__", cls._create_comparison_method(operator.ge))

@classmethod
def _add_logical_ops(cls) -> None:
setattr(cls, "__and__", cls._create_logical_method(operator.and_))
setattr(cls, "__rand__", cls._create_logical_method(rand_))
setattr(cls, "__or__", cls._create_logical_method(operator.or_))
setattr(cls, "__ror__", cls._create_logical_method(ror_))
setattr(cls, "__xor__", cls._create_logical_method(operator.xor))
setattr(cls, "__rxor__", cls._create_logical_method(rxor))

@classmethod
def _add_all(cls):
cls._add_logical_ops()
cls._add_arithmetic_ops()
cls._add_comparison_ops()
26 changes: 25 additions & 1 deletion src/awkward_pandas/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@
import awkward as ak
import polars as pl

from awkward_pandas.mixin import ArithmeticMixin


@pl.api.register_series_namespace("ak")
@pl.api.register_dataframe_namespace("ak")
class AwkwardOperations:
class AwkwardOperations(ArithmeticMixin):
def __init__(self, df: pl.DataFrame):
self._df = df

def __array_function__(self, *args, **kwargs):
return self.array.__array_function__(*args, **kwargs)

def __array_ufunc__(self, *args, **kwargs):
if args[1] == "__call__":
return args[0](self.array, *args[3:], **kwargs)
raise NotImplementedError

def __dir__(self) -> Iterable[str]:
return [
_
Expand Down Expand Up @@ -62,6 +72,20 @@ def f(*others, **kwargs):
raise AttributeError(item)
return f

@classmethod
def _create_op(cls, op):
def run(self, *args, **kwargs):
return ak_to_polars(op(self.array, *args, **kwargs))

return run

_create_arithmetic_method = _create_op
_create_comparison_method = _create_op
_create_logical_method = _create_op


AwkwardOperations._add_all()


def ak_to_polars(arr: ak.Array) -> pl.DataFrame | pl.Series:
return pl.from_arrow(ak.to_arrow(arr, extensionarray=False))
33 changes: 33 additions & 0 deletions tests/test_polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
import pytest

import awkward_pandas.polars # noqa: F401

pl = pytest.importorskip("polars")


def test_simple():
s = pl.Series([[1, 2, 3], [], [4, 5]])
s2 = s.ak[:, -1:]
assert s2.to_list() == [[3], [], [5]]


def test_apply():
s = pl.Series([[1, 2, 3], [], [4, 5]])
s2 = s.ak.apply(np.negative)
assert s2.to_list() == [[-1, -2, -3], [], [-4, -5]]


def test_operator():
s = pl.Series([[1, 2, 3], [], [4, 5]])
s2 = s.ak + 1
assert s2.to_list() == [[2, 3, 4], [], [5, 6]]


def test_ufunc():
s = pl.Series([[1, 2, 3], [], [4, 5]])
s2 = np.negative(s.ak)
assert s2.to_list() == [[-1, -2, -3], [], [-4, -5]]

s2 = np.add(s.ak, 1)
assert s2.to_list() == [[2, 3, 4], [], [5, 6]]

0 comments on commit 618f0e2

Please sign in to comment.