Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
SELECT
"t2"."playerID",
"t2"."yearID",
"t2"."stint",
"t2"."teamID",
"t2"."lgID",
"t2"."G",
"t2"."AB",
"t2"."R",
"t2"."H",
"t2"."X2B",
"t2"."X3B",
"t2"."HR",
"t2"."RBI",
"t2"."SB",
"t2"."CS",
"t2"."BB",
"t2"."SO",
"t2"."IBB",
"t2"."HBP",
"t2"."SH",
"t2"."SF",
"t2"."GIDP"
"t2"."playerID" AS "playerID",
"t2"."yearID" AS "yearID",
"t2"."stint" AS "stint",
"t2"."teamID" AS "teamID",
"t2"."lgID" AS "lgID",
"t2"."G" AS "G",
"t2"."AB" AS "AB",
"t2"."R" AS "R",
"t2"."H" AS "H",
"t2"."X2B" AS "X2B",
"t2"."X3B" AS "X3B",
"t2"."HR" AS "HR",
"t2"."RBI" AS "RBI",
"t2"."SB" AS "SB",
"t2"."CS" AS "CS",
"t2"."BB" AS "BB",
"t2"."SO" AS "SO",
"t2"."IBB" AS "IBB",
"t2"."HBP" AS "HBP",
"t2"."SH" AS "SH",
"t2"."SF" AS "SF",
"t2"."GIDP" AS "GIDP"
FROM "batting" AS "t2"
LEFT OUTER JOIN "awards_players" AS "t3"
ON "t2"."playerID" = "t3"."awardID"
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
SELECT
"t2"."playerID",
"t2"."yearID",
"t2"."stint",
"t2"."teamID",
"t2"."lgID",
"t2"."G",
"t2"."AB",
"t2"."R",
"t2"."H",
"t2"."X2B",
"t2"."X3B",
"t2"."HR",
"t2"."RBI",
"t2"."SB",
"t2"."CS",
"t2"."BB",
"t2"."SO",
"t2"."IBB",
"t2"."HBP",
"t2"."SH",
"t2"."SF",
"t2"."GIDP"
"t2"."playerID" AS "playerID",
"t2"."yearID" AS "yearID",
"t2"."stint" AS "stint",
"t2"."teamID" AS "teamID",
"t2"."lgID" AS "lgID",
"t2"."G" AS "G",
"t2"."AB" AS "AB",
"t2"."R" AS "R",
"t2"."H" AS "H",
"t2"."X2B" AS "X2B",
"t2"."X3B" AS "X3B",
"t2"."HR" AS "HR",
"t2"."RBI" AS "RBI",
"t2"."SB" AS "SB",
"t2"."CS" AS "CS",
"t2"."BB" AS "BB",
"t2"."SO" AS "SO",
"t2"."IBB" AS "IBB",
"t2"."HBP" AS "HBP",
"t2"."SH" AS "SH",
"t2"."SF" AS "SF",
"t2"."GIDP" AS "GIDP"
FROM "batting" AS "t2"
ANY JOIN "awards_players" AS "t3"
ON "t2"."playerID" = "t3"."playerID"
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
SELECT
"t2"."playerID",
"t2"."yearID",
"t2"."stint",
"t2"."teamID",
"t2"."lgID",
"t2"."G",
"t2"."AB",
"t2"."R",
"t2"."H",
"t2"."X2B",
"t2"."X3B",
"t2"."HR",
"t2"."RBI",
"t2"."SB",
"t2"."CS",
"t2"."BB",
"t2"."SO",
"t2"."IBB",
"t2"."HBP",
"t2"."SH",
"t2"."SF",
"t2"."GIDP"
"t2"."playerID" AS "playerID",
"t2"."yearID" AS "yearID",
"t2"."stint" AS "stint",
"t2"."teamID" AS "teamID",
"t2"."lgID" AS "lgID",
"t2"."G" AS "G",
"t2"."AB" AS "AB",
"t2"."R" AS "R",
"t2"."H" AS "H",
"t2"."X2B" AS "X2B",
"t2"."X3B" AS "X3B",
"t2"."HR" AS "HR",
"t2"."RBI" AS "RBI",
"t2"."SB" AS "SB",
"t2"."CS" AS "CS",
"t2"."BB" AS "BB",
"t2"."SO" AS "SO",
"t2"."IBB" AS "IBB",
"t2"."HBP" AS "HBP",
"t2"."SH" AS "SH",
"t2"."SF" AS "SF",
"t2"."GIDP" AS "GIDP"
FROM "batting" AS "t2"
LEFT ANY JOIN "awards_players" AS "t3"
ON "t2"."playerID" = "t3"."playerID"
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
SELECT
"t2"."playerID",
"t2"."yearID",
"t2"."stint",
"t2"."teamID",
"t2"."lgID",
"t2"."G",
"t2"."AB",
"t2"."R",
"t2"."H",
"t2"."X2B",
"t2"."X3B",
"t2"."HR",
"t2"."RBI",
"t2"."SB",
"t2"."CS",
"t2"."BB",
"t2"."SO",
"t2"."IBB",
"t2"."HBP",
"t2"."SH",
"t2"."SF",
"t2"."GIDP"
"t2"."playerID" AS "playerID",
"t2"."yearID" AS "yearID",
"t2"."stint" AS "stint",
"t2"."teamID" AS "teamID",
"t2"."lgID" AS "lgID",
"t2"."G" AS "G",
"t2"."AB" AS "AB",
"t2"."R" AS "R",
"t2"."H" AS "H",
"t2"."X2B" AS "X2B",
"t2"."X3B" AS "X3B",
"t2"."HR" AS "HR",
"t2"."RBI" AS "RBI",
"t2"."SB" AS "SB",
"t2"."CS" AS "CS",
"t2"."BB" AS "BB",
"t2"."SO" AS "SO",
"t2"."IBB" AS "IBB",
"t2"."HBP" AS "HBP",
"t2"."SH" AS "SH",
"t2"."SF" AS "SF",
"t2"."GIDP" AS "GIDP"
FROM "batting" AS "t2"
INNER JOIN "awards_players" AS "t3"
ON "t2"."playerID" = "t3"."playerID"
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
SELECT
"t2"."playerID",
"t2"."yearID",
"t2"."stint",
"t2"."teamID",
"t2"."lgID",
"t2"."G",
"t2"."AB",
"t2"."R",
"t2"."H",
"t2"."X2B",
"t2"."X3B",
"t2"."HR",
"t2"."RBI",
"t2"."SB",
"t2"."CS",
"t2"."BB",
"t2"."SO",
"t2"."IBB",
"t2"."HBP",
"t2"."SH",
"t2"."SF",
"t2"."GIDP"
"t2"."playerID" AS "playerID",
"t2"."yearID" AS "yearID",
"t2"."stint" AS "stint",
"t2"."teamID" AS "teamID",
"t2"."lgID" AS "lgID",
"t2"."G" AS "G",
"t2"."AB" AS "AB",
"t2"."R" AS "R",
"t2"."H" AS "H",
"t2"."X2B" AS "X2B",
"t2"."X3B" AS "X3B",
"t2"."HR" AS "HR",
"t2"."RBI" AS "RBI",
"t2"."SB" AS "SB",
"t2"."CS" AS "CS",
"t2"."BB" AS "BB",
"t2"."SO" AS "SO",
"t2"."IBB" AS "IBB",
"t2"."HBP" AS "HBP",
"t2"."SH" AS "SH",
"t2"."SF" AS "SF",
"t2"."GIDP" AS "GIDP"
FROM "batting" AS "t2"
LEFT OUTER JOIN "awards_players" AS "t3"
ON "t2"."playerID" = "t3"."playerID"
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
SELECT
"t2"."string_col"
"t2"."string_col" AS "string_col"
FROM (
SELECT
"t1"."string_col",
"t1"."string_col" AS "string_col",
SUM("t1"."float_col") AS "total"
FROM (
SELECT
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
"t0"."uuid",
"t0"."uuid" AS "uuid",
minIf("t0"."ts", "t0"."search_level" = 1) AS "min_date"
FROM "t" AS "t0"
GROUP BY
Expand Down
63 changes: 63 additions & 0 deletions ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
from urllib.parse import quote_plus

import pandas as pd
import pandas.testing as tm
Expand All @@ -12,6 +13,14 @@
import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis import config, udf
from ibis.backends.clickhouse.tests.conftest import (
CLICKHOUSE_HOST,
CLICKHOUSE_PASS,
CLICKHOUSE_PORT,
CLICKHOUSE_USER,
IBIS_TEST_CLICKHOUSE_DB,
)
from ibis.backends.tests.errors import ClickHouseDatabaseError
from ibis.util import gen_name

cc = pytest.importorskip("clickhouse_connect")
Expand Down Expand Up @@ -346,3 +355,57 @@ def test_create_table_no_syntax_error(con):
)
t = con.create_table(gen_name("clickouse_temp_table"), schema=schema, temp=True)
assert t.count().execute() == 0


def test_password_with_bracket():
password = f'{os.environ.get("IBIS_TEST_CLICKHOUSE_PASSWORD", "")}[]'
quoted_pass = quote_plus(password)
host = os.environ.get("IBIS_TEST_CLICKHOUSE_HOST", "localhost")
user = os.environ.get("IBIS_TEST_CLICKHOUSE_USER", "default")
port = int(os.environ.get("IBIS_TEST_CLICKHOUSE_PORT", 8123))
with pytest.raises(
cc.driver.exceptions.DatabaseError, match="password is incorrect"
):
ibis.clickhouse.connect(host=host, user=user, port=port, password=quoted_pass)


def test_from_url(con):
assert ibis.connect(
f"clickhouse://{CLICKHOUSE_USER}:{CLICKHOUSE_PASS}@{CLICKHOUSE_HOST}:{CLICKHOUSE_PORT}/{IBIS_TEST_CLICKHOUSE_DB}"
)


def test_invalid_port(con):
port = 9999
url = f"clickhouse://{CLICKHOUSE_USER}:{CLICKHOUSE_PASS}@{CLICKHOUSE_HOST}:{port}/{IBIS_TEST_CLICKHOUSE_DB}"
with pytest.raises(cc.driver.exceptions.DatabaseError):
ibis.connect(url)


def test_subquery_with_join(con):
name = gen_name("clickhouse_tmp_table")

s = con.create_table(name, pa.Table.from_pydict({"a": [1, 2, 3]}), temp=True)

sql = f"""
SELECT
"o"."a"
FROM (
SELECT
"w"."a"
FROM "{name}" AS "s"
INNER JOIN "{name}" AS "w"
USING ("a")
) AS "o"
"""
with pytest.raises(
ClickHouseDatabaseError, match="Identifier 'o.a' cannot be resolved"
):
# https://github.com/ClickHouse/ClickHouse/issues/66133
con.sql(sql)

# this works because we add the additional alias in the inner query
w = s.view()
expr = s.join(w, "a").select(a=w.a).select(b=lambda t: t.a + 1)
result = expr.to_pandas()
assert set(result["b"].tolist()) == {2, 3, 4}
1 change: 1 addition & 0 deletions ibis/backends/clickhouse/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def test_array_discovery_clickhouse(con):
),
id="nested",
),
param("Date32", dt.Date(nullable=False), id="date32"),
param("DateTime", dt.Timestamp(scale=0, nullable=False), id="datetime"),
param(
"DateTime('Europe/Budapest')",
Expand Down
27 changes: 3 additions & 24 deletions ibis/backends/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import TYPE_CHECKING, Any

import _pytest
import numpy as np
import pandas as pd
import pytest
from packaging.requirements import Requirement
Expand Down Expand Up @@ -237,11 +236,6 @@ def pytest_collection_modifyitems(session, config, items):
all_backends = _get_backend_names()
additional_markers = []

try:
import pyspark
except ImportError:
pyspark = None

unrecognized_backends = set()
for item in items:
# Yell loudly if unrecognized backend in notimpl, notyet or never
Expand Down Expand Up @@ -271,23 +265,6 @@ def pytest_collection_modifyitems(session, config, items):
if not any(item.iter_markers(name="benchmark")):
item.add_marker(pytest.mark.core)

# skip or xfail pyspark tests that run afoul of our non-ancient stack
for _ in item.iter_markers(name="pyspark"):
if not isinstance(item, pytest.DoctestItem):
additional_markers.append(
(
item,
[
pytest.mark.skipif(
pyspark is not None
and vparse(pyspark.__version__) < vparse("3.3.3")
and vparse(np.__version__) >= vparse("1.24"),
reason="PySpark doesn't support numpy >= 1.24",
),
],
)
)

if unrecognized_backends:
raise pytest.PytestCollectionWarning("\n" + "\n".join(unrecognized_backends))

Expand Down Expand Up @@ -420,7 +397,9 @@ def _filter_none_from_raises(kwargs):
failing_specs = []
for spec in specs:
req = Requirement(spec)
if req.specifier.contains(importlib.import_module(req.name).__version__):
if req.specifier.contains(
importlib.import_module(req.name).__version__
) and ((not req.marker) or req.marker.evaluate()):
failing_specs.append(spec)
reason = f"{backend} backend test fails with {backend}{specs}"
if provided_reason is not None:
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/dask/convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import dask.dataframe as dd
import numpy as np
import pandas as pd
import pandas.api.types as pdt

Expand All @@ -25,7 +26,7 @@ def convert_column(cls, obj, dtype):

@classmethod
def convert_default(cls, s, dtype, pandas_type):
if pandas_type == object:
if pandas_type == np.object_:
func = lambda x: x if x is pd.NA else dt.normalize(dtype, x)
meta = (s.name, pandas_type)
return s.map(func, na_action="ignore", meta=meta).astype(pandas_type)
Expand Down
17 changes: 16 additions & 1 deletion ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,25 @@ def visit(cls, op: ops.Sort, parent, keys):
# 2. sort the dataframe using those columns
# 3. drop the sort key columns
ascending = [key.ascending for key in op.keys]
nulls_first = [key.nulls_first for key in op.keys]

if all(nulls_first):
na_position = "first"
elif not any(nulls_first):
na_position = "last"
else:
raise ValueError(
"dask does not support specifying null ordering for individual columns"
)

newcols = {gen_name("sort_key"): col for col in keys}
names = list(newcols.keys())
df = parent.assign(**newcols)
df = df.sort_values(by=names, ascending=ascending)
df = df.sort_values(
by=names,
ascending=ascending,
na_position=na_position,
)
return df.drop(names, axis=1)

@classmethod
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/dask/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ def elementwise(cls, func, operands, name, dtype):
def partitionwise(cls, func, operands, name, dtype):
cols = {}
kwargs = {}
for name, operand in operands.items():
for opname, operand in operands.items():
if isinstance(operand, (tuple, list)):
for i, v in enumerate(operand):
cols[f"{name}_{i}"] = v
kwargs[name] = tuple(f"{name}_{i}" for i in range(len(operand)))
cols[f"{opname}_{i}"] = v
kwargs[opname] = tuple(f"{opname}_{i}" for i in range(len(operand)))
else:
cols[name] = operand
kwargs[name] = name
cols[opname] = operand
kwargs[opname] = opname

def mapper(df):
unpacked = {}
Expand Down
58 changes: 0 additions & 58 deletions ibis/backends/dask/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,64 +773,6 @@ def q_fun(x, quantile):
tm.assert_series_equal(result, expected, check_index=False)


def test_searched_case_scalar(client):
expr = ibis.case().when(True, 1).when(False, 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_searched_case_column(batting, batting_pandas_df):
t = batting
df = batting_pandas_df
expr = (
ibis.case()
.when(t.RBI < 5, "really bad team")
.when(t.teamID == "PH1", "ph1 team")
.else_(t.teamID)
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI < 5, df.teamID == "PH1"],
["really bad team", "ph1 team"],
df.teamID,
)
)
tm.assert_series_equal(result, expected, check_names=False)


def test_simple_case_scalar(client):
x = ibis.literal(2)
expr = x.case().when(2, x - 1).when(3, x + 1).when(4, x + 2).end()
result = client.execute(expr)
expected = np.int8(1)
assert result == expected


def test_simple_case_column(batting, batting_pandas_df):
t = batting
df = batting_pandas_df
expr = (
t.RBI.case()
.when(5, "five")
.when(4, "four")
.when(3, "three")
.else_("could be good?")
.end()
)
result = expr.execute()
expected = pd.Series(
np.select(
[df.RBI == 5, df.RBI == 4, df.RBI == 3],
["five", "four", "three"],
"could be good?",
)
)
tm.assert_series_equal(result, expected, check_names=False)


def test_table_distinct(t, df):
expr = t[["dup_strings"]].distinct()
result = expr.compile()
Expand Down
49 changes: 30 additions & 19 deletions ibis/backends/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
import sqlglot as sg
import sqlglot.expressions as sge

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl
from ibis.backends.datafusion.compiler import DataFusionCompiler
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import C
from ibis.backends.sql.compilers import DataFusionCompiler
from ibis.backends.sql.compilers.base import C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.formats.pyarrow import PyArrowSchema, PyArrowType
Expand Down Expand Up @@ -77,12 +79,13 @@ def version(self):
def do_connect(
self, config: Mapping[str, str | Path] | SessionContext | None = None
) -> None:
"""Create a Datafusion backend for use with Ibis.
"""Create a DataFusion `Backend` for use with Ibis.
Parameters
----------
config
Mapping of table names to files.
Mapping of table names to files or a `SessionContext`
instance.
Examples
--------
Expand Down Expand Up @@ -112,6 +115,18 @@ def do_connect(
for name, path in config.items():
self.register(path, table_name=name)

@util.experimental
@classmethod
def from_connection(cls, con: SessionContext) -> Backend:
"""Create a DataFusion `Backend` from an existing `SessionContext` instance.
Parameters
----------
con
A `SessionContext` instance.
"""
return ibis.datafusion.connect(con)

def disconnect(self) -> None:
pass

Expand All @@ -125,7 +140,7 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
src = sge.Create(
this=table,
kind="VIEW",
expression=sg.parse_one(query, read="datafusion"),
expression=sg.parse_one(query, read=self.dialect),
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
)

Expand Down Expand Up @@ -329,7 +344,7 @@ def register(
table_name
The name of the table
kwargs
Datafusion-specific keyword arguments
DataFusion-specific keyword arguments
Examples
--------
Expand Down Expand Up @@ -410,11 +425,6 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
empty_dataset = ds.dataset([], schema=schema.to_pyarrow())
self.con.register_dataset(name=name, dataset=empty_dataset)

def _register_in_memory_tables(self, expr: ir.Expr) -> None:
if self.supports_in_memory_tables:
for memtable in expr.op().find(ops.InMemoryTable):
self._register_in_memory_table(memtable)

def read_csv(
self, path: str | Path, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
Expand All @@ -428,7 +438,7 @@ def read_csv(
An optional name to use for the created table. This defaults to
a sequentially generated name.
**kwargs
Additional keyword arguments passed to Datafusion loading function.
Additional keyword arguments passed to DataFusion loading function.
Returns
-------
Expand Down Expand Up @@ -456,7 +466,7 @@ def read_parquet(
An optional name to use for the created table. This defaults to
a sequentially generated name.
**kwargs
Additional keyword arguments passed to Datafusion loading function.
Additional keyword arguments passed to DataFusion loading function.
Returns
-------
Expand Down Expand Up @@ -542,13 +552,13 @@ def make_gen():
# convert the renamed + casted columns into a record batch
pa.RecordBatch.from_struct_array(
# rename columns to match schema because datafusion lowercases things
pa.RecordBatch.from_arrays(batch.columns, names=names)
pa.RecordBatch.from_arrays(batch.to_pyarrow().columns, names=names)
# cast the struct array to the desired types to work around
# https://github.com/apache/arrow-datafusion-python/issues/534
.to_struct_array()
.cast(struct_schema, safe=False)
)
for batch in frame.collect()
for batch in frame.execute_stream()
)

return pa.ipc.RecordBatchReader.from_batches(schema.to_pyarrow(), make_gen())
Expand Down Expand Up @@ -581,7 +591,7 @@ def create_table(
temp: bool = False,
overwrite: bool = False,
):
"""Create a table in Datafusion.
"""Create a table in DataFusion.
Parameters
----------
Expand Down Expand Up @@ -633,7 +643,8 @@ def create_table(
)
)
elif obj is not None:
_read_in_memory(obj, name, self, overwrite=overwrite)
table_ident = sg.table(name, db=database, quoted=quoted).sql(self.dialect)
_read_in_memory(obj, table_ident, self, overwrite=overwrite)
return self.table(name, database=database)
else:
query = None
Expand Down Expand Up @@ -692,7 +703,7 @@ def truncate_table(
table_loc = self._warn_and_create_table_loc(database, schema)
catalog, db = self._to_catalog_db_tuple(table_loc)

ident = sg.table(name, db=db, catalog=catalog).sql(self.name)
ident = sg.table(name, db=db, catalog=catalog).sql(self.dialect)
with self._safe_raw_sql(sge.delete(ident)):
pass

Expand All @@ -701,7 +712,7 @@ def truncate_table(
def _create_and_drop_memtable(_conn, table_name, tmp_name, overwrite):
"""Workaround inability to overwrite tables in dataframe API.
Datafusion has helper methods for loading in-memory data, but these methods
DataFusion has helper methods for loading in-memory data, but these methods
don't allow overwriting tables.
The SQL interface allows creating tables from existing tables, so we register
the data as a table using the dataframe API, then run a
Expand Down
41 changes: 35 additions & 6 deletions ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

import pytest
import sqlglot as sg

import ibis
from ibis.backends.conftest import TEST_TABLES
Expand All @@ -17,10 +18,11 @@ class TestConf(BackendTest):
supports_json = False
supports_arrays = True
supports_tpch = True
supports_tpcds = True
stateful = False
deps = ("datafusion",)
# Query 1 seems to require a bit more room here
tpch_absolute_tolerance = 0.11
tpc_absolute_tolerance = 0.11

def _load_data(self, **_: Any) -> None:
con = self.connection
Expand All @@ -38,13 +40,40 @@ def _load_data(self, **_: Any) -> None:
def connect(*, tmpdir, worker_id, **kw):
return ibis.datafusion.connect(**kw)

def load_tpch(self) -> None:
def _load_tpc(self, *, suite, scale_factor):
con = self.connection
for path in self.data_dir.joinpath("tpch", "sf=0.17", "parquet").glob(
"*.parquet"
):
schema = f"tpc{suite}"
con.create_database(schema)
for path in self.data_dir.joinpath(
schema, f"sf={scale_factor}", "parquet"
).glob("*.parquet"):
table_name = path.with_suffix("").name
con.read_parquet(path, table_name=table_name)
con.con.sql(
# datafusion can't create an external table in a specific schema it seems
# so hack around that by
#
# 1. creating an external table in the current schema
# 2. create an internal table in the desired schema using a
# CTAS from the external table
# 3. drop the external table
f"CREATE EXTERNAL TABLE {table_name} STORED AS PARQUET LOCATION '{path}'"
)

con.con.sql(
f"CREATE TABLE {schema}.{table_name} AS SELECT * FROM {table_name}"
)
con.con.sql(f"DROP TABLE {table_name}")

def _transform_tpc_sql(self, parsed, *, suite, leaves):
def add_catalog_and_schema(node):
if isinstance(node, sg.exp.Table) and node.name in leaves:
return node.__class__(
catalog=f"tpc{suite}",
**{k: v for k, v in node.args.items() if k != "catalog"},
)
return node

return parsed.transform(add_catalog_and_schema)


@pytest.fixture(scope="session")
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/datafusion/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ def test_register_dataset(conn):
with pytest.warns(FutureWarning, match="v9.1"):
conn.register(dataset, "my_table")
assert conn.table("my_table").x.sum().execute() == 6


def test_create_table_with_uppercase_name(conn):
tab = pa.table({"x": [1, 2, 3]})
conn.create_table("MY_TABLE", tab)
assert conn.table("MY_TABLE").x.sum().execute() == 6
2 changes: 1 addition & 1 deletion ibis/backends/datafusion/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def median(a: float) -> float:


@pytest.mark.xfail(
condition=vparse(datafusion.__version__) == vparse("38.0.1"),
condition=vparse(datafusion.__version__) >= vparse("38.0.1"),
reason="internal error about MEDIAN(G) naming",
)
def test_builtin_agg_udf_filtered(con):
Expand Down
47 changes: 28 additions & 19 deletions ibis/backends/druid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
import contextlib
import json
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlparse
from urllib.parse import unquote_plus

import pydruid.db
import sqlglot as sg

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.backends.druid.compiler import DruidCompiler
from ibis import util
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import STAR
from ibis.backends.sql.compilers import DruidCompiler
from ibis.backends.sql.compilers.base import STAR
from ibis.backends.sql.datatypes import DruidType

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from urllib.parse import ParseResult

import pandas as pd
import pyarrow as pa

import ibis.expr.operations as ops
import ibis.expr.types as ir


Expand All @@ -39,7 +41,7 @@ def version(self) -> str:
[(version,)] = result.fetchall()
return version

def _from_url(self, url: str, **kwargs):
def _from_url(self, url: ParseResult, **kwargs):
"""Connect to a backend using a URL `url`.
Parameters
Expand All @@ -55,24 +57,16 @@ def _from_url(self, url: str, **kwargs):
A backend instance
"""

url = urlparse(url)
query_params = parse_qs(url.query)
kwargs = {
"user": url.username,
"password": url.password,
"password": unquote_plus(url.password)
if url.password is not None
else None,
"host": url.hostname,
"path": url.path,
"port": url.port,
} | kwargs

for name, value in query_params.items():
if len(value) > 1:
kwargs[name] = value
elif len(value) == 1:
kwargs[name] = value[0]
else:
raise com.IbisError(f"Invalid URL parameter: {name}")
**kwargs,
}

self._convert_kwargs(kwargs)

Expand All @@ -88,6 +82,21 @@ def do_connect(self, **kwargs: Any) -> None:
header = kwargs.pop("header", True)
self.con = pydruid.db.connect(**kwargs, header=header)

@util.experimental
@classmethod
def from_connection(cls, con: pydruid.db.api.Connection) -> Backend:
"""Create an Ibis client from an existing connection to a Druid database.
Parameters
----------
con
An existing connection to a Druid database.
"""
new_backend = cls()
new_backend._can_reconnect = False
new_backend.con = con
return new_backend

@contextlib.contextmanager
def _safe_raw_sql(self, query, *args, **kwargs):
with contextlib.suppress(AttributeError):
Expand Down Expand Up @@ -179,7 +188,7 @@ def list_tables(
tables = result.fetchall()
return self._filter_with_like([table.TABLE_NAME for table in tables], like=like)

def _register_in_memory_tables(self, expr):
def _register_in_memory_table(self, op: ops.InMemoryTable):
"""No-op. Table are inlined, for better or worse."""

def _cursor_batches(
Expand Down
92 changes: 56 additions & 36 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends import CanCreateDatabase, CanCreateSchema, UrlFromPath
from ibis.backends.duckdb.compiler import DuckDBCompiler
from ibis.backends.duckdb.converter import DuckDBPandasData
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import STAR, C
from ibis.backends.sql.compilers import DuckDBCompiler
from ibis.backends.sql.compilers.base import STAR, C
from ibis.common.dispatch import lazy_singledispatch
from ibis.expr.operations.udf import InputType
from ibis.util import deprecated
Expand Down Expand Up @@ -278,12 +278,6 @@ def create_table(

return self.table(name, database=(catalog, database))

def _load_into_cache(self, name, expr):
self.create_table(name, expr, schema=expr.schema(), temp=True)

def _clean_up_cached_table(self, op):
self.drop_table(op.name)

def table(
self, name: str, schema: str | None = None, database: str | None = None
) -> ir.Table:
Expand Down Expand Up @@ -479,6 +473,31 @@ def do_connect(

self.con = duckdb.connect(str(database), config=config, read_only=read_only)

self._post_connect(extensions)

@util.experimental
@classmethod
def from_connection(
cls,
con: duckdb.DuckDBPyConnection,
extensions: Sequence[str] | None = None,
) -> Backend:
"""Create an Ibis client from an existing connection to a DuckDB database.
Parameters
----------
con
An existing connection to a DuckDB database.
extensions
A list of duckdb extensions to install/load upon connection.
"""
new_backend = cls(extensions=extensions)
new_backend._can_reconnect = False
new_backend.con = con
new_backend._post_connect(extensions)
return new_backend

def _post_connect(self, extensions: Sequence[str] | None = None) -> None:
# Load any pre-specified extensions
if extensions is not None:
self._load_extensions(extensions)
Expand Down Expand Up @@ -891,23 +910,20 @@ def read_in_memory(
return self.table(table_name)

def read_delta(
self,
source_table: str,
table_name: str | None = None,
**kwargs: Any,
self, source_table: str, table_name: str | None = None, **kwargs: Any
) -> ir.Table:
"""Register a Delta Lake table as a table in the current database.
Parameters
----------
source_table
The data source. Must be a directory
containing a Delta Lake table.
The data source. Must be a directory containing a Delta Lake table.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
**kwargs
Additional keyword arguments passed to deltalake.DeltaTable.
An optional name to use for the created table. This defaults to a
generated name.
kwargs
Additional keyword arguments passed to the `delta` extension's
`delta_scan` function.
Returns
-------
Expand All @@ -919,21 +935,25 @@ def read_delta(

table_name = table_name or util.gen_name("read_delta")

try:
from deltalake import DeltaTable
except ImportError:
raise ImportError(
"The deltalake extra is required to use the "
"read_delta method. You can install it using pip:\n\n"
"pip install 'ibis-framework[deltalake]'\n"
)
# always try to load the delta extension
extensions = ["delta"]

delta_table = DeltaTable(source_table, **kwargs)
# delta handles s3 itself, not with httpfs
if source_table.startswith(("http://", "https://")):
extensions.append("httpfs")

return self.read_in_memory(
delta_table.to_pyarrow_dataset(), table_name=table_name
self._load_extensions(extensions)

options = [
sg.to_identifier(key).eq(sge.convert(val)) for key, val in kwargs.items()
]
self._create_temp_view(
table_name,
sg.select(STAR).from_(self.compiler.f.delta_scan(source_table, *options)),
)

return self.table(table_name)

def list_tables(
self,
like: str | None = None,
Expand Down Expand Up @@ -1399,7 +1419,7 @@ def execute(
# but calling `to_pylist()` will render it as None
col.null_count
)
else col.to_pandas(timestamp_as_object=True)
else col.to_pandas()
)
for name, col in zip(table.column_names, table.columns)
}
Expand Down Expand Up @@ -1551,13 +1571,13 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
}
)

def _register_in_memory_tables(self, expr: ir.Expr) -> None:
for memtable in expr.op().find(ops.InMemoryTable):
self._register_in_memory_table(memtable)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
# only register if we haven't already done so
if (name := op.name) not in self.list_tables():
name = op.name
try:
# this handles tables _and_ views
self.con.table(name)
except (duckdb.CatalogException, duckdb.InvalidInputException):
# only register if we haven't already done so
self.con.register(name, op.data.to_pyarrow(op.schema))

def _register_udfs(self, expr: ir.Expr) -> None:
Expand Down
32 changes: 25 additions & 7 deletions ibis/backends/duckdb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING

import pytest
import sqlglot as sg

import ibis
from ibis.backends.conftest import TEST_TABLES
Expand Down Expand Up @@ -48,6 +49,7 @@ class TestConf(BackendTest):
deps = ("duckdb",)
stateful = False
supports_tpch = True
supports_tpcds = True
driver_supports_multiple_statements = True

def preload(self):
Expand Down Expand Up @@ -107,15 +109,31 @@ def connect(*, tmpdir, worker_id, **kw) -> BaseBackend:
kw["extension_directory"] = extension_directory
return ibis.duckdb.connect(**kw)

def load_tpch(self) -> None:
"""Load TPC-H data."""
def _load_tpc(self, *, suite, scale_factor):
con = self.connection
for path in self.data_dir.joinpath("tpch", "sf=0.17", "parquet").glob(
"*.parquet"
):
schema = f"tpc{suite}"
con.con.execute(f"CREATE OR REPLACE SCHEMA {schema}")
parquet_dir = self.data_dir.joinpath(schema, f"sf={scale_factor}", "parquet")
assert parquet_dir.exists(), parquet_dir
for path in parquet_dir.glob("*.parquet"):
table_name = path.with_suffix("").name
# duckdb automatically infers the sf=0.17 as a hive partition
con.read_parquet(path, table_name=table_name, hive_partitioning=False)
# duckdb automatically infers the sf= as a hive partition so we
# need to disable it
con.con.execute(
f"CREATE OR REPLACE VIEW {schema}.{table_name} AS "
f"FROM read_parquet({str(path)!r}, hive_partitioning=false)"
)

def _transform_tpc_sql(self, parsed, *, suite, leaves):
def add_catalog_and_schema(node):
if isinstance(node, sg.exp.Table) and node.name in leaves:
return node.__class__(
catalog=f"tpc{suite}",
**{k: v for k, v in node.args.items() if k != "catalog"},
)
return node

return parsed.transform(add_catalog_and_schema)


@pytest.fixture(scope="session")
Expand Down
33 changes: 30 additions & 3 deletions ibis/backends/duckdb/tests/test_geospatial.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from operator import methodcaller
from operator import attrgetter, methodcaller

import numpy.testing as npt
import pandas.testing as tm
import pyarrow as pa
import pytest
from packaging.version import parse as vparse
from pytest import param

import ibis
Expand Down Expand Up @@ -168,10 +169,36 @@ def test_geospatial_start_point(lines, lines_gdf):


# this one takes a bit longer than the rest.
def test_geospatial_unary_union(zones, zones_gdf):
@pytest.mark.parametrize(
"expected_func",
[
param(
attrgetter("unary_union"),
marks=pytest.mark.xfail(
condition=vparse(gpd.__version__) >= vparse("1"),
raises=Warning,
reason="unary_union property is deprecated",
),
id="version<1",
),
param(
methodcaller("union_all"),
marks=pytest.mark.xfail(
condition=(
vparse(gpd.__version__) < vparse("1")
or vparse(shapely.__version__) >= vparse("2.0.5")
),
raises=(AttributeError, AssertionError),
reason="union_all doesn't exist; shapely 2.0.5 results in a different value for union_all",
),
id="version>=1",
),
],
)
def test_geospatial_unary_union(zones, zones_gdf, expected_func):
unary_union = zones.geom.unary_union().name("unary_union")
# this returns a shapely geometry object
gp_unary_union = zones_gdf.geometry.unary_union
gp_unary_union = expected_func(zones_gdf.geometry)

# using set_precision because https://github.com/duckdb/duckdb_spatial/issues/189
assert shapely.equals(
Expand Down
10 changes: 8 additions & 2 deletions ibis/backends/duckdb/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def test_temp_directory(tmp_path):
@pytest.fixture(scope="session")
def pgurl(): # pragma: no cover
pgcon = ibis.postgres.connect(
user="postgres", password="postgres", host="localhost"
user="postgres",
password="postgres", # noqa: S106
host="localhost",
)

df = pd.DataFrame({"x": [1.0, 2.0, 3.0, 1.0], "y": ["a", "b", "c", "a"]})
Expand All @@ -193,7 +195,11 @@ def test_read_postgres(con, pgurl): # pragma: no cover

@pytest.fixture(scope="session")
def mysqlurl(): # pragma: no cover
mysqlcon = ibis.mysql.connect(user="ibis", password="ibis", database="ibis_testing")
mysqlcon = ibis.mysql.connect(
user="ibis",
password="ibis", # noqa: S106
database="ibis_testing",
)

df = pd.DataFrame({"x": [1.0, 2.0, 3.0, 1.0], "y": ["a", "b", "c", "a"]})
s = ibis.schema(dict(x="float64", y="str"))
Expand Down
52 changes: 36 additions & 16 deletions ibis/backends/exasol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import datetime
import re
from typing import TYPE_CHECKING, Any
from urllib.parse import parse_qs, urlparse
from urllib.parse import unquote_plus

import pyexasol
import sqlglot as sg
Expand All @@ -19,12 +19,13 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends import CanCreateDatabase, CanCreateSchema
from ibis.backends.exasol.compiler import ExasolCompiler
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import STAR, C
from ibis.backends.sql.compilers import ExasolCompiler
from ibis.backends.sql.compilers.base import STAR, C

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
from urllib.parse import ParseResult

import pandas as pd
import polars as pl
Expand Down Expand Up @@ -96,28 +97,47 @@ def do_connect(
quote_ident=True,
**kwargs,
)
self._post_connect(timezone)

@util.experimental
@classmethod
def from_connection(
cls, con: pyexasol.ExaConnection, timezone: str | None = None
) -> Backend:
"""Create an Ibis client from an existing connection to an Exasol database.
Parameters
----------
con
An existing connection to an Exasol database.
timezone
The session timezone.
"""
if timezone is None:
timezone = (con.execute("SELECT SESSIONTIMEZONE").fetchone() or ("UTC",))[0]

new_backend = cls(timezone=timezone)
new_backend._can_reconnect = False
new_backend.con = con
new_backend._post_connect(timezone)
return new_backend

def _post_connect(self, timezone: str = "UTC") -> None:
with self.begin() as con:
con.execute(f"ALTER SESSION SET TIME_ZONE = {timezone!r}")

def _from_url(self, url: str, **kwargs) -> BaseBackend:
def _from_url(self, url: ParseResult, **kwargs) -> BaseBackend:
"""Construct an ibis backend from a URL."""
url = urlparse(url)
query_params = parse_qs(url.query)
kwargs = {
"user": url.username,
"password": url.password,
"password": unquote_plus(url.password)
if url.password is not None
else None,
"schema": url.path[1:] or None,
"host": url.hostname,
"port": url.port,
} | kwargs

for name, value in query_params.items():
if len(value) > 1:
kwargs[name] = value
elif len(value) == 1:
kwargs[name] = value[0]
else:
raise com.IbisError(f"Invalid URL parameter: {name}")
**kwargs,
}

self._convert_kwargs(kwargs)

Expand Down
1 change: 0 additions & 1 deletion ibis/backends/exasol/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class TestConf(ServiceBackendTest):
reduction_tolerance = 1e-7
stateful = True
service_name = "exasol"
supports_tpch = False
force_sort = True
deps = ("pyexasol",)

Expand Down
50 changes: 29 additions & 21 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import sqlglot as sg
import sqlglot.expressions as sge

import ibis
import ibis.common.exceptions as exc
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends import CanCreateDatabase, NoUrl
from ibis.backends.flink.compiler import FlinkCompiler
from ibis.backends.flink.ddl import (
CreateDatabase,
CreateTableWithSchema,
Expand All @@ -22,6 +23,7 @@
RenameTable,
)
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import FlinkCompiler
from ibis.backends.tests.errors import Py4JJavaError
from ibis.expr.operations.udf import InputType
from ibis.util import gen_name
Expand Down Expand Up @@ -70,6 +72,18 @@ def do_connect(self, table_env: TableEnvironment) -> None:
"""
self._table_env = table_env

@util.experimental
@classmethod
def from_connection(cls, table_env: TableEnvironment) -> Backend:
"""Create a Flink `Backend` from an existing table environment.
Parameters
----------
table_env
A table environment.
"""
return ibis.flink.connect(table_env)

def disconnect(self) -> None:
pass

Expand Down Expand Up @@ -445,27 +459,21 @@ def create_table(

# In-memory data is created as views in `pyflink`
if obj is not None:
if isinstance(obj, pd.DataFrame):
dataframe = obj

elif isinstance(obj, pa.Table):
dataframe = obj.to_pandas()

elif isinstance(obj, ir.Table):
# Note (mehmet): If obj points to in-memory data, we create a view.
# Other cases are unsupported for now, e.g., obj is of UnboundTable.
# See TODO right below for more context on how we handle in-memory data.
op = obj.op()
if isinstance(op, ops.InMemoryTable):
dataframe = op.data.to_frame()
else:
raise exc.IbisError(
"`obj` is of type ibis.expr.types.Table but it is not in-memory. "
"Currently, only in-memory tables are supported. "
"See ibis.memtable() for info on creating in-memory table."
)
if not isinstance(obj, ir.Table):
obj = ibis.memtable(obj)

# Note (mehmet): If obj points to in-memory data, we create a view.
# Other cases are unsupported for now, e.g., obj is of UnboundTable.
# See TODO right below for more context on how we handle in-memory data.
op = obj.op()
if isinstance(op, ops.InMemoryTable):
dataframe = op.data.to_frame()
else:
raise exc.IbisError(f"Unsupported `obj` type: {type(obj)}")
raise exc.IbisError(
"`obj` is of type ibis.expr.types.Table but it is not in-memory. "
"Currently, only in-memory tables are supported. "
"See ibis.memtable() for info on creating in-memory table."
)

# TODO (mehmet): Flink requires a source connector to create regular tables.
# In-memory data can only be created as a view (virtual table). So we decided
Expand Down
68 changes: 50 additions & 18 deletions ibis/backends/flink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,25 @@
if TYPE_CHECKING:
from pyflink.table import StreamTableEnvironment

TEST_TABLES["functional_alltypes"] = ibis.schema(
{
"id": "int32",
"bool_col": "boolean",
"tinyint_col": "int8",
"smallint_col": "int16",
"int_col": "int32",
"bigint_col": "int64",
"float_col": "float32",
"double_col": "float64",
"date_string_col": "string",
"string_col": "string",
"timestamp_col": "timestamp(3)", # overriding the higher level fixture with precision because Flink's
# watermark must use a field of type TIMESTAMP(p) or TIMESTAMP_LTZ(p), where 'p' is from 0 to 3
"year": "int32",
"month": "int32",
}
)


def get_table_env(
local_env: bool,
Expand Down Expand Up @@ -152,24 +171,7 @@ def awards_players_schema():

@pytest.fixture
def functional_alltypes_schema():
return ibis.schema(
{
"id": "int32",
"bool_col": "boolean",
"tinyint_col": "int8",
"smallint_col": "int16",
"int_col": "int32",
"bigint_col": "int64",
"float_col": "float32",
"double_col": "float64",
"date_string_col": "string",
"string_col": "string",
"timestamp_col": "timestamp(3)", # overriding the higher level fixture with precision because Flink's
# watermark must use a field of type TIMESTAMP(p) or TIMESTAMP_LTZ(p), where 'p' is from 0 to 3
"year": "int32",
"month": "int32",
}
)
return TEST_TABLES["functional_alltypes"]


@pytest.fixture
Expand All @@ -188,3 +190,33 @@ def generate_csv_configs(csv_file):
}

return generate_csv_configs


@pytest.fixture(scope="session")
def functional_alltypes_no_header(tmpdir_factory, data_dir):
file = tmpdir_factory.mktemp("data") / "functional_alltypes.csv"
with (
open(data_dir / "csv" / "functional_alltypes.csv") as reader,
open(str(file), mode="w") as writer,
):
reader.readline() # read the first line and discard it
for line in reader:
writer.write(line)
return file


@pytest.fixture(scope="session", autouse=True)
def functional_alltypes_with_watermark(con, functional_alltypes_no_header):
# create a streaming table with watermark for testing event-time based ops
t = con.create_table(
"functional_alltypes_with_watermark",
schema=TEST_TABLES["functional_alltypes"],
tbl_properties={
"connector": "filesystem",
"path": functional_alltypes_no_header,
"format": "csv",
},
watermark=ibis.watermark("timestamp_col", ibis.interval(seconds=10)),
temp=True,
)
return t
50 changes: 0 additions & 50 deletions ibis/backends/flink/tests/test_compiler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from __future__ import annotations

from operator import methodcaller

import pytest
from pytest import param

import ibis
from ibis.common.deferred import _


def test_sum(simple_table, assert_sql):
expr = simple_table.a.sum()
Expand Down Expand Up @@ -103,48 +98,3 @@ def test_having(simple_table, assert_sql):
.aggregate(simple_table.b.sum().name("b_sum"))
)
assert_sql(expr)


@pytest.mark.parametrize(
"method",
[
methodcaller("tumble", window_size=ibis.interval(minutes=15)),
methodcaller(
"hop",
window_size=ibis.interval(minutes=15),
window_slide=ibis.interval(minutes=1),
),
methodcaller(
"cumulate",
window_size=ibis.interval(minutes=1),
window_step=ibis.interval(seconds=10),
),
],
ids=["tumble", "hop", "cumulate"],
)
def test_windowing_tvf(simple_table, method, assert_sql):
expr = method(simple_table.window_by(time_col=simple_table.i))
assert_sql(expr)


def test_window_aggregation(simple_table, assert_sql):
expr = (
simple_table.window_by(time_col=simple_table.i)
.tumble(window_size=ibis.interval(minutes=15))
.group_by(["window_start", "window_end", "g"])
.aggregate(mean=_.d.mean())
)
assert_sql(expr)


def test_window_topn(simple_table, assert_sql):
expr = simple_table.window_by(time_col="i").tumble(
window_size=ibis.interval(seconds=600),
)["a", "b", "c", "d", "g", "window_start", "window_end"]
expr = expr.mutate(
rownum=ibis.row_number().over(
group_by=["window_start", "window_end"], order_by=ibis.desc("g")
)
)
expr = expr[expr.rownum <= 3]
assert_sql(expr)
71 changes: 65 additions & 6 deletions ibis/backends/flink/tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas.testing as tm
import pyarrow as pa
import pytest
from pytest import param

import ibis
import ibis.common.exceptions as exc
Expand Down Expand Up @@ -62,6 +63,66 @@ def test_create_table(con, awards_players_schema, temp_table, csv_source_configs
assert temp_table not in con.list_tables()


@pytest.mark.parametrize(
"obj, table_name",
[
param(lambda: pa.table({"a": ["a"], "b": [1]}), "df_arrow", id="pyarrow table"),
param(lambda: pd.DataFrame({"a": ["a"], "b": [1]}), "df_pandas", id="pandas"),
param(
lambda: pytest.importorskip("polars").DataFrame({"a": ["a"], "b": [1]}),
"df_polars_eager",
id="polars dataframe",
),
param(
lambda: pytest.importorskip("polars").LazyFrame({"a": ["a"], "b": [1]}),
"df_polars_lazy",
id="polars lazyframe",
),
param(
lambda: ibis.memtable([("a", 1)], columns=["a", "b"]),
"memtable",
id="memtable_list",
),
param(
lambda: ibis.memtable(pa.table({"a": ["a"], "b": [1]})),
"memtable_pa",
id="memtable pyarrow",
),
param(
lambda: ibis.memtable(pd.DataFrame({"a": ["a"], "b": [1]})),
"memtable_pandas",
id="memtable pandas",
),
param(
lambda: ibis.memtable(
pytest.importorskip("polars").DataFrame({"a": ["a"], "b": [1]})
),
"memtable_polars_eager",
id="memtable polars dataframe",
),
param(
lambda: ibis.memtable(
pytest.importorskip("polars").LazyFrame({"a": ["a"], "b": [1]})
),
"memtable_polars_lazy",
id="memtable polars lazyframe",
),
],
)
def test_create_table_in_memory(con, obj, table_name, monkeypatch):
"""Same as in ibis/backends/tests/test_client.py, with temp=True."""
monkeypatch.setattr(ibis.options, "default_backend", con)
obj = obj()
t = con.create_table(table_name, obj, temp=True)

result = pa.table({"a": ["a"], "b": [1]})
assert table_name in con.list_tables()

assert result.equals(t.to_pyarrow())

con.drop_table(table_name, force=True)


def test_recreate_table_from_schema(
con, awards_players_schema, temp_table, csv_source_configs
):
Expand Down Expand Up @@ -362,25 +423,23 @@ def test_rename_table(con, awards_players_schema, temp_table, csv_source_configs
@pytest.mark.parametrize(
"obj",
[
pytest.param(
[("fred flintstone", 35, 1.28), ("barney rubble", 32, 2.32)], id="list"
),
pytest.param(
param([("fred flintstone", 35, 1.28), ("barney rubble", 32, 2.32)], id="list"),
param(
{
"name": ["fred flintstone", "barney rubble"],
"age": [35, 32],
"gpa": [1.28, 2.32],
},
id="dict",
),
pytest.param(
param(
pd.DataFrame(
[("fred flintstone", 35, 1.28), ("barney rubble", 32, 2.32)],
columns=["name", "age", "gpa"],
),
id="pandas_dataframe",
),
pytest.param(
param(
pa.Table.from_arrays(
[
pa.array(["fred flintstone", "barney rubble"]),
Expand Down
166 changes: 0 additions & 166 deletions ibis/backends/flink/tests/test_join.py

This file was deleted.

7 changes: 4 additions & 3 deletions ibis/backends/flink/tests/test_memtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@pytest.mark.parametrize(
"data,schema,expected",
("data", "schema", "expected"),
[
pytest.param(
{"value": [{"a": 1}, {"a": 2}]},
Expand All @@ -35,8 +35,9 @@ def test_create_memtable(con, data, schema, expected):
# cannot use con.execute(t) directly because of some behavioral discrepancy between
# `TableEnvironment.execute_sql()` and `TableEnvironment.sql_query()`; this doesn't
# seem to be an issue if we don't execute memtable directly
result = con.raw_sql(con.compile(t)).collect()
assert all(element in result for element in expected)
result = list(con.raw_sql(con.compile(t)).collect())
for element in expected:
assert element in result


@pytest.mark.notyet(
Expand Down
29 changes: 27 additions & 2 deletions ibis/backends/flink/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytest import param

import ibis
from ibis import _
from ibis.backends.tests.errors import Py4JJavaError


Expand Down Expand Up @@ -53,13 +54,37 @@ def test_window_invalid_start_end(con, window):
con.execute(expr)


def test_range_window(con, simple_table, assert_sql):
def test_range_window(simple_table, assert_sql):
expr = simple_table.f.sum().over(
range=(-ibis.interval(minutes=500), 0), order_by=simple_table.f
)
assert_sql(expr)


def test_rows_window(con, simple_table, assert_sql):
def test_rows_window(simple_table, assert_sql):
expr = simple_table.f.sum().over(rows=(-1000, 0), order_by=simple_table.f)
assert_sql(expr)


def test_tumble_window_by_grouped_agg(con):
t = con.table("functional_alltypes_with_watermark")
expr = (
t.window_by(time_col=t.timestamp_col)
.tumble(size=ibis.interval(seconds=30))
.agg(by=["string_col"], avg=_.float_col.mean())
)
result = expr.to_pandas()
assert list(result.columns) == ["window_start", "window_end", "string_col", "avg"]
assert result.shape == (610, 4)


def test_tumble_window_by_ungrouped_agg(con):
t = con.table("functional_alltypes_with_watermark")
expr = (
t.window_by(time_col=t.timestamp_col)
.tumble(size=ibis.interval(seconds=30))
.agg(avg=_.float_col.mean())
)
result = expr.to_pandas()
assert list(result.columns) == ["window_start", "window_end", "avg"]
assert result.shape == (610, 3)
42 changes: 24 additions & 18 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal
from urllib.parse import parse_qs, urlparse

import impala.dbapi as impyla
import sqlglot as sg
Expand All @@ -21,7 +20,6 @@
from ibis import util
from ibis.backends.impala import ddl, udf
from ibis.backends.impala.client import ImpalaTable
from ibis.backends.impala.compiler import ImpalaCompiler
from ibis.backends.impala.ddl import (
CTAS,
CreateDatabase,
Expand All @@ -40,11 +38,14 @@
wrap_udf,
)
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers import ImpalaCompiler

if TYPE_CHECKING:
from collections.abc import Mapping
from pathlib import Path
from urllib.parse import ParseResult

import impala.hiveserver2 as hs2
import pandas as pd
import polars as pl
import pyarrow as pa
Expand All @@ -67,7 +68,7 @@ class Backend(SQLBackend):

supports_in_memory_tables = True

def _from_url(self, url: str, **kwargs: Any) -> Backend:
def _from_url(self, url: ParseResult, **kwargs: Any) -> Backend:
"""Connect to a backend using a URL `url`.
Parameters
Expand All @@ -83,8 +84,6 @@ def _from_url(self, url: str, **kwargs: Any) -> Backend:
A backend instance
"""
url = urlparse(url)

for name in ("username", "hostname", "port", "password"):
if value := (
getattr(url, name, None)
Expand All @@ -99,16 +98,6 @@ def _from_url(self, url: str, **kwargs: Any) -> Backend:
if database:
kwargs["database"] = database

query_params = parse_qs(url.query)

for name, value in query_params.items():
if len(value) > 1:
kwargs[name] = value
elif len(value) == 1:
kwargs[name] = value[0]
else:
raise com.IbisError(f"Invalid URL parameter: {name}")

self._convert_kwargs(kwargs)
return self.connect(**kwargs)

Expand Down Expand Up @@ -195,6 +184,25 @@ def do_connect(
cur.ping()

self.con = con
self._post_connect()

@util.experimental
@classmethod
def from_connection(cls, con: hs2.HiveServer2Connection) -> Backend:
"""Create an Impala `Backend` from an existing HS2 connection.
Parameters
----------
con
An existing connection to HiveServer2 (HS2).
"""
new_backend = cls()
new_backend._can_reconnect = False
new_backend.con = con
new_backend._post_connect()
return new_backend

def _post_connect(self) -> None:
self.options = {}

@cached_property
Expand Down Expand Up @@ -1239,9 +1247,7 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
).sql(self.name, pretty=True)

data = op.data.to_frame().itertuples(index=False)
specs = ", ".join("?" * len(schema))
table = sg.table(name, quoted=quoted).sql(self.name)
insert_stmt = f"INSERT INTO {table} VALUES ({specs})"
insert_stmt = self._build_insert_template(name, schema=schema)
with self._safe_raw_sql(create_stmt) as cur:
for row in data:
cur.execute(insert_stmt, row)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/impala/tests/mocks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from ibis.backends.impala.compiler import ImpalaCompiler
from ibis.backends.sql.compilers import ImpalaCompiler
from ibis.tests.expr.mocks import MockBackend


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC NULLS LAST) AS `First(double_col)`
FIRST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `First(double_col)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LAG(`t0`.`string_col`, 2) OVER (ORDER BY NULL ASC NULLS LAST) AS `Lag(string_col, 2)`
LAG(`t0`.`string_col`, 2) OVER (ORDER BY NULL ASC) AS `Lag(string_col, 2)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LAG(`t0`.`string_col`) OVER (ORDER BY NULL ASC NULLS LAST) AS `Lag(string_col)`
LAG(`t0`.`string_col`) OVER (ORDER BY NULL ASC) AS `Lag(string_col)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LAG(`t0`.`string_col`, 1, 0) OVER (ORDER BY NULL ASC NULLS LAST) AS `Lag(string_col, 0)`
LAG(`t0`.`string_col`, 1, 0) OVER (ORDER BY NULL ASC) AS `Lag(string_col, 0)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC NULLS LAST) AS `Last(double_col)`
LAST_VALUE(`t0`.`double_col`) OVER (ORDER BY `t0`.`id` ASC) AS `Last(double_col)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LEAD(`t0`.`string_col`, 2) OVER (ORDER BY NULL ASC NULLS LAST) AS `Lead(string_col, 2)`
LEAD(`t0`.`string_col`, 2) OVER (ORDER BY NULL ASC) AS `Lead(string_col, 2)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LEAD(`t0`.`string_col`) OVER (ORDER BY NULL ASC NULLS LAST) AS `Lead(string_col)`
LEAD(`t0`.`string_col`) OVER (ORDER BY NULL ASC) AS `Lead(string_col)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LEAD(`t0`.`string_col`, 1, 0) OVER (ORDER BY NULL ASC NULLS LAST) AS `Lead(string_col, 0)`
LEAD(`t0`.`string_col`, 1, 0) OVER (ORDER BY NULL ASC) AS `Lead(string_col, 0)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
NTILE(3) OVER (ORDER BY `t0`.`double_col` ASC NULLS LAST) - 1 AS `NTile(3)`
NTILE(3) OVER (ORDER BY `t0`.`double_col` ASC) - 1 AS `NTile(3)`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
PERCENT_RANK() OVER (ORDER BY `t0`.`double_col` ASC NULLS LAST) AS `PercentRank()`
PERCENT_RANK() OVER (ORDER BY `t0`.`double_col` ASC) AS `PercentRank()`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ SELECT
)
THEN 2
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ SELECT
)
THEN 2
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ SELECT
WHEN 50 <= `t0`.`f`
THEN 4
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ SELECT
)
THEN 2
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ SELECT
)
THEN 3
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ SELECT
WHEN 10 < `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@ SELECT
)
THEN 2
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ SELECT
WHEN 10 <= `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ SELECT
WHEN 10 <= `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS INT) AS `Cast(Bucket(f), int32)`
END AS INT) AS `Cast(Bucket(f, ()), int32)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ SELECT
WHEN 10 <= `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS DOUBLE) AS `Cast(Bucket(f), float64)`
END AS DOUBLE) AS `Cast(Bucket(f, ()), float64)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ SELECT
)
THEN 3
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ SELECT
WHEN 50 < `t0`.`f`
THEN 4
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f)`
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
COALESCE(`t0`.`l_extendedprice`, 0) AS `Coalesce()`
COALESCE(`t0`.`l_extendedprice`, 0) AS `Coalesce((l_extendedprice, 0))`
FROM `tpch_lineitem` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
COALESCE(`t0`.`l_extendedprice`, 0.0) AS `Coalesce()`
COALESCE(`t0`.`l_extendedprice`, 0.0) AS `Coalesce((l_extendedprice, 0.0))`
FROM `tpch_lineitem` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
COALESCE(`t0`.`l_quantity`, 0) AS `Coalesce()`
COALESCE(`t0`.`l_quantity`, 0) AS `Coalesce((l_quantity, 0))`
FROM `tpch_lineitem` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ SELECT
WHEN `t0`.`c` < 0
THEN `t0`.`a` * 2
ELSE CAST(NULL AS BIGINT)
END AS `SearchedCase(Cast(None, int64))`
END AS `SearchedCase((Greater(f, 0), Less(c, 0)), (Multiply(d, 2), Multiply(a, 2)), Cast(None, int64))`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
CASE `t0`.`g` WHEN 'foo' THEN 'bar' WHEN 'baz' THEN 'qux' ELSE 'default' END AS `SimpleCase(g, 'default')`
CASE `t0`.`g` WHEN 'foo' THEN 'bar' WHEN 'baz' THEN 'qux' ELSE 'default' END AS `SimpleCase(g, ('foo', 'baz'), ('bar', 'qux'), 'default')`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
COALESCE(`t0`.`int_col`, `t0`.`bigint_col`) AS `Coalesce()`
COALESCE(`t0`.`int_col`, `t0`.`bigint_col`) AS `Coalesce((int_col, bigint_col))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
COALESCE(`t0`.`string_col`, 'foo') AS `Coalesce()`
COALESCE(`t0`.`string_col`, 'foo') AS `Coalesce((string_col, 'foo'))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
GREATEST(`t0`.`int_col`, `t0`.`bigint_col`) AS `Greatest()`
GREATEST(`t0`.`int_col`, `t0`.`bigint_col`) AS `Greatest((int_col, bigint_col))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
GREATEST(`t0`.`string_col`, 'foo') AS `Greatest()`
GREATEST(`t0`.`string_col`, 'foo') AS `Greatest((string_col, 'foo'))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LEAST(`t0`.`int_col`, `t0`.`bigint_col`) AS `Least()`
LEAST(`t0`.`int_col`, `t0`.`bigint_col`) AS `Least((int_col, bigint_col))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LEAST(`t0`.`string_col`, 'foo') AS `Least()`
LEAST(`t0`.`string_col`, 'foo') AS `Least((string_col, 'foo'))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT * FROM (SELECT `t1`.`col`, COUNT(*) OVER (ORDER BY NULL ASC NULLS LAST) AS `analytic` FROM (SELECT `t0`.`col`, NULL AS `filter` FROM `x` AS `t0` WHERE NULL IS NULL) AS `t1`) AS `t2`
SELECT * FROM (SELECT `t1`.`col`, COUNT(*) OVER (ORDER BY NULL ASC) AS `analytic` FROM (SELECT `t0`.`col`, NULL AS `filter` FROM `x` AS `t0` WHERE NULL IS NULL) AS `t1`) AS `t2`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
`t0`.`g` IN ('foo', 'bar', 'baz') AS `InValues(g)`
`t0`.`g` IN ('foo', 'bar', 'baz') AS `InValues(g, ('foo', 'bar', 'baz'))`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
NOT (
`t0`.`g` IN ('foo', 'bar', 'baz')
) AS `Not(InValues(g))`
) AS `Not(InValues(g, ('foo', 'bar', 'baz')))`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
2 IN (`t0`.`a`, `t0`.`b`, `t0`.`c`) AS `InValues(2)`
2 IN (`t0`.`a`, `t0`.`b`, `t0`.`c`) AS `InValues(2, (a, b, c))`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
NOT (
2 IN (`t0`.`a`, `t0`.`b`, `t0`.`c`)
) AS `Not(InValues(2))`
) AS `Not(InValues(2, (a, b, c)))`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1 +1 @@
SELECT `t0`.`one`, `t0`.`two`, `t0`.`three`, SUM(`t0`.`two`) OVER (PARTITION BY `t0`.`three` ORDER BY `t0`.`one` ASC NULLS LAST) AS `four` FROM `my_data` AS `t0`
SELECT `t0`.`one`, `t0`.`two`, `t0`.`three`, SUM(`t0`.`two`) OVER (PARTITION BY `t0`.`three` ORDER BY `t0`.`one` ASC) AS `four` FROM `my_data` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
FIND_IN_SET(`t0`.`string_col`, CONCAT_WS(',', 'a', 'b')) - 1 AS `FindInSet(string_col)`
FIND_IN_SET(`t0`.`string_col`, CONCAT_WS(',', 'a', 'b')) - 1 AS `FindInSet(string_col, ('a', 'b'))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
FIND_IN_SET(`t0`.`string_col`, CONCAT_WS(',', 'a')) - 1 AS `FindInSet(string_col)`
FIND_IN_SET(`t0`.`string_col`, CONCAT_WS(',', 'a')) - 1 AS `FindInSet(string_col, ('a',))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SELECT
CONCAT_WS(',', 'a', 'b') AS `StringJoin(',')`
CONCAT_WS(',', 'a', 'b') AS `StringJoin(('a', 'b'), ',')`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
COALESCE(`t0`.`double_col`, 0) AS `Coalesce()`
COALESCE(`t0`.`double_col`, 0) AS `Coalesce((double_col, 0))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
COALESCE(`t0`.`int_col`, 0) AS `Coalesce()`
COALESCE(`t0`.`int_col`, 0) AS `Coalesce((int_col, 0))`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ SELECT
`t0`.`i`,
`t0`.`j`,
`t0`.`k`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `lag`,
LEAD(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) - `t0`.`f` AS `fwd_diff`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `last`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`d` ASC NULLS LAST) AS `lag2`
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) AS `lag`,
LEAD(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) - `t0`.`f` AS `fwd_diff`,
FIRST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) AS `first`,
LAST_VALUE(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) AS `last`,
LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`d` ASC) AS `lag2`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ SELECT
`t0`.`i`,
`t0`.`j`,
`t0`.`k`,
`t0`.`f` / SUM(`t0`.`f`) OVER (ORDER BY NULL ASC NULLS LAST) AS `normed_f`
`t0`.`f` / SUM(`t0`.`f`) OVER (ORDER BY NULL ASC) AS `normed_f`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
MAX(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC NULLS LAST) AS `foo`
MAX(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
MAX(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC NULLS LAST) AS `foo`
MAX(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
AVG(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC NULLS LAST) AS `foo`
AVG(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
AVG(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC NULLS LAST) AS `foo`
AVG(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
MIN(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC NULLS LAST) AS `foo`
MIN(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
MIN(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC NULLS LAST) AS `foo`
MIN(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC NULLS LAST) AS `foo`
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC NULLS LAST) AS `foo`
SUM(`t0`.`f`) OVER (ORDER BY `t0`.`d` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SELECT
`t0`.`g`,
SUM(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) - SUM(`t0`.`f`) OVER (ORDER BY NULL ASC NULLS LAST) AS `result`
SUM(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) - SUM(`t0`.`f`) OVER (ORDER BY NULL ASC) AS `result`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
LAG(`t0`.`f` - LAG(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST)) OVER (ORDER BY `t0`.`f` ASC NULLS LAST) AS `foo`
LAG(`t0`.`f` - LAG(`t0`.`f`) OVER (ORDER BY `t0`.`f` ASC)) OVER (ORDER BY `t0`.`f` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SELECT
`t0`.`f`,
ROW_NUMBER() OVER (ORDER BY `t0`.`f` DESC) - 1 AS `revrank`
ROW_NUMBER() OVER (ORDER BY `t0`.`f` DESC NULLS LAST) - 1 AS `revrank`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SELECT
LAG(`t0`.`d`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` DESC) AS `foo`,
MAX(`t0`.`a`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` DESC) AS `Max(a)`
LAG(`t0`.`d`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` DESC NULLS LAST) AS `foo`,
MAX(`t0`.`a`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` DESC NULLS LAST) AS `Max(a)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
SELECT
LAG(
`t0`.`f` - LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` ASC NULLS LAST)
) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` ASC NULLS LAST) AS `foo`
LAG(`t0`.`f` - LAG(`t0`.`f`) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` ASC)) OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
SELECT
`t0`.`g`,
RANK() OVER (ORDER BY `t0`.`f` ASC NULLS LAST) - 1 AS `minr`,
DENSE_RANK() OVER (ORDER BY `t0`.`f` ASC NULLS LAST) - 1 AS `denser`
RANK() OVER (ORDER BY `t0`.`f` ASC) - 1 AS `minr`,
DENSE_RANK() OVER (ORDER BY `t0`.`f` ASC) - 1 AS `denser`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ SELECT
`t0`.`i`,
`t0`.`j`,
`t0`.`k`,
ROW_NUMBER() OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC NULLS LAST) AS `foo`
ROW_NUMBER() OVER (PARTITION BY `t0`.`g` ORDER BY NULL ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ SELECT
`t0`.`i`,
`t0`.`j`,
`t0`.`k`,
ROW_NUMBER() OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` ASC NULLS LAST) - 1 AS `foo`
ROW_NUMBER() OVER (PARTITION BY `t0`.`g` ORDER BY `t0`.`f` ASC) - 1 AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ SELECT
`t0`.`j`,
`t0`.`k`,
(
ROW_NUMBER() OVER (ORDER BY `t0`.`f` ASC NULLS LAST) - 1
ROW_NUMBER() OVER (ORDER BY `t0`.`f` ASC) - 1
) / 2 AS `new`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN 10 preceding AND 5 preceding) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN 10 preceding AND 5 preceding) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN UNBOUNDED PRECEDING AND 2 following) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND 2 following) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN CURRENT ROW AND 2 following) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN CURRENT ROW AND 2 following) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN 5 following AND 10 following) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN 5 following AND 10 following) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN 5 preceding AND UNBOUNDED FOLLOWING) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN 5 preceding AND UNBOUNDED FOLLOWING) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN 5 preceding AND CURRENT ROW) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN 5 preceding AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN 5 preceding AND 2 following) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN 5 preceding AND 2 following) AS `foo`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC NULLS LAST ROWS BETWEEN 10 preceding AND CURRENT ROW) AS `foo`
SUM(`t0`.`d`) OVER (ORDER BY `t0`.`f` ASC ROWS BETWEEN 10 preceding AND CURRENT ROW) AS `foo`
FROM `alltypes` AS `t0`
10 changes: 5 additions & 5 deletions ibis/backends/impala/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ def test_udf_primitive_output_types(ty, value, column, table):
ibis_type = dt.validate_type(ty)

expr = func(value)
assert type(expr) == getattr(ir, ibis_type.scalar)
assert type(expr) is getattr(ir, ibis_type.scalar)

expr = func(table[column])
assert type(expr) == getattr(ir, ibis_type.column)
assert type(expr) is getattr(ir, ibis_type.column)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -184,9 +184,9 @@ def test_uda_primitive_output_types(ty, value):
def test_decimal(dec):
func = _register_udf(["decimal(12, 2)"], "decimal(12, 2)", "test")
expr = func(1.0)
assert type(expr) == ir.DecimalScalar
assert type(expr) is ir.DecimalScalar
expr = func(dec)
assert type(expr) == ir.DecimalColumn
assert type(expr) is ir.DecimalColumn


@pytest.mark.parametrize(
Expand Down Expand Up @@ -382,7 +382,7 @@ def identity_func_testing(udf_ll, con, test_data_db, datatype, literal, column):
result = con.execute(expr)
# Hacky
if datatype == "timestamp":
assert type(result) == pd.Timestamp
assert type(result) is pd.Timestamp
else:
lop = literal.op()
if isinstance(lop, ir.Literal):
Expand Down
245 changes: 177 additions & 68 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import datetime
import struct
from contextlib import closing
from functools import partial
from itertools import repeat
from operator import itemgetter
from typing import TYPE_CHECKING, Any

Expand All @@ -23,9 +21,9 @@
import ibis.expr.types as ir
from ibis import util
from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl
from ibis.backends.mssql.compiler import MSSQLCompiler
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compiler import C
from ibis.backends.sql.compilers import MSSQLCompiler
from ibis.backends.sql.compilers.base import STAR, C

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -114,27 +112,49 @@ def do_connect(
if user is None and password is None:
kwargs.setdefault("Trusted_Connection", "yes")

con = pyodbc.connect(
self.con = pyodbc.connect(
user=user,
server=host,
port=port,
server=f"{host},{port}",
password=password,
database=database,
driver=driver,
**kwargs,
)

self._post_connect()

@util.experimental
@classmethod
def from_connection(cls, con: pyodbc.Connection) -> Backend:
"""Create an Ibis client from an existing connection to a MSSQL database.

Parameters
----------
con
An existing connection to a MSSQL database.
"""
new_backend = cls()
new_backend._can_reconnect = False
new_backend.con = con
new_backend._post_connect()
return new_backend

def _post_connect(self):
# -155 is the code for datetimeoffset
con.add_output_converter(-155, datetimeoffset_to_datetime)
self.con.add_output_converter(-155, datetimeoffset_to_datetime)

with closing(con.cursor()) as cur:
with closing(self.con.cursor()) as cur:
cur.execute("SET DATEFIRST 1")

self.con = con

def get_schema(
self, name: str, *, catalog: str | None = None, database: str | None = None
) -> sch.Schema:
# TODO: this is brittle and should be improved. We want to be able to
# identify if a given table is a temp table and update the search
# location accordingly.
if name.startswith("ibis_cache_"):
catalog, database = ("tempdb", "dbo")
name = "##" + name
conditions = [sg.column("table_name").eq(sge.convert(name))]

if database is not None:
Expand Down Expand Up @@ -248,6 +268,11 @@ def current_database(self) -> str:

@contextlib.contextmanager
def begin(self):
with contextlib.closing(self.con.cursor()) as cur:
yield cur

@contextlib.contextmanager
def _ddl_begin(self):
con = self.con
cur = con.cursor()
try:
Expand All @@ -269,94 +294,131 @@ def _safe_raw_sql(self, query, *args, **kwargs):
cur.execute(query, *args, **kwargs)
yield cur

@contextlib.contextmanager
def _safe_ddl(self, query, *args, **kwargs):
with contextlib.suppress(AttributeError):
query = query.sql(self.dialect)

with self._ddl_begin() as cur:
cur.execute(query, *args, **kwargs)
yield cur

def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any:
with contextlib.suppress(AttributeError):
query = query.sql(self.dialect)

con = self.con
cursor = con.cursor()

try:
cursor.execute(query, **kwargs)
except Exception:
con.rollback()
cursor.close()
raise
else:
con.commit()
return cursor
cursor.execute(query, **kwargs)
return cursor

def create_catalog(self, name: str, force: bool = False) -> None:
name = self._quote(name)
expr = (
sg.select(STAR)
.from_(sg.table("databases", db="sys"))
.where(C.name.eq(sge.convert(name)))
)
stmt = sge.Create(
kind="DATABASE", this=sg.to_identifier(name, quoted=self.compiler.quoted)
).sql(self.dialect)
create_stmt = (
f"""\
IF NOT EXISTS (SELECT name FROM sys.databases WHERE name = {name})
IF NOT EXISTS ({expr.sql(self.dialect)})
BEGIN
CREATE DATABASE {name};
{stmt};
END;
GO"""
if force
else f"CREATE DATABASE {name}"
else stmt
)
with self._safe_raw_sql(create_stmt):
with self._safe_ddl(create_stmt):
pass

def drop_catalog(self, name: str, force: bool = False) -> None:
name = self._quote(name)
if_exists = "IF EXISTS " * force

with self._safe_raw_sql(f"DROP DATABASE {if_exists}{name}"):
with self._safe_ddl(
sge.Drop(
kind="DATABASE",
this=sg.to_identifier(name, quoted=self.compiler.quoted),
exists=force,
)
):
pass

def create_database(
self, name: str, catalog: str | None = None, force: bool = False
) -> None:
current_catalog = self.current_catalog
should_switch_catalog = catalog is not None and catalog != current_catalog
quoted = self.compiler.quoted

name = self._quote(name)
expr = (
sg.select(STAR)
.from_(sg.table("schemas", db="sys"))
.where(C.name.eq(sge.convert(name)))
)
stmt = sge.Create(
kind="SCHEMA", this=sg.to_identifier(name, quoted=quoted)
).sql(self.dialect)

create_stmt = (
f"""\
IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = {name})
IF NOT EXISTS ({expr.sql(self.dialect)})
BEGIN
CREATE SCHEMA {name};
{stmt};
END;
GO"""
if force
else f"CREATE SCHEMA {name}"
else stmt
)

with self.begin() as cur:
if should_switch_catalog:
cur.execute(f"USE {self._quote(catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(catalog, quoted=quoted)).sql(
self.dialect
)
)

cur.execute(create_stmt)

if should_switch_catalog:
cur.execute(f"USE {self._quote(current_catalog)}")

def _quote(self, name: str):
return sg.to_identifier(name, quoted=True).sql(self.dialect)
cur.execute(
sge.Use(this=sg.to_identifier(current_catalog, quoted=quoted)).sql(
self.dialect
)
)

def drop_database(
self, name: str, catalog: str | None = None, force: bool = False
) -> None:
current_catalog = self.current_catalog
should_switch_catalog = catalog is not None and catalog != current_catalog

name = self._quote(name)

if_exists = "IF EXISTS " * force
quoted = self.compiler.quoted

with self.begin() as cur:
if should_switch_catalog:
cur.execute(f"USE {self._quote(catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(catalog, quoted=quoted)).sql(
self.dialect
)
)

cur.execute(f"DROP SCHEMA {if_exists}{name}")
cur.execute(
sge.Drop(
kind="SCHEMA",
exists=force,
this=sg.to_identifier(name, quoted=quoted),
).sql(self.dialect)
)

if should_switch_catalog:
cur.execute(f"USE {self._quote(current_catalog)}")
cur.execute(
sge.Use(this=sg.to_identifier(current_catalog, quoted=quoted)).sql(
self.dialect
)
)

def list_tables(
self,
Expand Down Expand Up @@ -448,20 +510,56 @@ def create_table(
temp: bool = False,
overwrite: bool = False,
) -> ir.Table:
"""Create a new table.

Parameters
----------
name
Name of the new table.
obj
An Ibis table expression or pandas table that will be used to
extract the schema and the data of the new table. If not provided,
`schema` must be given.
schema
The schema for the new table. Only one of `schema` or `obj` can be
provided.
database
Name of the database where the table will be created, if not the
default.

To specify a location in a separate catalog, you can pass in the
catalog and database as a string `"catalog.database"`, or as a tuple of
strings `("catalog", "database")`.
temp
Whether a table is temporary or not.
All created temp tables are "Global Temporary Tables". They will be
created in "tempdb.dbo" and will be prefixed with "##".
overwrite
Whether to clobber existing data.
`overwrite` and `temp` cannot be used together with MSSQL.

Returns
-------
Table
The table that was created.

"""
if obj is None and schema is None:
raise ValueError("Either `obj` or `schema` must be specified")

if database is not None and database != self.current_database:
raise com.UnsupportedOperationError(
"Creating tables in other databases is not supported by Postgres"
if temp and overwrite:
raise ValueError(
"MSSQL doesn't support overwriting temp tables, create a new temp table instead."
)
else:
database = None

table_loc = self._to_sqlglot_table(database)
catalog, db = self._to_catalog_db_tuple(table_loc)

properties = []

if temp:
properties.append(sge.TemporaryProperty())
catalog, db = None, None

temp_memtable_view = None
if obj is not None:
Expand Down Expand Up @@ -495,8 +593,10 @@ def create_table(
else:
temp_name = name

table = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted)
raw_table = sg.table(temp_name, catalog=database, quoted=False)
table = sg.table(
"#" * temp + temp_name, catalog=catalog, db=db, quoted=self.compiler.quoted
)
raw_table = sg.table(temp_name, catalog=catalog, db=db, quoted=False)
target = sge.Schema(this=table, expressions=column_defs)

create_stmt = sge.Create(
Expand All @@ -505,11 +605,22 @@ def create_table(
properties=sge.Properties(expressions=properties),
)

this = sg.table(name, catalog=database, quoted=self.compiler.quoted)
raw_this = sg.table(name, catalog=database, quoted=False)
with self._safe_raw_sql(create_stmt) as cur:
this = sg.table(name, catalog=catalog, db=db, quoted=self.compiler.quoted)
raw_this = sg.table(name, catalog=catalog, db=db, quoted=False)
with self._safe_ddl(create_stmt) as cur:
if query is not None:
insert_stmt = sge.Insert(this=table, expression=query).sql(self.dialect)
# You can specify that a table is temporary for the sqlglot `Create` but not
# for the subsequent `Insert`, so we need to shove a `#` in
# front of the table identifier.
_table = sg.table(
"##" * temp + temp_name,
catalog=catalog,
db=db,
quoted=self.compiler.quoted,
)
insert_stmt = sge.Insert(this=_table, expression=query).sql(
self.dialect
)
cur.execute(insert_stmt)

if overwrite:
Expand All @@ -525,11 +636,17 @@ def create_table(
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(name, database=database)
return self.table(
"##" * temp + name,
database=("tempdb" * temp or catalog, "dbo" * temp or db),
)

# preserve the input schema if it was provided
return ops.DatabaseTable(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
name,
schema=schema,
source=self,
namespace=ops.Namespace(catalog=catalog, database=db),
).to_expr()

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
Expand Down Expand Up @@ -570,19 +687,11 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:

df = op.data.to_frame()
data = df.itertuples(index=False)
cols = ", ".join(
ident.sql(self.dialect)
for ident in map(
partial(sg.to_identifier, quoted=quoted), schema.keys()
)
)
specs = ", ".join(repeat("?", len(schema)))
table = sg.table(name, quoted=quoted)
sql = f"INSERT INTO {table.sql(self.dialect)} ({cols}) VALUES ({specs})"

with self._safe_raw_sql(create_stmt) as cur:
insert_stmt = self._build_insert_template(name, schema=schema, columns=True)
with self._safe_ddl(create_stmt) as cur:
if not df.empty:
cur.executemany(sql, data)
cur.executemany(insert_stmt, data)

def _to_sqlglot(
self, expr: ir.Expr, *, limit: str | None = None, params=None, **_: Any
Expand Down
18 changes: 18 additions & 0 deletions ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,21 @@ def test_list_tables_schema_warning_refactor(con):

assert con.list_tables(database="msdb.dbo", like="restore") == restore_tables
assert con.list_tables(database=("msdb", "dbo"), like="restore") == restore_tables


def test_create_temp_table_from_obj(con):
obj = {"team": ["john", "joe"]}

t = con.create_table("team", obj, temp=True)

t2 = con.table("##team", database="tempdb.dbo")

assert t.to_pyarrow().equals(t2.to_pyarrow())

persisted_from_temp = con.create_table("fuhreal", t2)

assert "fuhreal" in con.list_tables()

assert persisted_from_temp.to_pyarrow().equals(t2.to_pyarrow())

con.drop_table("fuhreal")
Loading