| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import re | ||
| from typing import Mapping | ||
|
|
||
| import datafusion as df | ||
| import pyarrow as pa | ||
|
|
||
| import ibis.common.exceptions as com | ||
| import ibis.expr.schema as sch | ||
| import ibis.expr.types as ir | ||
| from ibis.backends.base import BaseBackend | ||
|
|
||
| from .compiler import translate | ||
|
|
||
|
|
||
| def _to_pyarrow_table(frame): | ||
| batches = frame.collect() | ||
| if batches: | ||
| return pa.Table.from_batches(batches) | ||
| else: | ||
| # TODO(kszucs): file a bug to datafusion because the fields' | ||
| # nullability from frame.schema() is not always consistent | ||
| # with the first record batch's schema | ||
| return pa.Table.from_batches(batches, schema=frame.schema()) | ||
|
|
||
|
|
||
| class Backend(BaseBackend): | ||
| name = 'datafusion' | ||
| builder = None | ||
|
|
||
| @property | ||
| def version(self): | ||
| try: | ||
| import importlib.metadata as importlib_metadata | ||
| except ImportError: | ||
| # TODO: remove this when Python 3.7 support is dropped | ||
| import importlib_metadata | ||
| return importlib_metadata.version("datafusion") | ||
|
|
||
| def do_connect(self, config): | ||
| """ | ||
| Create a DataFusionClient for use with Ibis | ||
| Parameters | ||
| ---------- | ||
| config : DataFusionContext or dict | ||
| Returns | ||
| ------- | ||
| DataFusionClient | ||
| """ | ||
| if isinstance(config, df.ExecutionContext): | ||
| self._context = config | ||
| else: | ||
| self._context = df.ExecutionContext() | ||
|
|
||
| for name, path in config.items(): | ||
| strpath = str(path) | ||
| if strpath.endswith('.csv'): | ||
| self.register_csv(name, path) | ||
| elif strpath.endswith('.parquet'): | ||
| self.register_parquet(name, path) | ||
| else: | ||
| raise ValueError( | ||
| "Currently the DataFusion backend only supports CSV " | ||
| "files with the extension .csv and Parquet files with " | ||
| "the .parquet extension." | ||
| ) | ||
|
|
||
| def current_database(self): | ||
| raise NotImplementedError() | ||
|
|
||
| def list_databases(self, like: str = None) -> list[str]: | ||
| raise NotImplementedError() | ||
|
|
||
| def list_tables(self, like: str = None, database: str = None) -> list[str]: | ||
| """List the available tables.""" | ||
| tables = list(self._context.tables()) | ||
| if like is not None: | ||
| pattern = re.compile(like) | ||
| return list(filter(lambda t: pattern.findall(t), tables)) | ||
| return tables | ||
|
|
||
| def table(self, name, schema=None): | ||
| """Get an ibis expression representing a DataFusion table. | ||
| Parameters | ||
| --------- | ||
| name | ||
| The name of the table to retreive | ||
| schema | ||
| An optional schema | ||
| Returns | ||
| ------- | ||
| ibis.expr.types.TableExpr | ||
| A table expression | ||
| """ | ||
| catalog = self._context.catalog() | ||
| database = catalog.database('public') | ||
| table = database.table(name) | ||
| schema = sch.infer(table.schema) | ||
| return self.table_class(name, schema, self).to_expr() | ||
|
|
||
| def register_csv(self, name, path, schema=None): | ||
| """Register a CSV file with with `name` located at `path`. | ||
| Parameters | ||
| ---------- | ||
| name | ||
| The name of the table | ||
| path | ||
| The path to the CSV file | ||
| schema | ||
| An optional schema | ||
| """ | ||
| self._context.register_csv(name, path, schema=schema) | ||
|
|
||
| def register_parquet(self, name, path, schema=None): | ||
| """Register a parquet file with with `name` located at `path`. | ||
| Parameters | ||
| ---------- | ||
| name | ||
| The name of the table | ||
| path | ||
| The path to the parquet file | ||
| schema | ||
| An optional schema | ||
| """ | ||
| self._context.register_parquet(name, path, schema=schema) | ||
|
|
||
| def execute( | ||
| self, | ||
| expr: ir.Expr, | ||
| params: Mapping[ir.Expr, object] = None, | ||
| limit: str = 'default', | ||
| **kwargs, | ||
| ): | ||
| if isinstance(expr, ir.TableExpr): | ||
| frame = self.compile(expr, params, **kwargs) | ||
| table = _to_pyarrow_table(frame) | ||
| return table.to_pandas() | ||
| elif isinstance(expr, ir.ColumnExpr): | ||
| # expression must be named for the projection | ||
| expr = expr.name('tmp').to_projection() | ||
| frame = self.compile(expr, params, **kwargs) | ||
| table = _to_pyarrow_table(frame) | ||
| return table['tmp'].to_pandas() | ||
| elif isinstance(expr, ir.ScalarExpr): | ||
| if expr.op().root_tables(): | ||
| # there are associated datafusion tables so convert the expr | ||
| # to a selection which we can directly convert to a datafusion | ||
| # plan | ||
| expr = expr.name('tmp').to_projection() | ||
| frame = self.compile(expr, params, **kwargs) | ||
| else: | ||
| # doesn't have any tables associated so create a plan from a | ||
| # dummy datafusion table | ||
| compiled = self.compile(expr, params, **kwargs) | ||
| frame = self._context.empty_table().select(compiled) | ||
| table = _to_pyarrow_table(frame) | ||
| return table[0][0].as_py() | ||
| else: | ||
| raise com.IbisError( | ||
| f"Cannot execute expression of type: {type(expr)}" | ||
| ) | ||
|
|
||
| def compile( | ||
| self, expr: ir.Expr, params: Mapping[ir.Expr, object] = None, **kwargs | ||
| ): | ||
| return translate(expr) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,396 @@ | ||
| import functools | ||
| import operator | ||
|
|
||
| import datafusion as df | ||
| import datafusion.functions | ||
| import pyarrow as pa | ||
|
|
||
| import ibis.common.exceptions as com | ||
| import ibis.expr.operations as ops | ||
| import ibis.expr.types as ir | ||
|
|
||
| from .datatypes import to_pyarrow_type | ||
|
|
||
|
|
||
| @functools.singledispatch | ||
| def translate(expr): | ||
| raise NotImplementedError(expr) | ||
|
|
||
|
|
||
| @translate.register(ir.Expr) | ||
| def expression(expr): | ||
| return translate(expr.op(), expr) | ||
|
|
||
|
|
||
| @translate.register(ops.Node) | ||
| def operation(op, expr): | ||
| raise com.OperationNotDefinedError(f'No translation rule for {type(op)}') | ||
|
|
||
|
|
||
| @translate.register(ops.DatabaseTable) | ||
| def table(op, expr): | ||
| name, _, client = op.args | ||
| return client._context.table(name) | ||
|
|
||
|
|
||
| @translate.register(ops.Literal) | ||
| def literal(op, expr): | ||
| if isinstance(op.value, (set, frozenset)): | ||
| value = list(op.value) | ||
| else: | ||
| value = op.value | ||
|
|
||
| arrow_type = to_pyarrow_type(op.dtype) | ||
| arrow_scalar = pa.scalar(value, type=arrow_type) | ||
|
|
||
| return df.literal(arrow_scalar) | ||
|
|
||
|
|
||
| @translate.register(ops.Cast) | ||
| def cast(op, expr): | ||
| arg = translate(op.arg) | ||
| typ = to_pyarrow_type(op.to) | ||
| return arg.cast(to=typ) | ||
|
|
||
|
|
||
| @translate.register(ops.TableColumn) | ||
| def column(op, expr): | ||
| table_op = op.table.op() | ||
|
|
||
| if hasattr(table_op, "name"): | ||
| return df.column(f'{table_op.name}."{op.name}"') | ||
| else: | ||
| return df.column(op.name) | ||
|
|
||
|
|
||
| @translate.register(ops.SortKey) | ||
| def sort_key(op, expr): | ||
| arg = translate(op.expr) | ||
| return arg.sort(ascending=op.ascending) | ||
|
|
||
|
|
||
| @translate.register(ops.Selection) | ||
| def selection(op, expr): | ||
| plan = translate(op.table) | ||
|
|
||
| selections = [] | ||
| for expr in op.selections or [op.table]: | ||
| # TODO(kszucs) it would be nice if we wouldn't need to handle the | ||
| # specific cases in the backend implementations, we could add a | ||
| # new operator which retrieves all of the TableExpr columns | ||
| # (.e.g. Asterisk) so the translate() would handle this | ||
| # automatically | ||
| if isinstance(expr, ir.TableExpr): | ||
| for name in expr.columns: | ||
| column = expr.get_column(name) | ||
| field = translate(column) | ||
| if column.has_name(): | ||
| field = field.alias(column.get_name()) | ||
| selections.append(field) | ||
| elif isinstance(expr, ir.ValueExpr): | ||
| field = translate(expr) | ||
| if expr.has_name(): | ||
| field = field.alias(expr.get_name()) | ||
| selections.append(field) | ||
| else: | ||
| raise com.TranslationError( | ||
| "DataFusion backend is unable to compile selection with " | ||
| f"expression type of {type(expr)}" | ||
| ) | ||
|
|
||
| plan = plan.select(*selections) | ||
|
|
||
| if op.predicates: | ||
| predicates = map(translate, op.predicates) | ||
| predicate = functools.reduce(operator.and_, predicates) | ||
| plan = plan.filter(predicate) | ||
|
|
||
| if op.sort_keys: | ||
| sort_keys = map(translate, op.sort_keys) | ||
| plan = plan.sort(*sort_keys) | ||
|
|
||
| return plan | ||
|
|
||
|
|
||
| @translate.register(ops.Aggregation) | ||
| def aggregation(op, expr): | ||
| table = translate(op.table) | ||
| group_by = [translate(expr) for expr in op.by] | ||
|
|
||
| metrics = [] | ||
| for expr in op.metrics: | ||
| agg = translate(expr) | ||
| if expr.has_name(): | ||
| agg = agg.alias(expr.get_name()) | ||
| metrics.append(agg) | ||
|
|
||
| return table.aggregate(group_by, metrics) | ||
|
|
||
|
|
||
| @translate.register(ops.Not) | ||
| def invert(op, expr): | ||
| arg = translate(op.arg) | ||
| return ~arg | ||
|
|
||
|
|
||
| @translate.register(ops.Abs) | ||
| def abs(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.abs(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Ceil) | ||
| def ceil(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.ceil(arg).cast(pa.int64()) | ||
|
|
||
|
|
||
| @translate.register(ops.Floor) | ||
| def floor(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.floor(arg).cast(pa.int64()) | ||
|
|
||
|
|
||
| @translate.register(ops.Round) | ||
| def round(op, expr): | ||
| arg = translate(op.arg) | ||
| if op.digits is not None: | ||
| raise com.UnsupportedOperationError( | ||
| 'Rounding to specific digits is not supported in datafusion' | ||
| ) | ||
| return df.functions.round(arg).cast(pa.int64()) | ||
|
|
||
|
|
||
| @translate.register(ops.Ln) | ||
| def ln(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.ln(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Log2) | ||
| def log2(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.log2(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Log10) | ||
| def log10(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.log10(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Sqrt) | ||
| def sqrt(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.sqrt(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Strip) | ||
| def strip(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.trim(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.LStrip) | ||
| def lstrip(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.ltrim(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.RStrip) | ||
| def rstrip(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.rtrim(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Lowercase) | ||
| def lower(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.lower(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Uppercase) | ||
| def upper(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.upper(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Reverse) | ||
| def reverse(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.reverse(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.StringLength) | ||
| def strlen(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.character_length(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Capitalize) | ||
| def capitalize(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.initcap(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Substring) | ||
| def substring(op, expr): | ||
| arg = translate(op.arg) | ||
| start = translate(op.start + 1) | ||
| length = translate(op.length) | ||
| return df.functions.substr(arg, start, length) | ||
|
|
||
|
|
||
| @translate.register(ops.RegexExtract) | ||
| def regex_extract(op, expr): | ||
| arg = translate(op.arg) | ||
| pattern = translate(op.pattern) | ||
| return df.functions.regexp_match(arg, pattern) | ||
|
|
||
|
|
||
| @translate.register(ops.Repeat) | ||
| def repeat(op, expr): | ||
| arg = translate(op.arg) | ||
| times = translate(op.times) | ||
| return df.functions.repeat(arg, times) | ||
|
|
||
|
|
||
| @translate.register(ops.LPad) | ||
| def lpad(op, expr): | ||
| arg = translate(op.arg) | ||
| length = translate(op.length) | ||
| pad = translate(op.pad) | ||
| return df.functions.lpad(arg, length, pad) | ||
|
|
||
|
|
||
| @translate.register(ops.RPad) | ||
| def rpad(op, expr): | ||
| arg = translate(op.arg) | ||
| length = translate(op.length) | ||
| pad = translate(op.pad) | ||
| return df.functions.rpad(arg, length, pad) | ||
|
|
||
|
|
||
| @translate.register(ops.GreaterEqual) | ||
| def ge(op, expr): | ||
| return translate(op.left) >= translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.LessEqual) | ||
| def le(op, expr): | ||
| return translate(op.left) <= translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.Greater) | ||
| def gt(op, expr): | ||
| return translate(op.left) > translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.Less) | ||
| def lt(op, expr): | ||
| return translate(op.left) < translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.Equals) | ||
| def eq(op, expr): | ||
| return translate(op.left) == translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.NotEquals) | ||
| def ne(op, expr): | ||
| return translate(op.left) != translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.Add) | ||
| def add(op, expr): | ||
| return translate(op.left) + translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.Subtract) | ||
| def sub(op, expr): | ||
| return translate(op.left) - translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.Multiply) | ||
| def mul(op, expr): | ||
| return translate(op.left) * translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.Divide) | ||
| def div(op, expr): | ||
| return translate(op.left) / translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.FloorDivide) | ||
| def floordiv(op, expr): | ||
| return df.functions.floor(translate(op.left) / translate(op.right)) | ||
|
|
||
|
|
||
| @translate.register(ops.Modulus) | ||
| def mod(op, expr): | ||
| return translate(op.left) % translate(op.right) | ||
|
|
||
|
|
||
| @translate.register(ops.Sum) | ||
| def sum(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.sum(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Min) | ||
| def min(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.min(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Max) | ||
| def max(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.max(arg) | ||
|
|
||
|
|
||
| @translate.register(ops.Mean) | ||
| def mean(op, expr): | ||
| arg = translate(op.arg) | ||
| return df.functions.avg(arg) | ||
|
|
||
|
|
||
| def _prepare_contains_options(options): | ||
| if isinstance(options, ir.AnyScalar): | ||
| # TODO(kszucs): it would be better if we could pass an arrow | ||
| # ListScalar to datafusions in_list function | ||
| return [df.literal(v) for v in options.op().value] | ||
| else: | ||
| return translate(options) | ||
|
|
||
|
|
||
| @translate.register(ops.ValueList) | ||
| def value_list(op, expr): | ||
| return list(map(translate, op.values)) | ||
|
|
||
|
|
||
| @translate.register(ops.Contains) | ||
| def contains(op, expr): | ||
| value = translate(op.value) | ||
| options = _prepare_contains_options(op.options) | ||
| return df.functions.in_list(value, options, negated=False) | ||
|
|
||
|
|
||
| @translate.register(ops.NotContains) | ||
| def not_contains(op, expr): | ||
| value = translate(op.value) | ||
| options = _prepare_contains_options(op.options) | ||
| return df.functions.in_list(value, options, negated=True) | ||
|
|
||
|
|
||
| @translate.register(ops.ElementWiseVectorizedUDF) | ||
| def elementwise_udf(op, expr): | ||
| udf = df.udf( | ||
| op.func, | ||
| input_types=list(map(to_pyarrow_type, op.input_type)), | ||
| return_type=to_pyarrow_type(op.return_type), | ||
| volatility="volatile", | ||
| ) | ||
| args = map(translate, op.func_args) | ||
|
|
||
| return udf(*args) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| import functools | ||
|
|
||
| import pyarrow as pa | ||
|
|
||
| import ibis.common.exceptions as com | ||
| import ibis.expr.datatypes as dt | ||
| import ibis.expr.schema as sch | ||
|
|
||
| # TODO(kszucs): the following conversions are really rudimentary | ||
| # we should have a pyarrow backend which would be responsible | ||
| # for conversions between ibis types to pyarrow types | ||
|
|
||
| # TODO(kszucs): support nested and parametric types | ||
| # consolidate with the logic from the parquet backend | ||
|
|
||
|
|
||
| _to_ibis_dtypes = { | ||
| pa.int8(): dt.Int8, | ||
| pa.int16(): dt.Int16, | ||
| pa.int32(): dt.Int32, | ||
| pa.int64(): dt.Int64, | ||
| pa.uint8(): dt.UInt8, | ||
| pa.uint16(): dt.UInt16, | ||
| pa.uint32(): dt.UInt32, | ||
| pa.uint64(): dt.UInt64, | ||
| pa.float16(): dt.Float16, | ||
| pa.float32(): dt.Float32, | ||
| pa.float64(): dt.Float64, | ||
| pa.string(): dt.String, | ||
| pa.binary(): dt.Binary, | ||
| pa.bool_(): dt.Boolean, | ||
| } | ||
|
|
||
|
|
||
| @dt.dtype.register(pa.DataType) | ||
| def from_pyarrow_primitive(arrow_type, nullable=True): | ||
| return _to_ibis_dtypes[arrow_type](nullable=nullable) | ||
|
|
||
|
|
||
| @dt.dtype.register(pa.TimestampType) | ||
| def from_pyarrow_timestamp(arrow_type, nullable=True): | ||
| return dt.TimestampType(timezone=arrow_type.tz) | ||
|
|
||
|
|
||
| @sch.infer.register(pa.Schema) | ||
| def infer_pyarrow_schema(schema): | ||
| fields = [(f.name, dt.dtype(f.type, nullable=f.nullable)) for f in schema] | ||
| return sch.schema(fields) | ||
|
|
||
|
|
||
| _to_pyarrow_types = { | ||
| dt.Int8: pa.int8(), | ||
| dt.Int16: pa.int16(), | ||
| dt.Int32: pa.int32(), | ||
| dt.Int64: pa.int64(), | ||
| dt.UInt8: pa.uint8(), | ||
| dt.UInt16: pa.uint16(), | ||
| dt.UInt32: pa.uint32(), | ||
| dt.UInt64: pa.uint64(), | ||
| dt.Float16: pa.float16(), | ||
| dt.Float32: pa.float32(), | ||
| dt.Float64: pa.float64(), | ||
| dt.String: pa.string(), | ||
| dt.Binary: pa.binary(), | ||
| dt.Boolean: pa.bool_(), | ||
| dt.Timestamp: pa.timestamp('ns'), | ||
| } | ||
|
|
||
|
|
||
| @functools.singledispatch | ||
| def to_pyarrow_type(dtype): | ||
| return _to_pyarrow_types[dtype.__class__] | ||
|
|
||
|
|
||
| @to_pyarrow_type.register(dt.Array) | ||
| def from_ibis_array(dtype): | ||
| return pa.list_(to_pyarrow_type(dtype.value_type)) | ||
|
|
||
|
|
||
| @to_pyarrow_type.register(dt.Set) | ||
| def from_ibis_set(dtype): | ||
| return pa.list_(to_pyarrow_type(dtype.value_type)) | ||
|
|
||
|
|
||
| @to_pyarrow_type.register(dt.Interval) | ||
| def from_ibis_interval(dtype): | ||
| try: | ||
| return pa.duration(dtype.unit) | ||
| except ValueError: | ||
| raise com.IbisTypeError(f"Unsupported interval unit: {dtype.unit}") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| from pathlib import Path | ||
|
|
||
| import pyarrow as pa | ||
| import pytest | ||
|
|
||
| import ibis | ||
| import ibis.expr.types as ir | ||
| from ibis.backends.tests.base import BackendTest, RoundAwayFromZero | ||
|
|
||
|
|
||
| class TestConf(BackendTest, RoundAwayFromZero): | ||
| # check_names = False | ||
| # additional_skipped_operations = frozenset({ops.StringSQLLike}) | ||
| # supports_divide_by_zero = True | ||
| # returned_timestamp_unit = 'ns' | ||
| bool_is_int = True | ||
|
|
||
| @staticmethod | ||
| def connect(data_directory: Path): | ||
| # can be various types: | ||
| # pyarrow.RecordBatch | ||
| # parquet file path | ||
| # csv file path | ||
| client = ibis.datafusion.connect({}) | ||
| client.register_csv( | ||
| name='functional_alltypes', | ||
| path=data_directory / 'functional_alltypes.csv', | ||
| schema=pa.schema( | ||
| [ | ||
| ('index', 'int64'), | ||
| ('Unnamed 0', 'int64'), | ||
| ('id', 'int64'), | ||
| ('bool_col', 'int8'), | ||
| ('tinyint_col', 'int8'), | ||
| ('smallint_col', 'int16'), | ||
| ('int_col', 'int32'), | ||
| ('bigint_col', 'int64'), | ||
| ('float_col', 'float32'), | ||
| ('double_col', 'float64'), | ||
| ('date_string_col', 'string'), | ||
| ('string_col', 'string'), | ||
| ('timestamp_col', 'string'), | ||
| ('year', 'int64'), | ||
| ('month', 'int64'), | ||
| ] | ||
| ), | ||
| ) | ||
| client.register_csv( | ||
| name='batting', path=data_directory / 'batting.csv' | ||
| ) | ||
| client.register_csv( | ||
| name='awards_players', path=data_directory / 'awards_players.csv' | ||
| ) | ||
| return client | ||
|
|
||
| @property | ||
| def functional_alltypes(self) -> ir.TableExpr: | ||
| t = self.connection.table('functional_alltypes') | ||
| return t.mutate( | ||
| bool_col=t.bool_col == 1, | ||
| timestamp_col=t.timestamp_col.cast('timestamp'), | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope='session') | ||
| def client(data_directory): | ||
| return TestConf.connect(data_directory) | ||
|
|
||
|
|
||
| @pytest.fixture(scope='session') | ||
| def alltypes(client): | ||
| return client.table("functional_alltypes") | ||
|
|
||
|
|
||
| @pytest.fixture(scope='session') | ||
| def alltypes_df(alltypes): | ||
| return alltypes.execute() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| def test_list_tables(client): | ||
| assert set(client.list_tables()) == { | ||
| 'awards_players', | ||
| 'batting', | ||
| 'functional_alltypes', | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from .conftest import BackendTest | ||
|
|
||
|
|
||
| def test_where_multiple_conditions(alltypes, alltypes_df): | ||
| expr = alltypes.filter( | ||
| [ | ||
| alltypes.float_col > 0, | ||
| alltypes.smallint_col == 9, | ||
| alltypes.int_col < alltypes.float_col * 2, | ||
| ] | ||
| ) | ||
| result = expr.execute() | ||
|
|
||
| expected = alltypes_df[ | ||
| (alltypes_df['float_col'] > 0) | ||
| & (alltypes_df['smallint_col'] == 9) | ||
| & (alltypes_df['int_col'] < alltypes_df['float_col'] * 2) | ||
| ] | ||
|
|
||
| BackendTest.assert_frame_equal(result, expected) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| import pandas.testing as tm | ||
| import pyarrow.compute as pc | ||
|
|
||
| import ibis.expr.datatypes as dt | ||
| import ibis.expr.types as ir | ||
| from ibis.udf.vectorized import elementwise, reduction | ||
|
|
||
|
|
||
| @elementwise(input_type=['string'], output_type='int64') | ||
| def my_string_length(arr, **kwargs): | ||
| # arr is a pyarrow.StringArray | ||
| return pc.cast(pc.multiply(pc.utf8_length(arr), 2), target_type='int64') | ||
|
|
||
|
|
||
| @elementwise(input_type=[dt.int64, dt.int64], output_type=dt.int64) | ||
| def my_add(arr1, arr2, **kwargs): | ||
| return pc.add(arr1, arr2) | ||
|
|
||
|
|
||
| @reduction(input_type=[dt.float64], output_type=dt.float64) | ||
| def my_mean(arr): | ||
| return pc.mean(arr) | ||
|
|
||
|
|
||
| def test_udf(alltypes): | ||
| data_string_col = alltypes.date_string_col.execute() | ||
| expected = data_string_col.str.len() * 2 | ||
|
|
||
| expr = my_string_length(alltypes.date_string_col) | ||
| assert isinstance(expr, ir.ColumnExpr) | ||
|
|
||
| result = expr.execute() | ||
| tm.assert_series_equal(result, expected, check_names=False) | ||
|
|
||
|
|
||
| def test_multiple_argument_udf(alltypes): | ||
| expr = my_add(alltypes.smallint_col, alltypes.int_col) | ||
| result = expr.execute() | ||
|
|
||
| df = alltypes[['smallint_col', 'int_col']].execute() | ||
| expected = (df.smallint_col + df.int_col).astype('int64') | ||
|
|
||
| tm.assert_series_equal(result, expected.rename('tmp')) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,7 @@ | |
| } | ||
| ) | ||
| con = ibis.pandas.connect({"table1": df}) | ||
| @elementwise( | ||
| input_type=[dt.double], | ||
|
|
||