Skip to content

Commit

Permalink
fix: support dtype in __array__ methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Nov 14, 2022
1 parent cbb5fea commit 1294b76
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
24 changes: 14 additions & 10 deletions ibis/backends/tests/test_client.py
Expand Up @@ -753,20 +753,24 @@ def test_default_backend():
assert re.match(rx, sql) is not None


def test_dunder_array_table(alltypes, df):
@pytest.mark.parametrize("dtype", [None, "f8"])
def test_dunder_array_table(alltypes, df, dtype):
expr = alltypes.group_by("string_col").int_col.sum().order_by("string_col")
result = np.array(expr)
expected = np.array(
df.groupby("string_col").int_col.sum().reset_index().sort_values(["string_col"])
)
result = np.asarray(expr, dtype=dtype)
expected = np.asarray(expr.execute(), dtype=dtype)
np.testing.assert_array_equal(result, expected)


@pytest.mark.broken(["dask"], reason="Dask backend duplicates data")
def test_dunder_array_column(alltypes, df):
expr = alltypes.order_by("id").head(10).int_col
result = np.array(expr)
expected = df.sort_values(["id"]).head(10).int_col
@pytest.mark.parametrize("dtype", [None, "f8"])
def test_dunder_array_column(alltypes, df, dtype):
expr = (
alltypes.group_by("string_col")
.agg(int_col=lambda _: _.int_col.sum())
.order_by("string_col")
.int_col
)
result = np.asarray(expr, dtype=dtype)
expected = np.asarray(expr.execute(), dtype=dtype)
np.testing.assert_array_equal(result, expected)


Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/generic.py
Expand Up @@ -514,8 +514,8 @@ class Column(Value, JupyterMixin):

__array_ufunc__ = None

def __array__(self):
return self.execute().__array__()
def __array__(self, dtype=None):
return self.execute().__array__(dtype)

def __rich_console__(self, console, options):
named = self.name(self.op().name)
Expand Down
4 changes: 2 additions & 2 deletions ibis/expr/types/relations.py
Expand Up @@ -90,8 +90,8 @@ class Table(Expr, JupyterMixin):

__array_ufunc__ = None

def __array__(self):
return self.execute().__array__()
def __array__(self, dtype=None):
return self.execute().__array__(dtype)

def __contains__(self, name):
return name in self.schema()
Expand Down

0 comments on commit 1294b76

Please sign in to comment.