651 changes: 524 additions & 127 deletions ibis/backends/bigquery/__init__.py

Large diffs are not rendered by default.

17 changes: 5 additions & 12 deletions ibis/backends/bigquery/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.bigquery.datatypes import BigQuerySchema, BigQueryType

NATIVE_PARTITION_COL = "_PARTITIONTIME"
Expand All @@ -19,10 +18,10 @@ def schema_from_bigquery_table(table):
schema = BigQuerySchema.to_ibis(table.schema)

# Check for partitioning information
partition_info = table._properties.get("timePartitioning", None)
partition_info = table.time_partitioning
if partition_info is not None:
# We have a partitioned table
partition_field = partition_info.get("field", NATIVE_PARTITION_COL)
partition_field = partition_info.field or NATIVE_PARTITION_COL
# Only add a new column if it's not already a column in the schema
if partition_field not in schema:
schema |= {partition_field: dt.timestamp}
Expand Down Expand Up @@ -136,21 +135,17 @@ def bq_param_date(_: dt.Date, value, name):
)


class BigQueryTable(ops.DatabaseTable):
pass


def rename_partitioned_column(table_expr, bq_table, partition_col):
"""Rename native partition column to user-defined name."""
partition_info = bq_table._properties.get("timePartitioning", None)
partition_info = bq_table.time_partitioning

# If we don't have any partition information, the table isn't partitioned
if partition_info is None:
return table_expr

# If we have a partition, but no "field" field in the table properties,
# then use NATIVE_PARTITION_COL as the default
partition_field = partition_info.get("field", NATIVE_PARTITION_COL)
partition_field = partition_info.field or NATIVE_PARTITION_COL

# The partition field must be in table_expr columns
assert partition_field in table_expr.columns
Expand Down Expand Up @@ -201,9 +196,7 @@ def parse_project_and_dataset(project: str, dataset: str = "") -> tuple[str, str
"""
if dataset.count(".") > 1:
raise ValueError(
"{} is not a BigQuery dataset. More info https://cloud.google.com/bigquery/docs/datasets-intro".format(
dataset
)
f"{dataset} is not a BigQuery dataset. More info https://cloud.google.com/bigquery/docs/datasets-intro"
)
elif dataset.count(".") == 1:
data_project, dataset = dataset.split(".")
Expand Down
34 changes: 2 additions & 32 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
import toolz

import ibis.common.graph as lin
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql import compiler as sql_compiler
from ibis.backends.bigquery import operations, registry, rewrites
from ibis.backends.bigquery.datatypes import BigQueryType


class BigQueryUDFDefinition(sql_compiler.DDL):
Expand All @@ -35,7 +33,7 @@ class BigQueryUnion(sql_compiler.Union):

@classmethod
def keyword(cls, distinct):
"""Use disctinct UNION if distinct is True."""
"""Use distinct UNION if distinct is True."""
return "UNION DISTINCT" if distinct else "UNION ALL"


Expand Down Expand Up @@ -113,39 +111,10 @@ def _trans_param(self, op):
compiles = BigQueryExprTranslator.compiles


@BigQueryExprTranslator.rewrites(ops.NotAll)
def _rewrite_notall(op):
return ops.Any(ops.Not(op.arg), where=op.where)


@BigQueryExprTranslator.rewrites(ops.NotAny)
def _rewrite_notany(op):
return ops.All(ops.Not(op.arg), where=op.where)


class BigQueryTableSetFormatter(sql_compiler.TableSetFormatter):
def _quote_identifier(self, name):
return sg.to_identifier(name).sql("bigquery")

def _format_in_memory_table(self, op):
import ibis

schema = op.schema
names = schema.names
types = schema.types

raw_rows = []
for row in op.data.to_frame().itertuples(index=False):
raw_row = ", ".join(
f"{self._translate(lit.op())} AS {name}"
for lit, name in zip(
map(ibis.literal, row, types), map(self._quote_identifier, names)
)
)
raw_rows.append(f"STRUCT({raw_row})")
array_type = BigQueryType.from_ibis(dt.Array(op.schema.as_struct()))
return f"UNNEST({array_type}[{', '.join(raw_rows)}])"


class BigQueryCompiler(sql_compiler.Compiler):
translator_class = BigQueryExprTranslator
Expand All @@ -156,6 +125,7 @@ class BigQueryCompiler(sql_compiler.Compiler):

support_values_syntax_in_select = False
null_limit = None
cheap_in_memory_tables = True

@staticmethod
def _generate_setup_queries(expr, context):
Expand Down
43 changes: 32 additions & 11 deletions ibis/backends/bigquery/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import google.cloud.bigquery as bq
import sqlglot as sg

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.formats import SchemaMapper, TypeMapper
Expand Down Expand Up @@ -91,23 +92,43 @@ def from_ibis(cls, dtype: dt.DataType) -> str:
"BigQuery geography uses points on WGS84 reference ellipsoid."
f"Current geotype: {dtype.geotype}, Current srid: {dtype.srid}"
)
elif dtype.is_map():
raise NotImplementedError("Maps are not supported in BigQuery")
else:
return str(dtype).upper()


class BigQuerySchema(SchemaMapper):
@classmethod
def from_ibis(cls, schema: sch.Schema) -> list[bq.SchemaField]:
result = []
for name, dtype in schema.items():
if isinstance(dtype, dt.Array):
schema_fields = []

for name, typ in ibis.schema(schema).items():
if typ.is_array():
value_type = typ.value_type
if value_type.is_array():
raise TypeError("Nested arrays are not supported in BigQuery")

is_struct = value_type.is_struct()

field_type = (
"RECORD" if is_struct else BigQueryType.from_ibis(typ.value_type)
)
mode = "REPEATED"
dtype = dtype.value_type
fields = cls.from_ibis(ibis.schema(getattr(value_type, "fields", {})))
elif typ.is_struct():
field_type = "RECORD"
mode = "NULLABLE" if typ.nullable else "REQUIRED"
fields = cls.from_ibis(ibis.schema(typ.fields))
else:
mode = "REQUIRED" if not dtype.nullable else "NULLABLE"
field = bq.SchemaField(name, BigQueryType.from_ibis(dtype), mode=mode)
result.append(field)
return result
field_type = BigQueryType.from_ibis(typ)
mode = "NULLABLE" if typ.nullable else "REQUIRED"
fields = ()

schema_fields.append(
bq.SchemaField(name, field_type=field_type, mode=mode, fields=fields)
)
return schema_fields

@classmethod
def _dtype_from_bigquery_field(cls, field: bq.SchemaField) -> dt.DataType:
Expand All @@ -125,7 +146,8 @@ def _dtype_from_bigquery_field(cls, field: bq.SchemaField) -> dt.DataType:
elif mode == "REQUIRED":
return dtype.copy(nullable=False)
elif mode == "REPEATED":
return dt.Array(dtype)
# arrays with NULL elements aren't supported
return dt.Array(dtype.copy(nullable=False))
else:
raise TypeError(f"Unknown BigQuery field.mode: {mode}")

Expand All @@ -148,6 +170,5 @@ def spread_type(dt: dt.DataType):
for type_ in dt.types:
yield from spread_type(type_)
elif dt.is_map():
yield from spread_type(dt.key_type)
yield from spread_type(dt.value_type)
raise NotImplementedError("Maps are not supported in BigQuery")
yield dt
25 changes: 20 additions & 5 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _struct_field(translator, op):
def _struct_column(translator, op):
cols = (
f"{translator.translate(value)} AS {name}"
for name, value, in zip(op.names, op.values)
for name, value in zip(op.names, op.values)
)
return "STRUCT({})".format(", ".join(cols))

Expand Down Expand Up @@ -764,6 +764,18 @@ def _timestamp_delta(t, op):
)


def _group_concat(translator, op):
arg = op.arg
where = op.where

if where is not None:
arg = ops.IfElse(where, arg, ibis.NA)

arg = translator.translate(arg)
sep = translator.translate(op.sep)
return f"STRING_AGG({arg}, {sep})"


OPERATION_REGISTRY = {
**operation_registry,
# Literal
Expand All @@ -785,8 +797,10 @@ def _timestamp_delta(t, op):
ops.BitwiseXor: lambda t, op: f"{t.translate(op.left)} ^ {t.translate(op.right)}",
ops.BitwiseOr: lambda t, op: f"{t.translate(op.left)} | {t.translate(op.right)}",
ops.BitwiseAnd: lambda t, op: f"{t.translate(op.left)} & {t.translate(op.right)}",
ops.BitwiseLeftShift: lambda t, op: f"{t.translate(op.left)} << {t.translate(op.right)}",
ops.BitwiseRightShift: lambda t, op: f"{t.translate(op.left)} >> {t.translate(op.right)}",
ops.BitwiseLeftShift: lambda t,
op: f"{t.translate(op.left)} << {t.translate(op.right)}",
ops.BitwiseRightShift: lambda t,
op: f"{t.translate(op.left)} >> {t.translate(op.right)}",
# Temporal functions
ops.Date: unary("DATE"),
ops.DateFromYMD: fixed_arity("DATE", 3),
Expand Down Expand Up @@ -834,7 +848,7 @@ def _timestamp_delta(t, op):
ops.RegexSearch: _regex_search,
ops.RegexExtract: _regex_extract,
ops.RegexReplace: _regex_replace,
ops.GroupConcat: reduction("STRING_AGG"),
ops.GroupConcat: _group_concat,
ops.Cast: _cast,
ops.StructField: _struct_field,
ops.StructColumn: _struct_column,
Expand Down Expand Up @@ -914,7 +928,8 @@ def _timestamp_delta(t, op):
ops.RandomScalar: fixed_arity("RAND", 0),
ops.NthValue: _nth_value,
ops.JSONGetItem: lambda t, op: f"{t.translate(op.arg)}[{t.translate(op.index)}]",
ops.ArrayStringJoin: lambda t, op: f"ARRAY_TO_STRING({t.translate(op.arg)}, {t.translate(op.sep)})",
ops.ArrayStringJoin: lambda t,
op: f"ARRAY_TO_STRING({t.translate(op.arg)}, {t.translate(op.sep)})",
ops.StartsWith: fixed_arity("STARTS_WITH", 2),
ops.EndsWith: fixed_arity("ENDS_WITH", 2),
ops.TableColumn: table_column,
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/bigquery/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@ def bq_mean(op):
ops.Mean: bq_mean,
ops.Any: toolz.identity,
ops.All: toolz.identity,
ops.NotAny: toolz.identity,
ops.NotAll: toolz.identity,
}
100 changes: 35 additions & 65 deletions ibis/backends/bigquery/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,29 @@

import concurrent.futures
import contextlib
import functools
import io
import os
from typing import TYPE_CHECKING, Any
from typing import Any

import google.api_core.exceptions as gexc
import google.auth
import pytest
from google.cloud import bigquery as bq

import ibis
import ibis.expr.datatypes as dt
from ibis.backends.bigquery import EXTERNAL_DATA_SCOPES, Backend
from ibis.backends.bigquery.datatypes import BigQueryType
from ibis.backends.bigquery.datatypes import BigQuerySchema
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero, UnorderedComparator
from ibis.backends.tests.data import json_types, non_null_array_types, struct_types, win

if TYPE_CHECKING:
from collections.abc import Mapping

DATASET_ID = "ibis_gbq_testing"
DATASET_ID_TOKYO = "ibis_gbq_testing_tokyo"
REGION_TOKYO = "asia-northeast1"
DEFAULT_PROJECT_ID = "ibis-gbq"
PROJECT_ID_ENV_VAR = "GOOGLE_BIGQUERY_PROJECT_ID"


@functools.singledispatch
def ibis_type_to_bq_field(typ: dt.DataType) -> Mapping[str, Any]:
raise NotImplementedError(typ)


@ibis_type_to_bq_field.register(dt.DataType)
def _(typ: dt.DataType) -> Mapping[str, Any]:
return {"field_type": BigQueryType.from_ibis(typ)}


@ibis_type_to_bq_field.register(dt.Array)
def _(typ: dt.Array) -> Mapping[str, Any]:
return {
"field_type": BigQueryType.from_ibis(typ.value_type),
"mode": "REPEATED",
}


@ibis_type_to_bq_field.register(dt.Struct)
def _(typ: dt.Struct) -> Mapping[str, Any]:
return {
"field_type": "RECORD",
"mode": "NULLABLE" if typ.nullable else "REQUIRED",
"fields": ibis_schema_to_bq_schema(ibis.schema(typ.fields)),
}


def ibis_schema_to_bq_schema(schema):
return [
bq.SchemaField(
name.replace(":", "").replace(" ", "_"),
**ibis_type_to_bq_field(typ),
)
for name, typ in ibis.schema(schema).items()
]


class TestConf(UnorderedComparator, BackendTest, RoundAwayFromZero):
"""Backend-specific class with information for testing."""

Expand Down Expand Up @@ -129,9 +87,13 @@ def _load_data(self, **_: Any) -> None:
timestamp_table = bq.Table(
bq.TableReference(testing_dataset, "timestamp_column_parted")
)
timestamp_table.schema = ibis_schema_to_bq_schema(
dict(
my_timestamp_parted_col="timestamp", string_col="string", int_col="int"
timestamp_table.schema = BigQuerySchema.from_ibis(
ibis.schema(
dict(
my_timestamp_parted_col="timestamp",
string_col="string",
int_col="int",
)
)
)
timestamp_table.time_partitioning = bq.TimePartitioning(
Expand All @@ -141,8 +103,10 @@ def _load_data(self, **_: Any) -> None:

# ingestion date partitioning
date_table = bq.Table(bq.TableReference(testing_dataset, "date_column_parted"))
date_table.schema = ibis_schema_to_bq_schema(
dict(my_date_parted_col="date", string_col="string", int_col="int")
date_table.schema = BigQuerySchema.from_ibis(
ibis.schema(
dict(my_date_parted_col="date", string_col="string", int_col="int")
)
)
date_table.time_partitioning = bq.TimePartitioning(field="my_date_parted_col")
client.create_table(date_table, exists_ok=True)
Expand All @@ -161,8 +125,10 @@ def _load_data(self, **_: Any) -> None:
bq.TableReference(testing_dataset, "struct"),
job_config=bq.LoadJobConfig(
write_disposition=write_disposition,
schema=ibis_schema_to_bq_schema(
dict(abc="struct<a: float64, b: string, c: int64>")
schema=BigQuerySchema.from_ibis(
ibis.schema(
dict(abc="struct<a: float64, b: string, c: int64>")
)
),
),
)
Expand All @@ -176,13 +142,15 @@ def _load_data(self, **_: Any) -> None:
bq.TableReference(testing_dataset, "array_types"),
job_config=bq.LoadJobConfig(
write_disposition=write_disposition,
schema=ibis_schema_to_bq_schema(
dict(
x="array<int64>",
y="array<string>",
z="array<float64>",
grouper="string",
scalar_column="float64",
schema=BigQuerySchema.from_ibis(
ibis.schema(
dict(
x="array<int64>",
y="array<string>",
z="array<float64>",
grouper="string",
scalar_column="float64",
)
)
),
),
Expand Down Expand Up @@ -219,8 +187,10 @@ def _load_data(self, **_: Any) -> None:
bq.TableReference(testing_dataset, "numeric_table"),
job_config=bq.LoadJobConfig(
write_disposition=write_disposition,
schema=ibis_schema_to_bq_schema(
dict(string_col="string", numeric_col="decimal(38, 9)")
schema=BigQuerySchema.from_ibis(
ibis.schema(
dict(string_col="string", numeric_col="decimal(38, 9)")
)
),
source_format=bq.SourceFormat.NEWLINE_DELIMITED_JSON,
),
Expand All @@ -235,8 +205,8 @@ def _load_data(self, **_: Any) -> None:
bq.TableReference(testing_dataset, "win"),
job_config=bq.LoadJobConfig(
write_disposition=write_disposition,
schema=ibis_schema_to_bq_schema(
dict(g="string", x="int64", y="int64")
schema=BigQuerySchema.from_ibis(
ibis.schema(dict(g="string", x="int64", y="int64"))
),
),
)
Expand All @@ -250,7 +220,7 @@ def _load_data(self, **_: Any) -> None:
bq.TableReference(testing_dataset, "json_t"),
job_config=bq.LoadJobConfig(
write_disposition=write_disposition,
schema=ibis_schema_to_bq_schema(dict(js="json")),
schema=BigQuerySchema.from_ibis(ibis.schema(dict(js="json"))),
source_format=bq.SourceFormat.NEWLINE_DELIMITED_JSON,
),
)
Expand All @@ -267,7 +237,7 @@ def _load_data(self, **_: Any) -> None:
),
bq.TableReference(testing_dataset, table),
job_config=bq.LoadJobConfig(
schema=ibis_schema_to_bq_schema(schema),
schema=BigQuerySchema.from_ibis(ibis.schema(schema)),
write_disposition=write_disposition,
source_format=bq.SourceFormat.PARQUET,
),
Expand All @@ -288,7 +258,7 @@ def _load_data(self, **_: Any) -> None:
),
bq.TableReference(testing_dataset_tokyo, table),
job_config=bq.LoadJobConfig(
schema=ibis_schema_to_bq_schema(schema),
schema=BigQuerySchema.from_ibis(ibis.schema(schema)),
write_disposition=write_disposition,
source_format=bq.SourceFormat.PARQUET,
),
Expand Down
33 changes: 18 additions & 15 deletions ibis/backends/bigquery/tests/system/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,11 @@ def test_current_database(con, dataset_id):
db = con.current_database
assert db == dataset_id
assert db == con.dataset_id
assert con.list_tables(database=db, like="alltypes") == con.list_tables(
assert con.list_tables(schema=db, like="alltypes") == con.list_tables(
like="alltypes"
)


def test_database(con):
database = con.database(con.dataset_id)
assert database.list_tables(like="alltypes") == con.list_tables(like="alltypes")


def test_array_collect(struct_table):
key = struct_table.array_of_structs_col[0]["string_field"]
expr = struct_table.group_by(key=key).aggregate(
Expand Down Expand Up @@ -242,25 +237,29 @@ def test_set_database(con2):

def test_exists_table_different_project(con):
name = "co_daily_summary"
database = "bigquery-public-data.epa_historical_air_quality"
dataset = "bigquery-public-data.epa_historical_air_quality"

assert name in con.list_tables(database=database)
assert "foobar" not in con.list_tables(database=database)
assert name in con.list_tables(schema=dataset)
assert "foobar" not in con.list_tables(schema=dataset)


def test_multiple_project_queries(con, snapshot):
so = con.table("posts_questions", database="bigquery-public-data.stackoverflow")
trips = con.table("trips", database="nyc-tlc.yellow")
with pytest.warns(FutureWarning, match="`database` is deprecated as of v7.1"):
so = con.table("posts_questions", database="bigquery-public-data.stackoverflow")
with pytest.warns(FutureWarning, match="`database` is deprecated as of v7.1"):
trips = con.table("trips", database="nyc-tlc.yellow")
join = so.join(trips, so.tags == trips.rate_code)[[so.title]]
result = join.compile()
snapshot.assert_match(result, "out.sql")


def test_multiple_project_queries_database_api(con, snapshot):
stackoverflow = con.database("bigquery-public-data.stackoverflow")
posts_questions = stackoverflow.posts_questions
with pytest.warns(FutureWarning, match="`database` is deprecated as of v7.1"):
posts_questions = stackoverflow.posts_questions
yellow = con.database("nyc-tlc.yellow")
trips = yellow.trips
with pytest.warns(FutureWarning, match="`database` is deprecated as of v7.1"):
trips = yellow.trips
predicate = posts_questions.tags == trips.rate_code
join = posts_questions.join(trips, predicate)[[posts_questions.title]]
result = join.compile()
Expand All @@ -269,9 +268,13 @@ def test_multiple_project_queries_database_api(con, snapshot):

def test_multiple_project_queries_execute(con):
stackoverflow = con.database("bigquery-public-data.stackoverflow")
posts_questions = stackoverflow.posts_questions.limit(5)
with pytest.warns(FutureWarning, match="`database` is deprecated as of v7.1"):
posts_questions = stackoverflow.posts_questions
posts_questions = posts_questions.limit(5)
yellow = con.database("nyc-tlc.yellow")
trips = yellow.trips.limit(5)
with pytest.warns(FutureWarning, match="`database` is deprecated as of v7.1"):
trips = yellow.trips
trips = trips.limit(5)
predicate = posts_questions.tags == trips.rate_code
cols = [posts_questions.title]
join = posts_questions.left_join(trips, predicate)[cols]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
sum(t0.`foo`) AS `Sum_foo`
FROM t0 AS t0
FROM t0
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ FROM (
EXCEPT DISTINCT
SELECT
t1.*
FROM t1 AS t1
FROM t1
) AS t0
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ FROM (
INTERSECT DISTINCT
SELECT
t1.*
FROM t1 AS t1
FROM t1
) AS t0
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ FROM (
UNION ALL
SELECT
t1.*
FROM t1 AS t1
FROM t1
) AS t0
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ FROM (
UNION DISTINCT
SELECT
t1.*
FROM t1 AS t1
FROM t1
) AS t0
6 changes: 0 additions & 6 deletions ibis/backends/bigquery/tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,3 @@ def test_unnest(snapshot):
).select(level_two=lambda t: t.level_one.unnest())
)
snapshot.assert_match(result, "out_two_unnests.sql")


def test_compile_in_memory_table(snapshot):
t = ibis.memtable({"Column One": [1, 2, 3]})
result = ibis.bigquery.compile(t)
snapshot.assert_match(result, "out.sql")
89 changes: 38 additions & 51 deletions ibis/backends/bigquery/udf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,10 @@
_udf_name_cache: dict[str, Iterable[int]] = collections.defaultdict(itertools.count)


def _create_udf_node(name, fields):
"""Create a new UDF node type.
Parameters
----------
name : str
Then name of the UDF node
fields : OrderedDict
Mapping of class member name to definition
Returns
-------
result : type
A new BigQueryUDFNode subclass
"""
def _make_udf_name(name):
definition = next(_udf_name_cache[name])
external_name = f"{name}_{definition:d}"
return type(external_name, (BigQueryUDFNode,), fields)
return external_name


class _BigQueryUDF:
Expand Down Expand Up @@ -274,24 +260,6 @@ def js(
if libraries is None:
libraries = []

udf_node_fields = {
name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_)
for name, type_ in params.items()
}

udf_node_fields["dtype"] = output_type
udf_node_fields["shape"] = rlz.shape_like("args")
udf_node_fields["__slots__"] = ("sql",)

udf_node = _create_udf_node(name, udf_node_fields)

from ibis.backends.bigquery.compiler import compiles

@compiles(udf_node)
def compiles_udf_node(t, op):
args = ", ".join(map(t.translate, op.args))
return f"{udf_node.__name__}({args})"

bigquery_signature = ", ".join(
f"{name} {BigQueryType.from_ibis(dt.dtype(type_))}"
for name, type_ in params.items()
Expand All @@ -305,16 +273,35 @@ def compiles_udf_node(t, op):
False: "NOT DETERMINISTIC\n",
None: "",
}.get(determinism)

name = _make_udf_name(name)
sql_code = f'''\
CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature})
CREATE TEMPORARY FUNCTION {name}({bigquery_signature})
RETURNS {return_type}
{determinism_formatted}LANGUAGE js AS """
{body}
"""{libraries_opts};'''

udf_node_fields = {
name: rlz.ValueOf(None if type_ == "ANY TYPE" else type_)
for name, type_ in params.items()
}

udf_node_fields["dtype"] = output_type
udf_node_fields["shape"] = rlz.shape_like("args")
udf_node_fields["sql"] = sql_code

udf_node = type(name, (BigQueryUDFNode,), udf_node_fields)

from ibis.backends.bigquery.compiler import compiles

@compiles(udf_node)
def compiles_udf_node(t, op):
args = ", ".join(map(t.translate, op.args))
return f"{udf_node.__name__}({args})"

def wrapped(*args, **kwargs):
node = udf_node(*args, **kwargs)
object.__setattr__(node, "sql", sql_code)
return node.to_expr()

wrapped.__signature__ = inspect.Signature(
Expand Down Expand Up @@ -376,19 +363,6 @@ def sql(
}
return_type = BigQueryType.from_ibis(dt.dtype(output_type))

udf_node_fields["dtype"] = output_type
udf_node_fields["shape"] = rlz.shape_like("args")
udf_node_fields["__slots__"] = ("sql",)

udf_node = _create_udf_node(name, udf_node_fields)

from ibis.backends.bigquery.compiler import compiles

@compiles(udf_node)
def compiles_udf_node(t, op):
args_formatted = ", ".join(map(t.translate, op.args))
return f"{udf_node.__name__}({args_formatted})"

bigquery_signature = ", ".join(
"{name} {type}".format(
name=name,
Expand All @@ -398,14 +372,27 @@ def compiles_udf_node(t, op):
)
for name, type_ in params.items()
)
name = _make_udf_name(name)
sql_code = f"""\
CREATE TEMPORARY FUNCTION {udf_node.__name__}({bigquery_signature})
CREATE TEMPORARY FUNCTION {name}({bigquery_signature})
RETURNS {return_type}
AS ({sql_expression});"""

udf_node_fields["dtype"] = output_type
udf_node_fields["shape"] = rlz.shape_like("args")
udf_node_fields["sql"] = sql_code

udf_node = type(name, (BigQueryUDFNode,), udf_node_fields)

from ibis.backends.bigquery.compiler import compiles

@compiles(udf_node)
def compiles_udf_node(t, op):
args = ", ".join(map(t.translate, op.args))
return f"{udf_node.__name__}({args})"

def wrapper(*args, **kwargs):
node = udf_node(*args, **kwargs)
object.__setattr__(node, "sql", sql_code)
return node.to_expr()

return wrapper
Expand Down
128 changes: 82 additions & 46 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import clickhouse_connect as cc
import pyarrow as pa
import pyarrow_hotfix # noqa: F401
import sqlalchemy as sa
import sqlglot as sg
import toolz
Expand All @@ -23,7 +24,7 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base import BaseBackend, CanCreateDatabase
from ibis.backends.base.sqlglot import STAR, C, F, lit
from ibis.backends.base.sqlglot import STAR, C, F
from ibis.backends.clickhouse.compiler import translate
from ibis.backends.clickhouse.datatypes import ClickhouseType

Expand Down Expand Up @@ -201,7 +202,7 @@ def list_tables(
if database is None:
database = F.currentDatabase()
else:
database = lit(database)
database = sg.exp.convert(database)

query = query.where(C.database.eq(database).or_(C.is_temporary))

Expand Down Expand Up @@ -431,7 +432,10 @@ def table(self, name: str, database: str | None = None) -> ir.Table:
"""
schema = self.get_schema(name, database=database)
op = ops.DatabaseTable(
name=name, schema=schema, source=self, namespace=database
name=name,
schema=schema,
source=self,
namespace=ops.Namespace(database=database),
)
return op.to_expr()

Expand Down Expand Up @@ -486,21 +490,10 @@ def raw_sql(
self._log(query)
return self.con.query(query, external_data=external_data, **kwargs)

def fetch_from_cursor(self, cursor, schema):
import pandas as pd

from ibis.formats.pandas import PandasData

df = pd.DataFrame.from_records(iter(cursor), columns=schema.names)
return PandasData.convert_table(df, schema)

def close(self) -> None:
"""Close ClickHouse connection."""
self.con.close()

def _fully_qualified_name(self, name: str, database: str | None) -> str:
return sg.table(name, db=database).sql(dialect="clickhouse")

def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema:
"""Return a Schema object for the indicated table and database.
Expand Down Expand Up @@ -673,61 +666,104 @@ def create_table(
Table
The new table
"""
tmp = "TEMPORARY " * temp
replace = "OR REPLACE " * overwrite

if temp and overwrite:
raise com.IbisInputError("Cannot specify both temp and overwrite")

if not temp:
table = self._fully_qualified_name(name, database)
else:
table = name
database = None
code = f"CREATE {replace}{tmp}TABLE {table}"
raise com.IbisInputError(
"Cannot specify both `temp=True` and `overwrite=True` for ClickHouse"
)

if obj is None and schema is None:
raise com.IbisError("The schema or obj parameter is required")
raise com.IbisError("The `schema` or `obj` parameter is required")

if obj is not None and not isinstance(obj, ir.Expr):
obj = ibis.memtable(obj, schema=schema)

if schema is None:
schema = obj.schema()

serialized_schema = ", ".join(
f"`{name}` {ClickhouseType.to_string(typ)}" for name, typ in schema.items()
this = sg.exp.Schema(
this=sg.table(name, db=database),
expressions=[
sg.exp.ColumnDef(
this=sg.to_identifier(name), kind=ClickhouseType.from_ibis(typ)
)
for name, typ in schema.items()
],
)

code += f" ({serialized_schema}) ENGINE = {engine}"

if order_by is not None:
code += f" ORDER BY {', '.join(util.promote_list(order_by))}"
elif engine == "MergeTree":
# empty tuple to indicate no specific order when engine is
# MergeTree
code += " ORDER BY tuple()"
properties = [
# the engine cannot be quoted, since clickhouse won't allow e.g.,
# "File(Native)"
sg.exp.EngineProperty(this=sg.to_identifier(engine, quoted=False))
]

if temp:
properties.append(sg.exp.TemporaryProperty())

if order_by is not None or engine == "MergeTree":
# engine == "MergeTree" requires an order by clause, which is the
# empty tuple if order_by is False-y
properties.append(
sg.exp.Order(
expressions=[
sg.exp.Ordered(
this=sg.exp.Tuple(
expressions=list(map(sg.column, order_by or ()))
)
)
]
)
)

if partition_by is not None:
code += f" PARTITION BY {', '.join(util.promote_list(partition_by))}"
properties.append(
sg.exp.PartitionedByProperty(
this=sg.exp.Schema(
expressions=list(map(sg.to_identifier, partition_by))
)
)
)

if sample_by is not None:
code += f" SAMPLE BY {sample_by}"
properties.append(
sg.exp.SampleProperty(
this=sg.exp.Tuple(expressions=list(map(sg.column, sample_by)))
)
)

if settings:
kvs = ", ".join(f"{name}={value!r}" for name, value in settings.items())
code += f" SETTINGS {kvs}"
properties.append(
sg.exp.SettingsProperty(
expressions=[
sg.exp.SetItem(
this=sg.exp.EQ(
this=sg.to_identifier(name),
expression=sg.exp.convert(value),
)
)
for name, value in settings.items()
]
)
)

external_tables = {}
expression = None

if obj is not None:
code += f" AS {self.compile(obj)}"
external_tables = self._collect_in_memory_tables(obj)
else:
external_tables = {}
expression = self._to_sqlglot(obj)
external_tables.update(self._collect_in_memory_tables(obj))

code = sg.exp.Create(
this=this,
kind="TABLE",
replace=overwrite,
expression=expression,
properties=sg.exp.Properties(expressions=properties),
)

external_data = self._normalize_external_tables(external_tables)

# create the table
self.con.raw_query(code, external_data=external_data)
sql = code.sql(self.name, pretty=True)
self.con.raw_query(sql, external_data=external_data)

return self.table(name, database=database)

Expand Down
51 changes: 26 additions & 25 deletions ibis/backends/clickhouse/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,18 @@

import sqlglot as sg

import ibis.expr.analysis as an
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.clickhouse.compiler.relations import translate_rel
from ibis.backends.clickhouse.compiler.values import translate_val
from ibis.common.patterns import Call, _
from ibis.expr.analysis import c, p, x, y
from ibis.common.deferred import _
from ibis.expr.analysis import c, find_first_base_table, p, x, y
from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna, rewrite_sample

if TYPE_CHECKING:
from collections.abc import Mapping


a = Call.namespace(an)


def _translate_node(node, **kwargs):
if isinstance(node, ops.Value):
return translate_val(node, **kwargs)
Expand Down Expand Up @@ -88,44 +85,48 @@ def fn(node, _, **kwargs):
# `translate_val` rule
params = {param.op(): value for param, value in params.items()}
replace_literals = p.ScalarParameter(dtype=x) >> (
lambda op, ctx: ops.Literal(value=params[op], dtype=ctx[x])
lambda _, x: ops.Literal(value=params[_], dtype=x)
)

# replace the right side of InColumn into a scalar subquery for sql
# backends
replace_in_column_with_table_array_view = p.InColumn(..., y) >> _.copy(
replace_in_column_with_table_array_view = p.InColumn(options=y) >> _.copy(
options=c.TableArrayView(
c.Selection(table=a.find_first_base_table(y), selections=(y,))
c.Selection(table=lambda _, y: find_first_base_table(y), selections=(y,))
),
)

# replace any checks against an empty right side of the IN operation with
# `False`
replace_empty_in_values_with_false = p.InValues(..., ()) >> c.Literal(
replace_empty_in_values_with_false = p.InValues(options=()) >> c.Literal(
False, dtype="bool"
)

# replace `NotExistsSubquery` with `Not(ExistsSubquery)`
#
# this allows to avoid having another rule to negate ExistsSubquery
replace_notexists_subquery_with_not_exists = p.NotExistsSubquery(...) >> c.Not(
c.ExistsSubquery(...)
)
# subtract one from one-based functions to convert to zero-based indexing
subtract_one_from_one_indexed_functions = (
p.WindowFunction(p.RankBase | p.NTile)
| p.StringFind
| p.FindInSet
| p.ArrayPosition
) >> c.Subtract(_, 1)

add_one_to_nth_value_input = p.NthValue >> _.copy(nth=c.Add(_.nth, 1))

# clickhouse-specific rewrite to turn notany/notall into equivalent
# already-defined operations
replace_notany_with_min_not = p.NotAny(x, where=y) >> c.Min(c.Not(x), where=y)
replace_notall_with_max_not = p.NotAll(x, where=y) >> c.Max(c.Not(x), where=y)
nullify_empty_string_results = (p.ExtractURLField | p.DayOfWeekName) >> c.NullIf(
_, ""
)

op = op.replace(
replace_literals
| replace_in_column_with_table_array_view
| replace_empty_in_values_with_false
| replace_notexists_subquery_with_not_exists
| replace_notany_with_min_not
| replace_notall_with_max_not
| subtract_one_from_one_indexed_functions
| add_one_to_nth_value_input
| nullify_empty_string_results
| rewrite_fillna
| rewrite_dropna
| rewrite_sample
)
# apply translate rules in topological order
results = op.map(fn, filter=(ops.TableNode, ops.Value))
node = results[op]
node = op.map(fn)[op]
return node.this if isinstance(node, sg.exp.Subquery) else node
31 changes: 2 additions & 29 deletions ibis/backends/clickhouse/compiler/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sqlglot import FALSE, NULL, STAR
from ibis.backends.base.sqlglot import STAR


@functools.singledispatch
Expand All @@ -28,7 +28,7 @@ def _physical_table(op: ops.PhysicalTable, **_):

@translate_rel.register
def _database_table(op: ops.DatabaseTable, *, name, namespace, **_):
return sg.table(name, db=namespace)
return sg.table(name, db=namespace.schema, catalog=namespace.database)


def replace_tables_with_star_selection(node, alias=None):
Expand Down Expand Up @@ -200,33 +200,6 @@ def _distinct(op: ops.Distinct, *, table, **_):
return sg.select(STAR).distinct().from_(table)


@translate_rel.register
def _dropna(op: ops.DropNa, *, table, how, subset, **_):
colnames = op.schema.names
alias = table.alias_or_name

if subset is None:
columns = [sg.column(name, table=alias) for name in colnames]
else:
columns = subset

if columns:
func = sg.and_ if how == "any" else sg.or_
predicate = func(*(sg.not_(col.is_(NULL)) for col in columns))
elif how == "all":
predicate = FALSE
else:
predicate = None

if predicate is None:
return table

try:
return table.where(predicate)
except AttributeError:
return sg.select(STAR).from_(table).where(predicate)


@translate_rel.register
def _sql_string_view(op: ops.SQLStringView, query: str, **_: Any):
table = sg.table(op.name)
Expand Down
148 changes: 48 additions & 100 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,20 @@

import calendar
import functools
import math
import operator
from functools import partial
from typing import Any

import sqlglot as sg
from sqlglot.dialects.dialect import rename_func

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis import util
from ibis.backends.base.sqlglot import (
NULL,
STAR,
AggGen,
C,
F,
interval,
lit,
make_cast,
)
from ibis.backends.base.sqlglot import NULL, STAR, AggGen, C, F, interval, make_cast
from ibis.backends.clickhouse.datatypes import ClickhouseType

# TODO: This is a hack to get around the fact that sqlglot 17.8.6 is broken for
# ClickHouse's isNaN
sg.dialects.clickhouse.ClickHouse.Generator.TRANSFORMS.update(
{
sg.exp.IsNan: rename_func("isNaN"),
sg.exp.StartsWith: rename_func("startsWith"),
}
)


def _aggregate(funcname, *args, where):
has_filter = where is not None
Expand Down Expand Up @@ -224,9 +206,9 @@ def _string_find(op, *, arg, substr, start, end, **_):
raise com.UnsupportedOperationError("String find doesn't support end argument")

if start is not None:
return F.locate(arg, substr, start) - 1
return F.locate(arg, substr, start)

return F.locate(arg, substr) - 1
return F.locate(arg, substr)


@translate_val.register(ops.RegexSearch)
Expand All @@ -252,7 +234,7 @@ def _regex_extract(op, *, arg, pattern, index, **_):

@translate_val.register(ops.FindInSet)
def _index_of(op, *, needle, values, **_):
return F.indexOf(F.array(*values), needle) - 1
return F.indexOf(F.array(*values), needle)


@translate_val.register(ops.Round)
Expand Down Expand Up @@ -325,14 +307,14 @@ def _literal(op, *, value, dtype, **kw):
return NULL
return cast(NULL, dtype)
elif dtype.is_boolean():
return lit(bool(value))
return sg.exp.convert(bool(value))
elif dtype.is_inet():
v = str(value)
return F.toIPv6(v) if ":" in v else F.toIPv4(v)
elif dtype.is_string():
return lit(str(value).replace(r"\0", r"\\0"))
return sg.exp.convert(str(value).replace(r"\0", r"\\0"))
elif dtype.is_macaddr():
return lit(str(value))
return sg.exp.convert(str(value))
elif dtype.is_decimal():
precision = dtype.precision
if precision is None or not 1 <= precision <= 76:
Expand All @@ -350,10 +332,14 @@ def _literal(op, *, value, dtype, **kw):
type_name = F.toDecimal256
return type_name(value, dtype.scale)
elif dtype.is_numeric():
return lit(value)
if math.isnan(value):
return sg.exp.Literal(this="NaN", is_string=False)
elif math.isinf(value):
inf = sg.exp.Literal(this="inf", is_string=False)
return -inf if value < 0 else inf
return sg.exp.convert(value)
elif dtype.is_interval():
dtype = op.dtype
if dtype.unit.short in {"ms", "us", "ns"}:
if dtype.unit.short in ("ms", "us", "ns"):
raise com.UnsupportedOperationError(
"Clickhouse doesn't support subsecond interval resolutions"
)
Expand Down Expand Up @@ -393,7 +379,7 @@ def _literal(op, *, value, dtype, **kw):
values = []

for k, v in value.items():
keys.append(lit(k))
keys.append(sg.exp.convert(k))
values.append(
_literal(
ops.Literal(v, dtype=value_type), value=v, dtype=value_type, **kw
Expand Down Expand Up @@ -450,6 +436,16 @@ def _truncate(op, *, arg, unit, **_):
return converter(arg)


@translate_val.register(ops.TimestampBucket)
def _timestamp_bucket(op, *, arg, interval, offset, **_):
if offset is not None:
raise com.UnsupportedOperationError(
"Timestamp bucket with offset is not supported"
)

return F.toStartOfInterval(arg, interval)


@translate_val.register(ops.DateFromYMD)
def _date_from_ymd(op, *, year, month, day, **_):
return F.toDate(
Expand Down Expand Up @@ -578,12 +574,7 @@ def _clip(op, *, arg, lower, upper, **_):
def _struct_field(op, *, arg, field: str, **_):
arg_dtype = op.arg.dtype
idx = arg_dtype.names.index(field)
return cast(sg.exp.Dot(this=arg, expression=lit(idx + 1)), op.dtype)


@translate_val.register(ops.NthValue)
def _nth_value(op, *, arg, nth, **_):
return F.nth_value(arg, _parenthesize(op.nth, nth) + 1)
return cast(sg.exp.Dot(this=arg, expression=sg.exp.convert(idx + 1)), op.dtype)


@translate_val.register(ops.Repeat)
Expand Down Expand Up @@ -611,7 +602,8 @@ def _in_column(op, *, value, options, **_):
return value.isin(options.this if isinstance(options, sg.exp.Subquery) else options)


_NUM_WEEKDAYS = 7
_DAYS = calendar.day_name
_NUM_WEEKDAYS = len(_DAYS)


@translate_val.register(ops.DayOfWeekIndex)
Expand All @@ -632,15 +624,11 @@ def day_of_week_name(op, *, arg, **_):
#
# We test against 20 in CI, so we implement day_of_week_name as follows
num_weekdays = _NUM_WEEKDAYS
weekdays = range(num_weekdays)
base = (((F.toDayOfWeek(arg) - 1) % num_weekdays) + num_weekdays) % num_weekdays
return F.nullIf(
sg.exp.Case(
this=base,
ifs=[if_(day, calendar.day_name[day]) for day in weekdays],
default=lit(""),
),
"",
return sg.exp.Case(
this=base,
ifs=[if_(i, day) for i, day in enumerate(_DAYS)],
default=sg.exp.convert(""),
)


Expand Down Expand Up @@ -818,6 +806,16 @@ def formatter(op, *, left, right, **_):
ops.NTile: "ntile",
ops.ArrayIntersect: "arrayIntersect",
ops.ExtractEpochSeconds: "toRelativeSecondNum",
ops.NthValue: "nth_value",
ops.MinRank: "rank",
ops.DenseRank: "dense_rank",
ops.RowNumber: "row_number",
ops.ExtractProtocol: "protocol",
ops.ExtractAuthority: "netloc",
ops.ExtractHost: "domain",
ops.ExtractPath: "path",
ops.ExtractFragment: "fragment",
ops.ArrayPosition: "indexOf",
}


Expand Down Expand Up @@ -912,12 +910,8 @@ def _window_frame(op, *, group_by, order_by, start, end, max_lookback=None, **_)

@translate_val.register(ops.WindowFunction)
def _window(op: ops.WindowFunction, *, func, frame, **_: Any):
window = frame(this=func)

# preserve zero-based indexing
if isinstance(op.func, ops.RankBase):
return window - 1
return window
# frame is a partial call to sg.exp.Window
return frame(this=func)


def shift_like(op_class, func):
Expand All @@ -943,58 +937,17 @@ def formatter(op, *, arg, offset, default, **_):
shift_like(ops.Lead, F.leadInFrame)


@translate_val.register(ops.RowNumber)
def _row_number(op, **_):
return F.row_number()


@translate_val.register(ops.DenseRank)
def _dense_rank(op, **_):
return F.dense_rank()


@translate_val.register(ops.MinRank)
def _rank(op, **_):
return F.rank()


@translate_val.register(ops.ExtractProtocol)
def _extract_protocol(op, *, arg, **_):
return F.nullIf(F.protocol(arg), "")


@translate_val.register(ops.ExtractAuthority)
def _extract_authority(op, *, arg, **_):
return F.nullIf(F.netloc(arg), "")


@translate_val.register(ops.ExtractHost)
def _extract_host(op, *, arg, **_):
return F.nullIf(F.domain(arg), "")


@translate_val.register(ops.ExtractFile)
def _extract_file(op, *, arg, **_):
return F.nullIf(F.cutFragment(F.pathFull(arg)), "")


@translate_val.register(ops.ExtractPath)
def _extract_path(op, *, arg, **_):
return F.nullIf(F.path(arg), "")
return F.cutFragment(F.pathFull(arg))


@translate_val.register(ops.ExtractQuery)
def _extract_query(op, *, arg, key, **_):
if key is not None:
input = F.extractURLParameter(arg, key)
return F.extractURLParameter(arg, key)
else:
input = F.queryString(arg)
return F.nullIf(input, "")


@translate_val.register(ops.ExtractFragment)
def _extract_fragment(op, *, arg, **_):
return F.nullIf(F.fragment(arg), "")
return F.queryString(arg)


@translate_val.register(ops.ArrayStringJoin)
Expand All @@ -1019,11 +972,6 @@ def _array_filter(op, *, arg, param, body, **_):
return F.arrayFilter(func, arg)


@translate_val.register(ops.ArrayPosition)
def _array_position(op, *, arg, other, **_):
return F.indexOf(arg, other) - 1


@translate_val.register(ops.ArrayRemove)
def _array_remove(op, *, arg, other, **_):
x = sg.to_identifier("x")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
False
FALSE
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
True
TRUE
15 changes: 15 additions & 0 deletions ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,21 @@ def test_create_table_data(con, data, engine, temp_table):
assert len(t.execute()) == 3


def test_create_table_with_properties(con, temp_table):
data = pd.DataFrame({"a": list("abcde" * 20), "b": [1, 2, 3, 4, 5] * 20})
n = len(data)
t = con.create_table(
temp_table,
data,
schema=ibis.schema(dict(a="string", b="!uint32")),
order_by=["a", "b"],
partition_by=["a"],
sample_by=["b"],
settings={"allow_nullable_key": "1"},
)
assert t.count().execute() == n


@pytest.mark.parametrize(
"engine",
[
Expand Down
5 changes: 0 additions & 5 deletions ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import importlib
import importlib.metadata
import itertools
import sys
from functools import cache
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -332,10 +331,6 @@ def pytest_collection_modifyitems(session, config, items):
(
item,
[
pytest.mark.xfail(
sys.version_info >= (3, 11),
reason="PySpark doesn't support Python 3.11",
),
pytest.mark.xfail(
vparse(pd.__version__) >= vparse("2"),
reason="PySpark doesn't support pandas>=2",
Expand Down
61 changes: 60 additions & 1 deletion ibis/backends/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends.dask.core import execute_and_reset
from ibis.backends.pandas import BasePandasBackend
from ibis.backends.pandas.core import _apply_schema
from ibis.formats.pandas import DaskData

if TYPE_CHECKING:
import pathlib
from collections.abc import Mapping, MutableMapping

# Make sure that the pandas backend options have been loaded
Expand All @@ -29,6 +31,7 @@
class Backend(BasePandasBackend):
name = "dask"
backend_table_type = dd.DataFrame
supports_in_memory_tables = False

def do_connect(
self,
Expand Down Expand Up @@ -103,7 +106,7 @@ def compile(
Returns
-------
dask.dataframe.core.DataFrame | dask.dataframe.core.Series | das.dataframe.core.Scalar
dask.dataframe.core.DataFrame | dask.dataframe.core.Series | dask.dataframe.core.Scalar
Dask graph.
"""
params = {
Expand All @@ -113,6 +116,62 @@ def compile(

return execute_and_reset(query.op(), params=params, **kwargs)

def read_csv(
self, source: str | pathlib.Path, table_name: str | None = None, **kwargs: Any
):
"""Register a CSV file as a table in the current session.
Parameters
----------
source
The data source. Can be a local or remote file, pathlike objects
also accepted.
table_name
An optional name to use for the created table. This defaults to
a generated name.
**kwargs
Additional keyword arguments passed to Pandas loading function.
See https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html
for more information.
Returns
-------
ir.Table
The just-registered table
"""
table_name = table_name or util.gen_name("read_csv")
df = dd.read_csv(source, **kwargs)
self.dictionary[table_name] = df
return self.table(table_name)

def read_parquet(
self, source: str | pathlib.Path, table_name: str | None = None, **kwargs: Any
):
"""Register a parquet file as a table in the current session.
Parameters
----------
source
The data source(s). May be a path to a file, an iterable of files,
or directory of parquet files.
table_name
An optional name to use for the created table. This defaults to
a generated name.
**kwargs
Additional keyword arguments passed to Pandas loading function.
See https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html
for more information.
Returns
-------
ir.Table
The just-registered table
"""
table_name = table_name or util.gen_name("read_parquet")
df = dd.read_parquet(source, **kwargs)
self.dictionary[table_name] = df
return self.table(table_name)

def table(self, name: str, schema: sch.Schema | None = None):
df = self.dictionary[name]
schema = schema or self.schemas.get(name, None)
Expand Down
33 changes: 0 additions & 33 deletions ibis/backends/dask/execution/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
- ops.Aggregation
- ops.Any
- ops.NotAny
- ops.All
- ops.NotAll
"""

from __future__ import annotations
Expand Down Expand Up @@ -132,33 +129,3 @@ def execute_any_all_series_group_by(op, data, mask, aggcontext=None, **kwargs):
# here for future scaffolding.
result = aggcontext.agg(data, operator.methodcaller(name))
return result


@execute_node.register((ops.NotAny, ops.NotAll), dd.Series, (dd.Series, type(None)))
def execute_notany_series(op, data, mask, aggcontext=None, **kwargs):
if mask is not None:
data = data.loc[mask]

name = type(op).__name__[len("Not") :].lower()
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~aggcontext.agg(data, name)
else:
# Note this branch is not currently hit in the dask backend but is
# here for future scaffolding.
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))
return result


@execute_node.register((ops.NotAny, ops.NotAll), ddgb.SeriesGroupBy, type(None))
def execute_notany_series_group_by(op, data, mask, aggcontext=None, **kwargs):
name = type(op).__name__[len("Not") :].lower()
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~aggcontext.agg(data, name)
else:
# Note this branch is not currently hit in the dask backend but is
# here for future scaffolding.
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))

return result
5 changes: 5 additions & 0 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,3 +552,8 @@ def execute_table_array_view(op, _, **kwargs):
# Need to compute dataframe in order to squeeze into a scalar
ddf = execute(op.table)
return ddf.compute().squeeze()


@execute_node.register(ops.Sample, dd.DataFrame, object, object)
def execute_sample(op, data, fraction, seed, **kwargs):
return data.sample(frac=fraction, random_state=seed)
4 changes: 1 addition & 3 deletions ibis/backends/dask/execution/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ def compute_sorted_frame(
new_columns[computed_sort_key] = temporary_column

result = df.assign(**new_columns)
result = result.sort_values(
computed_sort_keys, ascending=ascending, kind="mergesort"
)
result = result.sort_values(computed_sort_keys, ascending=ascending)
# TODO: we'll eventually need to return this frame with the temporary
# columns and drop them in the caller (maybe using post_execute?)
ngrouping_keys = len(group_by)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/dask/execution/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def execute_window_op(
if frame.group_by:
if frame.order_by:
raise NotImplementedError("Grouped and order windows not supported yet")
# TODO finish implementeing grouped/order windows.
# TODO finish implementing grouped/order windows.
else:
if len(grouping_keys) == 1 and isinstance(grouping_keys[0], dd.Series):
# Dask will raise an exception about not supporting multiple Series in group by key
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/dask/tests/execution/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_round_decimal_with_negative_places(t, df):

@pytest.mark.xfail(
raises=OperationNotDefinedError,
reason="TODO - arrays - #2553"
reason="TODO - arrays - #2553",
# Need an ops.MultiQuantile execution func that dispatches on ndarrays
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -203,7 +203,7 @@ def test_arraylike_functions_transform_errors(t, df, ibis_func, exc):

@pytest.mark.xfail(
raises=OperationNotDefinedError,
reason="TODO - arrays - #2553"
reason="TODO - arrays - #2553",
# Need an ops.MultiQuantile execution func that dispatches on ndarrays
)
def test_quantile_array_access(client, t, df):
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/dask/tests/execution/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,7 @@ def test_nullif_inf(npartitions):
expected = dd.from_pandas(
pd.Series([np.nan, 3.14, np.nan, 42.0], name="a"),
npartitions=npartitions,
).reset_index(
drop=True
) # match dask reset index behavior
).reset_index(drop=True) # match dask reset index behavior
tm.assert_series_equal(result.compute(), expected.compute(), check_index=False)


Expand Down
11 changes: 5 additions & 6 deletions ibis/backends/dask/tests/execution/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def test_batting_avg_change_in_games_per_year(players, players_df):


@pytest.mark.xfail(
raises=NotImplementedError,
reason="Grouped and order windows not supported yet",
raises=AttributeError, reason="'Series' object has no attribute 'rank'"
)
def test_batting_most_hits(players, players_df):
expr = players.mutate(
Expand Down Expand Up @@ -237,7 +236,7 @@ def test_batting_specific_cumulative(batting, batting_df, op, sort_kind):
pandas_method = methodcaller(op)
expected = pandas_method(
batting_df[["G", "yearID"]]
.sort_values("yearID", kind=sort_kind)
.sort_values("yearID")
.G.rolling(len(batting_df), min_periods=1)
)
expected = expected.compute().sort_index().reset_index(drop=True)
Expand All @@ -253,7 +252,7 @@ def test_batting_cumulative(batting, batting_df, sort_kind):
columns = ["G", "yearID"]
more_values = (
batting_df[columns]
.sort_values("yearID", kind=sort_kind)
.sort_values("yearID")
.G.rolling(len(batting_df), min_periods=1)
.sum()
.astype("int64")
Expand Down Expand Up @@ -300,7 +299,7 @@ def test_batting_rolling(batting, batting_df, sort_kind):
columns = ["G", "yearID"]
more_values = (
batting_df[columns]
.sort_values("yearID", kind=sort_kind)
.sort_values("yearID")
.G.rolling(6, min_periods=1)
.sum()
.astype("int64")
Expand Down Expand Up @@ -401,7 +400,7 @@ def test_mutate_with_window_after_join(sort_kind, npartitions):
{
"dates": dd.concat([left_df.dates] * 3)
.compute()
.sort_values(kind=sort_kind)
.sort_values()
.reset_index(drop=True),
"ints": [0] * 3 + [1] * 3 + [2] * 3,
"strings": ["a"] * 3 + ["b"] * 3 + ["c"] * 3,
Expand Down
276 changes: 203 additions & 73 deletions ibis/backends/datafusion/__init__.py

Large diffs are not rendered by default.

1,049 changes: 0 additions & 1,049 deletions ibis/backends/datafusion/compiler.py

This file was deleted.

File renamed without changes.
128 changes: 128 additions & 0 deletions ibis/backends/datafusion/compiler/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Ibis expression to sqlglot compiler.
The compiler is built with a few `singledispatch` functions:
1. `translate_rel` for compiling `ops.TableNode`s
1. `translate_val` for compiling `ops.Value`s
## `translate`
### Node Implementation
There's a single `ops.Node` implementation for `ops.TableNode`s instances.
This function compiles each node in topological order. The topological sorting,
result caching, and iteration are all handled by
`ibis.expr.operations.core.Node.map`.
"""

from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Any

import sqlglot as sg

import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.datafusion.compiler.relations import translate_rel
from ibis.backends.datafusion.compiler.values import translate_val
from ibis.common.deferred import _
from ibis.expr.analysis import c, find_first_base_table, p, x, y
from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna, rewrite_sample

if TYPE_CHECKING:
from collections.abc import Mapping


def _translate_node(node, **kwargs):
if isinstance(node, ops.Value):
return translate_val(node, **kwargs)
assert isinstance(node, ops.TableNode)
return translate_rel(node, **kwargs)


def translate(op: ops.TableNode, params: Mapping[ir.Value, Any]) -> sg.exp.Expression:
"""Translate an ibis operation to a sqlglot expression.
Parameters
----------
op
An ibis `TableNode`
params
A mapping of expressions to concrete values
Returns
-------
sqlglot.expressions.Expression
A sqlglot expression
"""

gen_alias_index = itertools.count()

def fn(node, _, **kwargs):
result = _translate_node(node, **kwargs)

# don't alias root nodes or value ops
if node is op or isinstance(node, ops.Value):
return result

assert isinstance(node, ops.TableNode)

alias_index = next(gen_alias_index)
alias = f"t{alias_index:d}"

try:
return result.subquery(alias)
except AttributeError:
return sg.alias(result, alias)

# substitute parameters immediately to avoid having to define a
# ScalarParameter translation rule
#
# this lets us avoid threading `params` through every `translate_val` call
# only to be used in the one place it would be needed: the ScalarParameter
# `translate_val` rule
params = {param.op(): value for param, value in params.items()}
replace_literals = p.ScalarParameter(dtype=x) >> (
lambda _, x: ops.Literal(value=params[_], dtype=x)
)

# replace the right side of InColumn into a scalar subquery for sql
# backends
replace_in_column_with_table_array_view = p.InColumn(..., y) >> _.copy(
options=c.TableArrayView(
c.Selection(table=lambda _, y: find_first_base_table(y), selections=(y,))
),
)

# replace any checks against an empty right side of the IN operation with
# `False`
replace_empty_in_values_with_false = p.InValues(..., ()) >> c.Literal(
False, dtype="bool"
)

# subtract one from one-based functions to convert to zero-based indexing
subtract_one_from_one_indexed_functions = (
p.WindowFunction(p.RankBase | p.NTile)
| p.StringFind
| p.FindInSet
| p.ArrayPosition
) >> c.Subtract(_, 1)

add_one_to_nth_value_input = p.NthValue >> _.copy(nth=c.Add(_.nth, 1))

op = op.replace(
replace_literals
| replace_in_column_with_table_array_view
| replace_empty_in_values_with_false
| subtract_one_from_one_indexed_functions
| add_one_to_nth_value_input
| rewrite_fillna
| rewrite_dropna
| rewrite_sample
)

# apply translate rules in topological order
node = op.map(fn)[op]
return node.this if isinstance(node, sg.exp.Subquery) else node
187 changes: 187 additions & 0 deletions ibis/backends/datafusion/compiler/relations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
from __future__ import annotations

import functools

import sqlglot as sg

import ibis.common.exceptions as com
import ibis.expr.operations as ops
from ibis.backends.base.sqlglot import STAR


@functools.singledispatch
def translate_rel(op, **_):
"""Translate a value expression into sqlglot."""
raise com.OperationNotDefinedError(f"No translation rule for {type(op)}")


@translate_rel.register(ops.DummyTable)
def dummy_table(op, *, values, **_):
return sg.select(*values)


@translate_rel.register
def _physical_table(op: ops.PhysicalTable, **_):
return sg.table(op.name)


@translate_rel.register(ops.DatabaseTable)
def table(op, *, name, namespace, **_):
return sg.table(name, db=namespace.schema, catalog=namespace.database)


@translate_rel.register(ops.SelfReference)
def _self_ref(op, *, table, **_):
return sg.alias(table, op.name)


_JOIN_TYPES = {
ops.InnerJoin: "inner",
ops.LeftJoin: "left",
ops.RightJoin: "right",
ops.OuterJoin: "full",
ops.LeftAntiJoin: "left anti",
ops.LeftSemiJoin: "left semi",
}


@translate_rel.register
def _join(op: ops.Join, *, left, right, predicates, **_):
on = sg.and_(*predicates) if predicates else None
join_type = _JOIN_TYPES[type(op)]
try:
return left.join(right, join_type=join_type, on=on)
except AttributeError:
select_args = [f"{left.alias_or_name}.*"]

# select from both the left and right side of the join if the join
# is not a filtering join (semi join or anti join); filtering joins
# only return the left side columns
if not isinstance(op, (ops.LeftSemiJoin, ops.LeftAntiJoin)):
select_args.append(f"{right.alias_or_name}.*")
return (
sg.select(*select_args).from_(left).join(right, join_type=join_type, on=on)
)


def replace_tables_with_star_selection(node, alias=None):
if isinstance(node, (sg.exp.Subquery, sg.exp.Table, sg.exp.CTE)):
return sg.exp.Column(
this=STAR,
table=sg.to_identifier(alias if alias is not None else node.alias_or_name),
)
return node


@translate_rel.register
def _selection(op: ops.Selection, *, table, selections, predicates, sort_keys, **_):
# needs_alias should never be true here in explicitly, but it may get
# passed via a (recursive) call to translate_val
if isinstance(op.table, ops.Join) and not isinstance(
op.table, (ops.LeftSemiJoin, ops.LeftAntiJoin)
):
args = table.this.args
from_ = args["from"]
(join,) = args["joins"]
else:
from_ = join = None

alias = table.alias_or_name
selections = tuple(
replace_tables_with_star_selection(
node,
# replace the table name with the alias if the table is **not** a
# join, because we may be selecting from a subquery or an aliased
# table; otherwise we'll select from the _unaliased_ table or the
# _child_ table, which may have a different alias than the one we
# generated for the input table
alias if from_ is None and join is None else None,
)
for node in selections
) or (STAR,)

sel = sg.select(*selections).from_(from_ if from_ is not None else table)

if join is not None:
sel = sel.join(join)

if predicates:
if join is not None:
sel = sg.select(STAR).from_(sel.subquery(alias))
sel = sel.where(*predicates)

if sort_keys:
sel = sel.order_by(*sort_keys)

return sel


@translate_rel.register
def _limit(op: ops.Limit, *, table, n, offset, **_):
result = sg.select(STAR).from_(table)

if n is not None:
if not isinstance(n, int):
limit = sg.select(n).from_(table).subquery()
else:
limit = n
result = result.limit(limit)

if not isinstance(offset, int):
return result.offset(
sg.select(offset).from_(table).subquery().sql("clickhouse")
)

return result.offset(offset) if offset != 0 else result


@translate_rel.register(ops.Aggregation)
def _aggregation(
op: ops.Aggregation, *, table, metrics, by, having, predicates, sort_keys, **_
):
selections = (by + metrics) or (STAR,)
sel = sg.select(*selections).from_(table)

if by:
sel = sel.group_by(
*(key.this if isinstance(key, sg.exp.Alias) else key for key in by)
)

if predicates:
sel = sel.where(*predicates)

if having:
sel = sel.having(*having)

if sort_keys:
sel = sel.order_by(*sort_keys)

return sel


_SET_OP_FUNC = {
ops.Union: sg.union,
ops.Intersection: sg.intersect,
ops.Difference: sg.except_,
}


@translate_rel.register
def _set_op(op: ops.SetOp, *, left, right, distinct: bool = False, **_):
if isinstance(left, sg.exp.Table):
left = sg.select(STAR).from_(left)

if isinstance(right, sg.exp.Table):
right = sg.select(STAR).from_(right)

func = _SET_OP_FUNC[type(op)]

left = left.args.get("this", left)
right = right.args.get("this", right)

return func(left, right, distinct=distinct)


@translate_rel.register
def _distinct(op: ops.Distinct, *, table, **_):
return sg.select(STAR).distinct().from_(table)
807 changes: 807 additions & 0 deletions ibis/backends/datafusion/compiler/values.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis
from ibis.backends.conftest import TEST_TABLES
from ibis.backends.tests.base import BackendTest, RoundAwayFromZero
from ibis.backends.tests.data import array_types


class TestConf(BackendTest, RoundAwayFromZero):
Expand All @@ -15,7 +16,7 @@ class TestConf(BackendTest, RoundAwayFromZero):
# returned_timestamp_unit = 'ns'
supports_structs = False
supports_json = False
supports_arrays = False
supports_arrays = True
stateful = False
deps = ("datafusion",)

Expand All @@ -24,6 +25,7 @@ def _load_data(self, **_: Any) -> None:
for table_name in TEST_TABLES:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
con.register(path, table_name=table_name)
con.register(array_types, table_name="array_types")

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
Expand Down
43 changes: 43 additions & 0 deletions ibis/backends/datafusion/tests/test_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import pytest
from datafusion import (
SessionContext,
)

import ibis
from ibis.backends.conftest import TEST_TABLES


@pytest.fixture
def name_to_path(data_dir):
return {
table_name: data_dir / "parquet" / f"{table_name}.parquet"
for table_name in TEST_TABLES
}


def test_none_config():
config = None
conn = ibis.datafusion.connect(config)
assert conn.list_tables() == []


def test_str_config(name_to_path):
config = {name: str(path) for name, path in name_to_path.items()}
conn = ibis.datafusion.connect(config)
assert sorted(conn.list_tables()) == sorted(name_to_path)


def test_path_config(name_to_path):
config = name_to_path
conn = ibis.datafusion.connect(config)
assert sorted(conn.list_tables()) == sorted(name_to_path)


def test_context_config(name_to_path):
ctx = SessionContext()
for name, path in name_to_path.items():
ctx.register_parquet(name, str(path))
conn = ibis.datafusion.connect(ctx)
assert sorted(conn.list_tables()) == sorted(name_to_path)
8 changes: 8 additions & 0 deletions ibis/backends/datafusion/tests/test_string.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import annotations

import ibis


def test_string_length(con):
t = ibis.memtable({"s": ["aaa", "a", "aa"]})
assert con.execute(t.s.length()).gt(0).all()
38 changes: 38 additions & 0 deletions ibis/backends/datafusion/tests/test_temporal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from __future__ import annotations

from operator import methodcaller

import pytest
from pytest import param

import ibis


@pytest.mark.parametrize(
("func", "expected"),
[
param(
methodcaller("hour"),
14,
id="hour",
),
param(
methodcaller("minute"),
48,
id="minute",
),
param(
methodcaller("second"),
5,
id="second",
),
param(
methodcaller("millisecond"),
359,
id="millisecond",
),
],
)
def test_time_extract_literal(con, func, expected):
value = ibis.time("14:48:05.359")
assert con.execute(func(value).name("tmp")) == expected
5 changes: 1 addition & 4 deletions ibis/backends/datafusion/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pandas.testing as tm
import pytest

import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis import udf
Expand Down Expand Up @@ -75,6 +74,4 @@ def test_builtin_agg_udf_filtered(con):
def median(a: float, where: bool = True) -> float:
"""Median of a column."""

expr = median(con.tables.batting.G)
with pytest.raises(exc.OperationNotDefinedError, match="No translation rule for"):
con.execute(expr)
median(con.tables.batting.G).execute()
115 changes: 115 additions & 0 deletions ibis/backends/datafusion/udfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from __future__ import annotations

import itertools
from urllib.parse import parse_qs, urlsplit

import pyarrow as pa
import pyarrow.compute as pc
import pyarrow_hotfix # noqa: F401

import ibis.expr.datatypes as dt # noqa: TCH001


def _extract_epoch_seconds(array) -> dt.int32:
return pc.cast(pc.divide(pc.cast(array, pa.int64()), 1_000_000), pa.int32())


def extract_epoch_seconds_date(array: dt.date) -> dt.int32:
return _extract_epoch_seconds(array)


def extract_epoch_seconds_timestamp(array: dt.Timestamp(scale=6)) -> dt.int32:
return _extract_epoch_seconds(array)


def _extract_second(array):
return pc.cast(pc.second(array), pa.int32())


def extract_second_timestamp(array: dt.Timestamp(scale=9)) -> dt.int32:
return _extract_second(array)


def extract_second_time(array: dt.time) -> dt.int32:
return _extract_second(array)


def _extract_millisecond(array) -> dt.int32:
return pc.cast(pc.millisecond(array), pa.int32())


def extract_millisecond_timestamp(array: dt.Timestamp(scale=9)) -> dt.int32:
return _extract_millisecond(array)


def extract_millisecond_time(array: dt.time) -> dt.int32:
return _extract_millisecond(array)


def extract_microsecond(array: dt.Timestamp(scale=9)) -> dt.int32:
arr = pc.multiply(pc.millisecond(array), 1000)
return pc.cast(pc.add(pc.microsecond(array), arr), pa.int32())


def _extract_query_arrow(
arr: pa.StringArray, *, param: str | None = None
) -> pa.StringArray:
if param is None:

def _extract_query(url, param):
return urlsplit(url).query

params = itertools.repeat(None)
else:

def _extract_query(url, param):
query = urlsplit(url).query
value = parse_qs(query)[param]
return value if len(value) > 1 else value[0]

params = param.to_pylist()

return pa.array(map(_extract_query, arr.to_pylist(), params))


def extract_query(array: str) -> str:
return _extract_query_arrow(array)


def extract_query_param(array: str, param: str) -> str:
return _extract_query_arrow(array, param=param)


def extract_user_info(arr: str) -> str:
def _extract_user_info(url):
url_parts = urlsplit(url)
username = url_parts.username or ""
password = url_parts.password or ""
return f"{username}:{password}"

return pa.array(map(_extract_user_info, arr.to_pylist()))


def extract_url_field(arr: str, field: str) -> str:
field = field.to_pylist()[0]
return pa.array(getattr(url, field, "") for url in map(urlsplit, arr.to_pylist()))


def sign(arr: dt.float64) -> dt.float64:
return pc.sign(arr)


def _extract_minute(array) -> dt.int32:
return pc.cast(pc.minute(array), pa.int32())


def extract_minute_time(array: dt.time) -> dt.int32:
return _extract_minute(array)


def extract_minute_timestamp(array: dt.Timestamp(scale=9)) -> dt.int32:
return _extract_minute(array)


def extract_hour_time(array: dt.time) -> dt.int32:
return pc.cast(pc.hour(array), pa.int32())
6 changes: 2 additions & 4 deletions ibis/backends/druid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _has_table(self, connection, table_name: str, schema) -> bool:
return bool(connection.execute(query).scalar())

def _get_sqla_table(
self, name: str, schema: str | None = None, autoload: bool = True, **kwargs: Any
self, name: str, autoload: bool = True, **kwargs: Any
) -> sa.Table:
with warnings.catch_warnings():
warnings.filterwarnings(
Expand All @@ -136,6 +136,4 @@ def _get_sqla_table(
),
category=sa.exc.SAWarning,
)
return super()._get_sqla_table(
name, schema=schema, autoload=autoload, **kwargs
)
return super()._get_sqla_table(name, autoload=autoload, **kwargs)
2 changes: 2 additions & 0 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis.backends.druid.datatypes as ddt
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.druid.registry import operation_registry
from ibis.expr.rewrites import rewrite_sample


class DruidExprTranslator(AlchemyExprTranslator):
Expand All @@ -29,3 +30,4 @@ def translate(self, op):
class DruidCompiler(AlchemyCompiler):
translator_class = DruidExprTranslator
null_limit = sa.literal_column("ALL")
rewrites = AlchemyCompiler.rewrites | rewrite_sample
294 changes: 195 additions & 99 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
import os
import warnings
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
)
from typing import TYPE_CHECKING, Any

import duckdb
import pyarrow as pa
import pyarrow_hotfix # noqa: F401
import sqlalchemy as sa
import sqlglot as sg
import toolz

import ibis.common.exceptions as exc
Expand All @@ -24,7 +23,8 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base import CanCreateSchema
from ibis.backends.base.sql.alchemy import BaseAlchemyBackend
from ibis.backends.base.sql.alchemy import AlchemyCrossSchemaBackend
from ibis.backends.base.sqlglot import C, F
from ibis.backends.duckdb.compiler import DuckDBSQLCompiler
from ibis.backends.duckdb.datatypes import DuckDBType
from ibis.expr.operations.relations import PandasDataFrameProxy
Expand All @@ -36,6 +36,7 @@

import pandas as pd
import torch
from fsspec import AbstractFileSystem


def normalize_filenames(source_list):
Expand Down Expand Up @@ -67,11 +68,41 @@ def _format_kwargs(kwargs: Mapping[str, Any]):
}


class Backend(BaseAlchemyBackend, CanCreateSchema):
class _Settings:
def __init__(self, con):
self.con = con

def __getitem__(self, key):
try:
with self.con.begin() as con:
return con.exec_driver_sql(
f"select value from duckdb_settings() where name = '{key}'"
).one()
except sa.exc.NoResultFound:
raise KeyError(key)

def __setitem__(self, key, value):
with self.con.begin() as con:
con.exec_driver_sql(f"SET {key}='{value}'")

def __repr__(self):
with self.con.begin() as con:
kv = con.exec_driver_sql(
"select map(array_agg(name), array_agg(value)) from duckdb_settings()"
).scalar()

return repr(dict(zip(kv["key"], kv["value"])))


class Backend(AlchemyCrossSchemaBackend, CanCreateSchema):
name = "duckdb"
compiler = DuckDBSQLCompiler
supports_create_or_replace = True

@property
def settings(self) -> _Settings:
return _Settings(self)

@property
def current_database(self) -> str:
return self._scalar_query(sa.select(sa.func.current_database()))
Expand Down Expand Up @@ -234,10 +265,13 @@ def configure_connection(dbapi_connection, connection_record):
with contextlib.suppress(duckdb.InvalidInputException):
duckdb.execute("SELECT ?", (1,))

engine.dialect._backslash_escapes = False
super().do_connect(engine)

@staticmethod
def _sa_load_extensions(dbapi_con, extensions):
def _sa_load_extensions(
dbapi_con, extensions: list[str], force_install: bool = False
) -> None:
query = """
WITH exts AS (
SELECT extension_name AS name, aliases FROM duckdb_extensions()
Expand All @@ -250,22 +284,28 @@ def _sa_load_extensions(dbapi_con, extensions):
# Install and load all other extensions
todo = set(extensions).difference(installed)
for extension in todo:
dbapi_con.install_extension(extension)
dbapi_con.install_extension(extension, force_install=force_install)
dbapi_con.load_extension(extension)

def _load_extensions(self, extensions):
def _load_extensions(
self, extensions: list[str], force_install: bool = False
) -> None:
with self.begin() as con:
self._sa_load_extensions(con.connection, extensions)
self._sa_load_extensions(
con.connection, extensions, force_install=force_install
)

def load_extension(self, extension: str) -> None:
def load_extension(self, extension: str, force_install: bool = False) -> None:
"""Install and load a duckdb extension by name or path.
Parameters
----------
extension
The extension name or path.
force_install
Force reinstallation of the extension.
"""
self._load_extensions([extension])
self._load_extensions([extension], force_install=force_install)

def create_schema(
self, name: str, database: str | None = None, force: bool = False
Expand Down Expand Up @@ -454,85 +494,6 @@ def read_csv(
con.exec_driver_sql(view)
return self.table(table_name)

def _get_sqla_table(
self,
name: str,
schema: str | None = None,
database: str | None = None,
**_: Any,
) -> sa.Table:
if schema is None:
schema = self.current_schema
*db, schema = schema.split(".")
db = "".join(db) or database
ident = ".".join(
map(
self._quote,
filter(None, (db if db != self.current_database else None, schema)),
)
)

s = sa.table(
"columns",
sa.column("table_catalog", sa.TEXT()),
sa.column("table_schema", sa.TEXT()),
sa.column("table_name", sa.TEXT()),
sa.column("column_name", sa.TEXT()),
sa.column("data_type", sa.TEXT()),
sa.column("is_nullable", sa.TEXT()),
sa.column("ordinal_position", sa.INTEGER()),
schema="information_schema",
)

where = s.c.table_name == name

if db:
where &= s.c.table_catalog == db

if schema:
where &= s.c.table_schema == schema

query = (
sa.select(
s.c.column_name,
s.c.data_type,
(s.c.is_nullable == "YES").label("nullable"),
)
.where(where)
.order_by(sa.asc(s.c.ordinal_position))
)

with self.begin() as con:
# fetch metadata with pyarrow, it's much faster for wide tables
meta = con.execute(query).cursor.fetch_arrow_table()

if not meta:
raise sa.exc.NoSuchTableError(name)

names = meta["column_name"].to_pylist()
types = meta["data_type"].to_pylist()
nullables = meta["nullable"].to_pylist()

ibis_schema = sch.Schema(
{
name: DuckDBType.from_string(typ, nullable=nullable)
for name, typ, nullable in zip(names, types, nullables)
}
)
columns = self._columns_from_schema(name, ibis_schema)
return sa.table(name, *columns, schema=ident)

def drop_table(
self, name: str, database: str | None = None, force: bool = False
) -> None:
name = self._quote(name)
# TODO: handle database quoting
if database is not None:
name = f"{database}.{name}"
drop_stmt = "DROP TABLE" + (" IF EXISTS" * force) + f" {name}"
with self.begin() as con:
con.exec_driver_sql(drop_stmt)

def read_parquet(
self,
source_list: str | Iterable[str],
Expand Down Expand Up @@ -684,14 +645,73 @@ def read_delta(
delta_table.to_pyarrow_dataset(), table_name=table_name
)

def list_tables(self, like=None, database=None):
tables = self.inspector.get_table_names(schema=database)
views = self.inspector.get_view_names(schema=database)
# workaround for GH5503
temp_views = self.inspector.get_view_names(
schema="temp" if database is None else database
def list_tables(
self,
like: str | None = None,
database: str | None = None,
schema: str | None = None,
) -> list[str]:
"""List tables and views.
Parameters
----------
like
Regex to filter by table/view name.
database
Database name. If not passed, uses the current database. Only
supported with MotherDuck.
schema
Schema name. If not passed, uses the current schema.
Returns
-------
list[str]
List of table and view names.
Examples
--------
>>> import ibis
>>> con = ibis.duckdb.connect()
>>> foo = con.create_table("foo", schema=ibis.schema(dict(a="int")))
>>> con.list_tables()
['foo']
>>> bar = con.create_view("bar", foo)
>>> con.list_tables()
['bar', 'foo']
>>> con.create_schema("my_schema")
>>> con.list_tables(schema="my_schema")
[]
>>> with con.begin() as c:
... c.exec_driver_sql(
... "CREATE TABLE my_schema.baz (a INTEGER)"
... ) # doctest: +ELLIPSIS
...
<...>
>>> con.list_tables(schema="my_schema")
['baz']
"""
database = (
F.current_database() if database is None else sg.exp.convert(database)
)
schema = F.current_schema() if schema is None else sg.exp.convert(schema)

sql = (
sg.select(C.table_name)
.from_(sg.table("tables", db="information_schema"))
.distinct()
.where(
C.table_catalog.eq(database).or_(
C.table_catalog.eq(sg.exp.convert("temp"))
),
C.table_schema.eq(schema),
)
.sql(self.name, pretty=True)
)
return self._filter_with_like(tables + views + temp_views, like)

with self.begin() as con:
out = con.exec_driver_sql(sql).cursor.fetch_arrow_table()

return self._filter_with_like(out["table_name"].to_pylist(), like)

def read_postgres(
self, uri: str, table_name: str | None = None, schema: str = "public"
Expand Down Expand Up @@ -781,6 +801,44 @@ def read_sqlite(self, path: str | Path, table_name: str | None = None) -> ir.Tab

return self.table(table_name)

def attach(
self, path: str | Path, name: str | None = None, read_only: bool = False
) -> None:
"""Attach another DuckDB database to the current DuckDB session.
Parameters
----------
path
Path to the database to attach.
name
Name to attach the database as. Defaults to the basename of `path`.
read_only
Whether to attach the database as read-only.
"""
code = f"ATTACH '{path}'"

if name is not None:
name = sg.to_identifier(name).sql(self.name)
code += f" AS {name}"

if read_only:
code += " (READ_ONLY)"

with self.begin() as con:
con.exec_driver_sql(code)

def detach(self, name: str) -> None:
"""Detach a database from the current DuckDB session.
Parameters
----------
name
The name of the database to detach.
"""
name = sg.to_identifier(name).sql(self.name)
with self.begin() as con:
con.exec_driver_sql(f"DETACH {name}")

def attach_sqlite(
self, path: str | Path, overwrite: bool = False, all_varchar: bool = False
) -> None:
Expand Down Expand Up @@ -820,6 +878,44 @@ def attach_sqlite(
con.execute(sa.text(f"SET GLOBAL sqlite_all_varchar={all_varchar}"))
con.execute(sa.text(f"CALL sqlite_attach('{path}', overwrite={overwrite})"))

def register_filesystem(self, filesystem: AbstractFileSystem):
"""Register an `fsspec` filesystem object with DuckDB.
This allow a user to read from any `fsspec` compatible filesystem using
`read_csv`, `read_parquet`, `read_json`, etc.
::: {.callout-note}
Creating an `fsspec` filesystem requires that the corresponding
backend-specific `fsspec` helper library is installed.
e.g. to connect to Google Cloud Storage, `gcsfs` must be installed.
:::
Parameters
----------
filesystem
The fsspec filesystem object to register with DuckDB.
See https://duckdb.org/docs/guides/python/filesystems for details.
Examples
--------
>>> import ibis
>>> import fsspec
>>> gcs = fsspec.filesystem("gcs")
>>> con = ibis.duckdb.connect()
>>> con.register_filesystem(gcs)
>>> t = con.read_csv(
... "gcs://ibis-examples/data/band_members.csv.gz",
... table_name="band_members",
... )
DatabaseTable: band_members
name string
band string
"""
with self.begin() as con:
con.connection.register_filesystem(filesystem)

def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:
# Warn for any tables depending on RecordBatchReaders that have already
# started being consumed.
Expand Down Expand Up @@ -1123,7 +1219,7 @@ def _get_temp_view_definition(
def _register_udfs(self, expr: ir.Expr) -> None:
import ibis.expr.operations as ops

with self.begin() as con:
with self.con.connect() as con:
for udf_node in expr.op().find(ops.ScalarUDF):
compile_func = getattr(
self, f"_compile_{udf_node.__input_type__.name.lower()}_udf"
Expand Down
16 changes: 14 additions & 2 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ibis.backends.base.sql.alchemy.datatypes as sat
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter
from ibis.backends.duckdb.datatypes import DuckDBType
from ibis.backends.duckdb.registry import operation_registry

Expand Down Expand Up @@ -55,13 +56,24 @@ def compile_array(element, compiler, **kw):

@rewrites(ops.Any)
@rewrites(ops.All)
@rewrites(ops.NotAny)
@rewrites(ops.NotAll)
@rewrites(ops.StringContains)
def _no_op(expr):
return expr


class DuckDBTableSetFormatter(_AlchemyTableSetFormatter):
def _format_sample(self, op, table):
if op.method == "row":
method = sa.func.bernoulli
else:
method = sa.func.system
return table.tablesample(
sampling=method(sa.literal_column(f"{op.fraction * 100} PERCENT")),
seed=(None if op.seed is None else sa.literal_column(str(op.seed))),
)


class DuckDBSQLCompiler(AlchemyCompiler):
cheap_in_memory_tables = True
translator_class = DuckDBSQLExprTranslator
table_set_formatter_class = DuckDBTableSetFormatter
24 changes: 19 additions & 5 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ def _timestamp_from_unix(t, op):
raise UnsupportedOperationError(f"{unit!r} unit is not supported!")


def _timestamp_bucket(t, op):
arg = t.translate(op.arg)
interval = t.translate(op.interval)

origin = sa.literal_column("'epoch'::TIMESTAMP")

if op.offset is not None:
origin += t.translate(op.offset)
return sa.func.time_bucket(interval, arg, origin)


class struct_pack(GenericFunction):
def __init__(self, values: Mapping[str, Any], *, type: StructType) -> None:
super().__init__()
Expand Down Expand Up @@ -331,6 +342,11 @@ def _try_cast(t, op):
)


def _to_json_collection(t, op):
typ = t.get_sqla_type(op.dtype)
return try_cast(t.translate(op.arg), typ, type_=typ)


operation_registry.update(
{
ops.ArrayColumn: (
Expand Down Expand Up @@ -413,6 +429,7 @@ def _try_cast(t, op):
),
ops.TableColumn: _table_column,
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.TimestampBucket: _timestamp_bucket,
ops.TimestampNow: fixed_arity(
# duckdb 0.6.0 changes now to be a timestamp with time zone force
# it back to the original for backwards compatibility
Expand Down Expand Up @@ -477,18 +494,15 @@ def _try_cast(t, op):
ops.TimeDelta: _temporal_delta,
ops.DateDelta: _temporal_delta,
ops.TimestampDelta: _temporal_delta,
ops.ToJSONMap: _to_json_collection,
ops.ToJSONArray: _to_json_collection,
}
)


_invalid_operations = {
# ibis.expr.operations.analytic
ops.NTile,
# ibis.expr.operations.strings
ops.Translate,
# ibis.expr.operations.json
ops.ToJSONMap,
ops.ToJSONArray,
}

operation_registry = {
Expand Down
9 changes: 5 additions & 4 deletions ibis/backends/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ def ddl_script(self) -> Iterator[str]:

@staticmethod
def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
# extension directory per test worker to prevent simultaneous downloads
return ibis.duckdb.connect(
extension_directory=str(tmpdir.mktemp(f"{worker_id}_exts")), **kw
)
# use an extension directory per test worker to prevent simultaneous
# downloads
extension_directory = tmpdir.getbasetemp().joinpath("duckdb_extensions")
extension_directory.mkdir(exist_ok=True)
return ibis.duckdb.connect(extension_directory=extension_directory, **kw)

def load_tpch(self) -> None:
with self.connection.begin() as con:
Expand Down
103 changes: 103 additions & 0 deletions ibis/backends/duckdb/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

import duckdb
import pyarrow as pa
import pytest
import sqlalchemy as sa
from pytest import param

import ibis
import ibis.expr.datatypes as dt
from ibis.conftest import LINUX, SANDBOXED
from ibis.util import gen_name


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -55,3 +59,102 @@ def test_load_extension(ext_directory):
"""
).fetchall()
assert all(loaded for (loaded,) in results)


def test_cross_db(tmpdir):
import duckdb

path1 = str(tmpdir.join("test1.ddb"))
with duckdb.connect(path1) as con1:
con1.execute("CREATE SCHEMA foo")
con1.execute("CREATE TABLE t1 (x BIGINT)")
con1.execute("CREATE TABLE foo.t1 (x BIGINT)")

path2 = str(tmpdir.join("test2.ddb"))
con2 = ibis.duckdb.connect(path2)
t2 = con2.create_table("t2", schema=ibis.schema(dict(x="int")))

con2.attach(path1, name="test1", read_only=True)

t1_from_con2 = con2.table("t1", schema="test1.main")
assert t1_from_con2.schema() == t2.schema()
assert t1_from_con2.execute().equals(t2.execute())

foo_t1_from_con2 = con2.table("t1", schema="test1.foo")
assert foo_t1_from_con2.schema() == t2.schema()
assert foo_t1_from_con2.execute().equals(t2.execute())


def test_attach_detach(tmpdir):
import duckdb

path1 = str(tmpdir.join("test1.ddb"))
with duckdb.connect(path1):
pass

path2 = str(tmpdir.join("test2.ddb"))
con2 = ibis.duckdb.connect(path2)

# default name
name = "test1"
assert name not in con2.list_databases()

con2.attach(path1)
assert name in con2.list_databases()

con2.detach(name)
assert name not in con2.list_databases()

# passed-in name
name = "test_foo"
assert name not in con2.list_databases()

con2.attach(path1, name=name)
assert name in con2.list_databases()

con2.detach(name)
assert name not in con2.list_databases()

with pytest.raises(sa.exc.ProgrammingError):
con2.detach(name)


@pytest.mark.parametrize(
"scale",
[
None,
param(0, id="seconds"),
param(3, id="millis"),
param(6, id="micros"),
param(9, id="nanos"),
],
)
def test_create_table_with_timestamp_scales(con, scale):
schema = ibis.schema(dict(ts=dt.Timestamp(scale=scale)))
t = con.create_table(gen_name("duckdb_timestamp_scale"), schema=schema, temp=True)
assert t.schema() == schema


def test_config_options(con):
a_first = {"a": [None, 1]}
a_last = {"a": [1, None]}
nulls_first = pa.Table.from_pydict(a_first, schema=pa.schema([("a", pa.float64())]))
nulls_last = pa.Table.from_pydict(a_last, schema=pa.schema([("a", pa.float64())]))

t = ibis.memtable(a_last)

expr = t.order_by("a")

assert con.to_pyarrow(expr) == nulls_last

con.settings["null_order"] = "nulls_first"

assert con.to_pyarrow(expr) == nulls_first


def test_config_options_bad_option(con):
with pytest.raises(sa.exc.ProgrammingError):
con.settings["not_a_valid_option"] = "oopsie"

with pytest.raises(KeyError):
con.settings["i_didnt_set_this"]
7 changes: 5 additions & 2 deletions ibis/backends/duckdb/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
("SMALLINT", dt.int16),
("TIME", dt.time),
("TIME WITH TIME ZONE", dt.time),
("TIMESTAMP", dt.timestamp),
("TIMESTAMP WITH TIME ZONE", dt.Timestamp("UTC")),
("TIMESTAMP", dt.Timestamp(scale=6)),
("TIMESTAMP WITH TIME ZONE", dt.Timestamp(scale=6, timezone="UTC")),
("TINYINT", dt.int8),
("UBIGINT", dt.uint64),
("UINTEGER", dt.uint32),
Expand All @@ -53,6 +53,9 @@
("INTEGER[][]", dt.Array(dt.Array(dt.int32))),
("JSON", dt.json),
("HUGEINT", dt.Decimal(38, 0)),
("TIMESTAMP_S", dt.Timestamp(scale=0)),
("TIMESTAMP_MS", dt.Timestamp(scale=3)),
("TIMESTAMP_NS", dt.Timestamp(scale=9)),
]
],
)
Expand Down
28 changes: 27 additions & 1 deletion ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas.testing as tm
import pyarrow as pa
import pytest
import sqlalchemy as sa

import ibis
import ibis.common.exceptions as exc
Expand Down Expand Up @@ -279,9 +280,17 @@ def test_set_temp_dir(tmp_path):
assert path.exists()


@pytest.mark.xfail(
LINUX and SANDBOXED,
reason=(
"nix on linux cannot download duckdb extensions or data due to sandboxing; "
"duckdb will try to automatically install and load read_parquet"
),
raises=(duckdb.IOException, sa.exc.DBAPIError),
)
def test_s3_403_fallback(con, httpserver, monkeypatch):
# monkeypatch to avoid downloading extensions in tests
monkeypatch.setattr(con, "_load_extensions", lambda x: True)
monkeypatch.setattr(con, "_load_extensions", lambda _: True)

# Throw a 403 to trigger fallback to pyarrow.dataset
httpserver.expect_request("/myfile").respond_with_data(
Expand Down Expand Up @@ -341,3 +350,20 @@ def test_csv_with_slash_n_null(con, tmp_path):
t = con.read_csv(data_path, nullstr="\\N")
col = t.a.execute()
assert pd.isna(col.iat[-1])


@pytest.mark.xfail(
LINUX and SANDBOXED,
reason=("nix can't hit GCS because it is sandboxed."),
)
def test_register_filesystem_gcs(con):
import fsspec

gcs = fsspec.filesystem("gcs")

con.register_filesystem(gcs)
band_members = con.read_csv(
"gcs://ibis-examples/data/band_members.csv.gz", table_name="band_members"
)

assert band_members.count().to_pyarrow()
9 changes: 7 additions & 2 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pyflink.table import TableEnvironment
from pyflink.table.table_result import TableResult

from ibis.expr.streaming import Watermark
from ibis.api import Watermark


class Backend(BaseBackend, CanCreateDatabase):
Expand Down Expand Up @@ -191,7 +191,10 @@ def table(
_, quoted, unquoted = fully_qualified_re.search(qualified_name).groups()
unqualified_name = quoted or unquoted
node = ops.DatabaseTable(
unqualified_name, schema, self, namespace=database
unqualified_name,
schema,
self,
namespace=ops.Namespace(schema=database, database=catalog),
) # TODO(chloeh13q): look into namespacing with catalog + db
return node.to_expr()

Expand Down Expand Up @@ -309,6 +312,7 @@ def create_table(
"""
import pandas as pd
import pyarrow as pa
import pyarrow_hotfix # noqa: F401

import ibis.expr.types as ir

Expand Down Expand Up @@ -482,6 +486,7 @@ def insert(
"""
import pandas as pd
import pyarrow as pa
import pyarrow_hotfix # noqa: F401

if isinstance(obj, ir.Table):
expr = obj
Expand Down
9 changes: 0 additions & 9 deletions ibis/backends/flink/compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
from __future__ import annotations

from public import public

from ibis.backends.flink.compiler.core import translate

public(
translate=translate,
)
Loading