diff --git a/ibis/backends/snowflake/registry.py b/ibis/backends/snowflake/registry.py index f47969859809..47dec57e6014 100644 --- a/ibis/backends/snowflake/registry.py +++ b/ibis/backends/snowflake/registry.py @@ -4,8 +4,13 @@ import numpy as np import sqlalchemy as sa +from snowflake.sqlalchemy import ARRAY +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql import sqltypes +from sqlalchemy.sql.functions import GenericFunction import ibis.expr.operations as ops +from ibis import util from ibis.backends.base.sql.alchemy.registry import ( fixed_arity, geospatial_functions, @@ -251,7 +256,11 @@ def _unnest(t, op): *map(t.translate, op.cols) ), ops.ArraySlice: _array_slice, - ops.ArrayCollect: reduction(sa.func.array_agg), + ops.ArrayCollect: reduction( + lambda arg: sa.func.array_agg( + sa.func.ifnull(arg, sa.func.parse_json("null")), type_=ARRAY + ) + ), ops.StringSplit: fixed_arity(sa.func.split, 2), ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.func.to_variant(arg))), ops.All: reduction(sa.func.booland_agg), @@ -284,13 +293,7 @@ def _unnest(t, op): ops.StructColumn: lambda t, op: sa.func.object_construct_keep_null( *itertools.chain.from_iterable(zip(op.names, map(t.translate, op.values))) ), - ops.Unnest: unary( - lambda arg: ( - sa.func.table(sa.func.flatten(arg)) - .table_valued("value") - .columns["value"] - ) - ), + ops.Unnest: _unnest, } )