From e929f85b4ffadad432e6d8ee7267c58aea062454 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Sun, 10 Apr 2022 08:52:47 +0200 Subject: [PATCH] refactor(ir): simplify expressions by not storing dtype and name Expression classes should only provide user facing API and the underlying operations should hold all data. This will enable a simpler operation hierarchy (without intermediare expressions) wrapped with a single user facing expression. BREAKING CHANGE: The following are breaking changes due to simplifying expression internals - `ibis.expr.datatypes.DataType.scalar_type` and `DataType.column_type` factory methods have been removed, `DataType.scalar` and `DataType.column` class fields can be used to directly construct a corresponding expression instance (though prefer to use `operation.to_expr()`) - `ibis.expr.types.ValueExpr._name` and `ValueExpr._dtype`` fields are not accassible anymore. While these were not supposed to used directly now `ValueExpr.has_name()`, `ValueExpr.get_name()` and `ValueExpr.type()` methods are the only way to retrieve the expression's name and datatype. - `ibis.expr.operations.Node.output_type` is a property now not a method, decorate those methods with `@property` - `ibis.expr.operations.ValueOp` subclasses must define `output_shape` and `output_dtype` properties from now on (note the datatype abbreviation `dtype` in the property name) - `ibis.expr.rules.cast()`, `scalar_like()` and `array_like()` rules have been removed --- docs/user_guide/design.md | 11 +- ibis/backends/base/sql/alchemy/registry.py | 12 +- ibis/backends/base/sql/alchemy/translator.py | 4 +- .../base/sql/compiler/query_builder.py | 11 +- .../base/sql/compiler/select_builder.py | 22 ++- ibis/backends/base/sql/compiler/translator.py | 54 +++--- ibis/backends/base/sql/ddl.py | 3 +- ibis/backends/base/sql/registry/main.py | 8 + ibis/backends/base/sql/registry/window.py | 4 +- ibis/backends/clickhouse/registry.py | 10 + ibis/backends/dask/execution/generic.py | 7 + ibis/backends/dask/execution/selection.py | 6 +- ibis/backends/dask/execution/util.py | 9 +- ibis/backends/datafusion/compiler.py | 12 +- ibis/backends/impala/tests/test_exprs.py | 2 +- ibis/backends/impala/tests/test_udf.py | 18 +- ibis/backends/impala/udf.py | 6 +- ibis/backends/pandas/execution/generic.py | 7 + ibis/backends/pandas/execution/join.py | 4 +- ibis/backends/pandas/execution/selection.py | 9 +- ibis/backends/pandas/execution/util.py | 4 +- .../pandas/tests/execution/test_window.py | 29 +-- ibis/backends/postgres/compiler.py | 3 +- .../backends/postgres/tests/test_functions.py | 4 +- ibis/backends/postgres/udf.py | 18 +- ibis/backends/pyspark/compiler.py | 7 + .../pyspark/tests/test_timecontext.py | 4 +- ibis/expr/analysis.py | 43 ++--- ibis/expr/datatypes/core.py | 34 +--- ibis/expr/format.py | 12 ++ ibis/expr/lineage.py | 13 +- ibis/expr/operations/analytic.py | 109 ++++++----- ibis/expr/operations/arrays.py | 32 +++- ibis/expr/operations/core.py | 73 +++++++- ibis/expr/operations/generic.py | 176 ++++++++++-------- ibis/expr/operations/geospatial.py | 115 ++++++------ ibis/expr/operations/histograms.py | 17 +- ibis/expr/operations/logical.py | 63 +++---- ibis/expr/operations/maps.py | 53 +++--- ibis/expr/operations/numeric.py | 115 ++++++------ ibis/expr/operations/reductions.py | 84 ++++----- ibis/expr/operations/relations.py | 24 ++- ibis/expr/operations/sortkeys.py | 3 +- ibis/expr/operations/strings.py | 77 +++++--- ibis/expr/operations/temporal.py | 126 ++++++++----- ibis/expr/operations/vectorized.py | 16 +- ibis/expr/rules.py | 136 +++++++------- ibis/expr/timecontext.py | 8 + ibis/expr/types/analytic.py | 20 +- ibis/expr/types/core.py | 10 +- ibis/expr/types/generic.py | 89 ++++----- ibis/expr/types/relations.py | 30 +-- ibis/expr/types/sortkeys.py | 3 - ibis/expr/types/structs.py | 10 +- ibis/expr/types/temporal.py | 10 +- ibis/expr/visualize.py | 2 +- ibis/tests/expr/conftest.py | 7 - ibis/tests/expr/test_format.py | 3 - ibis/tests/expr/test_lineage.py | 11 +- ibis/tests/expr/test_operations.py | 5 +- ibis/tests/expr/test_rules.py | 6 - ibis/tests/expr/test_table.py | 37 ++-- ibis/tests/expr/test_temporal.py | 2 +- ibis/tests/expr/test_timestamp.py | 15 +- ibis/tests/expr/test_udf.py | 5 +- ibis/tests/expr/test_value_exprs.py | 11 +- ibis/tests/expr/test_visualize.py | 5 +- ibis/tests/expr/test_window_functions.py | 20 +- 68 files changed, 1081 insertions(+), 837 deletions(-) diff --git a/docs/user_guide/design.md b/docs/user_guide/design.md index e444acc6445b..b22919861127 100644 --- a/docs/user_guide/design.md +++ b/docs/user_guide/design.md @@ -77,9 +77,14 @@ import ibis.expr.rules as rlz from ibis.expr.operations import ValueOp class Log(ValueOp): - arg = rlz.double # A double scalar or column - base = rlz.optional(rlz.double) # Optional argument, defaults to None - output_type = rlz.typeof('arg') + # A double scalar or column + arg = rlz.double + # Optional argument, defaults to None + base = rlz.optional(rlz.double) + # Output expression's datatype will correspond to arg's datatype + output_dtype = rlz.dtype_like('arg') + # Output expression will be scalar if arg is scalar, column otherwise + output_shape = rlz.shape_like('arg') ``` This class describes an operation called `Log` that takes one required diff --git a/ibis/backends/base/sql/alchemy/registry.py b/ibis/backends/base/sql/alchemy/registry.py index 57a79ee1a348..2d29b202f608 100644 --- a/ibis/backends/base/sql/alchemy/registry.py +++ b/ibis/backends/base/sql/alchemy/registry.py @@ -207,6 +207,13 @@ def _group_concat(t, expr): return sa.func.group_concat(arg, sep) +def _alias(t, expr): + # just compile the underlying argument because the naming is handled + # by the translator for the top level expression + op = expr.op() + return t.translate(op.arg) + + def _literal(t, expr): dtype = expr.type() value = expr.op().value @@ -327,7 +334,9 @@ def _cumulative_to_window(translator, expr, window): klass = _cumulative_to_reduction[type(op)] new_op = klass(*op.args) - new_expr = expr._factory(new_op, name=expr._name) + new_expr = new_op.to_expr() + if expr.has_name(): + new_expr = new_expr.name(expr.get_name()) if type(new_op) in translator._rewrites: new_expr = translator._rewrites[type(new_op)](new_expr) @@ -442,6 +451,7 @@ def _string_join(t, expr): sqlalchemy_operation_registry: Dict[Any, Any] = { + ops.Alias: _alias, ops.And: fixed_arity(sql.and_, 2), ops.Or: fixed_arity(sql.or_, 2), ops.Not: unary(sa.not_), diff --git a/ibis/backends/base/sql/alchemy/translator.py b/ibis/backends/base/sql/alchemy/translator.py index 8b54864667f2..6618568aa2aa 100644 --- a/ibis/backends/base/sql/alchemy/translator.py +++ b/ibis/backends/base/sql/alchemy/translator.py @@ -36,9 +36,7 @@ class AlchemyExprTranslator(ExprTranslator): context_class = AlchemyContext def name(self, translated, name, force=True): - if hasattr(translated, 'label'): - return translated.label(name) - return translated + return translated.label(name) def get_sqla_type(self, data_type): return to_sqla_type(data_type, type_map=self._type_map) diff --git a/ibis/backends/base/sql/compiler/query_builder.py b/ibis/backends/base/sql/compiler/query_builder.py index fca88ea965b1..500a60057e85 100644 --- a/ibis/backends/base/sql/compiler/query_builder.py +++ b/ibis/backends/base/sql/compiler/query_builder.py @@ -525,8 +525,15 @@ class Compiler: @classmethod def make_context(cls, params=None): params = params or {} - params = {expr.op(): value for expr, value in params.items()} - return cls.context_class(compiler=cls, params=params) + + unaliased_params = {} + for expr, value in params.items(): + op = expr.op() + if isinstance(op, ops.Alias): + op = op.arg.op() + unaliased_params[op] = value + + return cls.context_class(compiler=cls, params=unaliased_params) @classmethod def to_ast(cls, expr, context=None): diff --git a/ibis/backends/base/sql/compiler/select_builder.py b/ibis/backends/base/sql/compiler/select_builder.py index 7df2f0135ed8..7f4cd040b314 100644 --- a/ibis/backends/base/sql/compiler/select_builder.py +++ b/ibis/backends/base/sql/compiler/select_builder.py @@ -34,7 +34,7 @@ def get_result(self): else: op = ops.NotExistsSubquery(self.foreign_table, self.predicates) - expr_type = dt.boolean.column_type() + expr_type = dt.boolean.column return expr_type(op) def _visit(self, expr): @@ -467,7 +467,8 @@ def _visit_select_expr(self, expr): new_args.append(arg) if not unchanged: - return expr._factory(type(op)(*new_args)) + new_op = type(op)(*new_args) + return new_op.to_expr() else: return expr else: @@ -500,8 +501,11 @@ def _visit_select_Histogram(self, expr): binwidth = op.binwidth base = op.base - bucket = (op.arg - base) / binwidth - return bucket.floor().name(expr._name) + bucket = ((op.arg - base) / binwidth).floor() + if expr.has_name(): + bucket = bucket.name(expr.get_name()) + + return bucket def _analyze_filter_exprs(self): # What's semantically contained in the filter predicates may need to be @@ -543,10 +547,11 @@ def _visit_filter(self, expr): left = self._visit_filter(op.left) right = self._visit_filter(op.right) unchanged = left is op.left and right is op.right - if not unchanged: - return expr._factory(type(op)(left, right)) - else: + if unchanged: return expr + else: + new_op = type(op)(left, right) + return new_op.to_expr() elif isinstance(op, (ops.Any, ops.TableColumn, ops.Literal)): return expr elif isinstance(op, ops.ValueOp): @@ -559,7 +564,8 @@ def _visit_filter(self, expr): if new is not old: unchanged = False if not unchanged: - return expr._factory(type(op)(*visited)) + new_op = type(op)(*visited) + return new_op.to_expr() else: return expr else: diff --git a/ibis/backends/base/sql/compiler/translator.py b/ibis/backends/base/sql/compiler/translator.py index 905cfcd49520..d950dd74a5f8 100644 --- a/ibis/backends/base/sql/compiler/translator.py +++ b/ibis/backends/base/sql/compiler/translator.py @@ -13,6 +13,7 @@ operation_registry, quote_identifier, ) +from ibis.expr.types.core import unnamed class QueryContext: @@ -200,6 +201,22 @@ def __init__(self, expr, context, named=False, permit_subquery=False): # For now, governing whether the result will have a name self.named = named + def _needs_name(self, expr): + if not self.named: + return False + + op = expr.op() + if isinstance(op, ops.TableColumn): + # This column has been given an explicitly different name + return expr.get_name() != op.name + + return expr.get_name() is not unnamed + + def name(self, translated, name, force=True): + return '{} AS {}'.format( + translated, quote_identifier(name, force=force) + ) + def get_result(self): """Compile SQL expression into a string.""" translated = self.translate(self.expr) @@ -221,27 +238,6 @@ def add_operation(cls, operation, translate_function): """ cls._registry[operation] = translate_function - def _needs_name(self, expr): - if not self.named: - return False - - op = expr.op() - if isinstance(op, ops.TableColumn): - # This column has been given an explicitly different name - if expr.get_name() != op.name: - return True - return False - - if expr.get_name() is ir.core.unnamed: - return False - - return True - - def name(self, translated: str, name: str, force: bool = True) -> str: - return '{} AS {}'.format( - translated, quote_identifier(name, force=force) - ) - def translate(self, expr): # The operation node type the typed expression wraps op = expr.op() @@ -331,7 +327,11 @@ def _bucket(expr): stmt = stmt.when(cmp(op.buckets[-1], op.arg), bucket_id) bucket_id += 1 - return stmt.end().name(expr._name) + result = stmt.end() + if expr.has_name(): + result = result.name(expr.get_name()) + + return result @rewrites(ops.CategoryLabel) @@ -345,7 +345,11 @@ def _category_label(expr): if op.nulls is not None: stmt = stmt.else_(op.nulls) - return stmt.end().name(expr._name) + result = stmt.end() + if expr.has_name(): + result = result.name(expr.get_name()) + + return result @rewrites(ops.Any) @@ -357,7 +361,7 @@ def _any_expand(expr): @rewrites(ops.NotAny) def _notany_expand(expr): arg = expr.op().args[0] - return arg.max() == 0 + return arg.max() == ibis.literal(0, type=arg.type()) @rewrites(ops.All) @@ -369,7 +373,7 @@ def _all_expand(expr): @rewrites(ops.NotAll) def _notall_expand(expr): arg = expr.op().args[0] - return arg.min() == 0 + return arg.min() == ibis.literal(0, type=arg.type()) @rewrites(ops.Cast) diff --git a/ibis/backends/base/sql/ddl.py b/ibis/backends/base/sql/ddl.py index aec57d527c49..e895b5d10670 100644 --- a/ibis/backends/base/sql/ddl.py +++ b/ibis/backends/base/sql/ddl.py @@ -220,7 +220,8 @@ def _partitioned_by(self): if self.partition is not None: return 'PARTITIONED BY ({})'.format( ', '.join( - quote_identifier(expr._name) for expr in self.partition + quote_identifier(expr.get_name()) + for expr in self.partition ) ) return None diff --git a/ibis/backends/base/sql/registry/main.py b/ibis/backends/base/sql/registry/main.py index 295ef05ffa3a..fba25caca758 100644 --- a/ibis/backends/base/sql/registry/main.py +++ b/ibis/backends/base/sql/registry/main.py @@ -8,6 +8,13 @@ from .literal import literal, null_literal +def alias(translator, expr): + # just compile the underlying argument because the naming is handled + # by the translator for the top level expression + op = expr.op() + return translator.translate(op.arg) + + def fixed_arity(func_name, arity): def formatter(translator, expr): op = expr.op() @@ -237,6 +244,7 @@ def hash(translator, expr): operation_registry = { + ops.Alias: alias, # Unary operations ops.NotNull: not_null, ops.IsNull: is_null, diff --git a/ibis/backends/base/sql/registry/window.py b/ibis/backends/base/sql/registry/window.py index 1826d0604c50..520623ab3e26 100644 --- a/ibis/backends/base/sql/registry/window.py +++ b/ibis/backends/base/sql/registry/window.py @@ -95,7 +95,9 @@ def cumulative_to_window(translator, expr, window): klass = _cumulative_to_reduction[type(op)] new_op = klass(*op.args) - new_expr = expr._factory(new_op, name=expr._name) + new_expr = new_op.to_expr() + if expr.has_name(): + new_expr = new_expr.name(expr.get_name()) if type(new_op) in translator._rewrites: new_expr = translator._rewrites[type(new_op)](new_expr) diff --git a/ibis/backends/clickhouse/registry.py b/ibis/backends/clickhouse/registry.py index 1bc0b6556499..c0f744869038 100644 --- a/ibis/backends/clickhouse/registry.py +++ b/ibis/backends/clickhouse/registry.py @@ -9,6 +9,15 @@ from .identifiers import quote_identifier +# TODO(kszucs): should inherit operation registry from the base compiler + + +def _alias(translator, expr): + # just compile the underlying argument because the naming is handled + # by the translator for the top level expression + op = expr.op() + return translator.translate(op.arg) + def _cast(translator, expr): from .client import ClickhouseDataType @@ -624,6 +633,7 @@ def _string_right(translator, expr): operation_registry = { + ops.Alias: _alias, # Unary operations ops.TypeOf: _unary('toTypeName'), ops.IsNan: _unary('isNaN'), diff --git a/ibis/backends/dask/execution/generic.py b/ibis/backends/dask/execution/generic.py index 376dc13cd14e..c829694c56d9 100644 --- a/ibis/backends/dask/execution/generic.py +++ b/ibis/backends/dask/execution/generic.py @@ -157,6 +157,13 @@ def execute_node_value_list(op, _, **kwargs): return [execute(arg, **kwargs) for arg in op.values] +@execute_node.register(ops.Alias, object) +def execute_alias_series(op, _, **kwargs): + # just compile the underlying argument because the naming is handled + # by the translator for the top level expression + return execute(op.arg, **kwargs) + + @execute_node.register(ops.Arbitrary, dd.Series, (dd.Series, type(None))) def execute_arbitrary_series_mask(op, data, mask, aggcontext=None, **kwargs): """ diff --git a/ibis/backends/dask/execution/selection.py b/ibis/backends/dask/execution/selection.py index e87fa7bc5dd1..e60a7af9aaed 100644 --- a/ibis/backends/dask/execution/selection.py +++ b/ibis/backends/dask/execution/selection.py @@ -42,7 +42,7 @@ def compute_projection_scalar_expr( timecontext: Optional[TimeContext] = None, **kwargs, ): - name = expr._name + name = expr.get_name() assert name is not None, 'Scalar selection name is None' op = expr.op() @@ -76,7 +76,7 @@ def compute_projection_column_expr( timecontext: Optional[TimeContext], **kwargs, ): - result_name = getattr(expr, '_name', None) + result_name = expr._safe_name op = expr.op() parent_table_op = parent.table.op() @@ -111,8 +111,6 @@ def compute_projection_column_expr( result = execute(expr, scope=scope, timecontext=timecontext, **kwargs) result = coerce_to_output(result, expr, data.index) - assert result_name is not None, 'Column selection name is None' - return result diff --git a/ibis/backends/dask/execution/util.py b/ibis/backends/dask/execution/util.py index 71368d9601d7..924051b43a37 100644 --- a/ibis/backends/dask/execution/util.py +++ b/ibis/backends/dask/execution/util.py @@ -92,9 +92,8 @@ def coerce_to_output( Examples -------- - For dataframe outputs, see ``_coerce_to_dataframe``. Examples below use - pandas objects for legibility, but functionality is the same on dask - objects. + Examples below use pandas objects for legibility, but functionality is the + same on dask objects. >>> coerce_to_output(pd.Series(1), expr) 0 1 @@ -111,7 +110,7 @@ def coerce_to_output( 0 [1, 2, 3] Name: result, dtype: object """ - result_name = expr.get_name() + result_name = expr._safe_name dataframe_exprs = ( ir.DestructColumn, ir.StructColumn, @@ -125,7 +124,7 @@ def coerce_to_output( elif isinstance(result, (pd.Series, dd.Series)): # Series from https://github.com/ibis-project/ibis/issues/2711 return result.rename(result_name) - elif isinstance(expr.op(), ops.Reduction): + elif isinstance(expr, ir.ScalarExpr): if isinstance(result, dd.core.Scalar): # wrap the scalar in a series out_dtype = _pandas_dtype_from_dd_scalar(result) diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 02b69559c3c3..5638034a117c 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -33,6 +33,12 @@ def table(op, expr): return client._context.table(name) +@translate.register(ops.Alias) +def alias(op, expr): + arg = translate(op.arg) + return arg.alias(op.name) + + @translate.register(ops.Literal) def literal(op, expr): if isinstance(op.value, (set, frozenset)): @@ -84,13 +90,9 @@ def selection(op, expr): for name in expr.columns: column = expr.get_column(name) field = translate(column) - if column.has_name(): - field = field.alias(column.get_name()) selections.append(field) elif isinstance(expr, ir.ValueExpr): field = translate(expr) - if expr.has_name(): - field = field.alias(expr.get_name()) selections.append(field) else: raise com.TranslationError( @@ -120,8 +122,6 @@ def aggregation(op, expr): metrics = [] for expr in op.metrics: agg = translate(expr) - if expr.has_name(): - agg = agg.alias(expr.get_name()) metrics.append(agg) return table.aggregate(group_by, metrics) diff --git a/ibis/backends/impala/tests/test_exprs.py b/ibis/backends/impala/tests/test_exprs.py index 1f7075673e1e..6415ced48103 100644 --- a/ibis/backends/impala/tests/test_exprs.py +++ b/ibis/backends/impala/tests/test_exprs.py @@ -53,7 +53,7 @@ def test_decimal_metadata(con): # TODO: what if user impyla version does not have decimal Metadata? -def test_builtins_1(con, alltypes): +def test_builtins(con, alltypes): table = alltypes i1 = table.tinyint_col diff --git a/ibis/backends/impala/tests/test_udf.py b/ibis/backends/impala/tests/test_udf.py index 0c53f0a07a3d..9245f1ffb6a5 100644 --- a/ibis/backends/impala/tests/test_udf.py +++ b/ibis/backends/impala/tests/test_udf.py @@ -143,13 +143,10 @@ def test_udf_primitive_output_types(ty, value, column, table): ibis_type = dt.validate_type(ty) expr = func(value) - assert type(expr) == type( # noqa: E501, E721 - ibis_type.scalar_type()(expr.op()) - ) + assert type(expr) == ibis_type.scalar + expr = func(table[column]) - assert type(expr) == type( # noqa: E501, E721 - ibis_type.column_type()(expr.op()) - ) + assert type(expr) == ibis_type.column @pytest.mark.parametrize( @@ -170,17 +167,16 @@ def test_udf_primitive_output_types(ty, value, column, table): ), ], ) -def test_uda_primitive_output_types(ty, value, table): +def test_uda_primitive_output_types(ty, value): func = _register_uda([ty], ty, 'test') ibis_type = dt.validate_type(ty) expr1 = func(value) + assert isinstance(expr1, ibis_type.scalar) + expr2 = func(value) - expected_type1 = type(ibis_type.scalar_type()(expr1.op())) - expected_type2 = type(ibis_type.scalar_type()(expr2.op())) - assert isinstance(expr1, expected_type1) - assert isinstance(expr2, expected_type2) + assert isinstance(expr2, ibis_type.scalar) def test_decimal(dec): diff --git a/ibis/backends/impala/udf.py b/ibis/backends/impala/udf.py index a876758f7cd5..904ec899b08c 100644 --- a/ibis/backends/impala/udf.py +++ b/ibis/backends/impala/udf.py @@ -71,7 +71,8 @@ def _create_operation_class(self): fields = { f'_{i}': rlz.value(dtype) for i, dtype in enumerate(self.inputs) } - fields['output_type'] = rlz.shape_like('args', self.output) + fields['output_dtype'] = self.output + fields['output_shape'] = rlz.shape_like('args') return type(f"UDF_{self.name}", (ops.ValueOp,), fields) @@ -80,7 +81,8 @@ def _create_operation_class(self): fields = { f'_{i}': rlz.value(dtype) for i, dtype in enumerate(self.inputs) } - fields['output_type'] = lambda op: self.output.scalar_type() + fields['output_dtype'] = self.output + fields['output_shape'] = rlz.Shape.SCALAR fields['_reduction'] = True return type(f"UDA_{self.name}", (ops.ValueOp,), fields) diff --git a/ibis/backends/pandas/execution/generic.py b/ibis/backends/pandas/execution/generic.py index d050a9e5c36b..5d0e905e04b5 100644 --- a/ibis/backends/pandas/execution/generic.py +++ b/ibis/backends/pandas/execution/generic.py @@ -848,6 +848,13 @@ def execute_node_self_reference_dataframe(op, data, **kwargs): return data +@execute_node.register(ops.Alias, object) +def execute_alias(op, _, **kwargs): + # just compile the underlying argument because the naming is handled + # by the translator for the top level expression + return execute(op.arg, **kwargs) + + @execute_node.register(ops.ValueList, collections.abc.Sequence) def execute_node_value_list(op, _, **kwargs): return [execute(arg, **kwargs) for arg in op.values] diff --git a/ibis/backends/pandas/execution/join.py b/ibis/backends/pandas/execution/join.py index 48edb9995363..f967436c4334 100644 --- a/ibis/backends/pandas/execution/join.py +++ b/ibis/backends/pandas/execution/join.py @@ -168,8 +168,8 @@ def _extract_predicate_names(predicates): raise TypeError( 'Only equality join predicates supported with pandas' ) - left_name = predicate.left._name - right_name = predicate.right._name + left_name = predicate.left.get_name() + right_name = predicate.right.get_name() lefts.append(left_name) rights.append(right_name) return lefts, rights diff --git a/ibis/backends/pandas/execution/selection.py b/ibis/backends/pandas/execution/selection.py index 9ae6b0b6964e..2b396c92a6c7 100644 --- a/ibis/backends/pandas/execution/selection.py +++ b/ibis/backends/pandas/execution/selection.py @@ -57,7 +57,7 @@ def compute_projection_scalar_expr( timecontext: Optional[TimeContext] = None, **kwargs, ): - name = expr._name + name = expr.get_name() assert name is not None, 'Scalar selection name is None' op = expr.op() @@ -97,7 +97,7 @@ def compute_projection_column_expr( timecontext: Optional[TimeContext], **kwargs, ): - result_name = getattr(expr, '_name', None) + result_name = expr._safe_name op = expr.op() parent_table_op = parent.table.op() @@ -137,7 +137,6 @@ def compute_projection_column_expr( expr, data.index, ) - assert result_name is not None, 'Column selection name is None' return result @@ -323,7 +322,9 @@ def build_df_from_selection( if selection + join_suffix not in data: raise KeyError(selection) selection += join_suffix - cols[selection].append(getattr(expr, "_name", selection)) + cols[selection].append( + expr.get_name() if expr.has_name() else selection + ) result = data[list(cols.keys())] diff --git a/ibis/backends/pandas/execution/util.py b/ibis/backends/pandas/execution/util.py index e257e6bd3769..8615b70c388c 100644 --- a/ibis/backends/pandas/execution/util.py +++ b/ibis/backends/pandas/execution/util.py @@ -115,7 +115,7 @@ def coerce_to_output( 0 [1, 2, 3] Name: result, dtype: object """ - result_name = getattr(expr, '_name', None) + result_name = expr._safe_name if isinstance(expr, (ir.DestructColumn, ir.StructColumn)): return _coerce_to_dataframe(result, expr.type()) @@ -126,7 +126,7 @@ def coerce_to_output( return _coerce_to_dataframe(result, expr.type()) elif isinstance(result, pd.Series): return result.rename(result_name) - elif isinstance(expr.op(), ops.Reduction): + elif isinstance(expr, ir.ScalarExpr): if index is None: # Wrap `result` into a single-element Series. return pd.Series([result], name=result_name) diff --git a/ibis/backends/pandas/tests/execution/test_window.py b/ibis/backends/pandas/tests/execution/test_window.py index 3cd861bd6a41..e9381b4a16cd 100644 --- a/ibis/backends/pandas/tests/execution/test_window.py +++ b/ibis/backends/pandas/tests/execution/test_window.py @@ -500,7 +500,7 @@ def test_project_scalar_after_join(): 'value': [0, 1, np.nan, 3, 4, np.nan, 6, 7, 8], } ) - con = Backend().connect({'left': left_df, 'right': right_df}) + con = ibis.pandas.connect({'left': left_df, 'right': right_df}) left, right = map(con.table, ('left', 'right')) joined = left.outer_join(right, left.ints == right.group) @@ -513,12 +513,13 @@ def test_project_scalar_after_join(): def test_project_list_scalar(): df = pd.DataFrame({'ints': range(3)}) - con = Backend().connect({'df': df}) - expr = con.table('df') - result = expr.mutate(res=expr.ints.quantile([0.5, 0.95])).execute() - tm.assert_series_equal( - result.res, pd.Series([[1.0, 1.9] for _ in range(0, 3)], name='res') - ) + con = ibis.pandas.connect({'df': df}) + table = con.table('df') + expr = table.mutate(res=table.ints.quantile([0.5, 0.95])) + result = expr.execute() + + expected = pd.Series([[1.0, 1.9] for _ in range(0, 3)], name='res') + tm.assert_series_equal(result.res, expected) @pytest.mark.parametrize( @@ -533,7 +534,7 @@ def test_window_with_preceding_expr(index): start = 2 data = np.arange(start, start + len(time)) df = pd.DataFrame({'value': data, 'time': time}, index=index(time)) - client = Backend().connect({'df': df}) + client = ibis.pandas.connect({'df': df}) t = client.table('df') expected = ( df.set_index('time') @@ -562,7 +563,7 @@ def test_window_with_mlb(): .rename_axis('time') .reset_index(drop=False) ) - client = Backend().connect({'df': df}) + client = ibis.pandas.connect({'df': df}) t = client.table('df') rows_with_mlb = rows_with_max_lookback(5, ibis.interval(days=10)) expr = t.mutate( @@ -592,17 +593,16 @@ def test_window_with_mlb(): def test_window_has_pre_execute_scope(): - signature = ops.Lag, Backend called = [0] - @pre_execute.register(*signature) + @pre_execute.register(ops.Lag, Backend) def test_pre_execute(op, client, **kwargs): called[0] += 1 return Scope() data = {'key': list('abc'), 'value': [1, 2, 3], 'dup': list('ggh')} df = pd.DataFrame(data, columns=['key', 'value', 'dup']) - client = Backend().connect({'df': df}) + client = ibis.pandas.connect({'df': df}) t = client.table('df') window = ibis.window(order_by='value') expr = t.key.lag(1).over(window).name('foo') @@ -612,7 +612,10 @@ def test_pre_execute(op, client, **kwargs): # once in window op at the top to pickup any scope changes before computing # twice in window op when calling execute on the ops.Lag node at the # beginning of execute and once before the actual computation - assert called[0] == 3 + # + # this process happens twice because of the pre_execute call on the Alias + # operation + assert called[0] == 3 + 3 def test_window_grouping_key_has_scope(t, df): diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index eb9bef1061fa..35e5885e8f9d 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -3,6 +3,7 @@ import ibis.expr.datatypes as dt import ibis.expr.operations as ops +import ibis.expr.rules as rlz from ibis.backends.base.sql.alchemy import ( AlchemyCompiler, AlchemyExprTranslator, @@ -13,7 +14,7 @@ class PostgresUDFNode(ops.ValueOp): - pass + output_shape = rlz.shape_like("args") class PostgreSQLExprTranslator(AlchemyExprTranslator): diff --git a/ibis/backends/postgres/tests/test_functions.py b/ibis/backends/postgres/tests/test_functions.py index f0c7a30cf311..e911cc67e98e 100644 --- a/ibis/backends/postgres/tests/test_functions.py +++ b/ibis/backends/postgres/tests/test_functions.py @@ -1241,11 +1241,11 @@ def test_anti_join(t, s): def test_create_table_from_expr(con, trunc, guid2): con.create_table(guid2, expr=trunc) t = con.table(guid2) - assert list(t.name.execute()) == list('abc') + assert list(t['name'].execute()) == list('abc') def test_truncate_table(con, trunc): - assert list(trunc.name.execute()) == list('abc') + assert list(trunc['name'].execute()) == list('abc') con.truncate_table(trunc.op().name) assert not len(trunc.execute()) diff --git a/ibis/backends/postgres/udf.py b/ibis/backends/postgres/udf.py index 7b5a5e097d1b..5a9f0e83fd3d 100644 --- a/ibis/backends/postgres/udf.py +++ b/ibis/backends/postgres/udf.py @@ -100,20 +100,10 @@ def existing_udf(name, input_types, output_type, schema=None, parameters=None): v.validate_output_type(output_type) - udf_node_fields = collections.OrderedDict( - [ - (name, rlz.value(type_)) - for name, type_ in zip(parameters, input_types) - ] - + [ - ( - 'output_type', - lambda self, output_type=output_type: rlz.shape_like( - self.args, dtype=output_type - ), - ) - ] - ) + udf_node_fields = { + name: rlz.value(type_) for name, type_ in zip(parameters, input_types) + } + udf_node_fields['output_dtype'] = output_type udf_node_fields['resolve_name'] = lambda self: name udf_node = _create_udf_node(name, udf_node_fields) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 17a0abf5916f..4090711770a1 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -113,6 +113,13 @@ def _can_be_replaced_by_column_name(column_expr, table): ) +@compiles(ops.Alias) +def compile_alias(t, expr, scope, timecontext, **kwargs): + op = expr.op() + arg = t.translate(op.arg, scope, timecontext, **kwargs) + return arg.alias(op.name) + + @compiles(ops.Selection) def compile_selection(t, expr, scope, timecontext, **kwargs): op = expr.op() diff --git a/ibis/backends/pyspark/tests/test_timecontext.py b/ibis/backends/pyspark/tests/test_timecontext.py index 88def5bc2bc6..173ed90c9af7 100644 --- a/ibis/backends/pyspark/tests/test_timecontext.py +++ b/ibis/backends/pyspark/tests/test_timecontext.py @@ -110,7 +110,9 @@ def adjust_context_window_check_scope( order_by='time', group_by='key', ) - value_count_over_win = CustomWindowOp(value_count, win).to_expr() + # the argument needs to be pull out from the alias + # any extensions must do the same + value_count_over_win = CustomWindowOp(value_count.op().arg, win).to_expr() expr = table.mutate(value_count_over_win.name('value_count_over_win')) diff --git a/ibis/expr/analysis.py b/ibis/expr/analysis.py index 493c673b5b85..ac635bf52a48 100644 --- a/ibis/expr/analysis.py +++ b/ibis/expr/analysis.py @@ -89,11 +89,11 @@ def _substitute(self, expr, mapping): except IbisTypeError: return expr - try: - name = expr.get_name() - except ExpressionError: - name = None - return expr._factory(new_node, name=name) + new_expr = type(expr)(new_node) + if expr.has_name(): + new_expr = new_expr.name(expr.get_name()) + + return new_expr class ScalarAggregate: @@ -144,13 +144,12 @@ def _visit(self, expr): subbed_arg = self._visit(arg) subbed_args.append(subbed_arg) - subbed_node = type(node)(*subbed_args) - if isinstance(expr, ir.ValueExpr): - result = expr._factory(subbed_node, name=expr._name) - else: - result = expr._factory(subbed_node) + new_op = type(node)(*subbed_args) + new_expr = new_op.to_expr() + if expr.has_name(): + new_expr = new_expr.name(name=expr.get_name()) - return result + return new_expr def _key(self, expr): return repr(expr.op()) @@ -278,10 +277,10 @@ def get_result(self): return expr lifted_node = type(node)(*lifted_args) - if isinstance(expr, ir.ValueExpr): - result = expr._factory(lifted_node, name=expr._name) - else: - result = expr._factory(lifted_node) + + result = type(expr)(lifted_node) + if isinstance(expr, ir.ValueExpr) and expr.has_name(): + result = result.name(expr.get_name()) return result @@ -351,7 +350,7 @@ def _lift_TableColumn(self, expr, block=None): if can_lift and not block: lifted_node = ops.TableColumn(lifted_root, node.name) - result = expr._factory(lifted_node, name=expr._name) + result = type(expr)(lifted_node) return result @@ -532,7 +531,7 @@ def get_mutation_exprs( is_first_overwrite = True expr_contains_overwrite = False if isinstance(expr, ir.DestructColumn): - if expr.get_name(): + if expr.has_name(): raise ExpressionError( f"Cannot name a destruct column: {expr.get_name()}" ) @@ -762,9 +761,8 @@ def _windowize(x, w): walked_child = _walk(window_arg, w) if walked_child is not window_arg: - walked = x._factory( - ops.WindowOp(walked_child, window_w), name=x._name - ) + op = ops.WindowOp(walked_child, window_w) + walked = op.to_expr().name(x.get_name()) else: walked = x @@ -797,7 +795,10 @@ def _walk(x, w): if not unchanged: new_op = type(op)(*windowed_args) - return x._factory(new_op, name=x._name) + expr = new_op.to_expr() + if x.has_name(): + expr = expr.name(x.get_name()) + return expr else: return x diff --git a/ibis/expr/datatypes/core.py b/ibis/expr/datatypes/core.py index 9be0334be1a9..f1e3fdeae545 100644 --- a/ibis/expr/datatypes/core.py +++ b/ibis/expr/datatypes/core.py @@ -102,7 +102,9 @@ def __call__(self, nullable: bool = True) -> DataType: "Please construct a new instance of the type to change the " "values of the attributes." ) - return self._factory(nullable=nullable) + kwargs = dict(zip(self.argnames, self.args)) + kwargs["nullable"] = nullable + return self.__class__(**kwargs) @property def _pretty_piece(self) -> str: @@ -131,11 +133,6 @@ def equals(self, other): ) return super().__cached_equals__(other) - def _factory(self, nullable: bool = True) -> DataType: - kwargs = dict(zip(self.argnames, self.args)) - kwargs["nullable"] = nullable - return self.__class__(**kwargs) - def castable(self, target, **kwargs): """Return whether this data type is castable to `target`.""" return castable(self, target, **kwargs) @@ -144,14 +141,6 @@ def cast(self, target, **kwargs): """Cast this data type to `target`.""" return cast(self, target, **kwargs) - def scalar_type(self): - """Return a scalar expression with this data type.""" - return functools.partial(self.scalar, dtype=self) - - def column_type(self): - """Return a column expression with this data type.""" - return functools.partial(self.column, dtype=self) - @dtype.register(DataType) def from_ibis_dtype(value: DataType) -> DataType: @@ -1297,14 +1286,10 @@ def infer_floating(value: float) -> Float64: @infer.register(int) -def infer_integer(value: int, allow_overflow: bool = False) -> Integer: +def infer_integer(value: int) -> Integer: for dtype in (int8, int16, int32, int64): if dtype.bounds.lower <= value <= dtype.bounds.upper: return dtype - - if not allow_overflow: - raise OverflowError(value) - return int64 @@ -1445,7 +1430,7 @@ def can_cast_integer_to_boolean( @castable.register(Integer, Interval) def can_cast_integer_to_interval( - source: Interval, target: Interval, **kwargs + source: Integer, target: Interval, **kwargs ) -> bool: return castable(source, target.value_type) @@ -1536,14 +1521,13 @@ def can_cast_special_string(source, target, **kwargs): def cast(source: str | DataType, target: str | DataType, **kwargs) -> DataType: """Attempts to implicitly cast from source dtype to target dtype""" - source, result_target = dtype(source), dtype(target) + source, target = dtype(source), dtype(target) - if not castable(source, result_target, **kwargs): + if not castable(source, target, **kwargs): raise IbisTypeError( - 'Datatype {} cannot be implicitly ' - 'casted to {}'.format(source, result_target) + f'Datatype {source} cannot be implicitly casted to {target}' ) - return result_target + return target same_kind = Dispatcher( diff --git a/ibis/expr/format.py b/ibis/expr/format.py index abe0b7260d30..b95a42bfbd35 100644 --- a/ibis/expr/format.py +++ b/ibis/expr/format.py @@ -576,6 +576,11 @@ def _fmt_value_value_op(op: ops.ValueOp, *, aliases: Aliases) -> str: return f"{op.__class__.__name__}({', '.join(args)})" +@fmt_value.register +def _fmt_value_alias(op: ops.Alias, *, aliases: Aliases) -> str: + return fmt_value(op.arg, aliases=aliases) + + @fmt_value.register def _fmt_value_table_column(op: ops.TableColumn, *, aliases: Aliases) -> str: return f"{aliases[op.table.op()]}.{op.name}" @@ -593,6 +598,13 @@ def _fmt_value_sort_key(op: ops.SortKey, *, aliases: Aliases) -> str: return f"{sort_direction}|{expr}" +@fmt_value.register +def _fmt_topk(op: ops.TopK, *, aliases: Aliases) -> str: + arg = fmt_value(op.arg, aliases=aliases) + by = fmt_value(op.by, aliases=aliases) + return f"TopK({arg}, {op.k}, {by})" + + @fmt_value.register def _fmt_value_physical_table(op: ops.PhysicalTable, **_: Any) -> str: """Format a table as value. diff --git a/ibis/expr/lineage.py b/ibis/expr/lineage.py index bc0d4e98179c..3118af5b2e5e 100644 --- a/ibis/expr/lineage.py +++ b/ibis/expr/lineage.py @@ -75,13 +75,13 @@ def _get_args(op, name): # if Selection.selections is always columnar, could use an # OrderedDict to prevent scanning the whole thing - return [col for col in result if col._name == name] + return [col for col in result if col.get_name() == name] elif isinstance(op, ops.Aggregation): assert name is not None, 'name is None' return [ col for col in itertools.chain(op.by, op.metrics) - if col._name == name + if col.get_name() == name ] else: return op.args @@ -111,7 +111,7 @@ def lineage(expr, container=Stack): if not isinstance(expr, ir.ColumnExpr): raise TypeError('Input expression must be an instance of ColumnExpr') - c = container([(expr, expr._name)]) + c = container([(expr, expr.get_name() if expr.has_name() else None)]) seen = set() @@ -126,7 +126,7 @@ def lineage(expr, container=Stack): # add our dependencies to the container if they match our name # and are ibis expressions c.extend( - (arg, getattr(arg, '_name', name)) + (arg, arg.get_name() if arg.has_name() else name) for arg in c.visitor(_get_args(node.op(), name)) if isinstance(arg, ir.Expr) ) @@ -149,9 +149,10 @@ def traverse(fn, expr, type=ir.Expr, container=Stack): expr: ir.Expr The traversable expression or a list of expressions. type: Type - Only the instances if this type are traversed. + Only the instances if this expression type gets traversed. container: Union[Stack, Queue], default Stack - Defines the traversing order. + Defines the traversing order. Use Stack for depth-first and + Queue for breadth-first search. """ args = expr if isinstance(expr, collections.abc.Iterable) else [expr] todo = container(arg for arg in args if isinstance(arg, type)) diff --git a/ibis/expr/operations/analytic.py b/ibis/expr/operations/analytic.py index 98bde9738c6c..6a88c948eef0 100644 --- a/ibis/expr/operations/analytic.py +++ b/ibis/expr/operations/analytic.py @@ -1,5 +1,6 @@ from public import public +from ...common.validators import immutable_property from .. import datatypes as dt from .. import rules as rlz from .. import types as ir @@ -7,18 +8,13 @@ from .core import ValueOp, distinct_roots -@public -class AnalyticOp(ValueOp): - pass - - @public class WindowOp(ValueOp): expr = rlz.analytic window = rlz.window(from_base_table_of="expr") - output_type = rlz.array_like('expr') - display_argnames = False + output_dtype = rlz.dtype_like("expr") + output_shape = rlz.Shape.COLUMNAR def __init__(self, expr, window): expr = propagate_down_window(expr, window) @@ -38,12 +34,19 @@ def root_tables(self): ) +@public +class AnalyticOp(ValueOp): + output_shape = rlz.Shape.COLUMNAR + + @public class ShiftBase(AnalyticOp): arg = rlz.column(rlz.any) + offset = rlz.optional(rlz.one_of((rlz.integer, rlz.interval))) default = rlz.optional(rlz.any) - output_type = rlz.typeof('arg') + + output_dtype = rlz.dtype_like("arg") @public @@ -58,14 +61,14 @@ class Lead(ShiftBase): @public class RankBase(AnalyticOp): - def output_type(self): - return dt.int64.column_type() + output_dtype = dt.int64 @public class MinRank(RankBase): - """Compute position of first element within each equal-value group in sorted - order. + """ + Compute position of first element within each equal-value group in sorted + order. Equivalent to SQL RANK(). Examples -------- @@ -83,14 +86,14 @@ class MinRank(RankBase): The min rank """ - # Equivalent to SQL RANK() arg = rlz.column(rlz.any) @public class DenseRank(RankBase): - """Compute the position of first element within each equal-value group in - sorted order, ignoring duplicate values. + """ + Compute position of first element within each equal-value group in sorted + order, ignoring duplicate values. Equivalent to SQL DENSE_RANK(). Examples -------- @@ -108,13 +111,14 @@ class DenseRank(RankBase): The rank """ - # Equivalent to SQL DENSE_RANK() arg = rlz.column(rlz.any) @public class RowNumber(RankBase): - """Compute the row number starting from 0. + """ + Compute row number starting from 0 after sorting by column expression. + Equivalent to SQL ROW_NUMBER(). Examples -------- @@ -142,12 +146,12 @@ class CumulativeSum(CumulativeOp): arg = rlz.column(rlz.numeric) - def output_type(self): + @immutable_property + def output_dtype(self): if isinstance(self.arg, ir.BooleanValue): - dtype = dt.int64 + return dt.int64 else: - dtype = self.arg.type().largest - return dtype.column_type() + return self.arg.type().largest @public @@ -156,12 +160,12 @@ class CumulativeMean(CumulativeOp): arg = rlz.column(rlz.numeric) - def output_type(self): + @immutable_property + def output_dtype(self): if isinstance(self.arg, ir.DecimalValue): - dtype = self.arg.type().largest + return self.arg.type().largest else: - dtype = dt.float64 - return dtype.column_type() + return dt.float64 @public @@ -169,7 +173,7 @@ class CumulativeMax(CumulativeOp): """Cumulative max. Requires an order window.""" arg = rlz.column(rlz.any) - output_type = rlz.array_like('arg') + output_dtype = rlz.dtype_like("arg") @public @@ -177,39 +181,57 @@ class CumulativeMin(CumulativeOp): """Cumulative min. Requires an order window.""" arg = rlz.column(rlz.any) - output_type = rlz.array_like('arg') + output_dtype = rlz.dtype_like("arg") + + +@public +class CumulativeAny(CumulativeOp): + arg = rlz.column(rlz.boolean) + output_dtype = rlz.dtype_like("arg") + + +@public +class CumulativeAll(CumulativeOp): + arg = rlz.column(rlz.boolean) + output_dtype = rlz.dtype_like("arg") @public class PercentRank(AnalyticOp): arg = rlz.column(rlz.any) - output_type = rlz.shape_like('arg', dt.double) + output_dtype = dt.double @public class NTile(AnalyticOp): arg = rlz.column(rlz.any) - buckets = rlz.integer - output_type = rlz.shape_like('arg', dt.int64) + buckets = rlz.scalar(rlz.integer) + output_dtype = dt.int64 @public class FirstValue(AnalyticOp): + """Retrieve the first element.""" + arg = rlz.column(rlz.any) - output_type = rlz.typeof('arg') + output_dtype = rlz.dtype_like("arg") @public class LastValue(AnalyticOp): + """Retrieve the last element.""" + arg = rlz.column(rlz.any) - output_type = rlz.typeof('arg') + output_dtype = rlz.dtype_like("arg") @public class NthValue(AnalyticOp): + """Retrieve the Nth element.""" + arg = rlz.column(rlz.any) nth = rlz.integer - output_type = rlz.typeof('arg') + output_dtype = rlz.dtype_like("arg") @public @@ -219,16 +241,19 @@ class Any(ValueOp): # array-like (an existence-type predicate) or scalar (a reduction) arg = rlz.column(rlz.boolean) + output_dtype = dt.boolean + @property def _reduction(self): roots = self.arg.op().root_tables() return len(roots) < 2 - def output_type(self): + @immutable_property + def output_shape(self): if self._reduction: - return dt.boolean.scalar_type() + return rlz.Shape.SCALAR else: - return dt.boolean.column_type() + return rlz.Shape.COLUMNAR def negate(self): return NotAny(self.arg) @@ -238,15 +263,3 @@ def negate(self): class NotAny(Any): def negate(self): return Any(self.arg) - - -@public -class CumulativeAny(CumulativeOp): - arg = rlz.column(rlz.boolean) - output_type = rlz.typeof('arg') - - -@public -class CumulativeAll(CumulativeOp): - arg = rlz.column(rlz.boolean) - output_type = rlz.typeof('arg') diff --git a/ibis/expr/operations/arrays.py b/ibis/expr/operations/arrays.py index 1595f9263ae6..ffa096588715 100644 --- a/ibis/expr/operations/arrays.py +++ b/ibis/expr/operations/arrays.py @@ -1,6 +1,7 @@ from public import public from ...common import exceptions as com +from ...common.validators import immutable_property from .. import datatypes as dt from .. import rules as rlz from .core import UnaryOp, ValueOp @@ -10,6 +11,8 @@ class ArrayColumn(ValueOp): cols = rlz.value_list_of(rlz.column(rlz.any), min_length=1) + output_shape = rlz.Shape.COLUMNAR + def __init__(self, cols): if len({col.type() for col in cols}) > 1: raise com.IbisTypeError( @@ -18,15 +21,18 @@ def __init__(self, cols): ) super().__init__(cols=cols) - def output_type(self): + @immutable_property + def output_dtype(self): first_dtype = self.cols[0].type() - return dt.Array(first_dtype).column_type() + return dt.Array(first_dtype) @public class ArrayLength(UnaryOp): arg = rlz.array - output_type = rlz.shape_like('arg', dt.int64) + + output_dtype = dt.int64 + output_shape = rlz.shape_like("args") @public @@ -34,7 +40,9 @@ class ArraySlice(ValueOp): arg = rlz.array start = rlz.integer stop = rlz.optional(rlz.integer) - output_type = rlz.typeof('arg') + + output_dtype = rlz.dtype_like("arg") + output_shape = rlz.shape_like("arg") @public @@ -42,16 +50,20 @@ class ArrayIndex(ValueOp): arg = rlz.array index = rlz.integer - def output_type(self): - value_dtype = self.arg.type().value_type - return rlz.shape_like(self.arg, value_dtype) + output_shape = rlz.shape_like("args") + + @immutable_property + def output_dtype(self): + return self.arg.type().value_type @public class ArrayConcat(ValueOp): left = rlz.array right = rlz.array - output_type = rlz.shape_like('left') + + output_dtype = rlz.dtype_like("left") + output_shape = rlz.shape_like("args") def __init__(self, left, right): left_dtype, right_dtype = left.type(), right.type() @@ -69,4 +81,6 @@ def __init__(self, left, right): class ArrayRepeat(ValueOp): arg = rlz.array times = rlz.integer - output_type = rlz.typeof('arg') + + output_dtype = rlz.dtype_like("arg") + output_shape = rlz.shape_like("args") diff --git a/ibis/expr/operations/core.py b/ibis/expr/operations/core.py index cb4fa9668413..e76119784f8e 100644 --- a/ibis/expr/operations/core.py +++ b/ibis/expr/operations/core.py @@ -1,13 +1,17 @@ from __future__ import annotations +from abc import abstractmethod + import toolz from public import public from ...common.exceptions import ExpressionError from ...common.grounds import Comparable from ...common.validators import immutable_property -from ...util import is_iterable +from ...util import UnnamedMarker, is_iterable from .. import rules as rlz +from .. import types as ir +from ..rules import Shape from ..schema import Schema from ..signature import Annotable @@ -64,8 +68,6 @@ def inputs(self): @property def exprs(self): - from .. import types as ir - return [arg for arg in self.args if isinstance(arg, ir.Expr)] def blocks(self): @@ -85,12 +87,15 @@ def is_ancestor(self, other): return self.equals(other) def to_expr(self): - return self._make_expr() + return self.output_type(self) + + def resolve_name(self): + raise ExpressionError(f'Expression is not named: {type(self)}') - def _make_expr(self): - klass = self.output_type() - return klass(self) + def has_resolved_name(self): + return False + @property def output_type(self): """Resolve the output type of the expression.""" raise NotImplementedError( @@ -110,11 +115,51 @@ class ValueOp(Node): def root_tables(self): return distinct_roots(*self.exprs) - def resolve_name(self): - raise ExpressionError(f'Expression is not named: {type(self)}') + @property + @abstractmethod + def output_dtype(self): + """ + Ibis datatype of the produced value expression. + + Returns + ------- + dt.DataType + """ + + @property + @abstractmethod + def output_shape(self): + """ + Shape of the produced value expression. + + Possible values are: "scalar" and "columnar" + + Returns + ------- + rlz.Shape + """ + + @property + def output_type(self): + if self.output_shape is Shape.COLUMNAR: + return self.output_dtype.column + else: + return self.output_dtype.scalar + + +@public +class Alias(ValueOp): + arg = rlz.any + name = rlz.instance_of((str, UnnamedMarker)) + + output_shape = rlz.shape_like("arg") + output_dtype = rlz.dtype_like("arg") def has_resolved_name(self): - return False + return True + + def resolve_name(self): + return self.name @public @@ -123,6 +168,10 @@ class UnaryOp(ValueOp): arg = rlz.any + @property + def output_shape(self): + return self.arg.op().output_shape + @public class BinaryOp(ValueOp): @@ -130,3 +179,7 @@ class BinaryOp(ValueOp): left = rlz.any right = rlz.any + + @property + def output_shape(self): + return max(self.left.op().output_shape, self.right.op().output_shape) diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index b70c474dd570..c92ec8121668 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -3,20 +3,21 @@ import datetime import decimal import enum -import functools import itertools import uuid +from operator import attrgetter import numpy as np import pandas as pd from public import public from ...common import exceptions as com +from ...common.validators import immutable_property from ...util import frozendict from .. import datatypes as dt from .. import rules as rlz from .. import types as ir -from .core import UnaryOp, ValueOp, distinct_roots +from .core import Node, UnaryOp, ValueOp, distinct_roots try: import shapely @@ -33,6 +34,8 @@ class TableColumn(ValueOp): table = rlz.table name = rlz.instance_of((str, int)) + output_shape = rlz.Shape.COLUMNAR + def __init__(self, table, name): schema = table.schema() @@ -58,18 +61,17 @@ def has_resolved_name(self): def root_tables(self): return self.table.op().root_tables() - def _make_expr(self): - dtype = self.table._get_type(self.name) - klass = dtype.column_type() - return klass(self, name=self.name) + @property + def output_dtype(self): + return self.table._get_type(self.name) @public class RowID(ValueOp): """The row number (an autonumeric) of the returned result.""" - def output_type(self): - return dt.int64.column_type() + output_shape = rlz.Shape.COLUMNAR + output_dtype = dt.int64 def resolve_name(self): return 'rowid' @@ -107,30 +109,33 @@ class TableArrayView(ValueOp): table = rlz.table + output_shape = rlz.Shape.COLUMNAR + + @property + def output_dtype(self): + return self.table._get_type(self.name) + @property def name(self): return self.table.schema().names[0] - def _make_expr(self): - ctype = self.table._get_type(self.name) - klass = ctype.column_type() - return klass(self, name=self.name) - @public class Cast(ValueOp): + """Explicitly cast value to a specific data type.""" + arg = rlz.any to = rlz.datatype - # see #396 for the issue preventing an implementation of resolve_name + output_shape = rlz.shape_like("arg") + output_dtype = property(attrgetter("to")) - def output_type(self): - return rlz.shape_like(self.arg, dtype=self.to) + # see #396 for the issue preventing an implementation of resolve_name @public class TypeOf(UnaryOp): - output_type = rlz.shape_like('arg', dt.string) + output_dtype = dt.string @public @@ -143,7 +148,7 @@ class IsNull(UnaryOp): Value expression indicating whether values are null """ - output_type = rlz.shape_like('arg', dt.boolean) + output_dtype = dt.boolean @public @@ -156,17 +161,18 @@ class NotNull(UnaryOp): Value expression indicating whether values are not null """ - output_type = rlz.shape_like('arg', dt.boolean) + output_dtype = dt.boolean @public class ZeroIfNull(UnaryOp): - output_type = rlz.typeof('arg') + output_dtype = rlz.dtype_like("arg") @public class IfNull(ValueOp): - """Equivalent to (but perhaps implemented differently): + """ + Equivalent to (but perhaps implemented differently): case().when(expr.notnull(), expr) .else_(null_substitute_expr) @@ -174,7 +180,9 @@ class IfNull(ValueOp): arg = rlz.any ifnull_expr = rlz.any - output_type = rlz.shape_like('args') + + output_dtype = rlz.dtype_like("args") + output_shape = rlz.shape_like("args") @public @@ -183,7 +191,8 @@ class NullIf(ValueOp): arg = rlz.any null_if_expr = rlz.any - output_type = rlz.shape_like('args') + output_dtype = rlz.dtype_like("args") + output_shape = rlz.shape_like("args") @public @@ -195,14 +204,16 @@ class CoalesceLike(ValueOp): # DOUBLE; use CAST() when inserting into a smaller numeric column arg = rlz.value_list_of(rlz.any) - def output_type(self): + output_shape = rlz.shape_like('arg') + + @immutable_property + def output_dtype(self): # filter out null types non_null_exprs = [arg for arg in self.arg if arg.type() != dt.null] - if not non_null_exprs: - dtype = dt.null + if non_null_exprs: + return rlz.highest_precedence_dtype(non_null_exprs) else: - dtype = rlz.highest_precedence_dtype(non_null_exprs) - return rlz.shape_like(self.arg, dtype) + return dt.null @public @@ -254,8 +265,8 @@ class Literal(ValueOp): ) dtype = rlz.datatype - def output_type(self): - return self.dtype.scalar_type() + output_shape = rlz.Shape.SCALAR + output_dtype = property(attrgetter("dtype")) def root_tables(self): return [] @@ -278,15 +289,15 @@ class ScalarParameter(ValueOp): rlz.instance_of(int), default=lambda: next(ScalarParameter._counter) ) + output_shape = rlz.Shape.SCALAR + output_dtype = property(attrgetter("dtype")) + def resolve_name(self): return f'param_{self.counter:d}' def __hash__(self): return hash((self.dtype, self.counter)) - def output_type(self): - return self.dtype.scalar_type() - @property def inputs(self): return () @@ -299,12 +310,13 @@ def root_tables(self): class ValueList(ValueOp): """Data structure for a list of value expressions""" + # NOTE: this proxies the ValueOp behaviour to the underlying values + values = rlz.tuple_of(rlz.any) - display_argnames = False # disable showing argnames in repr - def output_type(self): - dtype = rlz.highest_precedence_dtype(self.values) - return functools.partial(ir.ListExpr, dtype=dtype) + output_type = ir.ListExpr + output_dtype = rlz.dtype_like("values") + output_shape = rlz.shape_like("values") def root_tables(self): return distinct_roots(*self.values) @@ -312,34 +324,27 @@ def root_tables(self): @public class Constant(ValueOp): - pass + output_shape = rlz.Shape.SCALAR @public class TimestampNow(Constant): - def output_type(self): - return dt.timestamp.scalar_type() - - -@public -class FloatConstant(Constant): - def output_type(self): - return dt.float64.scalar_type() + output_dtype = dt.timestamp @public -class RandomScalar(FloatConstant): - pass +class RandomScalar(Constant): + output_dtype = dt.float64 @public -class E(FloatConstant): - pass +class E(Constant): + output_dtype = dt.float64 @public -class Pi(FloatConstant): - pass +class Pi(Constant): + output_dtype = dt.float64 @public @@ -347,51 +352,63 @@ class StructField(ValueOp): arg = rlz.struct field = rlz.instance_of(str) - def output_type(self): + output_shape = rlz.shape_like("arg") + + @immutable_property + def output_dtype(self): struct_dtype = self.arg.type() value_dtype = struct_dtype[self.field] - return rlz.shape_like(self.arg, value_dtype) + return value_dtype + def resolve_name(self): + return self.field -@public -class DecimalUnaryOp(UnaryOp): - arg = rlz.decimal + def has_resolved_name(self): + return True @public class DecimalPrecision(UnaryOp): - output_type = rlz.shape_like('arg', dt.int32) + arg = rlz.decimal + output_dtype = dt.int32 @public -class DecimalScale(DecimalUnaryOp): - output_type = rlz.shape_like('arg', dt.int32) +class DecimalScale(UnaryOp): + arg = rlz.decimal + output_dtype = dt.int32 @public class Hash(ValueOp): arg = rlz.any how = rlz.isin({'fnv', 'farm_fingerprint'}) - output_type = rlz.shape_like('arg', dt.int64) + + output_dtype = dt.int64 + output_shape = rlz.shape_like("arg") @public class HashBytes(ValueOp): arg = rlz.one_of({rlz.value(dt.string), rlz.value(dt.binary)}) how = rlz.isin({'md5', 'sha1', 'sha256', 'sha512'}) - output_type = rlz.shape_like('arg', dt.binary) + + output_dtype = dt.binary + output_shape = rlz.shape_like("arg") @public class SummaryFilter(ValueOp): expr = rlz.instance_of(ir.TopKExpr) - def output_type(self): - return dt.boolean.column_type() + output_dtype = dt.boolean + output_shape = rlz.Shape.COLUMNAR +# TODO(kszucs): shouldn't we move this operation to either +# analytic.py or reductions.py? @public -class TopK(ValueOp): +class TopK(Node): arg = rlz.column(rlz.any) k = rlz.non_negative_integer by = rlz.one_of( @@ -401,13 +418,19 @@ class TopK(ValueOp): ) ) - def output_type(self): - return ir.TopKExpr + output_type = ir.TopKExpr def blocks(self): return True + def root_tables(self): + args = (arg for arg in self.flat_args() if isinstance(arg, ir.Expr)) + return distinct_roots(*args) + +# TODO(kszucs): we should merge the case operations by making the +# cases, results and default optional arguments like they are in +# api.py @public class SimpleCase(ValueOp): base = rlz.any @@ -415,6 +438,8 @@ class SimpleCase(ValueOp): results = rlz.value_list_of(rlz.any) default = rlz.any + output_shape = rlz.shape_like("base") + def __init__(self, cases, results, **kwargs): assert len(cases) == len(results) super().__init__(cases=cases, results=results, **kwargs) @@ -422,10 +447,13 @@ def __init__(self, cases, results, **kwargs): def root_tables(self): return distinct_roots(*self.flat_args()) - def output_type(self): + @immutable_property + def output_dtype(self): + # TODO(kszucs): we could extend the functionality of + # rlz.shape_like to support varargs with .flat_args() + # to define a subset of input arguments values = self.results + [self.default] - dtype = rlz.highest_precedence_dtype(values) - return rlz.shape_like(self.base, dtype=dtype) + return rlz.highest_precedence_dtype(values) @public @@ -434,6 +462,8 @@ class SearchedCase(ValueOp): results = rlz.value_list_of(rlz.any) default = rlz.any + output_shape = rlz.shape_like("cases") + def __init__(self, cases, results, default): assert len(cases) == len(results) super().__init__(cases=cases, results=results, default=default) @@ -441,7 +471,7 @@ def __init__(self, cases, results, default): def root_tables(self): return distinct_roots(*self.flat_args()) - def output_type(self): + @immutable_property + def output_dtype(self): exprs = self.results + [self.default] - dtype = rlz.highest_precedence_dtype(exprs) - return rlz.shape_like(self.cases, dtype) + return rlz.highest_precedence_dtype(exprs) diff --git a/ibis/expr/operations/geospatial.py b/ibis/expr/operations/geospatial.py index 0290ff5146df..14989d0c5dd2 100644 --- a/ibis/expr/operations/geospatial.py +++ b/ibis/expr/operations/geospatial.py @@ -25,14 +25,14 @@ class GeoSpatialUnOp(UnaryOp): class GeoDistance(GeoSpatialBinOp): """Returns minimum distance between two geo spatial data""" - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public class GeoContains(GeoSpatialBinOp): """Check if the first geo spatial data contains the second one""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -40,14 +40,14 @@ class GeoContainsProperly(GeoSpatialBinOp): """Check if the first geo spatial data contains the second one, and no boundary points are shared.""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public class GeoCovers(GeoSpatialBinOp): """Returns True if no point in Geometry B is outside Geometry A""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -55,7 +55,7 @@ class GeoCoveredBy(GeoSpatialBinOp): """Returns True if no point in Geometry/Geography A is outside Geometry/Geography B""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -63,7 +63,7 @@ class GeoCrosses(GeoSpatialBinOp): """Returns True if the supplied geometries have some, but not all, interior points in common.""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -71,14 +71,14 @@ class GeoDisjoint(GeoSpatialBinOp): """Returns True if the Geometries do not “spatially intersect” - if they do not share any space together.""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public class GeoEquals(GeoSpatialBinOp): """Returns True if the given geometries represent the same geometry.""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -86,15 +86,14 @@ class GeoGeometryN(GeoSpatialUnOp): """Returns the Nth Geometry of a Multi geometry.""" n = rlz.integer - - output_type = rlz.shape_like('args', dt.geometry) + output_dtype = dt.geometry @public class GeoGeometryType(GeoSpatialUnOp): """Returns the type of the geometry.""" - output_type = rlz.shape_like('args', dt.string) + output_dtype = dt.string @public @@ -103,14 +102,14 @@ class GeoIntersects(GeoSpatialBinOp): - (share any portion of space) and False if they don’t (they are Disjoint). """ - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public class GeoIsValid(GeoSpatialUnOp): """Returns true if the geometry is well-formed.""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -126,7 +125,7 @@ class GeoLineLocatePoint(GeoSpatialBinOp): left = rlz.linestring right = rlz.point - output_type = rlz.shape_like('args', dt.halffloat) + output_dtype = dt.halffloat @public @@ -140,7 +139,7 @@ class GeoLineMerge(GeoSpatialUnOp): geometry collection. """ - output_type = rlz.shape_like('args', dt.geometry) + output_dtype = dt.geometry @public @@ -158,7 +157,7 @@ class GeoLineSubstring(GeoSpatialUnOp): start = rlz.floating end = rlz.floating - output_type = rlz.shape_like('args', dt.linestring) + output_dtype = dt.linestring @public @@ -170,7 +169,7 @@ class GeoOrderingEquals(GeoSpatialBinOp): are in the same order. """ - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -178,7 +177,7 @@ class GeoOverlaps(GeoSpatialBinOp): """Returns True if the Geometries share space, are of the same dimension, but are not completely contained by each other.""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -186,45 +185,42 @@ class GeoTouches(GeoSpatialBinOp): """Returns True if the geometries have at least one point in common, but their interiors do not intersect.""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public -class GeoUnaryUnion(Reduction): +class GeoUnaryUnion(Reduction, GeoSpatialUnOp): """Returns the pointwise union of the geometries in the column.""" - arg = rlz.column(rlz.geospatial) - - def output_type(self): - return dt.geometry.scalar_type() + output_dtype = dt.geometry @public class GeoUnion(GeoSpatialBinOp): """Returns the pointwise union of the two geometries.""" - output_type = rlz.shape_like('args', dt.geometry) + output_dtype = dt.geometry @public class GeoArea(GeoSpatialUnOp): """Area of the geo spatial data""" - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public class GeoPerimeter(GeoSpatialUnOp): """Perimeter of the geo spatial data""" - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public class GeoLength(GeoSpatialUnOp): """Length of geo spatial data""" - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public @@ -235,7 +231,7 @@ class GeoMaxDistance(GeoSpatialBinOp): in that geometry """ - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public @@ -244,7 +240,7 @@ class GeoX(GeoSpatialUnOp): Input must be a point """ - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public @@ -253,35 +249,35 @@ class GeoY(GeoSpatialUnOp): Input must be a point """ - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public class GeoXMin(GeoSpatialUnOp): """Returns Y minima of a bounding box 2d or 3d or a geometry""" - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public class GeoXMax(GeoSpatialUnOp): """Returns X maxima of a bounding box 2d or 3d or a geometry""" - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public class GeoYMin(GeoSpatialUnOp): """Returns Y minima of a bounding box 2d or 3d or a geometry""" - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public class GeoYMax(GeoSpatialUnOp): """Returns Y maxima of a bounding box 2d or 3d or a geometry""" - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public @@ -290,7 +286,7 @@ class GeoStartPoint(GeoSpatialUnOp): NULL if the input parameter is not a LINESTRING """ - output_type = rlz.shape_like('arg', dt.point) + output_dtype = dt.point @public @@ -299,7 +295,7 @@ class GeoEndPoint(GeoSpatialUnOp): NULL if the input parameter is not a LINESTRING """ - output_type = rlz.shape_like('arg', dt.point) + output_dtype = dt.point @public @@ -311,7 +307,8 @@ class GeoPoint(GeoSpatialBinOp): left = rlz.numeric right = rlz.numeric - output_type = rlz.shape_like('args', dt.point) + + output_dtype = dt.point @public @@ -323,14 +320,14 @@ class GeoPointN(GeoSpatialUnOp): """ n = rlz.integer - output_type = rlz.shape_like('args', dt.point) + output_dtype = dt.point @public class GeoNPoints(GeoSpatialUnOp): """Return the number of points in a geometry. Works for all geometries""" - output_type = rlz.shape_like('args', dt.int64) + output_dtype = dt.int64 @public @@ -339,14 +336,14 @@ class GeoNRings(GeoSpatialUnOp): rings. It counts the outer rings as well """ - output_type = rlz.shape_like('args', dt.int64) + output_dtype = dt.int64 @public class GeoSRID(GeoSpatialUnOp): """Returns the spatial reference identifier for the ST_Geometry.""" - output_type = rlz.shape_like('args', dt.int64) + output_dtype = dt.int64 @public @@ -354,7 +351,8 @@ class GeoSetSRID(GeoSpatialUnOp): """Set the spatial reference identifier for the ST_Geometry.""" srid = rlz.integer - output_type = rlz.shape_like('args', dt.geometry) + + output_dtype = dt.geometry @public @@ -365,15 +363,14 @@ class GeoBuffer(GeoSpatialUnOp): """ radius = rlz.floating - - output_type = rlz.shape_like('args', dt.geometry) + output_dtype = dt.geometry @public class GeoCentroid(GeoSpatialUnOp): """Returns the geometric center of a geometry.""" - output_type = rlz.shape_like('arg', dt.point) + output_dtype = dt.point @public @@ -384,7 +381,7 @@ class GeoDFullyWithin(GeoSpatialBinOp): distance = rlz.floating - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -395,14 +392,14 @@ class GeoDWithin(GeoSpatialBinOp): distance = rlz.floating - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public class GeoEnvelope(GeoSpatialUnOp): """Represents the bounding box of the supplied geometry.""" - output_type = rlz.shape_like('arg', dt.polygon) + output_dtype = dt.polygon @public @@ -415,14 +412,14 @@ class GeoAzimuth(GeoSpatialBinOp): left = rlz.point right = rlz.point - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public class GeoWithin(GeoSpatialBinOp): """Returns True if the geometry A is completely inside geometry B""" - output_type = rlz.shape_like('args', dt.boolean) + output_dtype = dt.boolean @public @@ -431,7 +428,7 @@ class GeoIntersection(GeoSpatialBinOp): of the Geometries. """ - output_type = rlz.shape_like('args', dt.geometry) + output_dtype = dt.geometry @public @@ -440,7 +437,7 @@ class GeoDifference(GeoSpatialBinOp): that does not intersect with geometry B """ - output_type = rlz.shape_like('args', dt.geometry) + output_dtype = dt.geometry @public @@ -450,7 +447,7 @@ class GeoSimplify(GeoSpatialUnOp): tolerance = rlz.floating preserve_collapsed = rlz.boolean - output_type = rlz.shape_like('arg', dt.geometry) + output_dtype = dt.geometry @public @@ -459,7 +456,7 @@ class GeoTransform(GeoSpatialUnOp): srid = rlz.integer - output_type = rlz.shape_like('arg', dt.geometry) + output_dtype = dt.geometry @public @@ -468,7 +465,7 @@ class GeoAsBinary(GeoSpatialUnOp): geometry/geography without SRID meta data. """ - output_type = rlz.shape_like('arg', dt.binary) + output_dtype = dt.binary @public @@ -477,7 +474,7 @@ class GeoAsEWKB(GeoSpatialUnOp): geometry/geography with SRID meta data. """ - output_type = rlz.shape_like('arg', dt.binary) + output_dtype = dt.binary @public @@ -486,7 +483,7 @@ class GeoAsEWKT(GeoSpatialUnOp): geometry/geography with SRID meta data. """ - output_type = rlz.shape_like('arg', dt.string) + output_dtype = dt.string @public @@ -495,4 +492,4 @@ class GeoAsText(GeoSpatialUnOp): geometry/geography without SRID metadata. """ - output_type = rlz.shape_like('arg', dt.string) + output_dtype = dt.string diff --git a/ibis/expr/operations/histograms.py b/ibis/expr/operations/histograms.py index 64b1e21e6783..0c32c165ea34 100644 --- a/ibis/expr/operations/histograms.py +++ b/ibis/expr/operations/histograms.py @@ -7,13 +7,15 @@ @public class BucketLike(ValueOp): + output_shape = rlz.Shape.COLUMNAR + @property def nbuckets(self): return None - def output_type(self): - dtype = dt.Category(self.nbuckets) - return dtype.column_type() + @property + def output_dtype(self): + return dt.Category(self.nbuckets) @public @@ -63,9 +65,10 @@ def __init__(self, nbins, binwidth, **kwargs): raise ValueError('nbins and binwidth are mutually exclusive') super().__init__(nbins=nbins, binwidth=binwidth, **kwargs) - def output_type(self): + @property + def output_dtype(self): # always undefined cardinality (for now) - return dt.category.column_type() + return dt.category @public @@ -73,7 +76,9 @@ class CategoryLabel(ValueOp): arg = rlz.category labels = rlz.tuple_of(rlz.instance_of(str)) nulls = rlz.optional(rlz.instance_of(str)) - output_type = rlz.shape_like('arg', dt.string) + + output_dtype = dt.string + output_shape = rlz.shape_like("arg") def __init__(self, arg, labels, **kwargs): cardinality = arg.type().cardinality diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index ec1059c310a0..ccbd4f6d3373 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -1,8 +1,5 @@ -from contextlib import suppress - from public import public -from ...common import exceptions as com from .. import datatypes as dt from .. import rules as rlz from .core import BinaryOp, UnaryOp, ValueOp @@ -12,13 +9,15 @@ class LogicalBinaryOp(BinaryOp): left = rlz.boolean right = rlz.boolean - output_type = rlz.shape_like('args', dt.boolean) + + output_dtype = dt.boolean @public class Not(UnaryOp): arg = rlz.boolean - output_type = rlz.shape_like('arg', dt.boolean) + + output_dtype = dt.boolean @public @@ -41,37 +40,23 @@ class Comparison(BinaryOp): left = rlz.any right = rlz.any + output_dtype = dt.boolean + def __init__(self, left, right): """ Casting rules for type promotions (for resolving the output type) may depend in some cases on the target backend. - TODO: how will overflows be handled? Can we provide anything useful in Ibis to help the user avoid them? - :param left: :param right: """ - left, right = self._maybe_cast_args(left, right) - super().__init__(left=left, right=right) - - def _maybe_cast_args(self, left, right): - # it might not be necessary? - with suppress(com.IbisTypeError): - return left, rlz.cast(right, left) - - with suppress(com.IbisTypeError): - return rlz.cast(left, right), right - - return left, right - - def output_type(self): - if not rlz.comparable(self.left, self.right): + if not rlz.comparable(left, right): raise TypeError( 'Arguments with datatype {} and {} are ' - 'not comparable'.format(self.left.type(), self.right.type()) + 'not comparable'.format(left.type(), right.type()) ) - return rlz.shape_like(self.args, dt.boolean) + super().__init__(left=left, right=right) @public @@ -115,13 +100,23 @@ class Between(ValueOp): lower_bound = rlz.any upper_bound = rlz.any - def output_type(self): - arg, lower, upper = self.args - - if not (rlz.comparable(arg, lower) and rlz.comparable(arg, upper)): - raise TypeError('Arguments are not comparable') + output_dtype = dt.boolean + output_shape = rlz.shape_like("args") - return rlz.shape_like(self.args, dt.boolean) + def __init__(self, arg, lower_bound, upper_bound): + if not rlz.comparable(arg, lower_bound): + raise TypeError( + f'Argument with datatype {arg.type()} and lower bound ' + f'with datatype {lower_bound.type()} are not comparable' + ) + if not rlz.comparable(arg, upper_bound): + raise TypeError( + f'Argument with datatype {arg.type()} and upper bound ' + f'with datatype {upper_bound.type()} are not comparable' + ) + super().__init__( + arg=arg, lower_bound=lower_bound, upper_bound=upper_bound + ) @public @@ -136,8 +131,8 @@ class Contains(ValueOp): ] ) - def output_type(self): - return rlz.shape_like(self.flat_args(), dt.boolean) + output_dtype = dt.boolean + output_shape = rlz.shape_like("args") @public @@ -160,5 +155,5 @@ class Where(ValueOp): true_expr = rlz.any false_null_expr = rlz.any - def output_type(self): - return rlz.shape_like(self.bool_expr, self.true_expr.type()) + output_dtype = rlz.dtype_like("true_expr") + output_shape = rlz.shape_like("bool_expr") diff --git a/ibis/expr/operations/maps.py b/ibis/expr/operations/maps.py index cd7d1be94fdd..e57bb79db25c 100644 --- a/ibis/expr/operations/maps.py +++ b/ibis/expr/operations/maps.py @@ -1,15 +1,16 @@ from public import public from ...common import exceptions as com +from ...common.validators import immutable_property from .. import datatypes as dt from .. import rules as rlz -from .core import ValueOp +from .core import UnaryOp, ValueOp @public -class MapLength(ValueOp): +class MapLength(UnaryOp): arg = rlz.mapping - output_type = rlz.shape_like('arg', dt.int64) + output_dtype = dt.int64 @public @@ -17,8 +18,11 @@ class MapValueForKey(ValueOp): arg = rlz.mapping key = rlz.one_of([rlz.string, rlz.integer]) - def output_type(self): - return rlz.shape_like(tuple(self.args), self.arg.type().value_type) + output_shape = rlz.shape_like("args") + + @immutable_property + def output_dtype(self): + return self.arg.type().value_type @public @@ -27,43 +31,44 @@ class MapValueOrDefaultForKey(ValueOp): key = rlz.one_of([rlz.string, rlz.integer]) default = rlz.any - def output_type(self): - arg = self.arg - default = self.default - map_type = arg.type() - value_type = map_type.value_type - default_type = default.type() + output_shape = rlz.shape_like("args") + + @property + def output_dtype(self): + value_type = self.arg.type().value_type + default_type = self.default.type() - if default is not None and not dt.same_kind(default_type, value_type): + if not dt.same_kind(default_type, value_type): raise com.IbisTypeError( "Default value\n{}\nof type {} cannot be cast to map's value " - "type {}".format(default, default_type, value_type) + "type {}".format(self.default, default_type, value_type) ) - result_type = dt.highest_precedence((default_type, value_type)) - return rlz.shape_like(tuple(self.args), result_type) + return dt.highest_precedence((default_type, value_type)) @public -class MapKeys(ValueOp): +class MapKeys(UnaryOp): arg = rlz.mapping - def output_type(self): - arg = self.arg - return rlz.shape_like(arg, dt.Array(arg.type().key_type)) + @immutable_property + def output_dtype(self): + return dt.Array(self.arg.type().key_type) @public -class MapValues(ValueOp): +class MapValues(UnaryOp): arg = rlz.mapping - def output_type(self): - arg = self.arg - return rlz.shape_like(arg, dt.Array(arg.type().value_type)) + @immutable_property + def output_dtype(self): + return dt.Array(self.arg.type().value_type) @public class MapConcat(ValueOp): left = rlz.mapping right = rlz.mapping - output_type = rlz.typeof('left') + + output_shape = rlz.shape_like("args") + output_dtype = rlz.dtype_like("args") diff --git a/ibis/expr/operations/numeric.py b/ibis/expr/operations/numeric.py index f36fe66966d3..d32bf37c689e 100644 --- a/ibis/expr/operations/numeric.py +++ b/ibis/expr/operations/numeric.py @@ -3,6 +3,7 @@ from public import public from ... import util +from ...common.validators import immutable_property from .. import datatypes as dt from .. import rules as rlz from .. import types as ir @@ -17,51 +18,53 @@ class NumericBinaryOp(BinaryOp): @public class Add(NumericBinaryOp): - output_type = rlz.numeric_like('args', operator.add) + output_dtype = rlz.numeric_like("args", operator.add) @public class Multiply(NumericBinaryOp): - output_type = rlz.numeric_like('args', operator.mul) + output_dtype = rlz.numeric_like("args", operator.mul) @public class Power(NumericBinaryOp): - def output_type(self): + @property + def output_dtype(self): if util.all_of(self.args, ir.IntegerValue): - return rlz.shape_like(self.args, dt.float64) + return dt.float64 else: - return rlz.shape_like(self.args) + return rlz.highest_precedence_dtype(self.args) @public class Subtract(NumericBinaryOp): - output_type = rlz.numeric_like('args', operator.sub) + output_dtype = rlz.numeric_like("args", operator.sub) @public class Divide(NumericBinaryOp): - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public class FloorDivide(Divide): - output_type = rlz.shape_like('args', dt.int64) + output_dtype = dt.int64 @public class Modulus(NumericBinaryOp): - output_type = rlz.numeric_like('args', operator.mod) + output_dtype = rlz.numeric_like("args", operator.mod) @public class Negate(UnaryOp): arg = rlz.one_of((rlz.numeric, rlz.interval)) - output_type = rlz.typeof('arg') + + output_dtype = rlz.dtype_like("arg") @public -class NullIfZero(ValueOp): +class NullIfZero(UnaryOp): """Set values to NULL if they are equal to zero. Commonly used in cases where divide-by-zero would produce an overflow or @@ -80,31 +83,31 @@ class NullIfZero(ValueOp): """ arg = rlz.numeric - output_type = rlz.typeof('arg') + output_dtype = rlz.dtype_like("arg") @public -class IsNan(ValueOp): +class IsNan(UnaryOp): arg = rlz.floating - output_type = rlz.shape_like('arg', dt.boolean) + output_dtype = dt.boolean @public -class IsInf(ValueOp): +class IsInf(UnaryOp): arg = rlz.floating - output_type = rlz.shape_like('arg', dt.boolean) + output_dtype = dt.boolean @public class Abs(UnaryOp): """Absolute value""" - output_type = rlz.typeof('arg') + arg = rlz.numeric + output_dtype = rlz.dtype_like("arg") @public class Ceil(UnaryOp): - """ Round up to the nearest integer value greater than or equal to this value @@ -117,15 +120,16 @@ class Ceil(UnaryOp): arg = rlz.numeric - def output_type(self): + @property + def output_dtype(self): if isinstance(self.arg.type(), dt.Decimal): - return self.arg._factory - return rlz.shape_like(self.arg, dt.int64) + return self.arg.type() + else: + return dt.int64 @public class Floor(UnaryOp): - """ Round down to the nearest integer value less than or equal to this value @@ -138,10 +142,12 @@ class Floor(UnaryOp): arg = rlz.numeric - def output_type(self): + @property + def output_dtype(self): if isinstance(self.arg.type(), dt.Decimal): - return self.arg._factory - return rlz.shape_like(self.arg, dt.int64) + return self.arg.type() + else: + return dt.int64 @public @@ -149,13 +155,16 @@ class Round(ValueOp): arg = rlz.numeric digits = rlz.optional(rlz.numeric) - def output_type(self): - if isinstance(self.arg, ir.DecimalValue): - return self.arg._factory + output_shape = rlz.shape_like("arg") + + @property + def output_dtype(self): + if isinstance(self.arg.type(), dt.Decimal): + return self.arg.type() elif self.digits is None: - return rlz.shape_like(self.arg, dt.int64) + return dt.int64 else: - return rlz.shape_like(self.arg, dt.double) + return dt.double @public @@ -163,7 +172,9 @@ class Clip(ValueOp): arg = rlz.strict_numeric lower = rlz.optional(rlz.strict_numeric) upper = rlz.optional(rlz.strict_numeric) - output_type = rlz.typeof('arg') + + output_dtype = rlz.dtype_like("arg") + output_shape = rlz.shape_like("args") @public @@ -172,41 +183,41 @@ class BaseConvert(ValueOp): from_base = rlz.integer to_base = rlz.integer - def output_type(self): - return rlz.shape_like(tuple(self.flat_args()), dt.string) + output_dtype = dt.string + output_shape = rlz.shape_like("args") @public class MathUnaryOp(UnaryOp): arg = rlz.numeric - def output_type(self): - arg = self.arg - if isinstance(self.arg, ir.DecimalValue): - dtype = arg.type() + @immutable_property + def output_dtype(self): + if isinstance(self.arg.type(), dt.Decimal): + return self.arg.type() else: - dtype = dt.double - return rlz.shape_like(arg, dtype) + return dt.double @public -class ExpandingTypeMathUnaryOp(MathUnaryOp): - def output_type(self): - if not isinstance(self.arg, ir.DecimalValue): - return super().output_type() - arg = self.arg - return rlz.shape_like(arg, arg.type().largest) +class ExpandingMathUnaryOp(MathUnaryOp): + @immutable_property + def output_dtype(self): + if isinstance(self.arg.type(), dt.Decimal): + return self.arg.type().largest + else: + return dt.double @public -class Exp(ExpandingTypeMathUnaryOp): +class Exp(ExpandingMathUnaryOp): pass @public class Sign(UnaryOp): arg = rlz.numeric - output_type = rlz.typeof('arg') + output_dtype = rlz.dtype_like("arg") @public @@ -241,18 +252,14 @@ class Log10(Logarithm): @public -class Degrees(ExpandingTypeMathUnaryOp): +class Degrees(ExpandingMathUnaryOp): """Converts radians to degrees""" - arg = rlz.numeric - @public class Radians(MathUnaryOp): """Converts degrees to radians""" - arg = rlz.numeric - # TRIGONOMETRIC OPERATIONS @@ -261,8 +268,6 @@ class Radians(MathUnaryOp): class TrigonometricUnary(MathUnaryOp): """Trigonometric base unary""" - arg = rlz.numeric - @public class TrigonometricBinary(BinaryOp): @@ -270,7 +275,7 @@ class TrigonometricBinary(BinaryOp): left = rlz.numeric right = rlz.numeric - output_type = rlz.shape_like('args', dt.float64) + output_dtype = dt.float64 @public diff --git a/ibis/expr/operations/reductions.py b/ibis/expr/operations/reductions.py index c18c01841df0..b46f6beeda96 100644 --- a/ibis/expr/operations/reductions.py +++ b/ibis/expr/operations/reductions.py @@ -1,17 +1,18 @@ -import functools - from public import public +from ...common.validators import immutable_property from .. import datatypes as dt from .. import rules as rlz from .. import types as ir -from .core import ValueOp +from .core import ValueOp, distinct_roots @public class Reduction(ValueOp): _reduction = True + output_shape = rlz.Shape.SCALAR + class Filterable(ValueOp): where = rlz.optional(rlz.boolean) @@ -20,16 +21,14 @@ class Filterable(ValueOp): @public class Count(Filterable, Reduction): arg = rlz.one_of((rlz.column(rlz.any), rlz.table)) - - def output_type(self): - return functools.partial(ir.IntegerScalar, dtype=dt.int64) + output_dtype = dt.int64 @public class Arbitrary(Filterable, Reduction): arg = rlz.column(rlz.any) how = rlz.optional(rlz.isin({'first', 'last', 'heavy'})) - output_type = rlz.scalar_like('arg') + output_dtype = rlz.dtype_like('arg') @public @@ -47,7 +46,7 @@ class BitAnd(Filterable, Reduction): """ # noqa: E501 arg = rlz.column(rlz.integer) - output_type = rlz.scalar_like('arg') + output_dtype = rlz.dtype_like('arg') @public @@ -64,7 +63,7 @@ class BitOr(Filterable, Reduction): """ # noqa: E501 arg = rlz.column(rlz.integer) - output_type = rlz.scalar_like('arg') + output_dtype = rlz.dtype_like('arg') @public @@ -81,31 +80,34 @@ class BitXor(Filterable, Reduction): """ # noqa: E501 arg = rlz.column(rlz.integer) - output_type = rlz.scalar_like('arg') + output_dtype = rlz.dtype_like('arg') @public class Sum(Filterable, Reduction): arg = rlz.column(rlz.numeric) - def output_type(self): + @immutable_property + def output_dtype(self): if isinstance(self.arg, ir.BooleanValue): - dtype = dt.int64 + return dt.int64 else: - dtype = self.arg.type().largest - return dtype.scalar_type() + return self.arg.type().largest @public class Mean(Filterable, Reduction): arg = rlz.column(rlz.numeric) - def output_type(self): + @immutable_property + def output_dtype(self): if isinstance(self.arg, ir.DecimalValue): - dtype = self.arg.type() + return self.arg.type() else: - dtype = dt.float64 - return dtype.scalar_type() + return dt.float64 + + def root_tables(self): + return distinct_roots(self.arg) @public @@ -116,8 +118,7 @@ class Quantile(Reduction): {'linear', 'lower', 'higher', 'midpoint', 'nearest'} ) - def output_type(self): - return dt.float64.scalar_type() + output_dtype = dt.float64 @public @@ -128,8 +129,7 @@ class MultiQuantile(Quantile): {'linear', 'lower', 'higher', 'midpoint', 'nearest'} ) - def output_type(self): - return dt.Array(dt.float64).scalar_type() + output_dtype = dt.Array(dt.float64) @public @@ -137,12 +137,12 @@ class VarianceBase(Filterable, Reduction): arg = rlz.column(rlz.numeric) how = rlz.isin({'sample', 'pop'}) - def output_type(self): + @immutable_property + def output_dtype(self): if isinstance(self.arg, ir.DecimalValue): - dtype = self.arg.type().largest + return self.arg.type().largest else: - dtype = dt.float64 - return dtype.scalar_type() + return dt.float64 @public @@ -163,8 +163,7 @@ class Correlation(Filterable, Reduction): right = rlz.column(rlz.numeric) how = rlz.isin({'sample', 'pop'}) - def output_type(self): - return dt.float64.scalar_type() + output_dtype = dt.float64 @public @@ -175,20 +174,19 @@ class Covariance(Filterable, Reduction): right = rlz.column(rlz.numeric) how = rlz.isin({'sample', 'pop'}) - def output_type(self): - return dt.float64.scalar_type() + output_dtype = dt.float64 @public class Max(Filterable, Reduction): arg = rlz.column(rlz.any) - output_type = rlz.scalar_like('arg') + output_dtype = rlz.dtype_like('arg') @public class Min(Filterable, Reduction): arg = rlz.column(rlz.any) - output_type = rlz.scalar_like('arg') + output_dtype = rlz.dtype_like('arg') @public @@ -200,10 +198,8 @@ class HLLCardinality(Filterable, Reduction): arg = rlz.column(rlz.any) - def output_type(self): - # Impala 2.0 and higher returns a DOUBLE - # return ir.DoubleScalar - return functools.partial(ir.IntegerScalar, dtype=dt.int64) + # Impala 2.0 and higher returns a DOUBLE return ir.DoubleScalar + output_dtype = dt.int64 @public @@ -211,8 +207,7 @@ class GroupConcat(Filterable, Reduction): arg = rlz.column(rlz.any) sep = rlz.string - def output_type(self): - return dt.string.scalar_type() + output_dtype = dt.string @public @@ -223,13 +218,13 @@ class CMSMedian(Filterable, Reduction): """ arg = rlz.column(rlz.any) - output_type = rlz.scalar_like('arg') + output_dtype = rlz.dtype_like('arg') @public class All(Reduction): arg = rlz.column(rlz.boolean) - output_type = rlz.scalar_like('arg') + output_dtype = dt.boolean def negate(self): return NotAll(self.arg) @@ -245,14 +240,13 @@ def negate(self): class CountDistinct(Filterable, Reduction): arg = rlz.column(rlz.any) - def output_type(self): - return dt.int64.scalar_type() + output_dtype = dt.int64 @public class ArrayCollect(Reduction): arg = rlz.column(rlz.any) - def output_type(self): - dtype = dt.Array(self.arg.type()) - return dtype.scalar_type() + @immutable_property + def output_dtype(self): + return dt.Array(self.arg.type()) diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 62f6bf24429e..259ba2c67b73 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -8,10 +8,11 @@ from ... import util from ...common import exceptions as com +from .. import datatypes as dt from .. import rules as rlz from .. import schema as sch from .. import types as ir -from .core import Node, distinct_roots +from .core import Node, ValueOp, distinct_roots from .sortkeys import _maybe_convert_sort_keys _table_names = (f'unbound_table_{i:d}' for i in itertools.count()) @@ -24,12 +25,11 @@ def genname(): @public class TableNode(Node): + output_type = ir.TableExpr + def get_type(self, name): return self.schema[name] - def output_type(self): - return ir.TableExpr - def aggregate(self, this, metrics, by=None, having=None): return Aggregation(this, metrics, by=by, having=having) @@ -71,6 +71,12 @@ class UnboundTable(PhysicalTable): schema = rlz.instance_of(sch.Schema) name = rlz.optional(rlz.instance_of(str), default=genname) + def has_resolved_name(self): + return True + + def resolve_name(self): + return self.name + @public class DatabaseTable(PhysicalTable): @@ -746,12 +752,13 @@ def blocks(self): @public -class ExistsSubquery(Node): +class ExistsSubquery(ValueOp): foreign_table = rlz.table predicates = rlz.tuple_of(rlz.boolean) - def output_type(self): - return ir.ExistsExpr + output_dtype = dt.boolean + output_shape = rlz.Shape.COLUMNAR + output_type = ir.ExistsExpr @public @@ -759,8 +766,7 @@ class NotExistsSubquery(Node): foreign_table = rlz.table predicates = rlz.tuple_of(rlz.boolean) - def output_type(self): - return ir.ExistsExpr + output_type = ir.ExistsExpr @public diff --git a/ibis/expr/operations/sortkeys.py b/ibis/expr/operations/sortkeys.py index c46d979a50b2..3f196f36db74 100644 --- a/ibis/expr/operations/sortkeys.py +++ b/ibis/expr/operations/sortkeys.py @@ -70,8 +70,7 @@ class SortKey(Node): default=True, ) - def output_type(self): - return ir.SortExpr + output_type = ir.SortExpr def root_tables(self): return self.expr.op().root_tables() diff --git a/ibis/expr/operations/strings.py b/ibis/expr/operations/strings.py index 018824a1aea8..ae5f3bc8d2ba 100644 --- a/ibis/expr/operations/strings.py +++ b/ibis/expr/operations/strings.py @@ -8,7 +8,7 @@ @public class StringUnaryOp(UnaryOp): arg = rlz.string - output_type = rlz.shape_like('arg', dt.string) + output_dtype = dt.string @public @@ -51,21 +51,25 @@ class Substring(ValueOp): arg = rlz.string start = rlz.integer length = rlz.optional(rlz.integer) - output_type = rlz.shape_like('arg', dt.string) + + output_dtype = dt.string + output_shape = rlz.shape_like('arg') @public class StrRight(ValueOp): arg = rlz.string nchars = rlz.integer - output_type = rlz.shape_like('arg', dt.string) + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public class Repeat(ValueOp): arg = rlz.string times = rlz.integer - output_type = rlz.shape_like('arg', dt.string) + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public @@ -74,7 +78,9 @@ class StringFind(ValueOp): substr = rlz.string start = rlz.optional(rlz.integer) end = rlz.optional(rlz.integer) - output_type = rlz.shape_like('arg', dt.int64) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.int64 @public @@ -82,7 +88,9 @@ class Translate(ValueOp): arg = rlz.string from_str = rlz.string to_str = rlz.string - output_type = rlz.shape_like('arg', dt.string) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public @@ -90,7 +98,9 @@ class LPad(ValueOp): arg = rlz.string length = rlz.integer pad = rlz.optional(rlz.string) - output_type = rlz.shape_like('arg', dt.string) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public @@ -98,14 +108,18 @@ class RPad(ValueOp): arg = rlz.string length = rlz.integer pad = rlz.optional(rlz.string) - output_type = rlz.shape_like('arg', dt.string) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public class FindInSet(ValueOp): needle = rlz.string values = rlz.value_list_of(rlz.string, min_length=1) - output_type = rlz.shape_like('needle', dt.int64) + + output_shape = rlz.shape_like("needle") + output_dtype = dt.int64 @public @@ -113,29 +127,32 @@ class StringJoin(ValueOp): sep = rlz.string arg = rlz.value_list_of(rlz.string, min_length=1) - def output_type(self): - return rlz.shape_like(tuple(self.flat_args()), dt.string) + output_dtype = dt.string + output_shape = rlz.shape_like("arg") @public class StartsWith(ValueOp): arg = rlz.string - start = rlz.string - output_type = rlz.shape_like("arg", dt.boolean) + start = rlz.scalar(rlz.string) + output_dtype = dt.boolean + output_shape = rlz.shape_like("arg") @public class EndsWith(ValueOp): arg = rlz.string - end = rlz.string - output_type = rlz.shape_like("arg", dt.boolean) + end = rlz.scalar(rlz.string) + output_dtype = dt.boolean + output_shape = rlz.shape_like("arg") @public class FuzzySearch(ValueOp): arg = rlz.string pattern = rlz.string - output_type = rlz.shape_like('arg', dt.boolean) + output_dtype = dt.boolean + output_shape = rlz.shape_like('arg') @public @@ -160,7 +177,9 @@ class RegexExtract(ValueOp): arg = rlz.string pattern = rlz.string index = rlz.integer - output_type = rlz.shape_like('arg', dt.string) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public @@ -168,7 +187,9 @@ class RegexReplace(ValueOp): arg = rlz.string pattern = rlz.string replacement = rlz.string - output_type = rlz.shape_like('arg', dt.string) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public @@ -176,20 +197,26 @@ class StringReplace(ValueOp): arg = rlz.string pattern = rlz.string replacement = rlz.string - output_type = rlz.shape_like('arg', dt.string) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public class StringSplit(ValueOp): arg = rlz.string delimiter = rlz.string - output_type = rlz.shape_like('arg', dt.Array(dt.string)) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.Array(dt.string) @public class StringConcat(ValueOp): arg = rlz.value_list_of(rlz.string) - output_type = rlz.shape_like('arg', dt.string) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public @@ -208,14 +235,16 @@ class ParseURL(ValueOp): } ) key = rlz.optional(rlz.string) - output_type = rlz.shape_like('arg', dt.string) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public class StringLength(UnaryOp): - output_type = rlz.shape_like('arg', dt.int32) + output_dtype = dt.int32 @public class StringAscii(UnaryOp): - output_type = rlz.shape_like('arg', dt.int32) + output_dtype = dt.int32 diff --git a/ibis/expr/operations/temporal.py b/ibis/expr/operations/temporal.py index e6eee573ed97..650408ce9654 100644 --- a/ibis/expr/operations/temporal.py +++ b/ibis/expr/operations/temporal.py @@ -3,6 +3,8 @@ import toolz from public import public +from ... import util +from ...common.validators import immutable_property from .. import datatypes as dt from .. import rules as rlz from .. import types as ir @@ -78,28 +80,36 @@ class TimestampUnaryOp(UnaryOp): class TimestampTruncate(ValueOp): arg = rlz.timestamp unit = rlz.isin(_timestamp_units) - output_type = rlz.shape_like('arg', dt.timestamp) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.timestamp @public class DateTruncate(ValueOp): arg = rlz.date unit = rlz.isin(_date_units) - output_type = rlz.shape_like('arg', dt.date) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.date @public class TimeTruncate(ValueOp): arg = rlz.time unit = rlz.isin(_time_units) - output_type = rlz.shape_like('arg', dt.time) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.time @public class Strftime(ValueOp): arg = rlz.temporal format_str = rlz.string - output_type = rlz.shape_like('arg', dt.string) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.string @public @@ -107,12 +117,14 @@ class StringToTimestamp(ValueOp): arg = rlz.string format_str = rlz.string timezone = rlz.optional(rlz.string) - output_type = rlz.shape_like('arg', dt.Timestamp(timezone='UTC')) + + output_shape = rlz.shape_like("arg") + output_dtype = dt.Timestamp(timezone='UTC') @public class ExtractTemporalField(TemporalUnaryOp): - output_type = rlz.shape_like('arg', dt.int32) + output_dtype = dt.int32 ExtractTimestampField = ExtractTemporalField @@ -186,31 +198,29 @@ class ExtractMillisecond(ExtractTimeField): @public class DayOfWeekIndex(UnaryOp): arg = rlz.one_of([rlz.date, rlz.timestamp]) - output_type = rlz.shape_like('arg', dt.int16) + output_dtype = dt.int16 @public class DayOfWeekName(UnaryOp): arg = rlz.one_of([rlz.date, rlz.timestamp]) - output_type = rlz.shape_like('arg', dt.string) + output_dtype = dt.string @public class DayOfWeekNode(Node): arg = rlz.one_of([rlz.date, rlz.timestamp]) - - def output_type(self): - return ir.DayOfWeek + output_type = ir.DayOfWeek @public class Time(UnaryOp): - output_type = rlz.shape_like('arg', dt.time) + output_dtype = dt.time @public class Date(UnaryOp): - output_type = rlz.shape_like('arg', dt.date) + output_dtype = dt.date @public @@ -218,7 +228,9 @@ class DateFromYMD(ValueOp): year = rlz.integer month = rlz.integer day = rlz.integer - output_type = rlz.shape_like('args', dt.date) + + output_dtype = dt.date + output_shape = rlz.shape_like("args") @public @@ -226,7 +238,9 @@ class TimeFromHMS(ValueOp): hours = rlz.integer minutes = rlz.integer seconds = rlz.integer - output_type = rlz.shape_like('args', dt.time) + + output_dtype = dt.time + output_shape = rlz.shape_like("args") @public @@ -238,7 +252,9 @@ class TimestampFromYMDHMS(ValueOp): minutes = rlz.integer seconds = rlz.integer timezone = rlz.optional(rlz.string) - output_type = rlz.shape_like('args', dt.timestamp) + + output_dtype = dt.timestamp + output_shape = rlz.shape_like("args") @public @@ -246,49 +262,52 @@ class TimestampFromUNIX(ValueOp): arg = rlz.any # Only pandas-based backends support 'ns' unit = rlz.isin({'s', 'ms', 'us', 'ns'}) - output_type = rlz.shape_like('arg', dt.timestamp) + output_shape = rlz.shape_like('arg') + + output_dtype = dt.timestamp + output_shape = rlz.shape_like("args") @public class DateAdd(BinaryOp): left = rlz.date right = rlz.interval(units={'Y', 'Q', 'M', 'W', 'D'}) - output_type = rlz.shape_like('left') + output_dtype = rlz.dtype_like('left') @public class DateSub(BinaryOp): left = rlz.date right = rlz.interval(units={'Y', 'Q', 'M', 'W', 'D'}) - output_type = rlz.shape_like('left') + output_dtype = rlz.dtype_like('left') @public class DateDiff(BinaryOp): left = rlz.date right = rlz.date - output_type = rlz.shape_like('left', dt.Interval('D')) + output_dtype = dt.Interval('D') @public class TimeAdd(BinaryOp): left = rlz.time right = rlz.interval(units={'h', 'm', 's', 'ms', 'us', 'ns'}) - output_type = rlz.shape_like('left') + output_dtype = rlz.dtype_like('left') @public class TimeSub(BinaryOp): left = rlz.time right = rlz.interval(units={'h', 'm', 's', 'ms', 'us', 'ns'}) - output_type = rlz.shape_like('left') + output_dtype = rlz.dtype_like('left') @public class TimeDiff(BinaryOp): left = rlz.time right = rlz.time - output_type = rlz.shape_like('left', dt.Interval('s')) + output_dtype = dt.Interval('s') @public @@ -297,7 +316,7 @@ class TimestampAdd(BinaryOp): right = rlz.interval( units={'Y', 'Q', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns'} ) - output_type = rlz.shape_like('left') + output_dtype = rlz.dtype_like('left') @public @@ -306,35 +325,56 @@ class TimestampSub(BinaryOp): right = rlz.interval( units={'Y', 'Q', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns'} ) - output_type = rlz.shape_like('left') + output_dtype = rlz.dtype_like('left') @public class TimestampDiff(BinaryOp): left = rlz.timestamp right = rlz.timestamp - output_type = rlz.shape_like('left', dt.Interval('s')) + output_dtype = dt.Interval('s') + + +@public +class ToIntervalUnit(ValueOp): + arg = rlz.interval + unit = rlz.isin({'Y', 'Q', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns'}) + + output_shape = rlz.shape_like("arg") + + def __init__(self, arg, unit): + dtype = arg.type() + if dtype.unit != unit: + arg = util.convert_unit(arg, dtype.unit, unit) + super().__init__(arg=arg, unit=unit) + + @immutable_property + def output_dtype(self): + dtype = self.arg.type() + return dt.Interval( + unit=self.unit, + value_type=dtype.value_type, + nullable=dtype.nullable, + ) @public class IntervalBinaryOp(BinaryOp): - def output_type(self): - args = [ + @immutable_property + def output_dtype(self): + integer_args = [ arg.cast(arg.type().value_type) if isinstance(arg.type(), dt.Interval) else arg for arg in self.args ] - expr = rlz.numeric_like(args, self.__class__.op)(self) + value_dtype = rlz._promote_numeric_binop(integer_args, self.op) left_dtype = self.left.type() - dtype_type = type(left_dtype) - additional_args = { - attr: getattr(left_dtype, attr) - for attr in left_dtype.argnames - if attr not in ("unit", "value_type") - } - dtype = dtype_type(left_dtype.unit, expr.type(), **additional_args) - return rlz.shape_like(self.args, dtype=dtype) + return dt.Interval( + unit=left_dtype.unit, + value_type=value_dtype, + nullable=left_dtype.nullable, + ) @public @@ -370,13 +410,15 @@ class IntervalFromInteger(ValueOp): arg = rlz.integer unit = rlz.isin({'Y', 'Q', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns'}) + output_shape = rlz.shape_like("arg") + + @immutable_property + def output_dtype(self): + return dt.Interval(self.unit, self.arg.type()) + @property def resolution(self): - return dt.Interval(self.unit).resolution - - def output_type(self): - dtype = dt.Interval(self.unit, self.arg.type()) - return rlz.shape_like(self.arg, dtype=dtype) + return self.output_dtype.resolution @public diff --git a/ibis/expr/operations/vectorized.py b/ibis/expr/operations/vectorized.py index 95f00afa0051..70b6dfab04c4 100644 --- a/ibis/expr/operations/vectorized.py +++ b/ibis/expr/operations/vectorized.py @@ -11,6 +11,8 @@ class VectorizedUDF(ValueOp): func = rlz.instance_of((FunctionType, LambdaType)) func_args = rlz.tuple_of(rlz.column(rlz.any)) + # TODO(kszucs): should rename these arguments to + # input_dtypes and return_dtype input_type = rlz.tuple_of(rlz.datatype) return_type = rlz.datatype @@ -18,6 +20,10 @@ class VectorizedUDF(ValueOp): def inputs(self): return self.func_args + @property + def output_dtype(self): + return self.return_type + def root_tables(self): return distinct_roots(*self.func_args) @@ -26,21 +32,19 @@ def root_tables(self): class ElementWiseVectorizedUDF(VectorizedUDF): """Node for element wise UDF.""" - def output_type(self): - return self.return_type.column_type() + output_shape = rlz.Shape.COLUMNAR @public class ReductionVectorizedUDF(VectorizedUDF, Reduction): """Node for reduction UDF.""" - def output_type(self): - return self.return_type.scalar_type() + output_shape = rlz.Shape.SCALAR +# TODO(kszucs): revisit @public class AnalyticVectorizedUDF(VectorizedUDF, AnalyticOp): """Node for analytics UDF.""" - def output_type(self): - return self.return_type.column_type() + output_shape = rlz.Shape.COLUMNAR diff --git a/ibis/expr/rules.py b/ibis/expr/rules.py index 4e0629868341..a2585fdc966c 100644 --- a/ibis/expr/rules.py +++ b/ibis/expr/rules.py @@ -10,6 +10,7 @@ import ibis.expr.types as ir import ibis.util as util from ibis.common.validators import ( # noqa: F401 + immutable_property, instance_of, isin, list_of, @@ -21,6 +22,12 @@ ) +class Shape(enum.IntEnum): + SCALAR = 0 + COLUMNAR = 1 + # TABULAR = 2 + + def highest_precedence_dtype(exprs): """Return the highest precedence type from the passed expressions @@ -58,23 +65,6 @@ def comparable(left, right): return castable(left, right) or castable(right, left) -def cast(source, target): - """Currently Literal to *Scalar implicit casts are allowed""" - import ibis.expr.operations as ops # TODO: don't use ops here - - if not castable(source, target): - raise com.IbisTypeError('Source is not castable to target type!') - - # currently it prevents column -> scalar implicit castings - # however the datatypes are matching - op = source.op() - if not isinstance(op, ops.Literal): - raise com.IbisTypeError('Only able to implicitly cast literals!') - - out_type = target.type().scalar_type() - return out_type(op) - - # --------------------------------------------------------------------- # Input type validators / coercion functions @@ -251,41 +241,71 @@ def wrapper(name_or_value, *args, **kwargs): return wrapper -@promoter -def shape_like(arg, dtype=None): - if util.is_iterable(arg): - datatype = dtype or highest_precedence_dtype(arg) - columnar = util.any_of(arg, ir.AnyColumn) - else: - datatype = dtype or arg.type() - columnar = isinstance(arg, ir.AnyColumn) - - dtype = dt.dtype(datatype) +def dtype_like(name): + @immutable_property + def output_dtype(self): + arg = getattr(self, name) + if util.is_iterable(arg): + return highest_precedence_dtype(arg) + else: + return arg.type() + + return output_dtype + + +def shape_like(name): + @immutable_property + def output_shape(self): + arg = getattr(self, name) + if util.is_iterable(arg): + for expr in arg: + try: + op = expr.op() + except AttributeError: + continue + if op.output_shape is Shape.COLUMNAR: + return Shape.COLUMNAR + return Shape.SCALAR + else: + return arg.op().output_shape - if columnar: - return dtype.column_type() - else: - return dtype.scalar_type() + return output_shape -@promoter -def scalar_like(arg): - output_dtype = arg.type() - return output_dtype.scalar_type() +# TODO(kszucs): might just use bounds instead of actual literal values +# that could simplify interval binop output_type methods +# TODO(kszucs): pre-generate mapping? +def _promote_numeric_binop(exprs, op): + bounds, dtypes = [], [] + for arg in exprs: + dtypes.append(arg.type()) + if hasattr(arg.op(), 'value'): + # arg.op() is a literal + bounds.append([arg.op().value]) + else: + bounds.append(arg.type().bounds) + # In some cases, the bounding type might be int8, even though neither + # of the types are that small. We want to ensure the containing type is + # _at least_ as large as the smallest type in the expression. + values = starmap(op, product(*bounds)) + dtypes += [dt.infer(value) for value in values] -@promoter -def array_like(arg): - output_dtype = arg.type() - return output_dtype.column_type() + return dt.highest_precedence(dtypes) -column_like = array_like +def numeric_like(name, op): + @immutable_property + def output_dtype(self): + args = getattr(self, name) + if util.all_of(args, ir.IntegerValue): + result = _promote_numeric_binop(args, op) + else: + result = highest_precedence_dtype(args) + return result -@promoter -def typeof(arg): - return arg._factory + return output_dtype @validator @@ -491,34 +511,4 @@ def window(win, *, from_base_table_of, this): return win -# TODO: might just use bounds instead of actual literal values -# that could simplify interval binop output_type methods -def _promote_numeric_binop(exprs, op): - bounds, dtypes = [], [] - for arg in exprs: - dtypes.append(arg.type()) - if hasattr(arg.op(), 'value'): - # arg.op() is a literal - bounds.append([arg.op().value]) - else: - bounds.append(arg.type().bounds) - - # In some cases, the bounding type might be int8, even though neither - # of the types are that small. We want to ensure the containing type is - # _at least_ as large as the smallest type in the expression. - values = starmap(op, product(*bounds)) - dtypes += [dt.infer(value, allow_overflow=True) for value in values] - - return dt.highest_precedence(dtypes) - - -@promoter -def numeric_like(args, op): - if util.all_of(args, ir.IntegerValue): - dtype = _promote_numeric_binop(args, op) - return shape_like(args, dtype=dtype) - else: - return shape_like(args) - - # TODO: create varargs marker for impala udfs diff --git a/ibis/expr/timecontext.py b/ibis/expr/timecontext.py index 51fd27b11d1c..749a8be2cc14 100644 --- a/ibis/expr/timecontext.py +++ b/ibis/expr/timecontext.py @@ -281,6 +281,14 @@ def adjust_context_node( return timecontext +@adjust_context.register(ops.Alias) +def adjust_context_alias( + op: ops.Node, scope: 'Scope', timecontext: TimeContext +) -> TimeContext: + # For any node, by default, do not adjust time context + return adjust_context(op.arg.op(), scope, timecontext) + + @adjust_context.register(ops.AsOfJoin) def adjust_context_asof_join( op: ops.AsOfJoin, scope: 'Scope', timecontext: TimeContext diff --git a/ibis/expr/types/analytic.py b/ibis/expr/types/analytic.py index c621d11ad3bc..e96178f223a2 100644 --- a/ibis/expr/types/analytic.py +++ b/ibis/expr/types/analytic.py @@ -3,25 +3,22 @@ from public import public import ibis.common.exceptions as com +import ibis.expr.types as ir from .core import Expr @public class AnalyticExpr(Expr): - @property - def _factory(self): - def factory(arg): - return type(self)(arg) - - return factory + # TODO(kszucs): should be removed def type(self): return 'analytic' @public class ExistsExpr(AnalyticExpr): + # TODO(kszucs): should be removed def type(self): return 'exists' @@ -35,6 +32,7 @@ def _table_getitem(self): return self.to_filter() def to_filter(self): + # TODO: move to api.py import ibis.expr.operations as ops return ops.SummaryFilter(self).to_expr() @@ -42,19 +40,19 @@ def to_filter(self): def to_aggregation( self, metric_name=None, parent_table=None, backup_metric_name=None ): - """Convert the TopK operation to a table aggregation.""" - from .relations import find_base_table - + """ + Convert the TopK operation to a table aggregation + """ op = self.op() - arg_table = find_base_table(op.arg) + arg_table = ir.relations.find_base_table(op.arg) by = op.by if not isinstance(by, Expr): by = by(arg_table) by_table = arg_table else: - by_table = find_base_table(op.by) + by_table = ir.relations.find_base_table(op.by) if metric_name is None: if by.get_name() == op.arg.get_name(): diff --git a/ibis/expr/types/core.py b/ibis/expr/types/core.py index b30d857171b4..cec3696d15f2 100644 --- a/ibis/expr/types/core.py +++ b/ibis/expr/types/core.py @@ -71,6 +71,12 @@ def __bool__(self) -> bool: __nonzero__ = __bool__ + def has_name(self): + return self.op().has_resolved_name() + + def get_name(self): + return self.op().resolve_name() + @cached_property def _safe_name(self) -> str | None: """Get the name of an expression `expr` if one exists @@ -183,10 +189,6 @@ def pipe(self, f, *args: Any, **kwargs: Any) -> Expr: def op(self) -> ops.Node: return self._arg - @property - def _factory(self) -> type[Expr]: - return type(self) - def _find_backends(self) -> list[BaseBackend]: """Return the possible backends for an expression. diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 5d9e2437c31e..bd18251516b4 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1,10 +1,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Iterable, Sequence if TYPE_CHECKING: import ibis.expr.types as ir - import ibis.expr.operations as ops import ibis.expr.window as win from public import public @@ -18,48 +17,13 @@ @public class ValueExpr(Expr): - """Base class for an expression having a known type.""" - - _name: str | None - _dtype: dt.DataType - def __init__( - self, - arg: ops.ValueOp, - dtype: dt.DataType, - name: str | None = None, - ) -> None: - super().__init__(arg) - self._name = name - self._dtype = dtype - - def equals(self, other): - if not isinstance(other, Expr): - raise TypeError( - "invalid equality comparison between Expr and " - f"{type(other)}" - ) - return ( - self._safe_name == other._safe_name - and self._dtype.equals(other._dtype) - and super().equals(other) - ) - - def has_name(self) -> bool: - if self._name is not None: - return True - return self.op().has_resolved_name() - - def get_name(self) -> str: - if self._name is not None: - # This value has been explicitly named - return self._name - - # In some but not all cases we can get a name from the node that - # produces the value - return self.op().resolve_name() + """ + Base class for a data generating expression having a fixed and known type, + either a single value (scalar) + """ - def name(self, name: str) -> ValueExpr: + def name(self, name): """Rename an expression to `name`. Parameters @@ -81,24 +45,25 @@ def name(self, name: str) -> ValueExpr: a int64 b: r0.a """ - return self._factory(self._arg, name=name) + import ibis.expr.operations as ops - def type(self) -> dt.DataType: - """Return the data type of an expression. + # TODO(kszucs): shouldn't do simplification here, but rather later + # when simplifying the whole operation tree + # the expression's name is idendical to the new one + if self.has_name() and self.get_name() == name: + return self - Returns - ------- - DataType - Type of the expression - """ - return self._dtype + if isinstance(self.op(), ops.Alias): + # only keep a single alias operation + op = ops.Alias(arg=self.op().arg, name=name) + else: + op = ops.Alias(arg=self, name=name) - @property - def _factory(self) -> Callable[[ops.ValueOp, str | None], ValueExpr]: - def factory(arg: ops.ValueOp, name: str | None = None) -> ValueExpr: - return type(self)(arg, dtype=self.type(), name=name) + return op.to_expr() - return factory + # TODO(kszucs): should rename to dtype + def type(self): + return self.op().output_dtype @public @@ -153,6 +118,11 @@ def _repr_html_(self) -> str | None: return self.execute().to_frame()._repr_html_() +# TODO(kszucs): keep either ValueExpr or AnyValue +# TODO(kszucs): keep either ColumnExpr or AnyColumn +# TODO(kszucs): keep either ScalarExpr or ScalarColumn + + @public class AnyValue(ValueExpr): def hash(self, how: str = "fnv") -> ir.IntegerValue: @@ -441,6 +411,10 @@ def over(self, window: win.Window) -> ValueExpr: prior_op = self.op() + # TODO(kszucs): fix this ugly hack + if isinstance(prior_op, ops.Alias): + return prior_op.arg.over(window).name(prior_op.name) + if isinstance(prior_op, ops.WindowOp): op = prior_op.over(window) else: @@ -571,7 +545,7 @@ def group_concat( return ops.GroupConcat(self, sep=sep, where=where).to_expr() def __hash__(self) -> int: - return hash((self._name, self._dtype, self._arg)) + return super().__hash__() def __eq__(self, other: AnyValue) -> ir.BooleanValue: import ibis.expr.operations as ops @@ -866,6 +840,7 @@ class NullColumn(AnyColumn, NullValue): pass # noqa: E701,E302 +# TODO(kszucs): should remove the ColumnExpr base class? @public class ListExpr(ColumnExpr, AnyValue): @property diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 25e287356dcb..b56bd34291c1 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -76,13 +76,6 @@ def f( @public class TableExpr(Expr): - @property - def _factory(self): - def factory(arg): - return TableExpr(arg) - - return factory - def _is_valid(self, exprs): try: self._assert_valid(util.promote_list(exprs)) @@ -171,10 +164,12 @@ def __getattr__(self, key): def __dir__(self): return sorted(frozenset(dir(type(self)) + self.columns)) + # TODO(kszucs): should be removed def _resolve(self, exprs): exprs = util.promote_list(exprs) - return list(self._ensure_expr(x) for x in exprs) + return list(map(self._ensure_expr, exprs)) + # TODO(kszucs): should be removed def _ensure_expr(self, expr): if isinstance(expr, str): return self[expr] @@ -185,11 +180,13 @@ def _ensure_expr(self, expr): else: return expr + # TODO(kszucs): should be removed def _get_type(self, name): return self._arg.get_type(name) def get_columns(self, iterable: Iterable[str]) -> list[ColumnExpr]: - """Get multiple columns from the table. + """ + Get multiple columns from the table Examples -------- @@ -213,7 +210,8 @@ def get_columns(self, iterable: Iterable[str]) -> list[ColumnExpr]: return [self.get_column(x) for x in iterable] def get_column(self, name: str) -> ColumnExpr: - """Get a reference to a single column from the table. + """ + Get a reference to a single column from the table Returns ------- @@ -230,7 +228,8 @@ def columns(self): return list(self.schema().names) def schema(self) -> sch.Schema: - """Return the table's schema. + """ + Get the schema for this table (if one is known) Returns ------- @@ -244,7 +243,8 @@ def group_by( by=None, **additional_grouping_expressions: Any, ) -> GroupedTableExpr: - """Create a grouped table expression. + """ + Create a grouped table expression. Parameters ---------- @@ -557,8 +557,8 @@ def mutate( baz: 5 qux: r0.foo + r0.bar - Use the [`ValueExpr.name`][ibis.expr.types.generic.ValueExpr.name] - method to name the new columns. + Use the [`name`][ibis.expr.types.generic.ValueExpr.name] method to name + the new columns. >>> new_columns = [ibis.literal(5).name('baz',), ... (table.foo + table.bar).name('qux')] @@ -892,7 +892,7 @@ def set_column(self, name: str, expr: ir.ValueExpr) -> TableExpr: """ expr = self._ensure_expr(expr) - if expr._name != name: + if expr._safe_name != name: expr = expr.name(name) if name not in self: diff --git a/ibis/expr/types/sortkeys.py b/ibis/expr/types/sortkeys.py index 8f10646d245a..0baca8efa52c 100644 --- a/ibis/expr/types/sortkeys.py +++ b/ibis/expr/types/sortkeys.py @@ -12,8 +12,5 @@ @public class SortExpr(Expr): - def get_name(self) -> str | None: - return self.op().resolve_name() - def type(self) -> dt.DataType: return self.op().expr.type() diff --git a/ibis/expr/types/structs.py b/ibis/expr/types/structs.py index a026bb76b84e..5d17683dbb64 100644 --- a/ibis/expr/types/structs.py +++ b/ibis/expr/types/structs.py @@ -89,7 +89,7 @@ def destructure(self) -> DestructValue: DestructValue A destruct value expression. """ - return DestructValue(self._arg, self._dtype).name("") + return DestructValue(self._arg) @public @@ -105,7 +105,7 @@ def destructure(self) -> DestructScalar: DestructScalar A destruct scalar expression. """ - return DestructScalar(self._arg, self._dtype).name("") + return DestructScalar(self._arg) @public @@ -121,7 +121,7 @@ def destructure(self) -> DestructColumn: DestructColumn A destruct column expression. """ - return DestructColumn(self._arg, self._dtype).name("") + return DestructColumn(self._arg) @public @@ -132,6 +132,10 @@ class DestructValue(AnyValue): will be destructured and assigned to multiple columnns. """ + def name(self, name): + res = super().name(name) + return self.__class__(res.op()) + @public class DestructScalar(AnyScalar, DestructValue): diff --git a/ibis/expr/types/temporal.py b/ibis/expr/types/temporal.py index f1a2488c0f71..2500112bb9e8 100644 --- a/ibis/expr/types/temporal.py +++ b/ibis/expr/types/temporal.py @@ -9,7 +9,6 @@ import pandas as pd from .. import types as ir -from ... import util from .core import Expr, _binop from .generic import AnyColumn, AnyScalar, AnyValue @@ -487,12 +486,11 @@ class TimestampColumn(TemporalColumn, TimestampValue): class IntervalValue(AnyValue): def to_unit(self, target_unit: str) -> IntervalValue: """Convert this interval to units of `target_unit`.""" - if self._dtype.unit == target_unit: - return self + import ibis.expr.operations as ops - result = util.convert_unit(self, self._dtype.unit, target_unit) - object.__setattr__(result.type(), "unit", target_unit) - return result + # TODO(kszucs): should use a separate operation for unit conversion + # which we can rewrite/simplify to integer multiplication/division + return ops.ToIntervalUnit(self, unit=target_unit).to_expr() @property def years(self) -> ir.IntegerValue: diff --git a/ibis/expr/visualize.py b/ibis/expr/visualize.py index da21144a898c..6c62e96f4fc4 100644 --- a/ibis/expr/visualize.py +++ b/ibis/expr/visualize.py @@ -20,7 +20,7 @@ def get_type(expr): except (AttributeError, NotImplementedError): try: # As a last resort try get the name of the output_type class - return expr.op().output_type().__name__ + return expr.op().output_type.__name__ except (AttributeError, NotImplementedError): return '\u2205' # empty set character except com.IbisError: diff --git a/ibis/tests/expr/conftest.py b/ibis/tests/expr/conftest.py index fc72b3691c5d..c5ae292e6888 100644 --- a/ibis/tests/expr/conftest.py +++ b/ibis/tests/expr/conftest.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import collections - import pytest import ibis @@ -38,11 +36,6 @@ def schema(): ] -@pytest.fixture -def schema_dict(schema): - return collections.OrderedDict(schema) - - @pytest.fixture def table(schema): return ibis.table(schema, name='table') diff --git a/ibis/tests/expr/test_format.py b/ibis/tests/expr/test_format.py index 4517006d2dfe..3edea4dcf547 100644 --- a/ibis/tests/expr/test_format.py +++ b/ibis/tests/expr/test_format.py @@ -203,9 +203,6 @@ def test_scalar_parameter_repr(): value = ibis.param(dt.timestamp).name('value') assert repr(value) == "value: $(timestamp)" - value_op = value.op() - assert "ScalarParameter" in repr(value_op) - def test_repr_exact(): # NB: This is the only exact repr test. Do diff --git a/ibis/tests/expr/test_lineage.py b/ibis/tests/expr/test_lineage.py index b080645aad01..8cf26b8c3b9c 100644 --- a/ibis/tests/expr/test_lineage.py +++ b/ibis/tests/expr/test_lineage.py @@ -104,9 +104,11 @@ def test_lineage(companies): mutated.bucket, mutated, bucket.name('bucket'), + bucket, companies.funding_total_usd, companies, ] + assert len(results) == len(expected) for r, e in zip(results, expected): assert_equal(r, e) @@ -115,9 +117,11 @@ def test_lineage(companies): filtered.bucket, filtered, bucket.name('bucket'), + bucket, companies.funding_total_usd, companies, ] + assert len(results) == len(expected) for r, e in zip(results, expected): assert_equal(r, e) @@ -128,9 +132,11 @@ def test_lineage(companies): filtered.bucket, filtered, bucket.name('bucket'), + bucket, companies.funding_total_usd, companies, ] + assert len(results) == len(expected) for r, e in zip(results, expected): assert_equal(r, e) @@ -171,9 +177,7 @@ def test_lineage_join(companies, rounds): rounds.company_city, rounds.raised_amount_usd, ] - perc_raised = (expr.raised_amount_usd / expr.funding_total_usd).name( - 'perc_raised' - ) + perc_raised = expr.raised_amount_usd / expr.funding_total_usd results = list(lin.lineage(perc_raised)) expected = [ @@ -183,7 +187,6 @@ def test_lineage_join(companies, rounds): rounds.raised_amount_usd, rounds, expr.funding_total_usd, - # expr, # *could* appear here as well, but we've already traversed it companies.funding_total_usd, companies, ] diff --git a/ibis/tests/expr/test_operations.py b/ibis/tests/expr/test_operations.py index 5093710e11dc..12350f97e2c1 100644 --- a/ibis/tests/expr/test_operations.py +++ b/ibis/tests/expr/test_operations.py @@ -21,7 +21,6 @@ class Log(ops.Node): def test_ops_smoke(): expr = ir.literal(3) - ops.UnaryOp(expr) ops.Cast(expr, to='int64') ops.TypeOf(arg=2) ops.Negate(4) @@ -90,7 +89,8 @@ class MyOperation(ops.Node): def test_array_input(): class MyOp(ops.ValueOp): value = rlz.value(dt.Array(dt.double)) - output_type = rlz.typeof('value') + output_dtype = rlz.dtype_like('value') + output_shape = rlz.shape_like('value') raw_value = [1.0, 2.0, 3.0] op = MyOp(raw_value) @@ -104,6 +104,7 @@ class MyTableExpr(ir.TableExpr): pass class SpecialTable(ops.DatabaseTable): + @property def output_type(self): return MyTableExpr diff --git a/ibis/tests/expr/test_rules.py b/ibis/tests/expr/test_rules.py index d4423fbcf9bc..2da13051da3b 100644 --- a/ibis/tests/expr/test_rules.py +++ b/ibis/tests/expr/test_rules.py @@ -355,12 +355,6 @@ def test_table_with_schema_invalid(table): validator(table) -def test_shape_like_with_no_arguments(): - with pytest.raises(ValueError) as e: - rlz.shape_like([]) - assert str(e.value) == 'Must pass at least one expression' - - @pytest.mark.parametrize( ('rule', 'input'), [ diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index e4fc81661c13..f26153272366 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -75,13 +75,13 @@ def test_view_new_relation(table): assert roots[0] is tview.op() -def test_get_type(table, schema_dict): - for k, v in schema_dict.items(): +def test_get_type(table, schema): + for k, v in schema: assert table._get_type(k) == dt.dtype(v) -def test_getitem_column_select(table, schema_dict): - for k, v in schema_dict.items(): +def test_getitem_column_select(table): + for k in table.columns: col = table[k] # Make sure it's the right type @@ -424,7 +424,8 @@ def test_slice_convenience(table): def test_table_count(table): result = table.count() assert isinstance(result, ir.IntegerScalar) - assert isinstance(result.op(), ops.Count) + assert isinstance(result.op(), ops.Alias) + assert isinstance(result.op().arg.op(), ops.Count) assert result.get_name() == 'count' @@ -435,24 +436,28 @@ def test_len_raises_expression_error(table): def test_sum_expr_basics(table, int_col): # Impala gives bigint for all integer types - ex_class = ir.IntegerScalar result = table[int_col].sum() - assert isinstance(result, ex_class) - assert isinstance(result.op(), ops.Sum) + assert isinstance(result, ir.IntegerScalar) + assert isinstance(result.op(), ops.Alias) + assert isinstance(result.op().arg.op(), ops.Sum) + assert result.get_name() == "sum" def test_sum_expr_basics_floats(table, float_col): # Impala gives double for all floating point types - ex_class = ir.FloatingScalar result = table[float_col].sum() - assert isinstance(result, ex_class) - assert isinstance(result.op(), ops.Sum) + assert isinstance(result, ir.FloatingScalar) + assert isinstance(result.op(), ops.Alias) + assert isinstance(result.op().arg.op(), ops.Sum) + assert result.get_name() == "sum" def test_mean_expr_basics(table, numeric_col): result = table[numeric_col].mean() assert isinstance(result, ir.FloatingScalar) - assert isinstance(result.op(), ops.Mean) + assert isinstance(result.op(), ops.Alias) + assert isinstance(result.op().arg.op(), ops.Mean) + assert result.get_name() == "mean" def test_aggregate_no_keys(table): @@ -1357,9 +1362,13 @@ def test_mutate_chain(): a, b = three.op().selections # we can't fuse these correctly yet - assert isinstance(a.op(), ops.IfNull) + assert isinstance(a.op(), ops.Alias) + assert isinstance(a.op().arg.op(), ops.IfNull) assert isinstance(b.op(), ops.TableColumn) - assert isinstance(b.op().table.op().selections[1].op(), ops.IfNull) + + expr = b.op().table.op().selections[1] + assert isinstance(expr.op(), ops.Alias) + assert isinstance(expr.op().arg.op(), ops.IfNull) def test_multiple_dbcon(): diff --git a/ibis/tests/expr/test_temporal.py b/ibis/tests/expr/test_temporal.py index 7002ef1264a2..ac48cc3b1755 100644 --- a/ibis/tests/expr/test_temporal.py +++ b/ibis/tests/expr/test_temporal.py @@ -624,7 +624,7 @@ def test_complex_date_comparisons( def test_interval_column_name(table): c = table.i expr = (c - c).name('foo') - assert expr._name == 'foo' + assert expr.get_name() == 'foo' @pytest.mark.parametrize( diff --git a/ibis/tests/expr/test_timestamp.py b/ibis/tests/expr/test_timestamp.py index 4c88fe626a88..a03871d12535 100644 --- a/ibis/tests/expr/test_timestamp.py +++ b/ibis/tests/expr/test_timestamp.py @@ -42,7 +42,8 @@ def test_extract_fields(field, expected_operation, expected_type, alltypes): result = getattr(alltypes.i, field)() assert result.get_name() == field assert isinstance(result, expected_type) - assert isinstance(result.op(), expected_operation) + assert isinstance(result.op(), ops.Alias) + assert isinstance(result.op().arg.op(), expected_operation) def test_now(): @@ -77,12 +78,12 @@ def test_comparisons_string(alltypes): val = '2015-01-01 00:00:00' expr = alltypes.i > val op = expr.op() - assert isinstance(op.right, ir.TimestampScalar) + assert isinstance(op.right, ir.StringScalar) expr2 = val < alltypes.i op = expr2.op() assert isinstance(op, ops.Greater) - assert isinstance(op.right, ir.TimestampScalar) + assert isinstance(op.right, ir.StringScalar) def test_comparisons_pandas_timestamp(alltypes): @@ -119,8 +120,10 @@ def test_timestamp_field_access_on_date( ): date_col = alltypes.i.date() result = getattr(date_col, field)() + assert result.get_name() == field assert isinstance(result, expected_type) - assert isinstance(result.op(), expected_operation) + assert isinstance(result.op(), ops.Alias) + assert isinstance(result.op().arg.op(), expected_operation) @pytest.mark.parametrize( @@ -154,8 +157,10 @@ def test_timestamp_field_access_on_time( ): time_col = alltypes.i.time() result = getattr(time_col, field)() + assert result.get_name() == field assert isinstance(result, expected_type) - assert isinstance(result.op(), expected_operation) + assert isinstance(result.op(), ops.Alias) + assert isinstance(result.op().arg.op(), expected_operation) @pytest.mark.parametrize( diff --git a/ibis/tests/expr/test_udf.py b/ibis/tests/expr/test_udf.py index 4b2795eeb098..c89f7781290c 100644 --- a/ibis/tests/expr/test_udf.py +++ b/ibis/tests/expr/test_udf.py @@ -40,9 +40,8 @@ def test_vectorized_udf_operations(table, klass, output_type): assert udf.input_type == tuple([dt.int8(), dt.string(), dt.boolean()]) assert udf.return_type == dt.int8() - factory = udf.output_type() - expr = factory(udf) - assert isinstance(expr, output_type) + expr = udf.to_expr() + assert isinstance(expr, udf.output_type) with pytest.raises(com.IbisTypeError): # wrong function type diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index 06630ae036b2..967634d89694 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -1,4 +1,3 @@ -import functools import operator import uuid from collections import OrderedDict @@ -323,7 +322,8 @@ def test_distinct_table(functional_alltypes): def test_nunique(functional_alltypes): expr = functional_alltypes.string_col.nunique() - assert isinstance(expr.op(), ops.CountDistinct) + assert isinstance(expr.op(), ops.Alias) + assert isinstance(expr.op().arg.op(), ops.CountDistinct) def test_isnull(table): @@ -1019,11 +1019,12 @@ def __add__(self, other): __radd__ = __add__ - class FooNode(ops.ValueOp): + class FooNode(ops.Node): value = rlz.integer + @property def output_type(self): - return functools.partial(Foo, dtype=dt.int64) + return Foo left = ibis.literal(2) right = FooNode(3).to_expr() @@ -1042,7 +1043,7 @@ def test_empty_array_as_argument(): class Foo(ir.Expr): pass - class FooNode(ops.ValueOp): + class FooNode(ops.Node): value = rlz.value(dt.Array(dt.int64)) def output_type(self): diff --git a/ibis/tests/expr/test_visualize.py b/ibis/tests/expr/test_visualize.py index c3f8db42db4a..394df6d7cd5b 100644 --- a/ibis/tests/expr/test_visualize.py +++ b/ibis/tests/expr/test_visualize.py @@ -49,9 +49,7 @@ class MyExpr(ir.Expr): class MyExprNode(ops.Node): foo = rlz.string bar = rlz.numeric - - def output_type(self): - return MyExpr + output_type = MyExpr op = MyExprNode('Hello!', 42.3) expr = op.to_expr() @@ -71,6 +69,7 @@ class MyExprNode(ops.Node): foo = rlz.string bar = rlz.numeric + @property def output_type(self): return MyExpr diff --git a/ibis/tests/expr/test_window_functions.py b/ibis/tests/expr/test_window_functions.py index f7549a26ff05..a7f52ef42fa8 100644 --- a/ibis/tests/expr/test_window_functions.py +++ b/ibis/tests/expr/test_window_functions.py @@ -18,6 +18,8 @@ import ibis import ibis.common.exceptions as com +import ibis.expr.rules as rlz +import ibis.expr.types as ir from ibis.expr.window import _determine_how, rows_with_max_lookback from ibis.tests.util import assert_equal @@ -337,4 +339,20 @@ def test_combine_preserves_existing_window(): ) w = ibis.cumulative_window(order_by=t.one) mut = t.group_by(t.three).mutate(four=t.two.sum().over(w)) - assert mut.op().selections[1].op().window.following == 0 + + assert mut.op().selections[1].op().arg.op().window.following == 0 + + +def test_quantile_shape(): + t = ibis.table([("a", "float64")]) + + b1 = t.a.quantile(0.25).name("br2") + assert isinstance(b1, ir.ScalarExpr) + + projs = [b1] + expr = t.projection(projs) + (b1,) = expr.op().selections + + assert b1.op().output_shape == rlz.Shape.COLUMNAR + + assert isinstance(b1, ir.ColumnExpr)