Skip to content

Commit

Permalink
fix(api): support passing functools.partial objects to array .map
Browse files Browse the repository at this point in the history
…/`.filter` methods
  • Loading branch information
jcrist committed Sep 8, 2023
1 parent 519a9e0 commit 28f45d0
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 7 deletions.
15 changes: 13 additions & 2 deletions ibis/backends/tests/test_array.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import functools
import os

import numpy as np
Expand Down Expand Up @@ -527,9 +528,14 @@ def test_array_slice(backend, start, stop):
)
def test_array_map(backend, con, input, output):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
expected = pd.DataFrame(output)

expr = t.select(a=t.a.map(lambda x: x + 1))
result = con.execute(expr)
expected = pd.DataFrame(output)
backend.assert_frame_equal(result, expected)

expr = t.select(a=t.a.map(functools.partial(lambda x, y: x + y, y=1)))
result = con.execute(expr)
backend.assert_frame_equal(result, expected)


Expand All @@ -555,9 +561,14 @@ def test_array_map(backend, con, input, output):
)
def test_array_filter(backend, con, input, output):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
expected = pd.DataFrame(output)

expr = t.select(a=t.a.filter(lambda x: x > 1))
result = con.execute(expr)
expected = pd.DataFrame(output)
backend.assert_frame_equal(result, expected)

expr = t.select(a=t.a.filter(functools.partial(lambda x, y: x > y, y=1)))
result = con.execute(expr)
backend.assert_frame_equal(result, expected)


Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/operations/arrays.py
Expand Up @@ -80,7 +80,7 @@ class ArrayApply(Value):

@attribute
def parameter(self):
(name,) = self.func.__signature__.parameters.keys()
name = next(iter(self.func.__signature__.parameters.keys()))
return name

@attribute
Expand Down
8 changes: 4 additions & 4 deletions ibis/expr/types/arrays.py
Expand Up @@ -396,8 +396,8 @@ def map(self, func: Callable[[ir.Value], ir.Value]) -> ir.ArrayValue:
"""

@functools.wraps(func)
def wrapped(x):
return func(x.to_expr())
def wrapped(x, **kwargs):
return func(x.to_expr(), **kwargs)

return ops.ArrayMap(self, func=wrapped).to_expr()

Expand Down Expand Up @@ -444,8 +444,8 @@ def filter(
"""

@functools.wraps(predicate)
def wrapped(x):
return predicate(x.to_expr())
def wrapped(x, **kwargs):
return predicate(x.to_expr(), **kwargs)

return ops.ArrayFilter(self, func=wrapped).to_expr()

Expand Down
21 changes: 21 additions & 0 deletions ibis/tests/expr/test_value_exprs.py
@@ -1,5 +1,6 @@
from __future__ import annotations

import functools
import ipaddress
import operator
import uuid
Expand Down Expand Up @@ -1536,12 +1537,32 @@ def test_array_map():
assert result_float.type() == dt.Array(dt.float64)


def test_array_map_partial():
arr = ibis.array([1, 2, 3])

def add(x, y):
return x + y

result = arr.map(functools.partial(add, y=2))
assert result.type() == dt.Array(dt.int16)


def test_array_filter():
arr = ibis.array([1, 2, 3])
result = arr.filter(is_negative)
assert result.type() == arr.type()


def test_array_filter_partial():
arr = ibis.array([1, 2, 3])

def equal(x, y):
return x == y

result = arr.filter(functools.partial(equal, y=2))
assert result.type() == arr.type()


@pytest.mark.parametrize(
("func", "expected_type"),
[
Expand Down

0 comments on commit 28f45d0

Please sign in to comment.