24 changes: 15 additions & 9 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import contextlib
import os
from functools import lru_cache
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional

import toolz

Expand Down Expand Up @@ -136,12 +136,16 @@ def raw_sql(self, query: str):
::: {.callout-tip}
## Consider using [`.sql`](#ibis.backends.base.sql.BaseSQLBackend.sql) instead
If your query is a SELECT statement, you should use the
If your query is a `SELECT` statement you can use the
[backend `.sql`](#ibis.backends.base.sql.BaseSQLBackend.sql) method to avoid
having to release the cursor returned from this method manually.
having to manually release the cursor returned from this method.
::: {.callout-warning collapse="true"}
## The returned cursor object must be **manually released** if you use `raw_sql`.
::: {.callout-warning}
## The cursor returned from this method must be **manually released**
You **do not** need to call `.close()` on the cursor when running DDL
or DML statements like `CREATE`, `INSERT` or `DROP`, only when using
`SELECT` statements.
To release a cursor, call the `close` method on the returned cursor
object.
Expand All @@ -166,14 +170,13 @@ def raw_sql(self, query: str):
Parameters
----------
query
DDL or DML statement
SQL query string
Examples
--------
>>> con = ibis.connect("duckdb://")
>>> with con.raw_sql("SELECT 1") as cursor:
... result = cursor.fetchall()
...
>>> result
[(1,)]
>>> cursor.closed
Expand Down Expand Up @@ -251,18 +254,21 @@ def _register_udfs(self, expr: ir.Expr) -> None:
if self.supports_python_udfs:
raise NotImplementedError(self.name)

def _gen_udf_name(self, name: str, schema: Optional[str]) -> str:
return ".".join(filter(None, (schema, name)))

def _gen_udf_rule(self, op: ops.ScalarUDF):
@self.add_operation(type(op))
def _(t, op):
func = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__)))
func = self._gen_udf_name(op.__func_name__, schema=op.__udf_namespace__)
return f"{func}({', '.join(map(t.translate, op.args))})"

def _gen_udaf_rule(self, op: ops.AggUDF):
from ibis import NA

@self.add_operation(type(op))
def _(t, op):
func = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__)))
func = self._gen_udf_name(op.__func_name__, schema=op.__udf_namespace__)
args = ", ".join(
t.translate(
ops.IfElse(where, arg, NA)
Expand Down
66 changes: 63 additions & 3 deletions ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def _schema_from_sqla_table(
dtype = schema[name]
else:
dtype = cls.compiler.translator_class.get_ibis_type(
column.type, nullable=column.nullable
column.type, nullable=column.nullable or column.nullable is None
)
pairs.append((name, dtype))
return sch.schema(pairs)
Expand Down Expand Up @@ -589,7 +589,58 @@ def _handle_failed_column_type_inference(
)
return table

def raw_sql(self, query):
def raw_sql(self, query: str | sa.sql.ClauseElement):
"""Execute a query and return the cursor used for execution.
::: {.callout-tip}
## Consider using [`.sql`](#ibis.backends.base.sql.BaseSQLBackend.sql) instead
If your query is a `SELECT` statement you can use the
[backend `.sql`](#ibis.backends.base.sql.BaseSQLBackend.sql) method to avoid
having to manually release the cursor returned from this method.
::: {.callout-warning}
## The cursor returned from this method must be **manually released**
You **do not** need to call `.close()` on the cursor when running DDL
or DML statements like `CREATE`, `INSERT` or `DROP`, only when using
`SELECT` statements.
To release a cursor, call the `close` method on the returned cursor
object.
You can close the cursor by explicitly calling its `close` method:
```python
cursor = con.raw_sql("SELECT ...")
cursor.close()
```
Or you can use a context manager:
```python
with con.raw_sql("SELECT ...") as cursor:
...
```
:::
:::
Parameters
----------
query
SQL query or SQLAlchemy expression to execute
Examples
--------
>>> con = ibis.connect("duckdb://")
>>> with con.raw_sql("SELECT 1") as cursor:
... result = cursor.fetchall()
>>> result
[(1,)]
>>> cursor.closed
True
"""
return self.con.connect().execute(
sa.text(query) if isinstance(query, str) else query
)
Expand Down Expand Up @@ -939,6 +990,13 @@ def _get_table_identifier(self, *, name, namespace):
db=schema,
catalog=db,
quoted=self.compiler.translator_class._quote_table_names,
).transform(
lambda node: node.__class__(
this=node.this,
quoted=node.quoted or self.compiler.translator_class._quote_table_names,
)
if isinstance(node, sg.exp.Identifier)
else node
)
return table

Expand All @@ -963,7 +1021,9 @@ def _get_sqla_table(
def drop_table(
self, name: str, database: str | None = None, force: bool = False
) -> None:
table = sg.table(name, db=database)
table = sg.table(
name, db=database, quoted=self.compiler.translator_class._quote_table_names
)
drop_table = sg.exp.Drop(kind="TABLE", exists=force, this=table)
drop_table_sql = drop_table.sql(dialect=self.name)
with self.begin() as con:
Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/base/sql/alchemy/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ def _format_table(self, op):
)
dialect = translator._dialect_name
result.fullname = sg.table(
name, db=namespace.schema, catalog=namespace.database
name,
db=namespace.schema,
catalog=namespace.database,
quoted=translator._quote_table_names,
).sql(dialect=_SQLALCHEMY_TO_SQLGLOT_DIALECT.get(dialect, dialect))
elif isinstance(op, ops.SQLQueryResult):
columns = translator._schema_to_sqlalchemy_columns(op.schema)
Expand Down
110 changes: 91 additions & 19 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.elements import RANGE_CURRENT, RANGE_UNBOUNDED
from sqlalchemy.sql.functions import FunctionElement, GenericFunction

import ibis.common.exceptions as com
Expand Down Expand Up @@ -308,17 +309,84 @@ def _endswith(t, op):
return t.translate(op.arg).endswith(t.translate(op.end))


def _translate_window_boundary(boundary):
if boundary is None:
return None
def _reinterpret_range_bound(bound):
if bound is None:
return RANGE_UNBOUNDED

if isinstance(boundary.value, ops.Literal):
if boundary.preceding:
return -boundary.value.value
else:
return boundary.value.value
try:
lower = int(bound)
except ValueError as err:
sa.util.raise_(
sa.exc.ArgumentError(
"Integer, None or expression expected for range value"
),
replace_context=err,
)
except TypeError:
return bound
else:
return RANGE_CURRENT if lower == 0 else lower


def _interpret_range(self, range_):
if not isinstance(range_, tuple) or len(range_) != 2:
raise sa.exc.ArgumentError("2-tuple expected for range/rows")

lower = _reinterpret_range_bound(range_[0])
upper = _reinterpret_range_bound(range_[1])
return lower, upper


# monkeypatch to allow expressions in range and rows bounds
sa.sql.elements.Over._interpret_range = _interpret_range


def _compile_bounds(processor, left, right) -> str:
if left is RANGE_UNBOUNDED:
left = "UNBOUNDED PRECEDING"
elif left is RANGE_CURRENT:
left = "CURRENT ROW"
else:
left = f"{processor(left)} PRECEDING"

if right is RANGE_UNBOUNDED:
right = "UNBOUNDED FOLLOWING"
elif right is RANGE_CURRENT:
right = "CURRENT ROW"
else:
right = f"{processor(right)} FOLLOWING"

return f"BETWEEN {left} AND {right}"

raise com.TranslationError("Window boundaries must be literal values")

@compiles(sa.sql.elements.Over)
def compile_over(over, compiler, **kw) -> str:
processor = functools.partial(compiler.process, **kw)

text = processor(over.element)

if over.range_:
bounds = _compile_bounds(processor, *over.range_)
range_ = f"RANGE {bounds}"
elif over.rows:
bounds = _compile_bounds(processor, *over.rows)
range_ = f"ROWS {bounds}"
else:
range_ = None

args = [
f"{word} BY {processor(clause)}"
for word, clause in (
("PARTITION", over.partition_by),
("ORDER", over.order_by),
)
if clause is not None and len(clause)
]

if range_ is not None:
args.append(range_)

return f"{text} OVER ({' '.join(args)})"


def _window_function(t, window):
Expand Down Expand Up @@ -347,22 +415,26 @@ def _window_function(t, window):
else:
raise NotImplementedError(type(window.frame))

if t._forbids_frame_clause and isinstance(func, t._forbids_frame_clause):
# some functions on some backends don't support frame clauses
additional_params = {}
else:
start = _translate_window_boundary(window.frame.start)
end = _translate_window_boundary(window.frame.end)
additional_params = {how: (start, end)}
additional_params = {}

# some functions on some backends don't support frame clauses
if not t._forbids_frame_clause or not isinstance(func, t._forbids_frame_clause):
if (start := window.frame.start) is not None:
start = t.translate(start.value)

if (end := window.frame.end) is not None:
end = t.translate(end.value)

additional_params[how] = (start, end)

result = sa.over(
reduction, partition_by=partition_by, order_by=order_by, **additional_params
)

if isinstance(func, (ops.RowNumber, ops.DenseRank, ops.MinRank, ops.NTile)):
return result - 1
else:
return result
result -= 1

return result


def _lag(t, op):
Expand Down
16 changes: 0 additions & 16 deletions ibis/backends/base/sql/registry/identifiers.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,5 @@
from __future__ import annotations

# Copyright 2014 Cloudera Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base identifiers

base_identifiers = [
"add",
"aggregate",
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sqlglot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def exists(self, query):
return sg.exp.Exists(this=query)

def concat(self, *args):
return sg.exp.Concat.from_arg_list(list(map(sg.exp.convert, args)))
return sg.exp.Concat(expressions=list(map(sg.exp.convert, args)))

def map(self, keys, values):
return sg.exp.Map(keys=keys, values=values)
Expand Down
8 changes: 5 additions & 3 deletions ibis/backends/base/sqlglot/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,21 +257,23 @@ def _from_sqlglot_DECIMAL(
@classmethod
def _from_ibis_Array(cls, dtype: dt.Array) -> sge.DataType:
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.ARRAY, expressions=[value_type])
return sge.DataType(this=typecode.ARRAY, expressions=[value_type], nested=True)

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
key_type = cls.from_ibis(dtype.key_type)
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(this=typecode.MAP, expressions=[key_type, value_type])
return sge.DataType(
this=typecode.MAP, expressions=[key_type, value_type], nested=True
)

@classmethod
def _from_ibis_Struct(cls, dtype: dt.Struct) -> sge.DataType:
fields = [
sge.ColumnDef(this=str(name), kind=cls.from_ibis(field))
for name, field in dtype.items()
]
return sge.DataType(this=typecode.STRUCT, expressions=fields)
return sge.DataType(this=typecode.STRUCT, expressions=fields, nested=True)

@classmethod
def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType:
Expand Down
27 changes: 17 additions & 10 deletions ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import re
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional
from urllib.parse import parse_qs, urlparse

import google.auth.credentials
Expand All @@ -36,7 +36,6 @@
)
from ibis.backends.bigquery.compiler import BigQueryCompiler
from ibis.backends.bigquery.datatypes import BigQuerySchema, BigQueryType
from ibis.formats.pandas import PandasData

with contextlib.suppress(ImportError):
from ibis.backends.bigquery.udf import udf # noqa: F401
Expand Down Expand Up @@ -95,7 +94,7 @@ def _anonymous_unnest_to_explode(node: sg.exp.Expression):
return node


_MEMTABLE_PATTERN = re.compile(r"^_ibis_(?:pandas|pyarrow)_memtable_[a-z0-9]{26}$")
_MEMTABLE_PATTERN = re.compile(r"^_?ibis_(?:pandas|pyarrow)_memtable_[a-z0-9]{26}$")


def _qualify_memtable(
Expand Down Expand Up @@ -559,9 +558,12 @@ def table(
return rename_partitioned_column(table_expr, bq_table, self.partition_column)

def _make_session(self) -> tuple[str, str]:
if self._session_dataset is None:
if (
self._session_dataset is None
and (client := getattr(self, "client", None)) is not None
):
job_config = bq.QueryJobConfig(use_query_cache=False)
query = self.client.query(
query = client.query(
"SELECT 1", job_config=job_config, project=self.billing_project
)
query.result()
Expand Down Expand Up @@ -616,6 +618,7 @@ def compile(
The output of compilation. The type of this value depends on the
backend.
"""
self._make_session()
self._define_udf_translation_rules(expr)
sql = self.compiler.to_ast_ensure_limit(expr, limit, params=params).compile()

Expand Down Expand Up @@ -709,6 +712,8 @@ def execute(self, expr, params=None, limit="default", **kwargs):
return expr.__pandas_result__(result)

def fetch_from_cursor(self, cursor, schema):
from ibis.formats.pandas import PandasData

arrow_t = self._cursor_to_arrow(cursor)
df = arrow_t.to_pandas(timestamp_as_object=True)
return PandasData.convert_table(df, schema)
Expand Down Expand Up @@ -780,6 +785,12 @@ def to_pyarrow_batches(
)
return pa.RecordBatchReader.from_batches(schema.to_pyarrow(), batch_iter)

def _gen_udf_name(self, name: str, schema: Optional[str]) -> str:
func = ".".join(filter(None, (schema, name)))
if "." in func:
return ".".join(f"`{part}`" for part in func.split("."))
return func

def get_schema(self, name, schema: str | None = None, database: str | None = None):
table_ref = bq.TableReference(
bq.DatasetReference(
Expand Down Expand Up @@ -988,11 +999,7 @@ def create_table(
column_defs = [
sg.exp.ColumnDef(
this=name,
kind=sg.parse_one(
BigQueryType.from_ibis(typ),
into=sg.exp.DataType,
read=self.name,
),
kind=BigQueryType.from_ibis(typ),
constraints=(
None
if typ.nullable or typ.is_array()
Expand Down
11 changes: 5 additions & 6 deletions ibis/backends/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,21 @@ def bq_param_array(dtype: dt.Array, value, name):
value_type = dtype.value_type

try:
bigquery_type = BigQueryType.from_ibis(value_type)
bigquery_type = BigQueryType.to_string(value_type)
except NotImplementedError:
raise com.UnsupportedBackendType(dtype)
else:
if isinstance(value_type, dt.Struct):
if isinstance(value_type, dt.Array):
raise TypeError("ARRAY<ARRAY<T>> is not supported in BigQuery")
elif isinstance(value_type, dt.Struct):
query_value = [
bigquery_param(dtype.value_type, struct, f"element_{i:d}")
for i, struct in enumerate(value)
]
bigquery_type = "STRUCT"
elif isinstance(value_type, dt.Array):
raise TypeError("ARRAY<ARRAY<T>> is not supported in BigQuery")
else:
query_value = value
result = bq.ArrayQueryParameter(name, bigquery_type, query_value)
return result
return bq.ArrayQueryParameter(name, bigquery_type, query_value)


@bigquery_param.register
Expand Down
185 changes: 104 additions & 81 deletions ibis/backends/bigquery/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,124 @@
from __future__ import annotations

import google.cloud.bigquery as bq
import sqlglot as sg
import sqlglot.expressions as sge

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats import SchemaMapper, TypeMapper

_from_bigquery_types = {
"INT64": dt.Int64,
"INTEGER": dt.Int64,
"FLOAT": dt.Float64,
"FLOAT64": dt.Float64,
"BOOL": dt.Boolean,
"BOOLEAN": dt.Boolean,
"STRING": dt.String,
"DATE": dt.Date,
"TIME": dt.Time,
"BYTES": dt.Binary,
"JSON": dt.JSON,
}


class BigQueryType(TypeMapper):
@classmethod
def to_ibis(cls, typ: str, nullable: bool = True) -> dt.DataType:
if typ == "DATETIME":
return dt.Timestamp(timezone=None, nullable=nullable)
elif typ == "TIMESTAMP":
return dt.Timestamp(timezone="UTC", nullable=nullable)
elif typ == "NUMERIC":
return dt.Decimal(38, 9, nullable=nullable)
elif typ == "BIGNUMERIC":
return dt.Decimal(76, 38, nullable=nullable)
elif typ == "GEOGRAPHY":
return dt.GeoSpatial(geotype="geography", srid=4326, nullable=nullable)
else:
try:
return _from_bigquery_types[typ](nullable=nullable)
except KeyError:
raise TypeError(f"Unable to convert BigQuery type to ibis: {typ}")
from ibis.backends.base.sqlglot.datatypes import SqlglotType
from ibis.formats import SchemaMapper


class BigQueryType(SqlglotType):
dialect = "bigquery"

default_decimal_precision = 38
default_decimal_scale = 9

@classmethod
def _from_sqlglot_NUMERIC(cls) -> dt.Decimal:
return dt.Decimal(
cls.default_decimal_precision,
cls.default_decimal_scale,
nullable=cls.default_nullable,
)

@classmethod
def _from_sqlglot_BIGNUMERIC(cls) -> dt.Decimal:
return dt.Decimal(76, 38, nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_DATETIME(cls) -> dt.Decimal:
return dt.Timestamp(timezone=None, nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_TIMESTAMP(cls) -> dt.Decimal:
return dt.Timestamp(timezone="UTC", nullable=cls.default_nullable)

@classmethod
def _from_sqlglot_GEOGRAPHY(cls) -> dt.Decimal:
return dt.GeoSpatial(
geotype="geography", srid=4326, nullable=cls.default_nullable
)

@classmethod
def _from_sqlglot_TINYINT(cls) -> dt.Int64:
return dt.Int64(nullable=cls.default_nullable)

_from_sqlglot_UINT = (
_from_sqlglot_USMALLINT
) = (
_from_sqlglot_UTINYINT
) = _from_sqlglot_INT = _from_sqlglot_SMALLINT = _from_sqlglot_TINYINT

@classmethod
def _from_sqlglot_UBIGINT(cls) -> dt.Int64:
raise TypeError("Unsigned BIGINT isn't representable in BigQuery INT64")

@classmethod
def _from_sqlglot_FLOAT(cls) -> dt.Double:
return dt.Float64(nullable=cls.default_nullable)

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> str:
if dtype.is_floating():
return "FLOAT64"
elif dtype.is_uint64():
def _from_sqlglot_MAP(cls) -> dt.Map:
raise NotImplementedError(
"Cannot convert sqlglot Map type to ibis type: maps are not supported in BigQuery"
)

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
raise NotImplementedError(
"Cannot convert Ibis Map type to BigQuery type: maps are not supported in BigQuery"
)

@classmethod
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
if dtype.timezone is None:
return sge.DataType(this=sge.DataType.Type.DATETIME)
elif dtype.timezone == "UTC":
return sge.DataType(this=sge.DataType.Type.TIMESTAMPTZ)
else:
raise TypeError(
"Conversion from uint64 to BigQuery integer type (int64) is lossy"
"BigQuery does not support timestamps with timezones other than 'UTC'"
)
elif dtype.is_integer():
return "INT64"
elif dtype.is_binary():
return "BYTES"
elif dtype.is_date():
return "DATE"
elif dtype.is_timestamp():
if dtype.timezone is None:
return "DATETIME"
elif dtype.timezone == "UTC":
return "TIMESTAMP"
else:
raise TypeError(
"BigQuery does not support timestamps with timezones other than 'UTC'"
)
elif dtype.is_decimal():
if (dtype.precision, dtype.scale) == (76, 38):
return "BIGNUMERIC"
if (dtype.precision, dtype.scale) in [(38, 9), (None, None)]:
return "NUMERIC"

@classmethod
def _from_ibis_Decimal(cls, dtype: dt.Decimal) -> sge.DataType:
precision = dtype.precision
scale = dtype.scale
if (precision, scale) == (76, 38):
return sge.DataType(this=sge.DataType.Type.BIGDECIMAL)
elif (precision, scale) in ((38, 9), (None, None)):
return sge.DataType(this=sge.DataType.Type.DECIMAL)
else:
raise TypeError(
"BigQuery only supports decimal types with precision of 38 and "
f"scale of 9 (NUMERIC) or precision of 76 and scale of 38 (BIGNUMERIC). "
f"Current precision: {dtype.precision}. Current scale: {dtype.scale}"
)
elif dtype.is_array():
return f"ARRAY<{cls.from_ibis(dtype.value_type)}>"
elif dtype.is_struct():
fields = (
f"{sg.to_identifier(k).sql('bigquery')} {cls.from_ibis(v)}"
for k, v in dtype.fields.items()
)
return "STRUCT<{}>".format(", ".join(fields))
elif dtype.is_json():
return "JSON"
elif dtype.is_geospatial():
if (dtype.geotype, dtype.srid) == ("geography", 4326):
return "GEOGRAPHY"

@classmethod
def _from_ibis_UInt64(cls, dtype: dt.UInt64) -> sge.DataType:
raise TypeError(
f"Conversion from {dtype} to BigQuery integer type (Int64) is lossy"
)

@classmethod
def _from_ibis_UInt32(cls, dtype: dt.UInt32) -> sge.DataType:
return sge.DataType(this=sge.DataType.Type.BIGINT)

_from_ibis_UInt8 = _from_ibis_UInt16 = _from_ibis_UInt32

@classmethod
def _from_ibis_GeoSpatial(cls, dtype: dt.GeoSpatial) -> sge.DataType:
if (dtype.geotype, dtype.srid) == ("geography", 4326):
return sge.DataType(this=sge.DataType.Type.GEOGRAPHY)
else:
raise TypeError(
"BigQuery geography uses points on WGS84 reference ellipsoid."
f"Current geotype: {dtype.geotype}, Current srid: {dtype.srid}"
)
elif dtype.is_map():
raise NotImplementedError("Maps are not supported in BigQuery")
else:
return str(dtype).upper()


class BigQuerySchema(SchemaMapper):
Expand All @@ -112,7 +135,7 @@ def from_ibis(cls, schema: sch.Schema) -> list[bq.SchemaField]:
is_struct = value_type.is_struct()

field_type = (
"RECORD" if is_struct else BigQueryType.from_ibis(typ.value_type)
"RECORD" if is_struct else BigQueryType.to_string(typ.value_type)
)
mode = "REPEATED"
fields = cls.from_ibis(ibis.schema(getattr(value_type, "fields", {})))
Expand All @@ -121,7 +144,7 @@ def from_ibis(cls, schema: sch.Schema) -> list[bq.SchemaField]:
mode = "NULLABLE" if typ.nullable else "REQUIRED"
fields = cls.from_ibis(ibis.schema(typ.fields))
else:
field_type = BigQueryType.from_ibis(typ)
field_type = BigQueryType.to_string(typ)
mode = "NULLABLE" if typ.nullable else "REQUIRED"
fields = ()

Expand All @@ -138,7 +161,7 @@ def _dtype_from_bigquery_field(cls, field: bq.SchemaField) -> dt.DataType:
fields = {f.name: cls._dtype_from_bigquery_field(f) for f in field.fields}
dtype = dt.Struct(fields)
else:
dtype = BigQueryType.to_ibis(typ)
dtype = BigQueryType.from_string(typ)

mode = field.mode
if mode == "NULLABLE":
Expand Down
70 changes: 65 additions & 5 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def bigquery_cast_floating_to_integer(compiled_arg, from_, to):
@bigquery_cast.register(str, dt.DataType, dt.DataType)
def bigquery_cast_generate(compiled_arg, from_, to):
"""Cast to desired type."""
sql_type = BigQueryType.from_ibis(to)
sql_type = BigQueryType.to_string(to)
return f"CAST({compiled_arg} AS {sql_type})"


Expand Down Expand Up @@ -337,7 +337,7 @@ def _literal(t, op):

if value is None:
if not dtype.is_null():
return f"CAST(NULL AS {BigQueryType.from_ibis(dtype)})"
return f"CAST(NULL AS {BigQueryType.to_string(dtype)})"
return "NULL"
elif dtype.is_boolean():
return str(value).upper()
Expand All @@ -350,7 +350,7 @@ def _literal(t, op):
prefix = "-" * value.is_signed()
return f"CAST('{prefix}inf' AS FLOAT64)"
else:
return f"{BigQueryType.from_ibis(dtype)} '{value}'"
return f"{BigQueryType.to_string(dtype)} '{value}'"
elif dtype.is_uuid():
return _sg_literal(str(value))
elif dtype.is_numeric():
Expand Down Expand Up @@ -564,7 +564,7 @@ def compiles_string_to_timestamp(translator, op):


def compiles_floor(t, op):
bigquery_type = BigQueryType.from_ibis(op.dtype)
bigquery_type = BigQueryType.to_string(op.dtype)
arg = op.arg
return f"CAST(FLOOR({t.translate(arg)}) AS {bigquery_type})"

Expand Down Expand Up @@ -776,6 +776,64 @@ def _group_concat(translator, op):
return f"STRING_AGG({arg}, {sep})"


def _zero(dtype):
if dtype.is_interval():
return "MAKE_INTERVAL()"
return "0"


def _sign(value, dtype):
if dtype.is_interval():
zero = _zero(dtype)
return f"""\
CASE
WHEN {value} < {zero} THEN -1
WHEN {value} = {zero} THEN 0
WHEN {value} > {zero} THEN 1
ELSE NULL
END"""
return f"SIGN({value})"


def _nullifzero(step, zero, step_dtype):
if step_dtype.is_interval():
return f"IF({step} = {zero}, NULL, {step})"
return f"NULLIF({step}, {zero})"


def _make_range(func):
def _range(translator, op):
start = translator.translate(op.start)
stop = translator.translate(op.stop)
step = translator.translate(op.step)

step_dtype = op.step.dtype
step_sign = _sign(step, step_dtype)
delta_sign = _sign(step, step_dtype)
zero = _zero(step_dtype)
nullifzero = _nullifzero(step, zero, step_dtype)

condition = f"{nullifzero} IS NOT NULL AND {step_sign} = {delta_sign}"
gen_array = f"{func}({start}, {stop}, {step})"
inner = f"SELECT x FROM UNNEST({gen_array}) x WHERE x <> {stop}"
return f"IF({condition}, ARRAY({inner}), [])"

return _range


def _timestamp_range(translator, op):
start = op.start
stop = op.stop

if start.dtype.timezone is None or stop.dtype.timezone is None:
raise com.IbisTypeError(
"Timestamps without timezone values are not supported when generating timestamp ranges"
)

rule = _make_range("GENERATE_TIMESTAMP_ARRAY")
return rule(translator, op)


OPERATION_REGISTRY = {
**operation_registry,
# Literal
Expand Down Expand Up @@ -934,11 +992,13 @@ def _group_concat(translator, op):
ops.EndsWith: fixed_arity("ENDS_WITH", 2),
ops.TableColumn: table_column,
ops.CountDistinctStar: _count_distinct_star,
ops.Argument: lambda _, op: op.name,
ops.Argument: lambda _, op: op.param,
ops.Unnest: unary("UNNEST"),
ops.TimeDelta: _time_delta,
ops.DateDelta: _date_delta,
ops.TimestampDelta: _timestamp_delta,
ops.IntegerRange: _make_range("GENERATE_ARRAY"),
ops.TimestampRange: _timestamp_range,
}

_invalid_operations = {
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/bigquery/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ibis.backends.bigquery import EXTERNAL_DATA_SCOPES, Backend
from ibis.backends.bigquery.datatypes import BigQuerySchema
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero, UnorderedComparator
from ibis.backends.tests.base import BackendTest
from ibis.backends.tests.data import json_types, non_null_array_types, struct_types, win

DATASET_ID = "ibis_gbq_testing"
Expand All @@ -25,18 +25,16 @@
PROJECT_ID_ENV_VAR = "GOOGLE_BIGQUERY_PROJECT_ID"


class TestConf(UnorderedComparator, BackendTest, RoundAwayFromZero):
class TestConf(BackendTest):
"""Backend-specific class with information for testing."""

# These were moved from TestConf for use in common test suite.
# TODO: Indicate RoundAwayFromZero and UnorderedComparator.
# https://github.com/ibis-project/ibis-bigquery/issues/30
supports_divide_by_zero = True
supports_floating_modulus = False
returned_timestamp_unit = "us"
supports_structs = True
supports_json = True
check_names = False
force_sort = True
deps = ("google.cloud.bigquery",)

@staticmethod
Expand Down Expand Up @@ -206,7 +204,7 @@ def _load_data(self, **_: Any) -> None:
job_config=bq.LoadJobConfig(
write_disposition=write_disposition,
schema=BigQuerySchema.from_ibis(
ibis.schema(dict(g="string", x="int64", y="int64"))
ibis.schema(dict(g="string", x="!int64", y="int64"))
),
),
)
Expand Down
16 changes: 16 additions & 0 deletions ibis/backends/bigquery/tests/system/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,22 @@ def test_fully_qualified_table_creation(con, project_id, dataset_id, temp_table)
assert t.get_name() == f"{project_id}.{dataset_id}.{temp_table}"


def test_fully_qualified_memtable_compile(project_id, dataset_id):
new_bq_con = ibis.bigquery.connect(project_id=project_id, dataset_id=dataset_id)
# New connection shouldn't have _session_dataset populated after connection
assert new_bq_con._session_dataset is None

t = ibis.memtable(
{"a": [1, 2, 3], "b": [4, 5, 6]},
schema=ibis.schema({"a": "int64", "b": "int64"}),
)

# call to compile should fill in _session_dataset
sql = new_bq_con.compile(t)
assert new_bq_con._session_dataset is not None
assert project_id in sql


def test_create_table_with_options(con):
name = gen_name("bigquery_temp_table")
schema = ibis.schema(dict(a="int64", b="int64", c="array<string>", d="date"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
SELECT
t0.`rowindex`,
IF(pos = pos_2, repeated_struct_col, NULL) AS repeated_struct_col
IF(pos = pos_2, `repeated_struct_col`, NULL) AS `repeated_struct_col`
FROM array_test AS t0, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(t0.`repeated_struct_col`)) - 1)) AS pos
CROSS JOIN UNNEST(t0.`repeated_struct_col`) AS repeated_struct_col WITH OFFSET AS pos_2
CROSS JOIN UNNEST(t0.`repeated_struct_col`) AS `repeated_struct_col` WITH OFFSET AS pos_2
WHERE
pos = pos_2
OR (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
SELECT
IF(pos = pos_2, level_two, NULL) AS level_two
IF(pos = pos_2, `level_two`, NULL) AS `level_two`
FROM (
SELECT
t1.`rowindex`,
IF(pos = pos_2, level_one, NULL).`nested_struct_col` AS level_one
IF(pos = pos_2, `level_one`, NULL).`nested_struct_col` AS `level_one`
FROM array_test AS t1, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(t1.`repeated_struct_col`)) - 1)) AS pos
CROSS JOIN UNNEST(t1.`repeated_struct_col`) AS level_one WITH OFFSET AS pos_2
CROSS JOIN UNNEST(t1.`repeated_struct_col`) AS `level_one` WITH OFFSET AS pos_2
WHERE
pos = pos_2
OR (
Expand All @@ -17,7 +17,7 @@ FROM (
)
)
) AS t0, UNNEST(GENERATE_ARRAY(0, GREATEST(ARRAY_LENGTH(t0.`level_one`)) - 1)) AS pos
CROSS JOIN UNNEST(t0.`level_one`) AS level_two WITH OFFSET AS pos_2
CROSS JOIN UNNEST(t0.`level_one`) AS `level_two` WITH OFFSET AS pos_2
WHERE
pos = pos_2
OR (
Expand Down
15 changes: 13 additions & 2 deletions ibis/backends/bigquery/tests/unit/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pytest
import sqlglot as sg
from pytest import param

import ibis.expr.datatypes as dt
Expand Down Expand Up @@ -69,13 +70,13 @@
],
)
def test_simple(datatype, expected):
assert BigQueryType.from_ibis(datatype) == expected
assert BigQueryType.to_string(datatype) == expected


@pytest.mark.parametrize("datatype", [dt.uint64, dt.Decimal(8, 3)])
def test_simple_failure_mode(datatype):
with pytest.raises(TypeError):
BigQueryType.from_ibis(datatype)
BigQueryType.to_string(datatype)


@pytest.mark.parametrize(
Expand All @@ -101,3 +102,13 @@ def test_simple_failure_mode(datatype):
)
def test_spread_type(type_, expected):
assert list(spread_type(type_)) == expected


def test_struct_type():
dtype = dt.Array(dt.int64)
parsed_type = sg.parse_one("BIGINT[]", into=sg.exp.DataType, read="duckdb")

expected = "ARRAY<INT64>"

assert parsed_type.sql(dialect="bigquery") == expected
assert BigQueryType.to_string(dtype) == expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT
`bqutil`.`fn`.from_hex('face') AS `from_hex_'face'`
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT
farm_fingerprint(b'Hello, World!') AS `farm_fingerprint_b'Hello_ World_'`
31 changes: 31 additions & 0 deletions ibis/backends/bigquery/tests/unit/udf/test_builtin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import ibis

to_sql = ibis.bigquery.compile


@ibis.udf.scalar.builtin
def farm_fingerprint(value: bytes) -> int:
...


@ibis.udf.scalar.builtin(schema="bqutil.fn")
def from_hex(value: str) -> int:
"""Community function to convert from hex string to integer.
See:
https://github.com/GoogleCloudPlatform/bigquery-utils/tree/master/udfs/community#from_hexvalue-string
"""


def test_bqutil_fn_from_hex(snapshot):
# Project ID should be enclosed in backticks.
expr = from_hex("face")
snapshot.assert_match(to_sql(expr), "out.sql")


def test_farm_fingerprint(snapshot):
# No backticks needed if there is no schema defined.
expr = farm_fingerprint(b"Hello, World!")
snapshot.assert_match(to_sql(expr), "out.sql")
17 changes: 6 additions & 11 deletions ibis/backends/bigquery/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def python(
>>> @udf.python(input_type=[dt.double], output_type=dt.double)
... def add_one(x):
... return x + 1
...
>>> print(add_one.sql)
CREATE TEMPORARY FUNCTION add_one_0(x FLOAT64)
RETURNS FLOAT64
Expand All @@ -84,9 +83,7 @@ def python(
}
return add_one(x);
""";
>>> @udf.python(
... input_type=[dt.double, dt.double], output_type=dt.Array(dt.double)
... )
>>> @udf.python(input_type=[dt.double, dt.double], output_type=dt.Array(dt.double))
... def my_range(start, stop):
... def gen(start, stop):
... curr = start
Expand Down Expand Up @@ -121,9 +118,7 @@ def python(
""";
>>> @udf.python(
... input_type=[dt.double, dt.double],
... output_type=dt.Struct.from_tuples(
... [("width", "double"), ("height", "double")]
... ),
... output_type=dt.Struct.from_tuples([("width", "double"), ("height", "double")]),
... )
... def my_rectangle(width, height):
... class Rectangle:
Expand Down Expand Up @@ -261,10 +256,10 @@ def js(
libraries = []

bigquery_signature = ", ".join(
f"{name} {BigQueryType.from_ibis(dt.dtype(type_))}"
f"{name} {BigQueryType.to_string(dt.dtype(type_))}"
for name, type_ in params.items()
)
return_type = BigQueryType.from_ibis(dt.dtype(output_type))
return_type = BigQueryType.to_string(dt.dtype(output_type))
libraries_opts = (
f"\nOPTIONS (\n library={list(libraries)!r}\n)" if libraries else ""
)
Expand Down Expand Up @@ -361,14 +356,14 @@ def sql(
name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_)
for name, type_ in params.items()
}
return_type = BigQueryType.from_ibis(dt.dtype(output_type))
return_type = BigQueryType.to_string(dt.dtype(output_type))

bigquery_signature = ", ".join(
"{name} {type}".format(
name=name,
type="ANY TYPE"
if type_ == "ANY TYPE"
else BigQueryType.from_ibis(dt.dtype(type_)),
else BigQueryType.to_string(dt.dtype(type_)),
)
for name, type_ in params.items()
)
Expand Down
28 changes: 17 additions & 11 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import ast
import atexit
import glob
import json
from contextlib import closing, suppress
from functools import partial
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -123,6 +122,7 @@ def do_connect(
client_name: str = "ibis",
secure: bool | None = None,
compression: str | bool = True,
settings: Mapping[str, Any] | None = None,
**kwargs: Any,
):
"""Create a ClickHouse client for use with Ibis.
Expand All @@ -148,6 +148,8 @@ def do_connect(
The kind of compression to use for requests. See
https://clickhouse.com/docs/en/integrations/python#compression for
more information.
settings
ClickHouse session settings
kwargs
Client specific keyword arguments
Expand All @@ -158,6 +160,10 @@ def do_connect(
>>> client
<ibis.clickhouse.client.ClickhouseClient object at 0x...>
"""
if settings is None:
settings = {}
settings.setdefault("session_timezone", "UTC")

self.con = cc.get_client(
host=host,
# 8123 is the default http port 443 is https
Expand Down Expand Up @@ -516,16 +522,16 @@ def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema
return sch.Schema(dict(zip(names, map(ClickhouseType.from_string, types))))

def _get_schema_using_query(self, query: str) -> sch.Schema:
query = f"EXPLAIN json = 1, description = 0, header = 1 {query}"
with closing(self.raw_sql(query)) as results:
[[raw_plans]] = results.result_columns
[plan] = json.loads(raw_plans)
return sch.Schema(
{
field["Name"]: ClickhouseType.from_string(field["Type"])
for field in plan["Plan"]["Header"]
}
)
name = util.gen_name("get_schema_using_query")
with closing(self.raw_sql(f"CREATE VIEW {name} AS {query}")):
pass
try:
with closing(self.raw_sql(f"DESCRIBE {name}")) as results:
names, types, *_ = results.result_columns
finally:
with closing(self.raw_sql(f"DROP VIEW {name}")):
pass
return sch.Schema(dict(zip(names, map(ClickhouseType.from_string, types))))

@classmethod
def has_operation(cls, operation: type[ops.Value]) -> bool:
Expand Down
53 changes: 44 additions & 9 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,26 +346,32 @@ def _literal(op, *, value, dtype, **kw):

return interval(value, unit=dtype.resolution.upper())
elif dtype.is_timestamp():
funcname = "toDateTime"
fmt = "%Y-%m-%dT%H:%M:%S"

funcname = "makeDateTime"
if micros := value.microsecond:
funcname += "64"
fmt += ".%f"

args = [value.strftime(fmt)]
args = [
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
]

if micros % 1000:
args.append(micros)
args.append(6)
elif micros // 1000:
elif millis := micros // 1000:
args.append(millis)
args.append(3)

if (timezone := dtype.timezone) is not None:
args.append(timezone)

return F[funcname](*args)
elif dtype.is_date():
return F.toDate(value.strftime("%Y-%m-%d"))
return F.toDate(value.isoformat())
elif dtype.is_array():
value_type = dtype.value_type
values = [
Expand Down Expand Up @@ -816,6 +822,8 @@ def formatter(op, *, left, right, **_):
ops.ExtractPath: "path",
ops.ExtractFragment: "fragment",
ops.ArrayPosition: "indexOf",
ops.ArrayFlatten: "arrayFlatten",
ops.IntegerRange: "range",
}


Expand Down Expand Up @@ -956,8 +964,8 @@ def _array_string_join(op, *, arg, sep, **_):


@translate_val.register(ops.Argument)
def _argument(op, *, name, **_):
return sg.to_identifier(name)
def _argument(op, **_):
return sg.to_identifier(op.param)


@translate_val.register(ops.ArrayMap)
Expand Down Expand Up @@ -1015,3 +1023,30 @@ def _agg_udf(op, *, where, **kw) -> str:
@translate_val.register(ops.TimestampDelta)
def _delta(op, *, part, left, right, **_):
return sg.exp.DateDiff(this=left, expression=right, unit=part)


@translate_val.register(ops.TimestampRange)
def _timestamp_range(op, *, start, stop, step, **_):
unit = op.step.dtype.unit.name.lower()

if not isinstance(op.step, ops.Literal):
raise com.UnsupportedOperationError(
"ClickHouse doesn't support non-literal step values"
)

step_value = op.step.value

offset = sg.to_identifier("offset")

# e.g., offset -> dateAdd(DAY, offset, start)
func = sg.exp.Lambda(
this=F.dateAdd(sg.to_identifier(unit), offset, start), expressions=[offset]
)

if step_value == 0:
return F.array()

result = F.arrayMap(
func, F.range(0, F.timestampDiff(unit, start, stop), step_value)
)
return result
10 changes: 4 additions & 6 deletions ibis/backends/clickhouse/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
import ibis
import ibis.expr.types as ir
from ibis import util
from ibis.backends.tests.base import (
RoundHalfToEven,
ServiceBackendTest,
UnorderedComparator,
)
from ibis.backends.tests.base import ServiceBackendTest

if TYPE_CHECKING:
from collections.abc import Iterable
Expand All @@ -26,13 +22,15 @@
IBIS_TEST_CLICKHOUSE_DB = os.environ.get("IBIS_TEST_DATA_DB", "ibis_testing")


class TestConf(UnorderedComparator, ServiceBackendTest, RoundHalfToEven):
class TestConf(ServiceBackendTest):
check_dtype = False
supports_window_operations = False
returned_timestamp_unit = "s"
supported_to_timestamp_units = {"s"}
supports_floating_modulus = False
supports_json = False
force_sort = True
rounding_method = "half_to_even"
data_volume = "/var/lib/clickhouse/user_files/ibis"
service_name = "clickhouse"
deps = ("clickhouse_connect",)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
now()
now() AS "TimestampNow()"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toDate(toDateTime('2009-05-17T12:34:56'))
toDate(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toStartOfHour(toDateTime('2009-05-17T12:34:56'))
toStartOfHour(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toStartOfMinute(toDateTime('2009-05-17T12:34:56'))
toStartOfMinute(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toStartOfMinute(toDateTime('2009-05-17T12:34:56'))
toStartOfMinute(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toMonday(toDateTime('2009-05-17T12:34:56'))
toMonday(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toStartOfYear(toDateTime('2009-05-17T12:34:56'))
toStartOfYear(makeDateTime(2009, 5, 17, 12, 34, 56)) AS "TimestampTruncate(datetime.datetime(2009, 5, 17, 12, 34, 56))"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toDateTime64('2015-01-01T12:34:56.789321', 6)
makeDateTime64(2015, 1, 1, 12, 34, 56, 789321, 6) AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789321)"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toDateTime64('2015-01-01T12:34:56.789321', 6, 'UTC')
makeDateTime64(2015, 1, 1, 12, 34, 56, 789321, 6, 'UTC') AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789321, tzinfo=tzutc())"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toDateTime64('2015-01-01T12:34:56.789000', 3)
makeDateTime64(2015, 1, 1, 12, 34, 56, 789, 3) AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789000)"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toDateTime64('2015-01-01T12:34:56.789000', 3, 'UTC')
makeDateTime64(2015, 1, 1, 12, 34, 56, 789, 3, 'UTC') AS "datetime.datetime(2015, 1, 1, 12, 34, 56, 789000, tzinfo=tzutc())"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
FALSE
FALSE AS False
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
1.5
1.5 AS "1.5"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
5
5 AS "5"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
'I can''t'
'I can''t' AS """I can't"""
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
'An "escape"'
'An "escape"' AS "'An ""escape""'"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
'simple'
'simple' AS "'simple'"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
TRUE
TRUE AS True
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toDateTime('2015-01-01T12:34:56')
makeDateTime(2015, 1, 1, 12, 34, 56) AS "datetime.datetime(2015, 1, 1, 12, 34, 56)"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toDateTime('2015-01-01T12:34:56')
makeDateTime(2015, 1, 1, 12, 34, 56) AS "datetime.datetime(2015, 1, 1, 12, 34, 56)"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
toDateTime('2015-01-01T12:34:56')
makeDateTime(2015, 1, 1, 12, 34, 56) AS "datetime.datetime(2015, 1, 1, 12, 34, 56)"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
1 + 2
1 + 2 AS "Add(1, 2)"
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
now()
now() AS "TimestampNow()"
16 changes: 16 additions & 0 deletions ibis/backends/clickhouse/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from pytest import param

import ibis
import ibis.expr.datatypes as dt
import ibis.tests.strategies as its
from ibis.backends.clickhouse.datatypes import ClickhouseType
Expand Down Expand Up @@ -35,6 +36,21 @@ def test_columns_types_with_additional_argument(con):
assert df.datetime_ns_col.dtype.name == "datetime64[ns, UTC]"


def test_array_discovery_clickhouse(con):
t = con.tables.array_types
expected = ibis.schema(
dict(
x=dt.Array(dt.int64, nullable=False),
y=dt.Array(dt.string, nullable=False),
z=dt.Array(dt.float64, nullable=False),
grouper=dt.string,
scalar_column=dt.float64,
multi_dim=dt.Array(dt.Array(dt.int64, nullable=False), nullable=False),
)
)
assert t.schema() == expected


@pytest.mark.parametrize(
("ch_type", "ibis_type"),
[
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/clickhouse/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_typeof(con, value, expected):


@pytest.mark.parametrize(("value", "expected"), [("foo_bar", 7), ("", 0)])
def test_string_length(con, value, expected):
def test_tuple_string_length(con, value, expected):
assert con.execute(L(value).length()) == expected


Expand Down Expand Up @@ -208,7 +208,7 @@ def test_string_lower(con):
assert con.execute(L("FOO").lower()) == "foo"


def test_string_lenght(con):
def test_string_length(con):
assert con.execute(L("FOO").length()) == 3


Expand Down
1 change: 1 addition & 0 deletions ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def ddl_con(ddl_backend):
params=_get_backends_to_test(
keep=(
"duckdb",
"exasol",
"mssql",
"mysql",
"oracle",
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/aggcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def agg(
# (1) windowed.count() will exclude NaN observations
# , which results in incorrect window sizes.
# (2) windowed.apply(len, raw=True) will include NaN
# obversations, but doesn't work on non-numeric types.
# observations, but doesn't work on non-numeric types.
# https://github.com/pandas-dev/pandas/issues/23002
# To deal with this, we create a _placeholder column
windowed_frame = self.construct_window(grouped_frame)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/execution/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
execute_date_sub_diff_series_date,
execute_day_of_week_index_series,
execute_day_of_week_name_series,
execute_epoch_seconds,
execute_epoch_seconds_series,
execute_extract_microsecond_series,
execute_extract_millisecond_series,
execute_extract_timestamp_field_series,
Expand Down Expand Up @@ -61,7 +61,7 @@
ops.ExtractTemporalField: [((dd.Series,), execute_extract_timestamp_field_series)],
ops.ExtractMicrosecond: [((dd.Series,), execute_extract_microsecond_series)],
ops.ExtractMillisecond: [((dd.Series,), execute_extract_millisecond_series)],
ops.ExtractEpochSeconds: [((dd.Series,), execute_epoch_seconds)],
ops.ExtractEpochSeconds: [((dd.Series,), execute_epoch_seconds_series)],
ops.IntervalFromInteger: [((dd.Series,), execute_interval_from_integer_series)],
ops.IntervalAdd: [
(
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def test_invalid_connection_parameter_types(npartitions):
}
)

expeced_msg = re.escape(
expected_msg = re.escape(
"Expected an instance of 'dask.dataframe.DataFrame' for 'df', "
"got an instance of 'str' instead."
)
con = ibis.dask.connect()
with pytest.raises(TypeError, match=expeced_msg):
with pytest.raises(TypeError, match=expected_msg):
con.from_dataframe("file.csv")
49 changes: 36 additions & 13 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,27 +468,50 @@ def to_pyarrow_batches(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
chunk_size: int = 1_000_000,
**kwargs: Any,
) -> pa.ipc.RecordBatchReader:
pa = self._import_pyarrow()

self._register_udfs(expr)
self._register_in_memory_tables(expr)

frame = self.con.sql(self.compile(expr.as_table(), params, **kwargs))
return pa.ipc.RecordBatchReader.from_batches(frame.schema(), frame.collect())
table_expr = expr.as_table()
raw_sql = self.compile(table_expr, **kwargs)

frame = self.con.sql(raw_sql)

schema = table_expr.schema()
names = schema.names

struct_schema = schema.as_struct().to_pyarrow()

return pa.ipc.RecordBatchReader.from_batches(
schema.to_pyarrow(),
(
# convert the renamed + casted columns into a record batch
pa.RecordBatch.from_struct_array(
# rename columns to match schema because datafusion lowercases things
pa.RecordBatch.from_arrays(batch.columns, names=names)
# cast the struct array to the desired types to work around
# https://github.com/apache/arrow-datafusion-python/issues/534
.to_struct_array()
.cast(struct_schema)
)
for batch in frame.collect()
),
)

def execute(
self,
expr: ir.Expr,
params: Mapping[ir.Expr, object] | None = None,
limit: int | str | None = "default",
**kwargs: Any,
):
output = self.to_pyarrow(expr.as_table(), params=params, limit=limit, **kwargs)
return expr.__pandas_result__(output.to_pandas(timestamp_as_object=True))
def to_pyarrow(self, expr: ir.Expr, **kwargs: Any) -> pa.Table:
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
arrow_table = batch_reader.read_all()
return expr.__pyarrow_result__(arrow_table)

def execute(self, expr: ir.Expr, **kwargs: Any):
batch_reader = self.to_pyarrow_batches(expr, **kwargs)
return expr.__pandas_result__(
batch_reader.read_pandas(timestamp_as_object=True)
)

def _to_sqlglot(
self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any
Expand Down
32 changes: 28 additions & 4 deletions ibis/backends/datafusion/compiler/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,37 @@ def _limit(op: ops.Limit, *, table, n, offset, **_):
def _aggregation(
op: ops.Aggregation, *, table, metrics, by, having, predicates, sort_keys, **_
):
selections = (by + metrics) or (STAR,)
if by:
# datafusion doesn't support count distinct aggregations alongside
# computed grouping keys so create a projection of the key and all
# existing columns first, followed by the usual group by
#
# analogous to a user calling mutate -> group_by
by_names = frozenset(b.alias_or_name for b in by)
cols = [
sg.column(
name,
table=sg.to_identifier(table.alias_or_name, quoted=True),
quoted=True,
)
for name in op.table.schema.keys() - by_names
]
table = sg.select(*cols, *by).from_(table).subquery()

# datafusion lower cases all column names internally unless quoted so
# quoted=True is required here for correctness
by_names_quoted = tuple(
sg.column(b.alias_or_name, table=getattr(b, "table", None), quoted=True)
for b in by
)
selections = by_names_quoted + metrics
else:
selections = metrics or (STAR,)

sel = sg.select(*selections).from_(table)

if by:
sel = sel.group_by(
*(key.this if isinstance(key, sg.exp.Alias) else key for key in by)
)
sel = sel.group_by(*by_names_quoted)

if predicates:
sel = sel.where(*predicates)
Expand Down
12 changes: 11 additions & 1 deletion ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sqlglot import (
FALSE,
NULL,
AggGen,
F,
Expand Down Expand Up @@ -441,7 +442,16 @@ def string_find(op, *, arg, substr, start, end, **_):

@translate_val.register(ops.RegexSearch)
def regex_search(op, *, arg, pattern, **_):
return F.array_length(F.regexp_match(arg, pattern)) > 0
return if_(
sg.or_(arg.is_(NULL), pattern.is_(NULL)),
NULL,
F.coalesce(
# null is returned for non-matching patterns, so coalesce to false
# because that is the desired behavior for ops.RegexSearch
F.array_length(F.regexp_match(arg, pattern)) > 0,
FALSE,
),
)


@translate_val.register(ops.StringContains)
Expand Down
7 changes: 4 additions & 3 deletions ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
from ibis.backends.tests.data import array_types
from ibis.backends.tests.base import BackendTest
from ibis.backends.tests.data import array_types, win


class TestConf(BackendTest, RoundAwayFromZero):
class TestConf(BackendTest):
# check_names = False
# supports_divide_by_zero = True
# returned_timestamp_unit = 'ns'
Expand All @@ -26,6 +26,7 @@ def _load_data(self, **_: Any) -> None:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
con.register(path, table_name=table_name)
con.register(array_types, table_name="array_types")
con.register(win, table_name="win")

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/druid/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from requests import Session

import ibis
from ibis.backends.tests.base import RoundHalfToEven, ServiceBackendTest
from ibis.backends.tests.base import ServiceBackendTest

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -89,7 +89,7 @@ def run_query(session: Session, query: str) -> None:
time.sleep(REQUEST_INTERVAL)


class TestConf(ServiceBackendTest, RoundHalfToEven):
class TestConf(ServiceBackendTest):
# druid has the same rounding behavior as postgres
check_dtype = False
supports_window_operations = False
Expand All @@ -99,6 +99,7 @@ class TestConf(ServiceBackendTest, RoundHalfToEven):
native_bool = True
supports_structs = False
supports_json = False # it does, but we haven't implemented it
rounding_method = "half_to_even"
service_name = "druid-middlemanager"
deps = ("pydruid.db.sqlalchemy",)

Expand Down
129 changes: 94 additions & 35 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from ibis import util
from ibis.backends.base import CanCreateSchema
from ibis.backends.base.sql.alchemy import AlchemyCrossSchemaBackend
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported
from ibis.backends.base.sqlglot import C, F
from ibis.backends.duckdb.compiler import DuckDBSQLCompiler
from ibis.backends.duckdb.datatypes import DuckDBType
from ibis.expr.operations.relations import PandasDataFrameProxy
from ibis.expr.operations.udf import InputType
from ibis.formats.pandas import PandasData

Expand Down Expand Up @@ -494,6 +494,54 @@ def read_csv(
con.exec_driver_sql(view)
return self.table(table_name)

def read_geo(
self,
source: str,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a GEO file as a table in the current database.

Parameters
----------
source
The data source(s). Path to a file of geospatial files supported
by duckdb.
See https://duckdb.org/docs/extensions/spatial.html#st_read---read-spatial-data-from-files
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
**kwargs
Additional keyword arguments passed to DuckDB loading function.
See https://duckdb.org/docs/extensions/spatial.html#st_read---read-spatial-data-from-files
for more information.

Returns
-------
ir.Table
The just-registered table
"""

if not table_name:
table_name = util.gen_name("read_geo")

# load geospatial extension
self.load_extension("spatial")

source = util.normalize_filename(source)

if source.startswith(("http://", "https://", "s3://")):
self._load_extensions(["httpfs"])

source_expr = sa.select(sa.literal_column("*")).select_from(
sa.func.st_read(source, _format_kwargs(kwargs))
)

view = self._compile_temp_view(table_name, source_expr)
with self.begin() as con:
con.exec_driver_sql(view)
return self.table(table_name)

def read_parquet(
self,
source_list: str | Iterable[str],
Expand Down Expand Up @@ -682,10 +730,7 @@ def list_tables(
>>> con.list_tables(schema="my_schema")
[]
>>> with con.begin() as c:
... c.exec_driver_sql(
... "CREATE TABLE my_schema.baz (a INTEGER)"
... ) # doctest: +ELLIPSIS
...
... c.exec_driver_sql("CREATE TABLE my_schema.baz (a INTEGER)") # doctest: +ELLIPSIS
<...>
>>> con.list_tables(schema="my_schema")
['baz']
Expand Down Expand Up @@ -772,7 +817,6 @@ def read_sqlite(self, path: str | Path, table_name: str | None = None) -> ir.Tab
... con.execute(
... "INSERT INTO t VALUES (1, 'a'), (2, 'b'), (3, 'c')"
... ) # doctest: +ELLIPSIS
...
<...>
>>> con = ibis.connect("duckdb://")
>>> t = con.read_sqlite("/tmp/sqlite.db", table_name="t")
Expand Down Expand Up @@ -864,7 +908,6 @@ def attach_sqlite(
... con.execute(
... "INSERT INTO t VALUES (1, 'a'), (2, 'b'), (3, 'c')"
... ) # doctest: +ELLIPSIS
...
<...>
>>> con = ibis.connect("duckdb://")
>>> con.list_tables()
Expand Down Expand Up @@ -1083,9 +1126,7 @@ def to_parquet(

Partition on multiple columns.

>>> con.to_parquet(
... penguins, tempfile.mkdtemp(), partition_by=("year", "island")
... )
>>> con.to_parquet(penguins, tempfile.mkdtemp(), partition_by=("year", "island"))
"""
self._run_pre_execute_hooks(expr)
query = self._to_sql(expr, params=params)
Expand Down Expand Up @@ -1157,7 +1198,30 @@ def fetch_from_cursor(
for name, col in zip(table.column_names, table.columns)
}
)
return PandasData.convert_table(df, schema)
df = PandasData.convert_table(df, schema)
if not df.empty and geospatial_supported:
return self._to_geodataframe(df, schema)
return df

# TODO(gforsyth): this may not need to be specialized in the future
@staticmethod
def _to_geodataframe(df, schema):
"""Convert `df` to a `GeoDataFrame`.

Required libraries for geospatial support must be installed and
a geospatial column is present in the dataframe.
"""
import geopandas as gpd

geom_col = None
for name, dtype in schema.items():
if dtype.is_geospatial():
if not geom_col:
geom_col = name
df[name] = gpd.GeoSeries.from_wkb(df[name])
if geom_col:
df = gpd.GeoDataFrame(df, geometry=geom_col)
return df

def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
with self.begin() as con:
Expand All @@ -1171,10 +1235,6 @@ def _metadata(self, query: str) -> Iterator[tuple[str, dt.DataType]]:
yield name, ibis_type

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
# in theory we could use pandas dataframes, but when using dataframes
# with pyarrow datatypes later reads of this data segfault
import pandas as pd

schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
raise exc.IbisTypeError(
Expand All @@ -1184,32 +1244,15 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:

# only register if we haven't already done so
if (name := op.name) not in self.list_tables():
if isinstance(data := op.data, PandasDataFrameProxy):
table = data.to_frame()

# convert to object string dtypes because duckdb is either
# 1. extremely slow to register DataFrames with not-pyarrow
# string dtypes
# 2. broken for string[pyarrow] dtypes (segfault)
if conversions := {
colname: "str"
for colname, col in table.items()
if isinstance(col.dtype, pd.StringDtype)
}:
table = table.astype(conversions)
else:
table = data.to_pyarrow(schema)
table = op.data.to_pyarrow(schema)

# register creates a transaction, and we can't nest transactions so
# we create a function to encapsulate the whole shebang
def _register(name, table):
with self.begin() as con:
con.connection.register(name, table)

try:
_register(name, table)
except duckdb.NotImplementedException:
_register(name, data.to_pyarrow(schema))
_register(name, table)

def _get_temp_view_definition(
self, name: str, definition: sa.sql.compiler.Compiled
Expand All @@ -1234,7 +1277,10 @@ def _register_udfs(self, expr: ir.Expr) -> None:
def _compile_udf(self, udf_node: ops.ScalarUDF) -> None:
func = udf_node.__func__
name = func.__name__
input_types = [DuckDBType.to_string(arg.dtype) for arg in udf_node.args]
input_types = [
DuckDBType.to_string(param.annotation.pattern.dtype)
for param in udf_node.__signature__.parameters.values()
]
output_type = DuckDBType.to_string(udf_node.dtype)

def register_udf(con):
Expand Down Expand Up @@ -1277,3 +1323,16 @@ def _insert_dataframe(
if overwrite:
con.execute(t.delete())
con.execute(t.insert().from_select(columns, sa.select(source)))

def table(
self,
name: str,
database: str | None = None,
schema: str | None = None,
) -> ir.Table:
expr = super().table(name=name, database=database, schema=schema)
# load geospatial only if geo columns
if any(typ.is_geospatial() for typ in expr.op().schema.types):
self.load_extension("spatial")

return expr
14 changes: 14 additions & 0 deletions ibis/backends/duckdb/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@
from ibis.backends.base.sql.alchemy.datatypes import AlchemyType
from ibis.backends.base.sqlglot.datatypes import DuckDBType as SqlglotDuckdbType

try:
from geoalchemy2 import Geometry

class Geometry_WKB(Geometry):
as_binary = "ST_AsWKB"

except ImportError:

class Geometry_WKB:
...


_from_duckdb_types = {
psql.BYTEA: dt.Binary,
psql.UUID: dt.UUID,
Expand Down Expand Up @@ -35,6 +47,8 @@
dt.UInt16: ducktypes.USmallInteger,
dt.UInt32: ducktypes.UInteger,
dt.UInt64: ducktypes.UBigInteger,
# Handle projections with geometry columns
dt.Geometry: Geometry_WKB,
}


Expand Down
182 changes: 124 additions & 58 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
reduction,
try_cast,
)
from ibis.backends.duckdb.datatypes import Geometry_WKB
from ibis.backends.postgres.registry import (
_array_index,
_array_slice,
Expand Down Expand Up @@ -55,6 +56,70 @@ def _round(t, op):
}


def _centroid(t, op):
arg = t.translate(op.arg)
return sa.func.st_centroid(arg, type_=Geometry_WKB)


def _geo_end_point(t, op):
arg = t.translate(op.arg)
return sa.func.st_endpoint(arg, type_=Geometry_WKB)


def _geo_start_point(t, op):
arg = t.translate(op.arg)
return sa.func.st_startpoint(arg, type_=Geometry_WKB)


def _envelope(t, op):
arg = t.translate(op.arg)
return sa.func.st_envelope(arg, type_=Geometry_WKB)


def _geo_buffer(t, op):
arg = t.translate(op.arg)
radius = t.translate(op.radius)
return sa.func.st_buffer(arg, radius, type_=Geometry_WKB)


def _geo_unary_union(t, op):
arg = t.translate(op.arg)
return sa.func.st_union_agg(arg, type_=Geometry_WKB)


def _geo_point(t, op):
left = t.translate(op.left)
right = t.translate(op.right)
return sa.func.st_point(left, right, type_=Geometry_WKB)


def _geo_difference(t, op):
left = t.translate(op.left)
right = t.translate(op.right)
return sa.func.st_difference(left, right, type_=Geometry_WKB)


def _geo_intersection(t, op):
left = t.translate(op.left)
right = t.translate(op.right)
return sa.func.st_intersection(left, right, type_=Geometry_WKB)


def _geo_union(t, op):
left = t.translate(op.left)
right = t.translate(op.right)
return sa.func.st_union(left, right, type_=Geometry_WKB)


def _geo_convert(t, op):
arg = t.translate(op.arg)
source = op.source
target = op.target

# sa.true() setting always_xy=True
return sa.func.st_transform(arg, source, target, sa.true(), type_=Geometry_WKB)


def _generic_log(arg, base, *, type_):
return sa.func.ln(arg, type_=type_) / sa.func.ln(base, type_=type_)

Expand Down Expand Up @@ -120,12 +185,14 @@ def _literal(t, op):
value = op.value

if value is None:
return sa.null()
return (
sa.null() if dtype.is_null() else sa.cast(sa.null(), t.get_sqla_type(dtype))
)

sqla_type = t.get_sqla_type(dtype)

if dtype.is_interval():
return sa.literal_column(f"INTERVAL '{value} {dtype.resolution}'")
return getattr(sa.func, f"to_{dtype.unit.plural}")(value)
elif dtype.is_array():
values = value.tolist() if isinstance(value, np.ndarray) else value
return sa.cast(sa.func.list_value(*values), sqla_type)
Expand All @@ -152,8 +219,14 @@ def _literal(t, op):
return sa.func.map(
sa.func.list_value(*value.keys()), sa.func.list_value(*value.values())
)
elif dtype.is_timestamp():
return sa.cast(sa.literal(value.isoformat()), t.get_sqla_type(dtype))
elif dtype.is_date():
return sa.cast(sa.literal(str(value)), sqla_type)
return sa.func.make_date(value.year, value.month, value.day)
elif dtype.is_time():
return sa.func.make_time(
value.hour, value.minute, value.second + value.microsecond / 1e6
)
else:
return sa.cast(sa.literal(value), sqla_type)

Expand Down Expand Up @@ -260,48 +333,8 @@ def _array_intersect(t, op):
)
return t.translate(
ops.ArrayFilter(
op.left, param=name, body=ops.ArrayContains(op.right, parameter)
)
)


def _map_keys(t, op):
m = t.translate(op.arg)
return sa.cast(sa.func.json_keys(sa.func.to_json(m)), t.get_sqla_type(op.dtype))


def _is_map_literal(op):
return isinstance(op, ops.Literal) or (
isinstance(op, ops.Map)
and isinstance(op.keys, ops.Literal)
and isinstance(op.values, ops.Literal)
)


def _map_values(t, op):
if not _is_map_literal(arg := op.arg):
raise UnsupportedOperationError(
"Extracting values of non-literal maps is not yet supported by DuckDB"
)
m_json = sa.func.to_json(t.translate(arg))
return sa.cast(
sa.func.json_extract_string(m_json, sa.func.json_keys(m_json)),
t.get_sqla_type(op.dtype),
)


def _map_merge(t, op):
if not (_is_map_literal(op.left) and _is_map_literal(op.right)):
raise UnsupportedOperationError(
"Merging non-literal maps is not yet supported by DuckDB"
op.left, param=parameter.param, body=ops.ArrayContains(op.right, parameter)
)
left = sa.func.to_json(t.translate(op.left))
right = sa.func.to_json(t.translate(op.right))
pairs = sa.func.json_merge_patch(left, right)
keys = sa.func.json_keys(pairs)
return sa.cast(
sa.func.map(keys, sa.func.json_extract_string(pairs, keys)),
t.get_sqla_type(op.dtype),
)


Expand Down Expand Up @@ -347,6 +380,15 @@ def _to_json_collection(t, op):
return try_cast(t.translate(op.arg), typ, type_=typ)


def _array_remove(t, op):
arg = op.arg
param = ops.Argument(name="x", shape=arg.shape, dtype=arg.dtype.value_type)
return _array_filter(
t,
ops.ArrayFilter(arg, param=param.param, body=ops.NotEquals(param, op.other)),
)


operation_registry.update(
{
ops.ArrayColumn: (
Expand Down Expand Up @@ -395,19 +437,7 @@ def _to_json_collection(t, op):
1,
),
ops.ArraySort: fixed_arity(sa.func.list_sort, 1),
ops.ArrayRemove: lambda t, op: _array_filter(
t,
ops.ArrayFilter(
op.arg,
param="x",
body=ops.NotEquals(
ops.Argument(
name="x", shape=op.arg.shape, dtype=op.arg.dtype.value_type
),
op.other,
),
),
),
ops.ArrayRemove: _array_remove,
ops.ArrayUnion: lambda t, op: t.translate(
ops.ArrayDistinct(ops.ArrayConcat((op.left, op.right)))
),
Expand Down Expand Up @@ -470,7 +500,7 @@ def _to_json_collection(t, op):
),
ops.StartsWith: fixed_arity(sa.func.prefix, 2),
ops.EndsWith: fixed_arity(sa.func.suffix, 2),
ops.Argument: lambda _, op: sa.literal_column(op.name),
ops.Argument: lambda _, op: sa.literal_column(op.param),
ops.Unnest: unary(sa.func.unnest),
ops.MapGet: fixed_arity(
lambda arg, key, default: sa.func.coalesce(
Expand All @@ -496,6 +526,42 @@ def _to_json_collection(t, op):
ops.TimestampDelta: _temporal_delta,
ops.ToJSONMap: _to_json_collection,
ops.ToJSONArray: _to_json_collection,
ops.ArrayFlatten: unary(sa.func.flatten),
ops.IntegerRange: fixed_arity(sa.func.range, 3),
# geospatial
ops.GeoPoint: _geo_point,
ops.GeoAsText: unary(sa.func.ST_AsText),
ops.GeoArea: unary(sa.func.ST_Area),
ops.GeoBuffer: _geo_buffer,
ops.GeoCentroid: _centroid,
ops.GeoContains: fixed_arity(sa.func.ST_Contains, 2),
ops.GeoCovers: fixed_arity(sa.func.ST_Covers, 2),
ops.GeoCoveredBy: fixed_arity(sa.func.ST_CoveredBy, 2),
ops.GeoCrosses: fixed_arity(sa.func.ST_Crosses, 2),
ops.GeoDifference: _geo_difference,
ops.GeoDisjoint: fixed_arity(sa.func.ST_Disjoint, 2),
ops.GeoDistance: fixed_arity(sa.func.ST_Distance, 2),
ops.GeoDWithin: fixed_arity(sa.func.ST_DWithin, 3),
ops.GeoEndPoint: _geo_end_point,
ops.GeoEnvelope: _envelope,
ops.GeoEquals: fixed_arity(sa.func.ST_Equals, 2),
ops.GeoGeometryType: unary(sa.func.ST_GeometryType),
ops.GeoIntersection: _geo_intersection,
ops.GeoIntersects: fixed_arity(sa.func.ST_Intersects, 2),
ops.GeoIsValid: unary(sa.func.ST_IsValid),
ops.GeoLength: unary(sa.func.ST_Length),
ops.GeoNPoints: unary(sa.func.ST_NPoints),
ops.GeoOverlaps: fixed_arity(sa.func.ST_Overlaps, 2),
ops.GeoStartPoint: _geo_start_point,
ops.GeoTouches: fixed_arity(sa.func.ST_Touches, 2),
ops.GeoUnion: _geo_union,
ops.GeoUnaryUnion: _geo_unary_union,
ops.GeoWithin: fixed_arity(sa.func.ST_Within, 2),
ops.GeoX: unary(sa.func.ST_X),
ops.GeoY: unary(sa.func.ST_Y),
ops.GeoConvert: _geo_convert,
# other ops
ops.TimestampRange: fixed_arity(sa.func.range, 3),
}
)

Expand Down
81 changes: 78 additions & 3 deletions ibis/backends/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,35 @@

import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
from ibis.backends.tests.base import BackendTest
from ibis.conftest import SANDBOXED

if TYPE_CHECKING:
from collections.abc import Iterator

from ibis.backends.base import BaseBackend

TEST_TABLES_GEO = {
"zones": ibis.schema(
{
"zone": "string",
"LocationID": "int32",
"borough": "string",
"geom": "geometry",
"x_cent": "float32",
"y_cent": "float32",
}
),
"lines": ibis.schema(
{
"loc_id": "int32",
"geom": "geometry",
}
),
}

class TestConf(BackendTest, RoundAwayFromZero):

class TestConf(BackendTest):
supports_map = True
deps = "duckdb", "duckdb_engine"
stateful = False
Expand All @@ -24,19 +43,39 @@ class TestConf(BackendTest, RoundAwayFromZero):
def preload(self):
if not SANDBOXED:
self.connection._load_extensions(
["httpfs", "postgres_scanner", "sqlite_scanner"]
["httpfs", "postgres_scanner", "sqlite_scanner", "spatial"]
)

@property
def ddl_script(self) -> Iterator[str]:
from ibis.backends.base.sql.alchemy.geospatial import geospatial_supported

parquet_dir = self.data_dir / "parquet"
geojson_dir = self.data_dir / "geojson"
for table in TEST_TABLES:
yield (
f"""
CREATE OR REPLACE TABLE {table} AS
SELECT * FROM read_parquet('{parquet_dir / f'{table}.parquet'}')
"""
)
if geospatial_supported and not SANDBOXED:
for table in TEST_TABLES_GEO:
yield (
f"""
CREATE OR REPLACE TABLE {table} AS
SELECT * FROM st_read('{geojson_dir / f'{table}.geojson'}')
"""
)
yield (
"""
CREATE or REPLACE TABLE geo (name VARCHAR, geom GEOMETRY);
INSERT INTO geo VALUES
('Point', ST_GeomFromText('POINT(-100 40)')),
('Linestring', ST_GeomFromText('LINESTRING(0 0, 1 1, 2 1, 2 2)')),
('Polygon', ST_GeomFromText('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'));
"""
)
yield from super().ddl_script

@staticmethod
Expand All @@ -55,3 +94,39 @@ def load_tpch(self) -> None:
@pytest.fixture(scope="session")
def con(data_dir, tmp_path_factory, worker_id):
return TestConf.load_data(data_dir, tmp_path_factory, worker_id).connection


@pytest.fixture(scope="session")
def zones(con, data_dir):
zones = con.read_geo(data_dir / "geojson" / "zones.geojson")
return zones


@pytest.fixture(scope="session")
def lines(con, data_dir):
lines = con.read_geo(data_dir / "geojson" / "lines.geojson")
return lines


@pytest.fixture(scope="session")
def zones_gdf(data_dir):
gpd = pytest.importorskip("geopandas")
zones_gdf = gpd.read_file(data_dir / "geojson" / "zones.geojson")
return zones_gdf


@pytest.fixture(scope="session")
def lines_gdf(data_dir):
gpd = pytest.importorskip("geopandas")
lines_gdf = gpd.read_file(data_dir / "geojson" / "lines.geojson")
return lines_gdf


@pytest.fixture(scope="session")
def geotable(con):
return con.table("geo")


@pytest.fixture(scope="session")
def gdf(geotable):
return geotable.execute()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
ST_DWITHIN(t0.geom, t0.geom, CAST(3.0 AS REAL(53))) AS tmp
FROM t AS t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
ST_ASTEXT(t0.geom) AS tmp
FROM t AS t0
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
ST_NPOINTS(t0.geom) AS tmp
FROM t AS t0
212 changes: 212 additions & 0 deletions ibis/backends/duckdb/tests/test_geospatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from __future__ import annotations

import numpy.testing as npt
import pandas.testing as tm
import pyarrow as pa
import pytest
from pytest import param

import ibis

gpd = pytest.importorskip("geopandas")
gtm = pytest.importorskip("geopandas.testing")
shapely = pytest.importorskip("shapely")


def test_geospatial_point(zones, zones_gdf):
coord = zones.x_cent.point(zones.y_cent).name("coord")
# this returns GeometryArray
gp_coord = gpd.points_from_xy(zones_gdf.x_cent, zones_gdf.y_cent)

npt.assert_array_equal(coord.to_pandas().values, gp_coord)


# this functions are not implemented in geopandas
@pytest.mark.parametrize(
("operation", "keywords"),
[
param("as_text", {}, id="as_text"),
param("n_points", {}, id="n_points"),
],
)
def test_geospatial_unary_snapshot(operation, keywords, snapshot):
t = ibis.table([("geom", "geometry")], name="t")
expr = getattr(t.geom, operation)(**keywords).name("tmp")
snapshot.assert_match(ibis.to_sql(expr), "out.sql")


def test_geospatial_dwithin(snapshot):
t = ibis.table([("geom", "geometry")], name="t")
expr = t.geom.d_within(t.geom, 3.0).name("tmp")

snapshot.assert_match(ibis.to_sql(expr), "out.sql")


# geospatial unary functions that return a non-geometry series
# we can test using pd.testing (tm)
@pytest.mark.parametrize(
("op", "keywords", "gp_op"),
[
param("area", {}, "area", id="area"),
param("is_valid", {}, "is_valid", id="is_valid"),
param(
"geometry_type",
{},
"geom_type",
id="geometry_type",
marks=pytest.mark.xfail(raises=pa.lib.ArrowTypeError),
),
],
)
def test_geospatial_unary_tm(op, keywords, gp_op, zones, zones_gdf):
expr = getattr(zones.geom, op)(**keywords).name("tmp")
gp_expr = getattr(zones_gdf.geometry, gp_op)

tm.assert_series_equal(expr.to_pandas(), gp_expr, check_names=False)


@pytest.mark.parametrize(
("op", "keywords", "gp_op"),
[
param("x", {}, "x", id="x_coord"),
param("y", {}, "y", id="y_coord"),
],
)
def test_geospatial_xy(op, keywords, gp_op, zones, zones_gdf):
cen = zones.geom.centroid().name("centroid")
gp_cen = zones_gdf.geometry.centroid

expr = getattr(cen, op)(**keywords).name("tmp")
gp_expr = getattr(gp_cen, gp_op)

tm.assert_series_equal(expr.to_pandas(), gp_expr, check_names=False)


def test_geospatial_length(lines, lines_gdf):
# note: ST_LENGTH returns 0 for the case of polygon
# or multi polygon while pandas geopandas returns the perimeter.
length = lines.geom.length().name("length")
gp_length = lines_gdf.geometry.length

tm.assert_series_equal(length.to_pandas(), gp_length, check_names=False)


# geospatial binary functions that return a non-geometry series
# we can test using pd.testing (tm)
@pytest.mark.parametrize(
("op", "gp_op"),
[
param("contains", "contains", id="contains"),
param("geo_equals", "geom_equals", id="geo_eqs"),
param("covers", "covers", id="covers"),
param("covered_by", "covered_by", id="covered_by"),
param("crosses", "crosses", id="crosses"),
param("disjoint", "disjoint", id="disjoint"),
param("distance", "distance", id="distance"),
param("intersects", "intersects", id="intersects"),
param("overlaps", "overlaps", id="overlaps"),
param("touches", "touches", id="touches"),
param("within", "within", id="within"),
],
)
def test_geospatial_binary_tm(op, gp_op, zones, zones_gdf):
expr = getattr(zones.geom, op)(zones.geom).name("tmp")
gp_func = getattr(zones_gdf.geometry, gp_op)(zones_gdf.geometry)

tm.assert_series_equal(expr.to_pandas(), gp_func, check_names=False)


# geospatial unary functions that return a geometry series
# we can test using gpd.testing (gtm)
@pytest.mark.parametrize(
("op", "gp_op"),
[
param("centroid", "centroid", id="centroid"),
param("envelope", "envelope", id="envelope"),
],
)
def test_geospatial_unary_gtm(op, gp_op, zones, zones_gdf):
expr = getattr(zones.geom, op)().name("tmp")
gp_expr = getattr(zones_gdf.geometry, gp_op)

gtm.assert_geoseries_equal(expr.to_pandas(), gp_expr, check_crs=False)


# geospatial binary functions that return a geometry series
# we can test using gpd.testing (gtm)
@pytest.mark.parametrize(
("op", "gp_op"),
[
param("difference", "difference", id="difference"),
param("intersection", "intersection", id="intersection"),
param("union", "union", id=""),
],
)
def test_geospatial_binary_gtm(op, gp_op, zones, zones_gdf):
expr = getattr(zones.geom, op)(zones.geom).name("tmp")
gp_func = getattr(zones_gdf.geometry, gp_op)(zones_gdf.geometry)

gtm.assert_geoseries_equal(expr.to_pandas(), gp_func, check_crs=False)


def test_geospatial_end_point(lines, lines_gdf):
epoint = lines.geom.end_point().name("end_point")
# geopandas does not have end_point this is a work around to get it
gp_epoint = lines_gdf.geometry.boundary.explode(index_parts=True).xs(1, level=1)

gtm.assert_geoseries_equal(epoint.to_pandas(), gp_epoint, check_crs=False)


def test_geospatial_start_point(lines, lines_gdf):
spoint = lines.geom.start_point().name("start_point")
# geopandas does not have start_point this is a work around to get it
gp_spoint = lines_gdf.geometry.boundary.explode(index_parts=True).xs(0, level=1)

gtm.assert_geoseries_equal(spoint.to_pandas(), gp_spoint, check_crs=False)


# this one takes a bit longer than the rest.
def test_geospatial_unary_union(zones, zones_gdf):
unary_union = zones.geom.unary_union().name("unary_union")
# this returns a shapely geometry object
gp_unary_union = zones_gdf.geometry.unary_union

# using set_precision because https://github.com/duckdb/duckdb_spatial/issues/189
assert shapely.equals(
shapely.set_precision(unary_union.to_pandas(), grid_size=1e-7),
shapely.set_precision(gp_unary_union, grid_size=1e-7),
)


def test_geospatial_buffer_point(zones, zones_gdf):
cen = zones.geom.centroid().name("centroid")
gp_cen = zones_gdf.geometry.centroid

buffer = cen.buffer(100.0)
# geopandas resolution default is 16, while duckdb is 8.
gp_buffer = gp_cen.buffer(100.0, resolution=8)

gtm.assert_geoseries_equal(buffer.to_pandas(), gp_buffer, check_crs=False)


def test_geospatial_buffer(zones, zones_gdf):
buffer = zones.geom.buffer(100.0)
# geopandas resolution default is 16, while duckdb is 8.
gp_buffer = zones_gdf.geometry.buffer(100.0, resolution=8)

gtm.assert_geoseries_equal(buffer.to_pandas(), gp_buffer, check_crs=False)


# using a smaller dataset for time purposes
def test_geospatial_convert(geotable, gdf):
# geotable is fabricated but let's say the
# data is in CRS: EPSG:2263
# let's transform to EPSG:4326 (latitude-longitude projection)
geo_ll = geotable.geom.convert("EPSG:2263", "EPSG:4326")

gdf.crs = "EPSG:2263"
gdf_ll = gdf.geometry.to_crs(crs=4326)

gtm.assert_geoseries_equal(
geo_ll.to_pandas(), gdf_ll, check_less_precise=True, check_crs=False
)
84 changes: 84 additions & 0 deletions ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,82 @@ def test_read_parquet(data_dir):
assert t.count().execute()


@pytest.mark.xfail(
LINUX and SANDBOXED,
reason="nix on linux cannot download duckdb extensions or data due to sandboxing",
)
def test_load_spatial_when_geo_column(tmpdir):
pytest.importorskip("geoalchemy2")

path = str(tmpdir.join("test_load_spatial.ddb"))

with duckdb.connect(
# windows is horrible and cannot download in parallel without
# clobbering existing files, so give a temporary custom directory for
# extensions
path,
config={"extension_directory": str(tmpdir.join("extensions"))},
) as con:
con.install_extension("spatial")
con.load_extension("spatial")
con.execute(
# create a table with a geom column
"""
CREATE or REPLACE TABLE samples (name VARCHAR, geom GEOMETRY);

INSERT INTO samples VALUES
('Point', ST_GeomFromText('POINT(-100 40)')),
('Linestring', ST_GeomFromText('LINESTRING(0 0, 1 1, 2 1, 2 2)')),
('Polygon', ST_GeomFromText('POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))'));
"""
)

# load data from ibis and check for spatial extension
con = ibis.duckdb.connect(path)

query = """\
SELECT extension_name AS name
FROM duckdb_extensions()
WHERE installed AND loaded"""

assert "spatial" not in con.sql(query).name.to_pandas().values

# trigger spatial extension load
assert not con.tables.samples.head(1).geom.to_pandas().empty

assert "spatial" in con.sql(query).name.to_pandas().values


def test_read_geo_to_pyarrow(con, data_dir):
pytest.importorskip("geopandas")
shapely = pytest.importorskip("shapely")

t = con.read_geo(data_dir / "geojson" / "zones.geojson")
raw_geometry = t.head().to_pyarrow()["geom"].to_pandas()
assert len(shapely.from_wkb(raw_geometry))


def test_read_geo_to_geopandas(con, data_dir):
gpd = pytest.importorskip("geopandas")
t = con.read_geo(data_dir / "geojson" / "zones.geojson")
gdf = t.head().to_pandas()
assert isinstance(gdf, gpd.GeoDataFrame)


def test_read_geo_from_url(con, monkeypatch):
loaded_exts = []
monkeypatch.setattr(con, "_load_extensions", lambda x, **kw: loaded_exts.extend(x))

with pytest.raises((sa.exc.OperationalError, sa.exc.ProgrammingError)):
# The read will fail, either because the URL is bogus (which it is) or
# because the current connection doesn't have the spatial extension
# installed and so the call to `st_read` will raise a catalog error.
con.read_geo("https://...")

assert "spatial" in loaded_exts
assert "httpfs" in loaded_exts


@pytest.mark.xfail_version(
duckdb=["duckdb<0.7.0"], reason="read_json_auto doesn't exist", raises=exc.IbisError
)
Expand Down Expand Up @@ -367,3 +443,11 @@ def test_register_filesystem_gcs(con):
)

assert band_members.count().to_pyarrow()


def test_memtable_null_column_parquet_dtype_roundtrip(con, tmp_path):
before = ibis.memtable({"a": [None, None, None]}, schema={"a": "string"})
before.to_parquet(tmp_path / "tmp.parquet")
after = ibis.read_parquet(tmp_path / "tmp.parquet")

assert before.a.type() == after.a.type()
19 changes: 19 additions & 0 deletions ibis/backends/duckdb/tests/test_udf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import pytest
from pytest import param

from ibis import udf

Expand Down Expand Up @@ -85,3 +86,21 @@ def test_builtin_agg(con, func):
).scalar()

assert con.execute(expr) == expected


@udf.scalar.python
def dont_intercept_null(x: int) -> int:
assert x is not None
return x


@pytest.mark.parametrize(
("expr", "expected"),
[
param(dont_intercept_null(5), 5, id="notnull"),
param(dont_intercept_null(None), None, id="null"),
param(dont_intercept_null(5) + dont_intercept_null(None), None, id="mixed"),
],
)
def test_dont_intercept_null(con, expr, expected):
assert con.execute(expr) == expected
234 changes: 234 additions & 0 deletions ibis/backends/exasol/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
from __future__ import annotations

import re
import warnings
from collections import ChainMap
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any

import sqlalchemy as sa
import sqlglot as sg

from ibis import util
from ibis.backends.base.sql.alchemy import AlchemyCanCreateSchema, BaseAlchemyBackend
from ibis.backends.base.sqlglot.datatypes import PostgresType
from ibis.backends.exasol.compiler import ExasolCompiler

if TYPE_CHECKING:
from collections.abc import Iterable, MutableMapping

from ibis.backends.base import BaseBackend
from ibis.expr import datatypes as dt


class Backend(BaseAlchemyBackend, AlchemyCanCreateSchema):
name = "exasol"
compiler = ExasolCompiler
supports_temporary_tables = False
supports_create_or_replace = False
supports_in_memory_tables = False
supports_python_udfs = False

def do_connect(
self,
user: str,
password: str,
host: str = "localhost",
port: int = 8563,
schema: str | None = None,
encryption: bool = True,
certificate_validation: bool = True,
encoding: str = "en_US.UTF-8",
) -> None:
"""Create an Ibis client connected to an Exasol database.

Parameters
----------
user
Username used for authentication.
password
Password used for authentication.
host
Hostname to connect to (default: "localhost").
port
Port number to connect to (default: 8563)
schema
Database schema to open, if `None`, no schema will be opened.
encryption
Enables/disables transport layer encryption (default: True).
certificate_validation
Enables/disables certificate validation (default: True).
encoding
The encoding format (default: "en_US.UTF-8").
"""
options = [
"SSLCertificate=SSL_VERIFY_NONE" if not certificate_validation else "",
f"ENCRYPTION={'yes' if encryption else 'no'}",
f"CONNECTIONCALL={encoding}",
]
url_template = (
"exa+websocket://{user}:{password}@{host}:{port}/{schema}?{options}"
)
url = sa.engine.url.make_url(
url_template.format(
user=user,
password=password,
host=host,
port=port,
schema=schema,
options="&".join(options),
)
)
engine = sa.create_engine(url, poolclass=sa.pool.StaticPool)
super().do_connect(engine)

def _convert_kwargs(self, kwargs: MutableMapping) -> None:
def convert_sqla_to_ibis(keyword_arguments):
sqla_to_ibis = {"tls": "encryption", "username": "user"}
for sqla_kwarg, ibis_kwarg in sqla_to_ibis.items():
if sqla_kwarg in keyword_arguments:
keyword_arguments[ibis_kwarg] = keyword_arguments.pop(sqla_kwarg)

def filter_kwargs(keyword_arguments):
allowed_parameters = [
"user",
"password",
"host",
"port",
"schema",
"encryption",
"certificate",
"encoding",
]
to_be_removed = [
key for key in keyword_arguments if key not in allowed_parameters
]
for parameter_name in to_be_removed:
del keyword_arguments[parameter_name]

convert_sqla_to_ibis(kwargs)
filter_kwargs(kwargs)

def _from_url(self, url: str, **kwargs) -> BaseBackend:
"""Construct an ibis backend from a SQLAlchemy-conforming URL."""
kwargs = ChainMap(kwargs)
_, new_kwargs = self.inspector.dialect.create_connect_args(url)
kwargs = kwargs.new_child(new_kwargs)
kwargs = dict(kwargs)
self._convert_kwargs(kwargs)

return self.connect(**kwargs)

@property
def inspector(self):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=sa.exc.RemovedIn20Warning)
return super().inspector

@contextmanager
def begin(self):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=sa.exc.RemovedIn20Warning)
with super().begin() as con:
yield con

def list_tables(self, like=None, database=None):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=sa.exc.RemovedIn20Warning)
return super().list_tables(like=like, database=database)

def _get_sqla_table(
self,
name: str,
autoload: bool = True,
**kwargs: Any,
) -> sa.Table:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=sa.exc.RemovedIn20Warning)
return super()._get_sqla_table(name=name, autoload=autoload, **kwargs)

def _metadata(self, query: str) -> Iterable[tuple[str, dt.DataType]]:
table = sg.table(util.gen_name("exasol_metadata"))
create_view = sg.exp.Create(
kind="VIEW", this=table, expression=sg.parse_one(query, dialect="postgres")
)
drop_view = sg.exp.Drop(kind="VIEW", this=table)
describe = sg.exp.Describe(this=table).sql(dialect="postgres")
# strip trailing encodings e.g., UTF8
varchar_regex = re.compile(r"^(VARCHAR(?:\(\d+\)))?(?:\s+.+)?$")
with self.begin() as con:
con.exec_driver_sql(create_view.sql(dialect="postgres"))
try:
yield from (
(
name,
PostgresType.from_string(varchar_regex.sub(r"\1", typ)),
)
for name, typ, *_ in con.exec_driver_sql(describe)
)
finally:
con.exec_driver_sql(drop_view.sql(dialect="postgres"))

@property
def current_schema(self) -> str:
return self._scalar_query(sa.select(sa.text("CURRENT_SCHEMA")))

@property
def current_database(self) -> str:
return None

def drop_schema(
self, name: str, database: str | None = None, force: bool = False
) -> None:
if database is not None:
raise NotImplementedError(
"`database` argument is not supported for the Exasol backend"
)
drop_schema = sg.exp.Drop(
kind="SCHEMA", this=sg.to_identifier(name), exists=force
)
with self.begin() as con:
con.exec_driver_sql(drop_schema.sql(dialect="postgres"))

def create_schema(
self, name: str, database: str | None = None, force: bool = False
) -> None:
if database is not None:
raise NotImplementedError(
"`database` argument is not supported for the Exasol backend"
)
create_schema = sg.exp.Create(
kind="SCHEMA", this=sg.to_identifier(name), exists=force
)
with self.begin() as con:
open_schema = self.current_schema
con.exec_driver_sql(create_schema.sql(dialect="postgres"))
# Exasol implicitly opens the created schema, therefore we need to restore
# the previous context.
action = (
sa.text(f"OPEN SCHEMA {open_schema}")
if open_schema
else sa.text(f"CLOSE SCHEMA {name}")
)
con.exec_driver_sql(action)

def list_schemas(
self, like: str | None = None, database: str | None = None
) -> list[str]:
if database is not None:
raise NotImplementedError(
"`database` argument is not supported for the Exasol backend"
)

schema, table = "SYS", "EXA_SCHEMAS"
sch = sa.table(
table,
sa.column("schema_name", sa.TEXT()),
schema=schema,
)

query = sa.select(sch.c.schema_name)

with self.begin() as con:
schemas = list(con.execute(query).scalars())
return self._filter_with_like(schemas, like=like)
24 changes: 24 additions & 0 deletions ibis/backends/exasol/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import sqlalchemy as sa

from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.exasol import registry
from ibis.backends.exasol.datatypes import ExasolSQLType


class ExasolExprTranslator(AlchemyExprTranslator):
_registry = registry.create()
_rewrites = AlchemyExprTranslator._rewrites.copy()
_integer_to_timestamp = sa.func.from_unixtime
_dialect_name = "exa.websocket"
native_json_type = False
type_mapper = ExasolSQLType


rewrites = ExasolExprTranslator.rewrites


class ExasolCompiler(AlchemyCompiler):
translator_class = ExasolExprTranslator
support_values_syntax_in_select = False
26 changes: 26 additions & 0 deletions ibis/backends/exasol/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import sqlalchemy.types as sa_types

from ibis.backends.base.sql.alchemy.datatypes import AlchemyType

if TYPE_CHECKING:
import ibis.expr.datatypes as dt


class ExasolSQLType(AlchemyType):
dialect = "exa.websocket"

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sa_types.TypeEngine:
if dtype.is_string():
# see also: https://docs.exasol.com/db/latest/sql_references/data_types/datatypesoverview.htm
MAX_VARCHAR_SIZE = 2_000_000
return sa_types.VARCHAR(MAX_VARCHAR_SIZE)
return super().from_ibis(dtype)

@classmethod
def to_ibis(cls, typ: sa_types.TypeEngine, nullable: bool = True) -> dt.DataType:
return super().to_ibis(typ, nullable=nullable)
Loading