Skip to content

Commit

Permalink
Add binary dt methods
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed May 8, 2024
1 parent d3ba4c4 commit afd1761
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 53 deletions.
93 changes: 40 additions & 53 deletions src/awkward_pandas/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ def func(x, **kwargs):
return ak.transform(func, arr)


def _run_binary(layout1, layout2, op, kind=None, **kw):
if layout1.is_leaf and (kind is None or layout1.dtype.kind == kind):
return ak.str._apply_through_arrow(op, layout1, layout2, **kw)


def run_binary(arr: ak.Array, other, op, kind=None, **kw) -> ak.Array:
def func(arrays, **kwargs):
x, y = arrays
return _run_binary(x, y, op, kind=kind, **kw)

return ak.transform(func, arr, other)


def dec(func, mode="unary"):
# TODO: require kind= on functions that need timestamps

Expand All @@ -36,6 +49,18 @@ def f(self, *args, **kwargs):
run_unary(self.accessor.array, func, **kwargs)
)

elif mode == "binary":

@functools.wraps(func)
def f(self, other, *args, **kwargs):
if args:
sig = list(inspect.signature(func).parameters)[2:]
kwargs.update({k: arg for k, arg in zip(sig, args)})

return self.accessor.to_output(
run_binary(self.accessor.array, other.ak.array, func, **kwargs)
)

else:
raise NotImplementedError
return f
Expand Down Expand Up @@ -73,59 +98,21 @@ def __init__(self, accessor) -> None:
year = dec(pc.year)
year_month_day = dec(pc.year_month_day)

# the rest are binary
def day_time_interval_between(self, end):
raise NotImplementedError("TODO")

def days_between(self, end):
raise NotImplementedError("TODO")

def hours_between(self, end):
raise NotImplementedError("TODO")

def microseconds_between(self, end):
raise NotImplementedError("TODO")

def milliseconds_between(self, end):
raise NotImplementedError("TODO")

def minutes_between(self, end):
raise NotImplementedError("TODO")

def month_day_nano_interval_between(self, end):
raise NotImplementedError("TODO")

def month_interval_between(self, end):
raise NotImplementedError("TODO")

def nanoseconds_between(self, end):
return self.accessor.to_output(
pc.nanoseconds_between(self.accessor.arrow, end.ak.arrow),
)

def quarters_between(self, end):
raise NotImplementedError("TODO")

def seconds_between(self, end):
return self.accessor.to_output(
pc.seconds_between(self.accessor.arrow, end.ak.arrow)
)

def weeks_between(
self,
end,
/,
*,
count_from_zero=True,
week_start=1,
options=None,
):
raise NotImplementedError("TODO")

def years_between(self, end):
return self.accessor.to_output(
pc.years_between(self.accessor.arrow, end.ak.arrow)
)
day_time_interval_between = dec(pc.day_time_interval_between, mode="binary")
days_between = dec(pc.days_between, mode="binary")
hours_between = dec(pc.hours_between, mode="binary")
microseconds_between = dec(pc.microseconds_between, mode="binary")
milliseconds_between = dec(pc.milliseconds_between, mode="binary")
minutes_between = dec(pc.minutes_between, mode="binary")
month_day_nano_interval_between = dec(
pc.month_day_nano_interval_between, mode="binary"
)
month_interval_between = dec(pc.month_interval_between, mode="binary")
nanoseconds_between = dec(pc.nanoseconds_between, mode="binary")
quarters_between = dec(pc.quarters_between, mode="binary")
seconds_between = dec(pc.seconds_between, mode="binary")
weeks_between = dec(pc.weeks_between, mode="binary")
years_between = dec(pc.years_between, mode="binary")


def _to_arrow(array):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,30 @@ def test_bad_type():
s = pd.Series([[0, 1], [1, 0], [2]])
with pytest.raises(NotImplementedError):
s.ak.dt.second()


def test_binary():
s = pd.Series([[0, 1], [1, 0], [2]])
s2 = s.ak + 1
ts1 = s.ak.dt.cast("timestamp[s]")
ts2 = s2.ak.dt.cast("timestamp[s]")

out = ts1.ak.dt.nanoseconds_between(ts2)
assert out.tolist() == [
[1000000000, 1000000000],
[1000000000, 1000000000],
[1000000000],
]
assert str(out.dtype) == "list<item: int64>[pyarrow]"


def test_binary_with_kwargs():
s = pd.Series([[0, 1], [1, 0], [2]])
s2 = s.ak + int(24 * 3600 * 7 * 2.5)
ts1 = s.ak.dt.cast("timestamp[s]")
ts2 = s2.ak.dt.cast("timestamp[s]")

out = ts1.ak.dt.weeks_between(ts2, count_from_zero=False, week_start=2)
assert out.tolist() == [[2, 2], [2, 2], [2]]
out = ts1.ak.dt.weeks_between(ts2, count_from_zero=False, week_start=5)
assert out.tolist() == [[3, 3], [3, 3], [3]]

0 comments on commit afd1761

Please sign in to comment.