171 changes: 164 additions & 7 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def test_array_slice(con, start, stop):
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
reason="argument passes none of the following rules: ...",
)
def test_array_map(backend, con):
t = ibis.memtable(
Expand All @@ -493,24 +493,21 @@ def test_array_map(backend, con):
@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"impala",
"mssql",
"pandas",
"polars",
"postgres",
"snowflake",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["dask", "pandas"],
raises=com.OperationNotDefinedError,
reason="Operation 'ArrayMap' is not implemented for this backend'",
)
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
reason="argument passes none of the following rules: ...",
)
def test_array_filter(backend, con):
t = ibis.memtable(
Expand All @@ -520,3 +517,163 @@ def test_array_filter(backend, con):
result = con.execute(expr)
expected = pd.DataFrame({"a": [[2], [4]]})
backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(
["bigquery", "datafusion", "mssql", "pandas", "polars", "postgres"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["datafusion"], raises=Exception)
@pytest.mark.notimpl(
["dask"], raises=KeyError, reason="array_types table isn't defined"
)
@pytest.mark.never(["impala"], reason="array_types table isn't defined")
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
)
def test_array_contains(backend, con):
t = backend.array_types
expr = t.x.contains(1)
result = con.execute(expr)
expected = t.x.execute().map(lambda lst: 1 in lst)
backend.assert_series_equal(result, expected, check_names=False)


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"impala",
"mssql",
"pandas",
"polars",
"postgres",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
)
def test_array_position(backend, con):
t = ibis.memtable({"a": [[1], [], [42, 42], []]})
expr = t.a.index(42)
result = con.execute(expr)
expected = pd.Series([-1, -1, 0, -1], dtype="object")
backend.assert_series_equal(result, expected, check_names=False, check_dtype=False)


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"impala",
"mssql",
"pandas",
"polars",
"postgres",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
)
def test_array_remove(backend, con):
t = ibis.memtable({"a": [[3, 2], [], [42, 2], [2, 2], []]})
expr = t.a.remove(2)
result = con.execute(expr)
expected = pd.Series([[3], [], [42], [], []], dtype="object")
backend.assert_series_equal(result, expected, check_names=False)


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"impala",
"mssql",
"pandas",
"polars",
"postgres",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
)
def test_array_unique(backend, con):
t = ibis.memtable({"a": [[1, 3, 3], [], [42, 42], []]})
expr = t.a.unique()
result = con.execute(expr).map(set, na_action="ignore")
expected = pd.Series([{3, 1}, set(), {42}, set()], dtype="object")
backend.assert_series_equal(result, expected, check_names=False)


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"impala",
"mssql",
"pandas",
"polars",
"postgres",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
)
def test_array_sort(backend, con):
t = ibis.memtable({"a": [[3, 2], [], [42, 42], []]})
expr = t.a.sort()
result = con.execute(expr)
expected = pd.Series([[2, 3], [], [42, 42], []], dtype="object")
backend.assert_series_equal(result, expected, check_names=False)


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"impala",
"mssql",
"pandas",
"polars",
"postgres",
],
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
["trino", "pyspark"],
raises=AssertionError,
reason="array_distinct([NULL]) seems to differ from other backends",
)
@pytest.mark.notimpl(
["sqlite", "mysql"],
raises=com.IbisTypeError,
reason="argument passes none of the following rules:....",
)
def test_array_union(con):
t = ibis.memtable({"a": [[3, 2], [], []], "b": [[1, 3], [None], [5]]})
expr = t.a.union(t.b)
result = con.execute(expr).map(set, na_action="ignore")
expected = pd.Series([{1, 2, 3}, set(), {5}], dtype="object")
assert len(result) == len(expected)

for i, (lhs, rhs) in enumerate(zip(result, expected)):
assert lhs == rhs, f"row {i:d} differs"
24 changes: 20 additions & 4 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ def _timestamp_from_unix(t, op):
return sa.cast(res, t.get_sqla_type(op.output_dtype))


if_ = getattr(sa.func, "if")


def _neg_idx_to_pos(array, idx):
if_ = getattr(sa.func, "if")
arg_length = sa.func.cardinality(array)
return if_(idx < 0, arg_length + sa.func.greatest(idx, -arg_length), idx)

Expand Down Expand Up @@ -183,7 +185,6 @@ def _unnest(t, op):


def _where(t, op):
if_ = getattr(sa.func, "if")
return if_(
t.translate(op.bool_expr),
t.translate(op.true_expr),
Expand Down Expand Up @@ -277,6 +278,23 @@ def _array_filter(t, op):
lambda arg, times: sa.func.flatten(sa.func.repeat(arg, times)), 2
),
ops.ArraySlice: _array_slice,
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
ops.ArrayContains: fixed_arity(
lambda arr, el: if_(
arr != sa.null(),
sa.func.coalesce(sa.func.contains(arr, el), sa.false()),
sa.null(),
),
2,
),
ops.ArrayPosition: fixed_arity(
lambda lst, el: sa.func.array_position(lst, el) - 1, 2
),
ops.ArrayDistinct: fixed_arity(sa.func.array_distinct, 1),
ops.ArraySort: fixed_arity(sa.func.array_sort, 1),
ops.ArrayRemove: fixed_arity(sa.func.array_remove, 2),
ops.ArrayUnion: fixed_arity(sa.func.array_union, 2),
ops.JSONGetItem: _json_get_item,
ops.ExtractDayOfYear: unary(sa.func.day_of_year),
ops.ExtractWeekOfYear: unary(sa.func.week_of_year),
Expand Down Expand Up @@ -381,8 +399,6 @@ def _array_filter(t, op):
lambda sep, arr: sa.func.array_join(arr, sep), 2
),
ops.StartsWith: fixed_arity(sa.func.starts_with, 2),
ops.ArrayMap: _array_map,
ops.ArrayFilter: _array_filter,
ops.Argument: lambda _, op: sa.literal_column(op.name),
}
)
Expand Down
3 changes: 3 additions & 0 deletions ibis/expr/datatypes/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ def is_multipoint(self) -> bool:
def is_multipolygon(self) -> bool:
return isinstance(self, MultiPolygon)

def is_nested(self) -> bool:
return isinstance(self, (Array, Map, Struct, Set))

def is_null(self) -> bool:
return isinstance(self, Null)

Expand Down
51 changes: 51 additions & 0 deletions ibis/expr/operations/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,54 @@ def output_dtype(self):
return self.arg.output_dtype.value_type

output_shape = rlz.Shape.COLUMNAR


@public
class ArrayContains(Value):
arg = rlz.array
other = rlz.any

output_dtype = dt.boolean
output_shape = rlz.shape_like("args")


@public
class ArrayPosition(Value):
arg = rlz.array
other = rlz.any

output_dtype = dt.int64
output_shape = rlz.shape_like("args")


@public
class ArrayRemove(Value):
arg = rlz.array
other = rlz.any

output_dtype = rlz.dtype_like("arg")
output_shape = rlz.shape_like("args")


@public
class ArrayDistinct(Value):
arg = rlz.array
output_dtype = rlz.dtype_like("arg")
output_shape = rlz.shape_like("arg")


@public
class ArraySort(Value):
arg = rlz.array

output_dtype = rlz.dtype_like("arg")
output_shape = rlz.shape_like("arg")


@public
class ArrayUnion(Value):
left = rlz.array
right = rlz.array

output_dtype = rlz.dtype_like("args")
output_shape = rlz.shape_like("args")
288 changes: 288 additions & 0 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,294 @@ def filter(self, predicate: Callable[[ir.Value], ir.BooleanValue]) -> ir.ArrayVa
"""
return ops.ArrayFilter(self, func=predicate).to_expr()

def contains(self, other: ir.Value) -> ir.BooleanValue:
"""Return whether the array contains `other`.
Parameters
----------
other
Ibis expression to check for existence of in `self`
Returns
-------
BooleanValue
Whether `other` is contained in `self`
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"arr": [[1], [], [42, 42], None]})
>>> t
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ arr ┃
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │
├──────────────────────┤
│ [1] │
│ [] │
│ [42, 42] │
│ NULL │
└──────────────────────┘
>>> t.arr.contains(42)
┏━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayContains(arr, 42) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━┩
│ boolean │
├────────────────────────┤
│ False │
│ False │
│ True │
│ NULL │
└────────────────────────┘
>>> t.arr.contains(None)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayContains(arr, None) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ boolean │
├──────────────────────────┤
│ NULL │
│ NULL │
│ NULL │
│ NULL │
└──────────────────────────┘
"""
return ops.ArrayContains(self, other).to_expr()

def index(self, other: ir.Value) -> ir.IntegerValue:
"""Return the position of `other` in an array.
Parameters
----------
other
Ibis expression to existence of in `self`
Returns
-------
BooleanValue
The position of `other` in `self`
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"arr": [[1], [], [42, 42], None]})
>>> t
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ arr ┃
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │
├──────────────────────┤
│ [1] │
│ [] │
│ [42, 42] │
│ NULL │
└──────────────────────┘
>>> t.arr.index(42)
┏━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayPosition(arr, 42) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━┩
│ int64 │
├────────────────────────┤
│ -1 │
│ -1 │
│ 0 │
│ NULL │
└────────────────────────┘
>>> t.arr.index(800)
┏━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayPosition(arr, 800) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ int64 │
├─────────────────────────┤
│ -1 │
│ -1 │
│ -1 │
│ NULL │
└─────────────────────────┘
>>> t.arr.index(None)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayPosition(arr, None) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ int64 │
├──────────────────────────┤
│ NULL │
│ NULL │
│ NULL │
│ NULL │
└──────────────────────────┘
"""
return ops.ArrayPosition(self, other).to_expr()

def remove(self, other: ir.Value) -> ir.ArrayValue:
"""Remove `other` from `self`.
Parameters
----------
other
Element to remove from `self`.
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"arr": [[3, 2], [], [42, 2], [2, 2], None]})
>>> t
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ arr ┃
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │
├──────────────────────┤
│ [3, 2] │
│ [] │
│ [42, 2] │
│ [2, 2] │
│ NULL │
└──────────────────────┘
>>> t.arr.remove(2)
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayRemove(arr, 2) ┃
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │
├──────────────────────┤
│ [3] │
│ [] │
│ [42] │
│ [] │
│ NULL │
└──────────────────────┘
"""
return ops.ArrayRemove(self, other).to_expr()

def unique(self) -> ir.ArrayValue:
"""Return the unique values in an array.
!!! note "Element ordering in array may not be retained."
Returns
-------
ArrayValue
Unique values in an array
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"arr": [[1, 3, 3], [], [42, 42], None]})
>>> t
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ arr ┃
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │
├──────────────────────┤
│ [1, 3, ... +1] │
│ [] │
│ [42, 42] │
│ NULL │
└──────────────────────┘
>>> t.arr.unique()
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayDistinct(arr) ┃
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │
├──────────────────────┤
│ [3, 1] │
│ [] │
│ [42] │
│ NULL │
└──────────────────────┘
"""
return ops.ArrayDistinct(self).to_expr()

def sort(self) -> ir.ArrayValue:
"""Sort the elements in an array.
Returns
-------
ArrayValue
Sorted values in an array
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"arr": [[3, 2], [], [42, 42], None]})
>>> t
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ arr ┃
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │
├──────────────────────┤
│ [3, 2] │
│ [] │
│ [42, 42] │
│ NULL │
└──────────────────────┘
>>> t.arr.sort()
┏━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArraySort(arr) ┃
┡━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │
├──────────────────────┤
│ [2, 3] │
│ [] │
│ [42, 42] │
│ NULL │
└──────────────────────┘
"""
return ops.ArraySort(self).to_expr()

def union(self, other: ir.ArrayValue) -> ir.ArrayValue:
"""Union two arrays.
Parameters
----------
other
Another array to union with `self`
Returns
-------
ArrayValue
Unioned arrays
Examples
--------
>>> import ibis
>>> ibis.options.interactive = True
>>> t = ibis.memtable({"arr1": [[3, 2], [], None], "arr2": [[1, 3], [None], [5]]})
>>> t
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ arr1 ┃ arr2 ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │ array<int8> │
├──────────────────────┼──────────────────────┤
│ [3, 2] │ [1, 3] │
│ [] │ [None] │
│ NULL │ [5] │
└──────────────────────┴──────────────────────┘
>>> t.arr1.union(t.arr2)
┏━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayUnion(arr1, arr2) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int8> │
├────────────────────────┤
│ [1, 2, ... +1] │
│ [] │
│ [5] │
└────────────────────────┘
>>> t.arr1.union(t.arr2).contains(3)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayContains(ArrayUnion(arr1, arr2), 3) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ boolean │
├──────────────────────────────────────────┤
│ True │
│ False │
│ False │
└──────────────────────────────────────────┘
"""
return ops.ArrayUnion(self, other).to_expr()


@public
class ArrayScalar(Scalar, ArrayValue):
Expand Down