53 changes: 28 additions & 25 deletions ibis/backends/impala/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_join_no_predicates_for_impala(con, join_type, snapshot):
t1 = con.table("star1")
t2 = con.table("star2")

joined = getattr(t1, join_type)(t2)[[t1]]
joined = getattr(t1, join_type)(t2).select(t1)
result = ibis.to_sql(joined, dialect="impala")
snapshot.assert_match(result, "out.sql")

Expand Down Expand Up @@ -76,16 +76,16 @@ def test_nested_join_multiple_ctes(snapshot):
movies = ibis.table(dict(movieid="int64", title="string"), name="movies")

expr = ratings.timestamp.cast("timestamp")
ratings2 = ratings["userid", "movieid", "rating", expr.name("datetime")]
joined2 = ratings2.join(movies, ["movieid"])[ratings2, movies["title"]]
ratings2 = ratings.select("userid", "movieid", "rating", expr.name("datetime"))
joined2 = ratings2.join(movies, ["movieid"]).select(ratings2, movies["title"])
joined3 = joined2.filter([joined2.userid == 118205, joined2.datetime.year() > 2001])
top_user_old_movie_ids = joined3.filter(
[joined3.userid == 118205, joined3.datetime.year() < 2009]
)[["movieid"]]
# projection from a filter was hiding an insidious bug, so we're disabling
# that for now see issue #1295
cond = joined3.movieid.isin(top_user_old_movie_ids.movieid)
result = joined3[cond]
result = joined3.filter(cond)
compiled_result = ibis.to_sql(result, dialect="impala")
snapshot.assert_match(compiled_result, "out.sql")

Expand All @@ -109,7 +109,7 @@ def test_join_with_nested_or_condition(snapshot):
t2 = t1.view()

joined = t1.join(t2, [t1.a == t2.a, (t1.a != t2.b) | (t1.b != t2.a)])
expr = joined[t1]
expr = joined.select(t1)
result = ibis.to_sql(expr, dialect="impala")
snapshot.assert_match(result, "out.sql")

Expand All @@ -119,7 +119,7 @@ def test_join_with_nested_xor_condition(snapshot):
t2 = t1.view()

joined = t1.join(t2, [t1.a == t2.a, (t1.a != t2.b) ^ (t1.b != t2.a)])
expr = joined[t1]
expr = joined.select(t1)
result = ibis.to_sql(expr, dialect="impala")
snapshot.assert_match(result, "out.sql")

Expand All @@ -128,15 +128,15 @@ def test_join_with_nested_xor_condition(snapshot):
def test_is_parens(method, snapshot):
t = ibis.table([("a", "string"), ("b", "string")], "table")
func = operator.methodcaller(method)
expr = t[func(t.a) == func(t.b)]
expr = t.filter(func(t.a) == func(t.b))

result = ibis.to_sql(expr, dialect="impala")
snapshot.assert_match(result, "out.sql")


def test_is_parens_identical_to(snapshot):
t = ibis.table([("a", "string"), ("b", "string")], "table")
expr = t[t.a.identical_to(None) == t.b.identical_to(None)]
expr = t.filter(t.a.identical_to(None) == t.b.identical_to(None))

result = ibis.to_sql(expr, dialect="impala")
snapshot.assert_match(result, "out.sql")
Expand All @@ -147,37 +147,37 @@ def test_join_aliasing(snapshot):
[("a", "int64"), ("b", "int64"), ("c", "int64")], name="test_table"
)
test = test.mutate(d=test.a + 20)
test2 = test[test.d, test.c]
test2 = test.select(test.d, test.c)
idx = (test2.d / 15).cast("int64").name("idx")
test3 = test2.group_by([test2.d, idx, test2.c]).aggregate(row_count=test2.count())
test3_totals = test3.group_by(test3.d).aggregate(total=test3.row_count.sum())
test4 = test3.join(test3_totals, test3.d == test3_totals.d)[
test4 = test3.join(test3_totals, test3.d == test3_totals.d).select(
test3, test3_totals.total
]
test5 = test4[test4.row_count < test4.total / 2]
)
test5 = test4.filter(test4.row_count < test4.total / 2)
agg = (
test.group_by([test.d, test.b])
.aggregate(count=test.count(), unique=test.c.nunique())
.view()
)
result = agg.join(test5, agg.d == test5.d)[agg, test5.total]
result = agg.join(test5, agg.d == test5.d).select(agg, test5.total)
result = ibis.to_sql(result, dialect="impala")
snapshot.assert_match(result, "out.sql")


def test_multiple_filters(snapshot):
t = ibis.table([("a", "int64"), ("b", "string")], name="t0")
filt = t[t.a < 100]
expr = filt[filt.a == filt.a.max()]
filt = t.filter(t.a < 100)
expr = filt.filter(filt.a == filt.a.max())
result = ibis.to_sql(expr, dialect="impala")
snapshot.assert_match(result, "out.sql")


def test_multiple_filters2(snapshot):
t = ibis.table([("a", "int64"), ("b", "string")], name="t0")
filt = t[t.a < 100]
expr = filt[filt.a == filt.a.max()]
expr = expr[expr.b == "a"]
filt = t.filter(t.a < 100)
expr = filt.filter(filt.a == filt.a.max())
expr = expr.filter(expr.b == "a")
result = ibis.to_sql(expr, dialect="impala")
snapshot.assert_match(result, "out.sql")

Expand Down Expand Up @@ -250,7 +250,8 @@ def tpch(region, nation, customer, orders):
return (
region.join(nation, region.r_regionkey == nation.n_regionkey)
.join(customer, customer.c_nationkey == nation.n_nationkey)
.join(orders, orders.o_custkey == customer.c_custkey)[fields_of_interest]
.join(orders, orders.o_custkey == customer.c_custkey)
.select(fields_of_interest)
)


Expand All @@ -259,18 +260,20 @@ def test_join_key_name(tpch, snapshot):

pre_sizes = tpch.group_by(year).size()
t2 = tpch.view()
conditional_avg = t2[t2.region == tpch.region].o_totalprice.mean().name("mean")
conditional_avg = (
t2.filter(t2.region == tpch.region).o_totalprice.mean().name("mean")
)
amount_filter = tpch.o_totalprice > conditional_avg
post_sizes = tpch[amount_filter].group_by(year).size()
post_sizes = tpch.filter(amount_filter).group_by(year).size()

percent = (post_sizes[1] / pre_sizes[1].cast("double")).name("fraction")

expr = pre_sizes.join(post_sizes, pre_sizes.year == post_sizes.year)[
expr = pre_sizes.join(post_sizes, pre_sizes.year == post_sizes.year).select(
pre_sizes.year,
pre_sizes[1].name("pre_count"),
post_sizes[1].name("post_count"),
percent,
]
)
result = ibis.impala.compile(expr)
snapshot.assert_match(result, "out.sql")

Expand All @@ -281,11 +284,11 @@ def test_join_key_name2(tpch, snapshot):
pre_sizes = tpch.group_by(year).size()
post_sizes = tpch.group_by(year).size().view()

expr = pre_sizes.join(post_sizes, pre_sizes.year == post_sizes.year)[
expr = pre_sizes.join(post_sizes, pre_sizes.year == post_sizes.year).select(
pre_sizes.year,
pre_sizes[1].name("pre_count"),
post_sizes[1].name("post_count"),
]
)
result = ibis.impala.compile(expr)
snapshot.assert_match(result, "out.sql")

Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/impala/tests/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_decimal_casts(table, expr_fn, snapshot):
snapshot.assert_match(result, "out.sql")


@pytest.mark.parametrize("colname", ["a", "f", "h"])
@pytest.mark.parametrize("colname", ["a", "f"])
def test_negate(table, colname, snapshot):
result = translate(-table[colname])
snapshot.assert_match(result, "out.sql")
Expand All @@ -175,11 +175,11 @@ def test_timestamp_extract_field(table, field, snapshot):

def test_sql_extract(table, snapshot):
# integration with SQL translation
expr = table[
expr = table.select(
table.i.year().name("year"),
table.i.month().name("month"),
table.i.day().name("day"),
]
)

result = ibis.to_sql(expr, dialect="impala")
snapshot.assert_match(result, "out.sql")
Expand Down Expand Up @@ -252,8 +252,8 @@ def test_correlated_predicate_subquery(table, snapshot):
t1 = t0.view()

# both are valid constructions
expr1 = t0[t0.g == t1.g]
expr2 = t1[t0.g == t1.g]
expr1 = t0.filter(t0.g == t1.g)
expr2 = t1.filter(t0.g == t1.g)

snapshot.assert_match(translate(expr1), "out1.sql")
snapshot.assert_match(translate(expr2), "out2.sql")
Expand All @@ -263,9 +263,9 @@ def test_correlated_predicate_subquery(table, snapshot):
"expr_fn",
[
param(lambda b: b.any(), id="any"),
param(lambda b: -b.any(), id="not_any"),
param(lambda b: ~b.any(), id="not_any"),
param(lambda b: b.all(), id="all"),
param(lambda b: -b.all(), id="not_all"),
param(lambda b: ~b.all(), id="not_all"),
],
)
def test_any_all(table, expr_fn, snapshot):
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/impala/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def assert_sql_equal(expr, snapshot, out="out.sql"):

def test_aggregate_in_projection(alltypes, snapshot):
t = alltypes
proj = t[t, (t.f / t.f.sum()).name("normed_f")]
proj = t.select(t, (t.f / t.f.sum()).name("normed_f"))
assert_sql_equal(proj, snapshot)


Expand Down Expand Up @@ -93,7 +93,7 @@ def test_nested_analytic_function(alltypes, snapshot):
def test_rank_functions(alltypes, snapshot):
t = alltypes

proj = t[t.g, t.f.rank().name("minr"), t.f.dense_rank().name("denser")]
proj = t.select(t.g, t.f.rank().name("minr"), t.f.dense_rank().name("denser"))
assert_sql_equal(proj, snapshot)


Expand All @@ -113,7 +113,7 @@ def test_order_by_desc(alltypes, snapshot):

w = window(order_by=ibis.desc(t.f))

proj = t[t.f, ibis.row_number().over(w).name("revrank")]
proj = t.select(t.f, ibis.row_number().over(w).name("revrank"))
assert_sql_equal(proj, snapshot, "out1.sql")

expr = t.group_by("g").order_by(ibis.desc(t.f))[t.d.lag().name("foo"), t.a.max()]
Expand Down
132 changes: 75 additions & 57 deletions ibis/backends/mssql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,41 @@ def do_connect(
See https://learn.microsoft.com/en-us/sql/connect/odbc/windows/system-requirements-installation-and-driver-files
kwargs
Additional keyword arguments to pass to PyODBC.
Examples
--------
>>> import os
>>> import ibis
>>> host = os.environ.get("IBIS_TEST_MSSQL_HOST", "localhost")
>>> user = os.environ.get("IBIS_TEST_MSSQL_USER", "sa")
>>> password = os.environ.get("IBIS_TEST_MSSQL_PASSWORD", "1bis_Testing!")
>>> database = os.environ.get("IBIS_TEST_MSSQL_DATABASE", "ibis_testing")
>>> driver = os.environ.get("IBIS_TEST_MSSQL_PYODBC_DRIVER", "FreeTDS")
>>> con = ibis.mssql.connect(
... database=database,
... host=host,
... user=user,
... password=password,
... driver=driver,
... )
>>> con.list_tables() # doctest: +ELLIPSIS
[...]
>>> t = con.table("functional_alltypes")
>>> t
DatabaseTable: functional_alltypes
id int32
bool_col boolean
tinyint_col int16
smallint_col int16
int_col int32
bigint_col int64
float_col float32
double_col float64
date_string_col string
string_col string
timestamp_col timestamp(7)
year int32
month int32
"""

# If no user/password given, assume Windows Integrated Authentication
Expand Down Expand Up @@ -307,6 +342,9 @@ def _get_schema_using_query(self, query: str) -> sch.Schema:
elif newtyp.is_timestamp():
newtyp = newtyp.copy(scale=scale)

if name is None:
name = util.gen_name("col")

schema[name] = newtyp

return sch.Schema(schema)
Expand Down Expand Up @@ -622,11 +660,9 @@ def create_table(
properties.append(sge.TemporaryProperty())
catalog, db = None, None

temp_memtable_view = None
if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
else:
table = obj

Expand All @@ -636,37 +672,30 @@ def create_table(
else:
query = None

column_defs = [
sge.ColumnDef(
this=sg.to_identifier(colname, quoted=self.compiler.quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())]
),
)
for colname, typ in (schema or table.schema()).items()
]

if overwrite:
temp_name = util.gen_name(f"{self.name}_table")
else:
temp_name = name

table = sg.table(
"#" * temp + temp_name, catalog=catalog, db=db, quoted=self.compiler.quoted
)
if not schema:
schema = table.schema()

quoted = self.compiler.quoted
raw_table = sg.table(temp_name, catalog=catalog, db=db, quoted=False)
target = sge.Schema(this=table, expressions=column_defs)
target = sge.Schema(
this=sg.table(
"#" * temp + temp_name, catalog=catalog, db=db, quoted=quoted
),
expressions=schema.to_sqlglot(self.dialect),
)

create_stmt = sge.Create(
kind="TABLE",
this=target,
properties=sge.Properties(expressions=properties),
)

this = sg.table(name, catalog=catalog, db=db, quoted=self.compiler.quoted)
this = sg.table(name, catalog=catalog, db=db, quoted=quoted)
raw_this = sg.table(name, catalog=catalog, db=db, quoted=False)
with self._safe_ddl(create_stmt) as cur:
if query is not None:
Expand Down Expand Up @@ -699,10 +728,6 @@ def create_table(
db = "dbo"

if schema is None:
# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(name, database=(catalog, db))

# preserve the input schema if it was provided
Expand All @@ -713,6 +738,16 @@ def create_table(
namespace=ops.Namespace(catalog=catalog, database=db),
).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
# The single character U here means user-defined table
# see https://learn.microsoft.com/en-us/sql/relational-databases/system-catalog-views/sys-objects-transact-sql?view=sql-server-ver16
sql = sg.select(sg.func("object_id", sge.convert(name), sge.convert("U"))).sql(
self.dialect
)
with self.begin() as cur:
[(result,)] = cur.execute(sql).fetchall()
return result is not None

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand All @@ -721,41 +756,24 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
f"got null typed columns: {null_columns}"
)

# only register if we haven't already done so
if (name := op.name) not in self.list_tables():
quoted = self.compiler.quoted
column_defs = [
sg.exp.ColumnDef(
this=sg.to_identifier(colname, quoted=quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [
sg.exp.ColumnConstraint(
kind=sg.exp.NotNullColumnConstraint()
)
]
),
)
for colname, typ in schema.items()
]

create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted), expressions=column_defs
),
# properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]),
)
name = op.name
quoted = self.compiler.quoted

df = op.data.to_frame()
data = df.itertuples(index=False)
create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted),
expressions=schema.to_sqlglot(self.dialect),
),
)

insert_stmt = self._build_insert_template(name, schema=schema, columns=True)
with self._safe_ddl(create_stmt) as cur:
if not df.empty:
cur.executemany(insert_stmt, data)
df = op.data.to_frame()
data = df.itertuples(index=False)

insert_stmt = self._build_insert_template(name, schema=schema, columns=True)
with self._safe_ddl(create_stmt) as cur:
if not df.empty:
cur.executemany(insert_stmt, data)

def _cursor_batches(
self,
Expand Down
26 changes: 25 additions & 1 deletion ibis/backends/mssql/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def count_big(x, where: bool = True) -> int:
expr = count_big(ft.id)

expr = count_big(ft.id, where=ft.id == 1)
assert expr.execute() == ft[ft.id == 1].count().execute()
assert expr.execute() == ft.filter(ft.id == 1).count().execute()


@pytest.mark.parametrize("string", ["a", " ", "a ", " a", ""])
Expand Down Expand Up @@ -241,3 +241,27 @@ def test_from_url():
)
result = new_con.sql("SELECT 1 AS [a]").to_pandas().a.iat[0]
assert result == 1


def test_dot_sql_with_unnamed_columns(con):
expr = con.sql(
"SELECT CAST('2024-01-01 00:00:00' AS DATETIMEOFFSET), 'a' + 'b', 1 AS [col42]"
)

schema = expr.schema()
names = schema.names

assert len(names) == 3

assert names[0].startswith("ibis_col")
assert names[1].startswith("ibis_col")
assert names[2] == "col42"

assert schema.types == (
dt.Timestamp(timezone="UTC", scale=7),
dt.String(nullable=False),
dt.Int32(nullable=False),
)

df = expr.execute()
assert len(df) == 1
154 changes: 77 additions & 77 deletions ibis/backends/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pymysql
import sqlglot as sg
import sqlglot.expressions as sge
from pymysql.constants import ER

import ibis
import ibis.backends.sql.compilers as sc
Expand Down Expand Up @@ -122,33 +123,30 @@ def do_connect(
Examples
--------
>>> import os
>>> import getpass
>>> import ibis
>>> host = os.environ.get("IBIS_TEST_MYSQL_HOST", "localhost")
>>> user = os.environ.get("IBIS_TEST_MYSQL_USER", getpass.getuser())
>>> password = os.environ.get("IBIS_TEST_MYSQL_PASSWORD")
>>> user = os.environ.get("IBIS_TEST_MYSQL_USER", "ibis")
>>> password = os.environ.get("IBIS_TEST_MYSQL_PASSWORD", "ibis")
>>> database = os.environ.get("IBIS_TEST_MYSQL_DATABASE", "ibis_testing")
>>> con = connect(database=database, host=host, user=user, password=password)
>>> con = ibis.mysql.connect(database=database, host=host, user=user, password=password)
>>> con.list_tables() # doctest: +ELLIPSIS
[...]
>>> t = con.table("functional_alltypes")
>>> t
MySQLTable[table]
name: functional_alltypes
schema:
id : int32
bool_col : int8
tinyint_col : int8
smallint_col : int16
int_col : int32
bigint_col : int64
float_col : float32
double_col : float64
date_string_col : string
string_col : string
timestamp_col : timestamp
year : int32
month : int32
DatabaseTable: functional_alltypes
id int32
bool_col int8
tinyint_col int8
smallint_col int16
int_col int32
bigint_col int64
float_col float32
double_col float64
date_string_col string
string_col string
timestamp_col timestamp
year int32
month int32
"""
self.con = pymysql.connect(
user=user,
Expand Down Expand Up @@ -417,26 +415,18 @@ def create_table(
else:
query = None

column_defs = [
sge.ColumnDef(
this=sg.to_identifier(colname, quoted=self.compiler.quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())]
),
)
for colname, typ in (schema or table.schema()).items()
]

if overwrite:
temp_name = util.gen_name(f"{self.name}_table")
else:
temp_name = name

table = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted)
target = sge.Schema(this=table, expressions=column_defs)
if not schema:
schema = table.schema()

table_expr = sg.table(temp_name, catalog=database, quoted=self.compiler.quoted)
target = sge.Schema(
this=table_expr, expressions=schema.to_sqlglot(self.dialect)
)

create_stmt = sge.Create(
kind="TABLE",
Expand All @@ -447,15 +437,17 @@ def create_table(
this = sg.table(name, catalog=database, quoted=self.compiler.quoted)
with self._safe_raw_sql(create_stmt) as cur:
if query is not None:
insert_stmt = sge.Insert(this=table, expression=query).sql(self.name)
insert_stmt = sge.Insert(this=table_expr, expression=query).sql(
self.name
)
cur.execute(insert_stmt)

if overwrite:
cur.execute(
sge.Drop(kind="TABLE", this=this, exists=True).sql(self.name)
)
cur.execute(
f"ALTER TABLE IF EXISTS {table.sql(self.name)} RENAME TO {this.sql(self.name)}"
f"ALTER TABLE IF EXISTS {table_expr.sql(self.name)} RENAME TO {this.sql(self.name)}"
)

if schema is None:
Expand All @@ -471,6 +463,23 @@ def create_table(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
name = sg.to_identifier(name, quoted=self.compiler.quoted).sql(self.dialect)
# just return the single field with column names; no need to bring back
# everything if the command succeeds
sql = f"SHOW COLUMNS FROM {name} LIKE 'Field'"
try:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except pymysql.err.ProgrammingError as e:
err_code, _ = e.args
if err_code == ER.NO_SUCH_TABLE:
return False
raise
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand All @@ -479,48 +488,32 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
f"got null typed columns: {null_columns}"
)

# only register if we haven't already done so
if (name := op.name) not in self.list_tables():
quoted = self.compiler.quoted
column_defs = [
sg.exp.ColumnDef(
this=sg.to_identifier(colname, quoted=quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [
sg.exp.ColumnConstraint(
kind=sg.exp.NotNullColumnConstraint()
)
]
),
)
for colname, typ in schema.items()
]

create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted), expressions=column_defs
),
properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]),
)
create_stmt_sql = create_stmt.sql(self.name)
name = op.name
quoted = self.compiler.quoted

df = op.data.to_frame()
# nan can not be used with MySQL
df = df.replace(float("nan"), None)
create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted),
expressions=schema.to_sqlglot(self.dialect),
),
properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]),
)
create_stmt_sql = create_stmt.sql(self.name)

data = df.itertuples(index=False)
sql = self._build_insert_template(
name, schema=schema, columns=True, placeholder="%s"
)
with self.begin() as cur:
cur.execute(create_stmt_sql)
df = op.data.to_frame()
# nan can not be used with MySQL
df = df.replace(float("nan"), None)

if not df.empty:
cur.executemany(sql, data)
data = df.itertuples(index=False)
sql = self._build_insert_template(
name, schema=schema, columns=True, placeholder="%s"
)
with self.begin() as cur:
cur.execute(create_stmt_sql)

if not df.empty:
cur.executemany(sql, data)

@util.experimental
def to_pyarrow_batches(
Expand Down Expand Up @@ -564,3 +557,10 @@ def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
raise
df = MySQLPandasData.convert_table(df, schema)
return df

def _finalize_memtable(self, name: str) -> None:
"""No-op.
Executing **any** SQL in a finalizer causes the underlying connection
socket to be set to `None`. It is unclear why this happens.
"""
142 changes: 77 additions & 65 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import atexit
import contextlib
import re
import warnings
Expand All @@ -25,7 +24,7 @@
from ibis import util
from ibis.backends import CanListDatabase, CanListSchema
from ibis.backends.sql import SQLBackend
from ibis.backends.sql.compilers.base import STAR, C
from ibis.backends.sql.compilers.base import NULL, STAR, C

if TYPE_CHECKING:
from urllib.parse import ParseResult
Expand Down Expand Up @@ -122,6 +121,33 @@ def do_connect(
An Oracle Data Source Name. If provided, overrides all other
connection arguments except username and password.
Examples
--------
>>> import os
>>> import ibis
>>> host = os.environ.get("IBIS_TEST_ORACLE_HOST", "localhost")
>>> user = os.environ.get("IBIS_TEST_ORACLE_USER", "ibis")
>>> password = os.environ.get("IBIS_TEST_ORACLE_PASSWORD", "ibis")
>>> database = os.environ.get("IBIS_TEST_ORACLE_DATABASE", "IBIS_TESTING")
>>> con = ibis.oracle.connect(database=database, host=host, user=user, password=password)
>>> con.list_tables() # doctest: +ELLIPSIS
[...]
>>> t = con.table("functional_alltypes")
>>> t
DatabaseTable: functional_alltypes
id int64
bool_col int64
tinyint_col int64
smallint_col int64
int_col int64
bigint_col int64
float_col float64
double_col float64
date_string_col string
string_col string
timestamp_col timestamp(3)
year int64
month int64
"""
# SID: unique name of an INSTANCE running an oracle process (a single, identifiable machine)
# service name: an ALIAS to one (or many) individual instances that can
Expand Down Expand Up @@ -419,11 +445,9 @@ def create_table(
if temp:
properties.append(sge.TemporaryProperty())

temp_memtable_view = None
if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
else:
table = obj

Expand All @@ -433,26 +457,16 @@ def create_table(
else:
query = None

column_defs = [
sge.ColumnDef(
this=sg.to_identifier(colname, quoted=self.compiler.quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())]
),
)
for colname, typ in (schema or table.schema()).items()
]

if overwrite:
temp_name = util.gen_name(f"{self.name}_table")
else:
temp_name = name

initial_table = sg.table(temp_name, db=database, quoted=self.compiler.quoted)
target = sge.Schema(this=initial_table, expressions=column_defs)
target = sge.Schema(
this=initial_table,
expressions=(schema or table.schema()).to_sqlglot(self.dialect),
)

create_stmt = sge.Create(
kind="TABLE",
Expand All @@ -478,10 +492,6 @@ def create_table(
)

if schema is None:
# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(name, database=database)

# preserve the input schema if it was provided
Expand Down Expand Up @@ -512,49 +522,45 @@ def drop_table(

super().drop_table(name, database=(catalog, db), force=force)

def _in_memory_table_exists(self, name: str) -> bool:
sql = (
sg.select(NULL)
.from_(sg.to_identifier("USER_OBJECTS", quoted=self.compiler.quoted))
.where(
C.OBJECT_TYPE.eq(sge.convert("TABLE")),
C.OBJECT_NAME.eq(sge.convert(name)),
)
.limit(sge.convert(1))
.sql(self.dialect)
)
with self.begin() as cur:
results = cur.execute(sql).fetchall()
return bool(results)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema

# only register if we haven't already done so
if (name := op.name) not in self.list_tables():
quoted = self.compiler.quoted
column_defs = [
sge.ColumnDef(
this=sg.to_identifier(colname, quoted=quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [
sg.exp.ColumnConstraint(
kind=sg.exp.NotNullColumnConstraint()
)
]
),
name = op.name
quoted = self.compiler.quoted
create_stmt = sge.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted),
expressions=schema.to_sqlglot(self.dialect),
),
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
).sql(self.name)

data = op.data.to_frame().replace(float("nan"), None)
insert_stmt = self._build_insert_template(
name, schema=schema, placeholder=":{i:d}"
)
with self.begin() as cur:
cur.execute(create_stmt)
for start, end in util.chunks(len(data), chunk_size=128):
cur.executemany(
insert_stmt, list(data.iloc[start:end].itertuples(index=False))
)
for colname, typ in schema.items()
]

create_stmt = sge.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted), expressions=column_defs
),
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
).sql(self.name)

data = op.data.to_frame().replace(float("nan"), None)
insert_stmt = self._build_insert_template(
name, schema=schema, placeholder=":{i:d}"
)
with self.begin() as cur:
cur.execute(create_stmt)
for start, end in util.chunks(len(data), chunk_size=128):
cur.executemany(
insert_stmt, list(data.iloc[start:end].itertuples(index=False))
)

atexit.register(self._clean_up_tmp_table, name)

def _get_schema_using_query(self, query: str) -> sch.Schema:
name = util.gen_name("oracle_metadata")
Expand Down Expand Up @@ -635,6 +641,13 @@ def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
return OraclePandasData.convert_table(df, schema)

def _clean_up_tmp_table(self, name: str) -> None:
dialect = self.dialect

ident = sg.to_identifier(name, quoted=self.compiler.quoted)

truncate = sge.TruncateTable(expressions=[ident]).sql(dialect)
drop = sge.Drop(kind="TABLE", this=ident).sql(dialect)

with self.begin() as bind:
# global temporary tables cannot be dropped without first truncating them
#
Expand All @@ -643,9 +656,8 @@ def _clean_up_tmp_table(self, name: str) -> None:
# ignore DatabaseError exceptions because the table may not exist
# because it's already been deleted
with contextlib.suppress(oracledb.DatabaseError):
bind.execute(f'TRUNCATE TABLE "{name}"')
bind.execute(truncate)
with contextlib.suppress(oracledb.DatabaseError):
bind.execute(f'DROP TABLE "{name}"')
bind.execute(drop)

def _drop_cached_table(self, name):
self._clean_up_tmp_table(name)
_finalize_memtable = _drop_cached_table = _clean_up_tmp_table
8 changes: 5 additions & 3 deletions ibis/backends/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ def do_connect(
Examples
--------
>>> import ibis
>>> ibis.pandas.connect({"t": pd.DataFrame({"a": [1, 2, 3]})})
<ibis.backends.pandas.Backend at 0x...>
>>> ibis.pandas.connect({"t": pd.DataFrame({"a": [1, 2, 3]})}) # doctest: +ELLIPSIS
<ibis.backends.pandas.Backend object at 0x...>
"""
warnings.warn(
f"The {self.name} backend is slated for removal in 10.0.",
Expand Down Expand Up @@ -331,6 +330,9 @@ def execute(self, query, params=None, limit="default", **kwargs):
def _create_cached_table(self, name, expr):
return self.create_table(name, expr.execute())

def _finalize_memtable(self, name: str) -> None:
"""No-op, let Python handle clean up."""


@lazy_singledispatch
def _convert_object(obj: Any, _conn):
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pandas/tests/test_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_array_collect_grouped(t, df):
def test_array_collect_rolling_partitioned(t, df):
window = ibis.trailing_window(1, order_by=t.plain_int64)
colexpr = t.plain_float64.collect().over(window)
expr = t["dup_strings", "plain_int64", colexpr.name("collected")]
expr = t.select("dup_strings", "plain_int64", colexpr.name("collected"))
result = expr.execute()
expected = pd.DataFrame(
{
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_array_slice_scalar(client, start, stop):

@pytest.mark.parametrize("index", [1, 3, 4, 11, -11])
def test_array_index(t, df, index):
expr = t[t.array_of_float64[index].name("indexed")]
expr = t.select(t.array_of_float64[index].name("indexed"))
result = expr.execute()
expected = pd.DataFrame(
{
Expand Down
48 changes: 26 additions & 22 deletions ibis/backends/pandas/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

@mutating_join_type
def test_join(how, left, right, df1, df2):
expr = left.join(right, left.key == right.key, how=how)[
expr = left.join(right, left.key == right.key, how=how).select(
left, right.other_value, right.key3
]
)
result = expr.execute()
expected = pd.merge(df1, df2, how=how, on="key")
tm.assert_frame_equal(result[expected.columns], expected)


def test_cross_join(left, right, df1, df2):
expr = left.cross_join(right)[left, right.other_value, right.key3]
expr = left.cross_join(right).select(left, right.other_value, right.key3)
result = expr.execute()
expected = pd.merge(
df1.assign(dummy=1), df2.assign(dummy=1), how="inner", on="dummy"
Expand All @@ -37,14 +37,14 @@ def test_cross_join(left, right, df1, df2):

@mutating_join_type
def test_join_project_left_table(how, left, right, df1, df2):
expr = left.join(right, left.key == right.key, how=how)[left, right.key3]
expr = left.join(right, left.key == right.key, how=how).select(left, right.key3)
result = expr.execute()
expected = pd.merge(df1, df2, how=how, on="key")[list(left.columns) + ["key3"]]
tm.assert_frame_equal(result[expected.columns], expected)


def test_cross_join_project_left_table(left, right, df1, df2):
expr = left.cross_join(right)[left, right.key3]
expr = left.cross_join(right).select(left, right.key3)
result = expr.execute()
expected = pd.merge(
df1.assign(dummy=1), df2.assign(dummy=1), how="inner", on="dummy"
Expand All @@ -67,9 +67,9 @@ def test_cross_join_project_left_table(left, right, df1, df2):
],
)
def test_join_with_multiple_predicates(how, left, right, df1, df2):
expr = left.join(right, [left.key == right.key, left.key2 == right.key3], how=how)[
left, right.key3, right.other_value
]
expr = left.join(
right, [left.key == right.key, left.key2 == right.key3], how=how
).select(left, right.key3, right.other_value)
result = expr.execute()
expected = pd.merge(
df1,
Expand Down Expand Up @@ -110,7 +110,9 @@ def test_join_with_multiple_predicates(how, left, right, df1, df2):
)
def test_join_with_multiple_predicates_written_as_one(how, left, right, df1, df2):
predicate = (left.key == right.key) & (left.key2 == right.key3)
expr = left.join(right, predicate, how=how)[left, right.key3, right.other_value]
expr = left.join(right, predicate, how=how).select(
left, right.key3, right.other_value
)
result = expr.execute()
expected = pd.merge(
df1, df2, how=how, left_on=["key", "key2"], right_on=["key", "key3"]
Expand Down Expand Up @@ -155,7 +157,9 @@ def test_join_with_duplicate_non_key_columns_not_selected(how, left, right, df1,
left = left.mutate(x=left.value * 2)
right = right.mutate(x=right.other_value * 3)
right = right[["key", "other_value"]]
expr = left.join(right, left.key == right.key, how=how)[left, right.other_value]
expr = left.join(right, left.key == right.key, how=how).select(
left, right.other_value
)
result = expr.execute()
expected = pd.merge(
df1.assign(x=df1.value * 2),
Expand All @@ -169,7 +173,7 @@ def test_join_with_duplicate_non_key_columns_not_selected(how, left, right, df1,
@mutating_join_type
def test_join_with_post_expression_selection(how, left, right, df1, df2):
join = left.join(right, left.key == right.key, how=how)
expr = join[left.key, left.value, right.other_value]
expr = join.select(left.key, left.value, right.other_value)
result = expr.execute()
expected = pd.merge(df1, df2, on="key", how=how)[["key", "value", "other_value"]]
tm.assert_frame_equal(result[expected.columns], expected)
Expand All @@ -181,8 +185,8 @@ def test_join_with_post_expression_filter(how, left):
rhs = left[["key2", "value"]]

joined = lhs.join(rhs, "key2", how=how)
projected = joined[lhs, rhs.value]
expr = projected[projected.value == 4]
projected = joined.select(lhs, rhs.value)
expr = projected.filter(projected.value == 4)
result = expr.execute()

df1 = lhs.execute()
Expand All @@ -200,12 +204,12 @@ def test_multi_join_with_post_expression_filter(how, left, df1):
rhs2 = left[["key2", "value"]].rename(value2="value")

joined = lhs.join(rhs, "key2", how=how)
projected = joined[lhs, rhs.value]
filtered = projected[projected.value == 4]
projected = joined.select(lhs, rhs.value)
filtered = projected.filter(projected.value == 4)

joined2 = filtered.join(rhs2, "key2")
projected2 = joined2[filtered.key, rhs2.value2]
expr = projected2[projected2.value2 == 3]
projected2 = joined2.select(filtered.key, rhs2.value2)
expr = projected2.filter(projected2.value2 == 3)

result = expr.execute()

Expand All @@ -224,7 +228,7 @@ def test_multi_join_with_post_expression_filter(how, left, df1):
def test_join_with_non_trivial_key(how, left, right, df1, df2):
# also test that the order of operands in the predicate doesn't matter
join = left.join(right, right.key.length() == left.key.length(), how=how)
expr = join[left.key, left.value, right.other_value]
expr = join.select(left.key, left.value, right.other_value)
result = expr.execute()

expected = (
Expand All @@ -244,8 +248,8 @@ def test_join_with_non_trivial_key(how, left, right, df1, df2):
def test_join_with_non_trivial_key_project_table(how, left, right, df1, df2):
# also test that the order of operands in the predicate doesn't matter
join = left.join(right, right.key.length() == left.key.length(), how=how)
expr = join[left, right.other_value]
expr = expr[expr.key.length() == 1]
expr = join.select(left, right.other_value)
expr = expr.filter(expr.key.length() == 1)
result = expr.execute()

expected = (
Expand All @@ -267,7 +271,7 @@ def test_join_with_project_right_duplicate_column(client, how, left, df1, df3):
# also test that the order of operands in the predicate doesn't matter
right = client.table("df3")
join = left.join(right, ["key"], how=how)
expr = join[left.key, right.key2, right.other_value]
expr = join.select(left.key, right.key2, right.other_value)
result = expr.execute()

expected = (
Expand All @@ -283,7 +287,7 @@ def test_join_with_window_function(players_base, players_df, batting, batting_df

# this should be semi_join
tbl = batting.left_join(players, ["playerID"])
t = tbl[batting.G, batting.playerID, batting.teamID]
t = tbl.select(batting.G, batting.playerID, batting.teamID)
expr = t.group_by(t.teamID).mutate(
team_avg=lambda d: d.G.mean(),
demeaned_by_player=lambda d: d.G - d.G.mean(),
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/pandas/tests/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def test_literal(client):


def test_selection(t, df):
expr = t[((t.plain_strings == "a") | (t.plain_int64 == 3)) & (t.dup_strings == "d")]
expr = t.filter(
((t.plain_strings == "a") | (t.plain_int64 == 3)) & (t.dup_strings == "d")
)
result = expr.execute()
expected = df[
((df.plain_strings == "a") | (df.plain_int64 == 3)) & (df.dup_strings == "d")
Expand All @@ -45,12 +47,10 @@ def test_mutate(t, df):

def test_project_scope_does_not_override(t, df):
col = t.plain_int64
expr = t[
[
col.name("new_col"),
col.sum().over(ibis.window(group_by="dup_strings")).name("grouped"),
]
]
expr = t.select(
col.name("new_col"),
col.sum().over(ibis.window(group_by="dup_strings")).name("grouped"),
)
result = expr.execute()
expected = pd.concat(
[
Expand Down
10 changes: 5 additions & 5 deletions ibis/backends/pandas/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_players(players, players_df):


def test_batting_filter_mean(batting, batting_df):
expr = batting[batting.G > batting.G.mean()]
expr = batting.filter(batting.G > batting.G.mean())
result = expr.execute()
expected = batting_df[batting_df.G > batting_df.G.mean()].reset_index(drop=True)
tm.assert_frame_equal(result[expected.columns], expected)
Expand Down Expand Up @@ -361,7 +361,7 @@ def test_mutate_with_window_after_join(sort_kind):
left, right = map(con.table, ("left", "right"))

joined = left.outer_join(right, left.ints == right.group)
proj = joined[left, right.value]
proj = joined.select(left, right.value)
expr = proj.group_by("ints").mutate(sum=proj.value.sum())
result = expr.execute()
expected = pd.DataFrame(
Expand Down Expand Up @@ -390,7 +390,7 @@ def test_mutate_scalar_with_window_after_join():
left, right = map(con.table, ("left", "right"))

joined = left.outer_join(right, left.ints == right.group)
proj = joined[left, right.value]
proj = joined.select(left, right.value)
expr = proj.mutate(sum=proj.value.sum(), const=ibis.literal(1))
result = expr.execute()
expected = pd.DataFrame(
Expand All @@ -416,8 +416,8 @@ def test_project_scalar_after_join():
left, right = map(con.table, ("left", "right"))

joined = left.outer_join(right, left.ints == right.group)
proj = joined[left, right.value]
expr = proj[proj.value.sum().name("sum"), ibis.literal(1).name("const")]
proj = joined.select(left, right.value)
expr = proj.select(proj.value.sum().name("sum"), ibis.literal(1).name("const"))
result = expr.execute()
expected = pd.DataFrame(
{
Expand Down
40 changes: 38 additions & 2 deletions ibis/backends/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,25 @@ def do_connect(
tables
An optional mapping of string table names to polars LazyFrames.
Examples
--------
>>> import ibis
>>> import polars as pl
>>> ibis.options.interactive = True
>>> lazy_frame = pl.LazyFrame(
... {"name": ["Jimmy", "Keith"], "band": ["Led Zeppelin", "Stones"]}
... )
>>> con = ibis.polars.connect(tables={"band_members": lazy_frame})
>>> t = con.table("band_members")
>>> t
┏━━━━━━━━┳━━━━━━━━━━━━━━┓
┃ name ┃ band ┃
┡━━━━━━━━╇━━━━━━━━━━━━━━┩
│ string │ string │
├────────┼──────────────┤
│ Jimmy │ Led Zeppelin │
│ Keith │ Stones │
└────────┴──────────────┘
"""
if tables is not None and not isinstance(tables, Mapping):
raise TypeError("Input to ibis.polars.connect must be a mapping")
Expand All @@ -75,6 +94,15 @@ def table(self, name: str) -> ir.Table:
schema = sch.infer(self._tables[name])
return ops.DatabaseTable(name, schema, self).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
return name in self._tables

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
self._add_table(op.name, op.data.to_polars(op.schema).lazy())

def _finalize_memtable(self, name: str) -> None:
self.drop_table(name, force=True)

@deprecated(
as_of="9.1",
instead="use the explicit `read_*` method for the filetype you are trying to read, e.g., read_parquet, read_csv, etc.",
Expand Down Expand Up @@ -466,12 +494,20 @@ def _to_dataframe(
streaming: bool = False,
**kwargs: Any,
) -> pl.DataFrame:
lf = self.compile(expr, params=params, **kwargs)
self._run_pre_execute_hooks(expr)
table_expr = expr.as_table()
lf = self.compile(table_expr, params=params, **kwargs)
if limit == "default":
limit = ibis.options.sql.default_limit
if limit is not None:
lf = lf.limit(limit)
return lf.collect(streaming=streaming)
df = lf.collect(streaming=streaming)
# XXX: Polars sometimes returns data with the incorrect column names.
# For now we catch this case and rename them here if needed.
expected_cols = tuple(table_expr.columns)
if tuple(df.columns) != expected_cols:
df = df.rename(dict(zip(df.columns, expected_cols)))
return df

def execute(
self,
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
from math import isnan

import polars as pl
import sqlglot as sg

import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.pandas.rewrites import PandasAsofJoin, PandasJoin, PandasRename
from ibis.backends.sql.compilers.base import STAR
from ibis.backends.sql.dialects import Polars
from ibis.expr.operations.udf import InputType
from ibis.formats.polars import PolarsType
from ibis.util import gen_name
Expand Down Expand Up @@ -64,8 +67,9 @@ def dummy_table(op, **kw):


@translate.register(ops.InMemoryTable)
def in_memory_table(op, **_):
return op.data.to_polars(op.schema).lazy()
def in_memory_table(op, *, ctx, **_):
sql = sg.select(STAR).from_(sg.to_identifier(op.name, quoted=True)).sql(Polars)
return ctx.execute(sql, eager=False)


def _make_duration(value, dtype):
Expand Down
16 changes: 5 additions & 11 deletions ibis/backends/polars/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any

import numpy as np
import polars as pl
import pytest

import ibis
Expand All @@ -23,16 +22,11 @@ def _load_data(self, **_: Any) -> None:
con = self.connection
for table_name in TEST_TABLES:
path = self.data_dir / "parquet" / f"{table_name}.parquet"
with pytest.warns(FutureWarning, match="v9.1"):
con.register(path, table_name=table_name)
# TODO: remove warnings and replace register when implementing 8858
with pytest.warns(FutureWarning, match="v9.1"):
con.register(array_types, table_name="array_types")
con.register(struct_types, table_name="struct")
con.register(win, table_name="win")

# TODO: remove when pyarrow inputs are supported
con._add_table("topk", pl.from_arrow(topk).lazy())
con.read_parquet(path, table_name=table_name)
con.create_table("array_types", array_types)
con.create_table("struct", struct_types)
con.create_table("win", win)
con.create_table("topk", topk)

@staticmethod
def connect(*, tmpdir, worker_id, **kw):
Expand Down
23 changes: 23 additions & 0 deletions ibis/backends/polars/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import polars as pl
import polars.testing
import pytest

import ibis
Expand Down Expand Up @@ -37,3 +39,24 @@ def test_array_flatten(con):
{"id": data["id"], "flat": [row[0] for row in data["happy"]]}
)
tm.assert_frame_equal(result.to_pandas(), expected)


def test_memtable_polars_types(con):
# Check that we can create a memtable with some polars-specific types,
# and that those columns then work in downstream operations
df = pl.DataFrame(
{
"x": ["a", "b", "a"],
"y": ["c", "d", "c"],
"z": ["e", "f", "e"],
},
schema={
"x": pl.String,
"y": pl.Categorical,
"z": pl.Enum(["e", "f"]),
},
)
t = ibis.memtable(df)
res = con.to_polars((t.x + t.y + t.z).name("test"))
sol = (df["x"] + df["y"] + df["z"]).rename("test")
pl.testing.assert_series_equal(res, sol)
125 changes: 58 additions & 67 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ def _from_url(self, url: ParseResult, **kwargs):

return self.connect(**kwargs)

def _in_memory_table_exists(self, name: str) -> bool:
import psycopg2.errors

ident = sg.to_identifier(name, quoted=self.compiler.quoted)
sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)

try:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except psycopg2.errors.UndefinedTable:
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
from psycopg2.extras import execute_batch

Expand All @@ -99,54 +114,37 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
f"got null typed columns: {null_columns}"
)

# only register if we haven't already done so
if (name := op.name) not in self.list_tables():
quoted = self.compiler.quoted
column_defs = [
sg.exp.ColumnDef(
this=sg.to_identifier(colname, quoted=quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [
sg.exp.ColumnConstraint(
kind=sg.exp.NotNullColumnConstraint()
)
]
),
)
for colname, typ in schema.items()
]

create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted), expressions=column_defs
),
properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]),
)
create_stmt_sql = create_stmt.sql(self.dialect)

df = op.data.to_frame()
# nan gets compiled into 'NaN'::float which throws errors in non-float columns
# In order to hold NaN values, pandas automatically converts integer columns
# to float columns if there are NaN values in them. Therefore, we need to convert
# them to their original dtypes (that support pd.NA) to figure out which columns
# are actually non-float, then fill the NaN values in those columns with None.
convert_df = df.convert_dtypes()
for col in convert_df.columns:
if not is_float_dtype(convert_df[col]):
df[col] = df[col].replace(float("nan"), None)

data = df.itertuples(index=False)
sql = self._build_insert_template(
name, schema=schema, columns=True, placeholder="%s"
)
name = op.name
quoted = self.compiler.quoted
create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted),
expressions=schema.to_sqlglot(self.dialect),
),
properties=sg.exp.Properties(expressions=[sge.TemporaryProperty()]),
)
create_stmt_sql = create_stmt.sql(self.dialect)

df = op.data.to_frame()
# nan gets compiled into 'NaN'::float which throws errors in non-float columns
# In order to hold NaN values, pandas automatically converts integer columns
# to float columns if there are NaN values in them. Therefore, we need to convert
# them to their original dtypes (that support pd.NA) to figure out which columns
# are actually non-float, then fill the NaN values in those columns with None.
convert_df = df.convert_dtypes()
for col in convert_df.columns:
if not is_float_dtype(convert_df[col]):
df[col] = df[col].replace(float("nan"), None)

data = df.itertuples(index=False)
sql = self._build_insert_template(
name, schema=schema, columns=True, placeholder="%s"
)

with self.begin() as cur:
cur.execute(create_stmt_sql)
execute_batch(cur, sql, data, 128)
with self.begin() as cur:
cur.execute(create_stmt_sql)
execute_batch(cur, sql, data, 128)

@contextlib.contextmanager
def begin(self):
Expand Down Expand Up @@ -225,11 +223,10 @@ def do_connect(
Examples
--------
>>> import os
>>> import getpass
>>> import ibis
>>> host = os.environ.get("IBIS_TEST_POSTGRES_HOST", "localhost")
>>> user = os.environ.get("IBIS_TEST_POSTGRES_USER", getpass.getuser())
>>> password = os.environ.get("IBIS_TEST_POSTGRES_PASSWORD")
>>> user = os.environ.get("IBIS_TEST_POSTGRES_USER", "postgres")
>>> password = os.environ.get("IBIS_TEST_POSTGRES_PASSWORD", "postgres")
>>> database = os.environ.get("IBIS_TEST_POSTGRES_DATABASE", "ibis_testing")
>>> con = ibis.postgres.connect(database=database, host=host, user=user, password=password)
>>> con.list_tables() # doctest: +ELLIPSIS
Expand Down Expand Up @@ -672,26 +669,18 @@ def create_table(
else:
query = None

column_defs = [
sge.ColumnDef(
this=sg.to_identifier(colname, quoted=self.compiler.quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())]
),
)
for colname, typ in (schema or table.schema()).items()
]

if overwrite:
temp_name = util.gen_name(f"{self.name}_table")
else:
temp_name = name

table = sg.table(temp_name, db=database, quoted=self.compiler.quoted)
target = sge.Schema(this=table, expressions=column_defs)
if not schema:
schema = table.schema()

table_expr = sg.table(temp_name, db=database, quoted=self.compiler.quoted)
target = sge.Schema(
this=table_expr, expressions=schema.to_sqlglot(self.dialect)
)

create_stmt = sge.Create(
kind="TABLE",
Expand All @@ -702,15 +691,17 @@ def create_table(
this = sg.table(name, catalog=database, quoted=self.compiler.quoted)
with self._safe_raw_sql(create_stmt) as cur:
if query is not None:
insert_stmt = sge.Insert(this=table, expression=query).sql(self.dialect)
insert_stmt = sge.Insert(this=table_expr, expression=query).sql(
self.dialect
)
cur.execute(insert_stmt)

if overwrite:
cur.execute(
sge.Drop(kind="TABLE", this=this, exists=True).sql(self.dialect)
)
cur.execute(
f"ALTER TABLE IF EXISTS {table.sql(self.dialect)} RENAME TO {this.sql(self.dialect)}"
f"ALTER TABLE IF EXISTS {table_expr.sql(self.dialect)} RENAME TO {this.sql(self.dialect)}"
)

if schema is None:
Expand Down
22 changes: 6 additions & 16 deletions ibis/backends/postgres/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ def test_not_exists(alltypes, df):
t = alltypes
t2 = t.view()

expr = t[~((t.string_col == t2.string_col).any())]
expr = t.filter(~((t.string_col == t2.string_col).any()))
result = expr.execute()

left, right = df, t2.execute()
Expand Down Expand Up @@ -855,7 +855,7 @@ def test_window_with_arithmetic(alltypes, df):

def test_anonymous_aggregate(alltypes, df):
t = alltypes
expr = t[t.double_col > t.double_col.mean()]
expr = t.filter(t.double_col > t.double_col.mean())
result = expr.execute()
expected = df[df.double_col > df.double_col.mean()].reset_index(drop=True)
tm.assert_frame_equal(result, expected)
Expand Down Expand Up @@ -908,7 +908,7 @@ def test_array_collect(array_types):

@pytest.mark.parametrize("index", [0, 1, 3, 4, 11, -1, -3, -4, -11])
def test_array_index(array_types, index):
expr = array_types[array_types.y[index].name("indexed")]
expr = array_types.select(array_types.y[index].name("indexed"))
result = expr.execute()
expected = pd.DataFrame(
{
Expand Down Expand Up @@ -1011,13 +1011,11 @@ def test_analytic_functions(alltypes, assert_sql):
assert_sql(expr)


@pytest.mark.parametrize("opname", ["invert", "neg"])
def test_not_and_negate_bool(con, opname, df):
op = getattr(operator, opname)
def test_invert_bool(con, df):
t = con.table("functional_alltypes").limit(10)
expr = t.select(op(t.bool_col).name("bool_col"))
expr = t.select((~t.bool_col).name("bool_col"))
result = expr.execute().bool_col
expected = op(df.head(10).bool_col)
expected = ~df.head(10).bool_col
tm.assert_series_equal(result, expected)


Expand All @@ -1042,14 +1040,6 @@ def test_negate_non_boolean(con, field, df):
tm.assert_series_equal(result, expected)


def test_negate_boolean(con, df):
t = con.table("functional_alltypes").limit(10)
expr = t.select((-t.bool_col).name("bool_col"))
result = expr.execute().bool_col
expected = -df.head(10).bool_col
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("opname", ["sum", "mean", "min", "max", "std", "var"])
def test_boolean_reduction(alltypes, opname, df):
op = operator.methodcaller(opname)
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/postgres/tests/test_geospatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def test_get_point(geotable, expr_fn, expected):
# boundaries with the contains predicate. Work around this by adding a
# small buffer.
expr = geotable["geo_linestring"].buffer(0.01).contains(arg)
result = geotable[geotable, expr.name("tmp")].execute()["tmp"]
result = geotable.select(geotable, expr.name("tmp")).execute()["tmp"]
testing.assert_almost_equal(result, expected, decimal=2)


Expand All @@ -257,7 +257,7 @@ def test_area(con, geotable):
)
def test_srid(geotable, condition, expected):
"""Testing for geo spatial srid operation."""
expr = geotable[geotable.id, condition(geotable).name("tmp")]
expr = geotable.select(geotable.id, condition(geotable).name("tmp"))
result = expr.execute()["tmp"][[0]]
assert np.all(result == expected)

Expand All @@ -275,7 +275,7 @@ def test_srid(geotable, condition, expected):
)
def test_set_srid(geotable, condition, expected):
"""Testing for geo spatial set_srid operation."""
expr = geotable[geotable.id, condition(geotable).name("tmp")]
expr = geotable.select(geotable.id, condition(geotable).name("tmp"))
result = expr.execute()["tmp"][[0]]
assert np.all(result == expected)

Expand Down Expand Up @@ -305,7 +305,7 @@ def test_set_srid(geotable, condition, expected):
)
def test_transform(geotable, condition, expected):
"""Testing for geo spatial transform operation."""
expr = geotable[geotable.id, condition(geotable).name("tmp")]
expr = geotable.select(geotable.id, condition(geotable).name("tmp"))
result = expr.execute()["tmp"][[0]]
assert np.all(result == expected)

Expand All @@ -325,7 +325,7 @@ def test_transform(geotable, condition, expected):
def test_cast_geography(geotable, expr_fn):
"""Testing for geo spatial transform operation."""
p = expr_fn(geotable).cast("geography")
expr = geotable[geotable.id, p.distance(p).name("tmp")]
expr = geotable.select(geotable.id, p.distance(p).name("tmp"))
result = expr.execute()["tmp"][[0]]
# distance from a point to a same point should be 0
assert np.all(result == 0)
Expand All @@ -346,7 +346,7 @@ def test_cast_geography(geotable, expr_fn):
def test_cast_geometry(geotable, expr_fn):
"""Testing for geo spatial transform operation."""
p = expr_fn(geotable).cast("geometry")
expr = geotable[geotable.id, p.distance(p).name("tmp")]
expr = geotable.select(geotable.id, p.distance(p).name("tmp"))
result = expr.execute()["tmp"][[0]]
# distance from a point to a same point should be 0
assert np.all(result == 0)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def jsonb_t(con):
@pytest.mark.parametrize("data", [param({"status": True}, id="status")])
def test_json(data, alltypes):
lit = ibis.literal(json.dumps(data), type="json").name("tmp")
expr = alltypes[[alltypes.id, lit]].head(1)
expr = alltypes.select(alltypes.id, lit).head(1)
df = expr.execute()
assert df["tmp"].iloc[0] == data

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/test_postgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_load_geodata(con):


def test_empty_select(geotable):
expr = geotable[geotable.geo_point.geo_equals(geotable.geo_linestring)]
expr = geotable.filter(geotable.geo_point.geo_equals(geotable.geo_linestring))
result = expr.execute()
assert len(result) == 0

Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/postgres/tests/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
@pytest.mark.usefixtures("con")
def test_special_strings(alltypes, data, data_type):
lit = ibis.literal(data, type=data_type).name("tmp")
expr = alltypes[[alltypes.id, lit]].head(1)
expr = alltypes.select(alltypes.id, lit).head(1)
df = expr.execute()
assert df["tmp"].iloc[0] == uuid.UUID(data)
8 changes: 6 additions & 2 deletions ibis/backends/postgres/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,19 @@ def test_existing_sql_udf(con_for_udf, test_database, table):
"""Test creating ibis UDF object based on existing UDF in the database."""
# Create ibis UDF objects referring to UDFs already created in the database
custom_length_udf = con_for_udf.function("custom_len", database=test_database)
result_obj = table[table, custom_length_udf(table["user_name"]).name("custom_len")]
result_obj = table.select(
table, custom_length_udf(table["user_name"]).name("custom_len")
)
result = result_obj.execute()
assert result["custom_len"].sum() == result["name_length"].sum()


def test_existing_plpython_udf(con_for_udf, test_database, table):
# Create ibis UDF objects referring to UDFs already created in the database
py_length_udf = con_for_udf.function("pylen", database=test_database)
result_obj = table[table, py_length_udf(table["user_name"]).name("custom_len")]
result_obj = table.select(
table, py_length_udf(table["user_name"]).name("custom_len")
)
result = result_obj.execute()
assert result["custom_len"].sum() == result["name_length"].sum()

Expand Down
16 changes: 8 additions & 8 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,17 @@ def _register_udfs(self, expr: ir.Expr) -> None:
self._session.udf.register(f"unwrap_json_{typ.__name__}", unwrap_json(typ))
self._session.udf.register("unwrap_json_float", unwrap_json_float)

def _in_memory_table_exists(self, name: str) -> bool:
sql = f"SHOW TABLES IN {self.current_database} LIKE '{name}'"
return bool(self._session.sql(sql).count())

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = PySparkSchema.from_ibis(op.schema)
df = self._session.createDataFrame(data=op.data.to_frame(), schema=schema)
df.createOrReplaceTempView(op.name)
df.createTempView(op.name)

def _finalize_memtable(self, name: str) -> None:
self._session.catalog.dropTempView(name)

@contextlib.contextmanager
def _safe_raw_sql(self, query: str) -> Any:
Expand Down Expand Up @@ -594,13 +601,11 @@ def create_table(
table_loc = self._to_sqlglot_table(database)
catalog, db = self._to_catalog_db_tuple(table_loc)

temp_memtable_view = None
if obj is not None:
if isinstance(obj, ir.Expr):
table = obj
else:
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
query = self.compile(table)
mode = "overwrite" if overwrite else "error"
with self._active_catalog_database(catalog, db):
Expand All @@ -615,11 +620,6 @@ def create_table(
else:
raise com.IbisError("The schema or obj parameter is required")

# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)

return self.table(name, database=(catalog, db))

def create_view(
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/pyspark/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_array_slice_scalar(con, start, stop):

@pytest.mark.parametrize("index", [1, 3, 4, 11, -11])
def test_array_index(t, df, index):
expr = t[t.array_int[index].name("indexed")]
expr = t.select(t.array_int[index].name("indexed"))
result = expr.execute()

expected = pd.DataFrame(
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/pyspark/tests/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,16 @@ def test_insert_validate_types(con, alltypes, test_data_db, temp_table):
database=db,
)

to_insert = expr[
to_insert = expr.select(
expr.tinyint_col, expr.smallint_col.name("int_col"), expr.string_col
]
)
con.insert(temp_table, to_insert.limit(10))

to_insert = expr[
to_insert = expr.select(
expr.tinyint_col,
expr.smallint_col.cast("int32").name("int_col"),
expr.string_col,
]
)
con.insert(temp_table, to_insert.limit(10))


Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pyspark/tests/test_null.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_isnull(con):
table_pandas = table.execute()

for col, _ in table_pandas.items():
result = table[table[col].isnull()].execute().reset_index(drop=True)
result = table.filter(table[col].isnull()).execute().reset_index(drop=True)
expected = table_pandas[table_pandas[col].isnull()].reset_index(drop=True)
tm.assert_frame_equal(result, expected)

Expand All @@ -21,6 +21,6 @@ def test_notnull(con):
table_pandas = table.execute()

for col, _ in table_pandas.items():
result = table[table[col].notnull()].execute().reset_index(drop=True)
result = table.filter(table[col].notnull()).execute().reset_index(drop=True)
expected = table_pandas[table_pandas[col].notnull()].reset_index(drop=True)
tm.assert_frame_equal(result, expected)
160 changes: 68 additions & 92 deletions ibis/backends/risingwave/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,34 +78,36 @@ def do_connect(
Examples
--------
>>> import os
>>> import getpass
>>> import ibis
>>> host = os.environ.get("IBIS_TEST_RISINGWAVE_HOST", "localhost")
>>> user = os.environ.get("IBIS_TEST_RISINGWAVE_USER", getpass.getuser())
>>> password = os.environ.get("IBIS_TEST_RISINGWAVE_PASSWORD")
>>> user = os.environ.get("IBIS_TEST_RISINGWAVE_USER", "root")
>>> password = os.environ.get("IBIS_TEST_RISINGWAVE_PASSWORD", "")
>>> database = os.environ.get("IBIS_TEST_RISINGWAVE_DATABASE", "dev")
>>> con = connect(database=database, host=host, user=user, password=password)
>>> con = ibis.risingwave.connect(
... database=database,
... host=host,
... user=user,
... password=password,
... port=4566,
... )
>>> con.list_tables() # doctest: +ELLIPSIS
[...]
>>> t = con.table("functional_alltypes")
>>> t
RisingWaveTable[table]
name: functional_alltypes
schema:
id : int32
bool_col : boolean
tinyint_col : int16
smallint_col : int16
int_col : int32
bigint_col : int64
float_col : float32
double_col : float64
date_string_col : string
string_col : string
timestamp_col : timestamp
year : int32
month : int32
DatabaseTable: functional_alltypes
id int32
bool_col boolean
tinyint_col int16
smallint_col int16
int_col int32
bigint_col int64
float_col float32
double_col float64
date_string_col string
string_col string
timestamp_col timestamp(6)
year int32
month int32
"""

self.con = psycopg2.connect(
Expand Down Expand Up @@ -195,11 +197,9 @@ def create_table(
f"Creating temp tables is not supported by {self.name}"
)

temp_memtable_view = None
if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
else:
table = obj

Expand All @@ -209,26 +209,18 @@ def create_table(
else:
query = None

column_defs = [
sge.ColumnDef(
this=sg.to_identifier(colname, quoted=self.compiler.quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())]
),
)
for colname, typ in (schema or table.schema()).items()
]

if overwrite:
temp_name = util.gen_name(f"{self.name}_table")
else:
temp_name = name

table = sg.table(temp_name, db=database, quoted=self.compiler.quoted)
target = sge.Schema(this=table, expressions=column_defs)
if not schema:
schema = table.schema()

table_expr = sg.table(temp_name, db=database, quoted=self.compiler.quoted)
target = sge.Schema(
this=table_expr, expressions=schema.to_sqlglot(self.dialect)
)

if connector_properties is None:
create_stmt = sge.Create(
Expand All @@ -251,27 +243,40 @@ def create_table(
this = sg.table(name, db=database, quoted=self.compiler.quoted)
with self._safe_raw_sql(create_stmt) as cur:
if query is not None:
insert_stmt = sge.Insert(this=table, expression=query).sql(self.dialect)
insert_stmt = sge.Insert(this=table_expr, expression=query).sql(
self.dialect
)
cur.execute(insert_stmt)

if overwrite:
self.drop_table(name, database=database, force=True)
cur.execute(
f"ALTER TABLE {table.sql(self.dialect)} RENAME TO {this.sql(self.dialect)}"
f"ALTER TABLE {table_expr.sql(self.dialect)} RENAME TO {this.sql(self.dialect)}"
)

if schema is None:
# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)
return self.table(name, database=database)

# preserve the input schema if it was provided
return ops.DatabaseTable(
name, schema=schema, source=self, namespace=ops.Namespace(database=database)
).to_expr()

def _in_memory_table_exists(self, name: str) -> bool:
import psycopg2.errors

ident = sg.to_identifier(name, quoted=self.compiler.quoted)
sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)

try:
with self.begin() as cur:
cur.execute(sql)
cur.fetchall()
except psycopg2.errors.InternalError:
return False
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
schema = op.schema
if null_columns := [col for col, dtype in schema.items() if dtype.is_null()]:
Expand All @@ -280,42 +285,26 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
f"got null typed columns: {null_columns}"
)

# only register if we haven't already done so
if (name := op.name) not in self.list_tables():
quoted = self.compiler.quoted
column_defs = [
sg.exp.ColumnDef(
this=sg.to_identifier(colname, quoted=quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [
sg.exp.ColumnConstraint(
kind=sg.exp.NotNullColumnConstraint()
)
]
),
)
for colname, typ in schema.items()
]
name = op.name
quoted = self.compiler.quoted

create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted), expressions=column_defs
),
)
create_stmt_sql = create_stmt.sql(self.dialect)
create_stmt = sg.exp.Create(
kind="TABLE",
this=sg.exp.Schema(
this=sg.to_identifier(name, quoted=quoted),
expressions=schema.to_sqlglot(self.dialect),
),
)
create_stmt_sql = create_stmt.sql(self.dialect)

df = op.data.to_frame()
data = df.itertuples(index=False)
sql = self._build_insert_template(
name, schema=schema, columns=True, placeholder="%s"
)
with self.begin() as cur:
cur.execute(create_stmt_sql)
extras.execute_batch(cur, sql, data, 128)
df = op.data.to_frame()
data = df.itertuples(index=False)
sql = self._build_insert_template(
name, schema=schema, columns=True, placeholder="%s"
)
with self.begin() as cur:
cur.execute(create_stmt_sql)
extras.execute_batch(cur, sql, data, 128)

def list_databases(
self, *, like: str | None = None, catalog: str | None = None
Expand Down Expand Up @@ -442,21 +431,8 @@ def create_source(
Table
Table expression
"""
column_defs = [
sge.ColumnDef(
this=sg.to_identifier(colname, quoted=self.compiler.quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())]
),
)
for colname, typ in schema.items()
]

table = sg.table(name, db=database, quoted=self.compiler.quoted)
target = sge.Schema(this=table, expressions=column_defs)
target = sge.Schema(this=table, expressions=schema.to_sqlglot(self.dialect))

create_stmt = sge.Create(
kind="SOURCE",
Expand Down
20 changes: 5 additions & 15 deletions ibis/backends/risingwave/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def test_not_exists(alltypes, df):
t = alltypes
t2 = t.view()

expr = t[~((t.string_col == t2.string_col).any())]
expr = t.filter(~((t.string_col == t2.string_col).any()))
result = expr.execute()

left, right = df, t2.execute()
Expand Down Expand Up @@ -615,7 +615,7 @@ def test_window_with_arithmetic(alltypes, df):

def test_anonymous_aggregate(alltypes, df):
t = alltypes
expr = t[t.double_col > t.double_col.mean()]
expr = t.filter(t.double_col > t.double_col.mean())
result = expr.execute()
expected = df[df.double_col > df.double_col.mean()].reset_index(drop=True)
tm.assert_frame_equal(result, expected)
Expand Down Expand Up @@ -673,13 +673,11 @@ def test_identical_to(con, df):
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("opname", ["invert", "neg"])
def test_not_and_negate_bool(con, opname, df):
op = getattr(operator, opname)
def test_invert_bool(con, df):
t = con.table("functional_alltypes").limit(10)
expr = t.select(op(t.bool_col).name("bool_col"))
expr = t.select((~t.bool_col).name("bool_col"))
result = expr.execute().bool_col
expected = op(df.head(10).bool_col)
expected = ~df.head(10).bool_col
tm.assert_series_equal(result, expected)


Expand All @@ -704,14 +702,6 @@ def test_negate_non_boolean(con, field, df):
tm.assert_series_equal(result, expected)


def test_negate_boolean(con, df):
t = con.table("functional_alltypes").limit(10)
expr = t.select((-t.bool_col).name("bool_col"))
result = expr.execute().bool_col
expected = -df.head(10).bool_col
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("opname", ["sum", "mean", "min", "max", "std", "var"])
def test_boolean_reduction(alltypes, opname, df):
op = operator.methodcaller(opname)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/risingwave/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
@pytest.mark.parametrize("data", [param({"status": True}, id="status")])
def test_json(data, alltypes):
lit = ibis.literal(json.dumps(data), type="json").name("tmp")
expr = alltypes[[alltypes.id, lit]].head(1)
expr = alltypes.select(alltypes.id, lit).head(1)
df = expr.execute()
assert df["tmp"].iloc[0] == data
81 changes: 31 additions & 50 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import itertools
import json
import os
import shutil
import tempfile
import warnings
from operator import itemgetter
Expand Down Expand Up @@ -645,25 +644,37 @@ def list_tables(

return self._filter_with_like(tables + views, like=like)

def _in_memory_table_exists(self, name: str) -> bool:
import snowflake.connector

ident = sg.to_identifier(name, quoted=self.compiler.quoted)
sql = sg.select(sge.convert(1)).from_(ident).limit(0).sql(self.dialect)

try:
with self.con.cursor() as cur:
cur.execute(sql).fetchall()
except snowflake.connector.errors.ProgrammingError as e:
# this cryptic error message is the only generic and reliable way
# to tell if the error means "table not found for any reason"
# otherwise, we need to reraise the exception
if e.sqlstate == "42S02":
return False
raise
else:
return True

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
import pyarrow.parquet as pq

raw_name = op.name
name = op.name
data = op.data.to_pyarrow(schema=op.schema)

with self.con.cursor() as con:
if not con.execute(f"SHOW TABLES LIKE '{raw_name}'").fetchone():
tmpdir = tempfile.TemporaryDirectory()
try:
path = os.path.join(tmpdir.name, f"{raw_name}.parquet")
# optimize for bandwidth so use zstd which typically compresses
# better than the other options without much loss in speed
pq.write_table(
op.data.to_pyarrow(schema=op.schema), path, compression="zstd"
)
self.read_parquet(path, table_name=raw_name)
finally:
with contextlib.suppress(Exception):
shutil.rmtree(tmpdir.name)
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdir:
path = Path(tmpdir, f"{name}.parquet")
# optimize for bandwidth so use zstd which typically compresses
# better than the other options without much loss in speed
pq.write_table(data, path, compression="zstd")
self.read_parquet(path, table_name=name)

def create_catalog(self, name: str, force: bool = False) -> None:
current_catalog = self.current_catalog
Expand Down Expand Up @@ -811,21 +822,10 @@ def create_table(
db = db.name
target = sg.table(name, db=db, catalog=catalog, quoted=quoted)

column_defs = [
sge.ColumnDef(
this=sg.to_identifier(name, quoted=quoted),
kind=self.compiler.type_mapper.from_ibis(typ),
constraints=(
None
if typ.nullable
else [sge.ColumnConstraint(kind=sge.NotNullColumnConstraint())]
),
if schema:
target = sge.Schema(
this=target, expressions=schema.to_sqlglot(self.dialect)
)
for name, typ in (schema or {}).items()
]

if column_defs:
target = sge.Schema(this=target, expressions=column_defs)

properties = []

Expand All @@ -835,11 +835,9 @@ def create_table(
if comment is not None:
properties.append(sge.SchemaCommentProperty(this=sge.convert(comment)))

temp_memtable_view = None
if obj is not None:
if not isinstance(obj, ir.Expr):
table = ibis.memtable(obj)
temp_memtable_view = table.op().name
else:
table = obj

Expand All @@ -860,11 +858,6 @@ def create_table(
with self._safe_raw_sql(create_stmt):
pass

# Clean up temporary memtable if we've created one
# for in-memory reads
if temp_memtable_view is not None:
self.drop_table(temp_memtable_view)

return self.table(name, database=(catalog, db))

def read_csv(
Expand Down Expand Up @@ -1134,19 +1127,7 @@ def read_parquet(
sge.Create(
kind="TABLE",
this=sge.Schema(
this=qtable,
expressions=[
sge.ColumnDef(
this=sg.to_identifier(col, quoted=quoted),
kind=type_mapper.from_ibis(typ),
constraints=(
[sge.NotNullColumnConstraint()]
if not typ.nullable
else None
),
)
for col, typ in schema.items()
],
this=qtable, expressions=schema.to_sqlglot(self.dialect)
),
properties=sge.Properties(expressions=[sge.TemporaryProperty()]),
).sql(self.dialect),
Expand Down
12 changes: 6 additions & 6 deletions ibis/backends/snowflake/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def test_repeated_memtable_registration(simple_con, mocker):
for _ in range(n):
tm.assert_frame_equal(simple_con.execute(t), expected)

# assert that we called _register_in_memory_table exactly n times
assert spy.call_count == n
# assert that we called _register_in_memory_table exactly once
spy.assert_called_once()


def test_timestamp_tz_column(simple_con):
Expand Down Expand Up @@ -312,17 +312,17 @@ def test_compile_does_not_make_requests(con, mocker):
expr = astronauts.year_of_selection.value_counts()
spy = mocker.spy(con.con, "cursor")
assert expr.compile() is not None
assert spy.call_count == 0
spy.assert_not_called()

t = ibis.memtable({"a": [1, 2, 3]})
assert con.compile(t) is not None
assert spy.call_count == 0
spy.assert_not_called()

assert ibis.to_sql(t, dialect="snowflake") is not None
assert spy.call_count == 0
spy.assert_not_called()

assert ibis.to_sql(expr) is not None
assert spy.call_count == 0
spy.assert_not_called()


# this won't be hit in CI, but folks can test locally
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,6 @@ def _register_pandas_udf(self, udf_node: ops.ScalarUDF) -> str:
raise NotImplementedError(
f"pandas UDFs are not supported in the {self.dialect} backend"
)

def _finalize_memtable(self, name: str) -> None:
self.drop_table(name, force=True)
217 changes: 67 additions & 150 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ibis.config import options
from ibis.expr.operations.udf import InputType
from ibis.expr.rewrites import lower_stringslice
from ibis.util import get_subclasses

try:
from sqlglot.expressions import Alter
Expand All @@ -51,15 +52,7 @@ def AlterTable(*args, kind="TABLE", **kwargs):
from ibis.backends.sql.datatypes import SqlglotType


def get_leaf_classes(op):
for child_class in op.__subclasses__():
if not child_class.__subclasses__():
yield child_class
else:
yield from get_leaf_classes(child_class)


ALL_OPERATIONS = frozenset(get_leaf_classes(ops.Node))
ALL_OPERATIONS = frozenset(get_subclasses(ops.Node))


class AggGen:
Expand Down Expand Up @@ -239,20 +232,6 @@ def __getitem__(self, key: str) -> sge.Column:
STAR = sge.Star()


def parenthesize_inputs(f):
"""Decorate a translation rule to parenthesize inputs."""

def wrapper(self, op, *, left, right):
return f(
self,
op,
left=self._add_parens(op.left, left),
right=self._add_parens(op.right, right),
)

return wrapper


@public
class SQLGlotCompiler(abc.ABC):
__slots__ = "f", "v"
Expand Down Expand Up @@ -393,45 +372,50 @@ class SQLGlotCompiler(abc.ABC):
ops.Uppercase: "upper",
}

BINARY_INFIX_OPS = (
# Binary operations
ops.Add,
ops.Subtract,
ops.Multiply,
ops.Divide,
ops.Modulus,
ops.Power,
BINARY_INFIX_OPS = {
# Numeric
ops.Add: (sge.Add, True),
ops.Subtract: (sge.Sub, False),
ops.Multiply: (sge.Mul, True),
ops.Divide: (sge.Div, False),
ops.Modulus: (sge.Mod, False),
ops.Power: (sge.Pow, False),
# Comparisons
ops.GreaterEqual,
ops.Greater,
ops.LessEqual,
ops.Less,
ops.Equals,
ops.NotEquals,
# Boolean comparisons
ops.And,
ops.Or,
ops.Xor,
# Bitwise business
ops.BitwiseLeftShift,
ops.BitwiseRightShift,
ops.BitwiseAnd,
ops.BitwiseOr,
ops.BitwiseXor,
# Time arithmetic
ops.DateAdd,
ops.DateSub,
ops.DateDiff,
ops.TimestampAdd,
ops.TimestampSub,
ops.TimestampDiff,
# Interval Marginalia
ops.IntervalAdd,
ops.IntervalMultiply,
ops.IntervalSubtract,
)
ops.GreaterEqual: (sge.GTE, False),
ops.Greater: (sge.GT, False),
ops.LessEqual: (sge.LTE, False),
ops.Less: (sge.LT, False),
ops.Equals: (sge.EQ, False),
ops.NotEquals: (sge.NEQ, False),
# Logical
ops.And: (sge.And, True),
ops.Or: (sge.Or, True),
ops.Xor: (sge.Xor, True),
# Bitwise
ops.BitwiseLeftShift: (sge.BitwiseLeftShift, False),
ops.BitwiseRightShift: (sge.BitwiseRightShift, False),
ops.BitwiseAnd: (sge.BitwiseAnd, True),
ops.BitwiseOr: (sge.BitwiseOr, True),
ops.BitwiseXor: (sge.BitwiseXor, True),
# Date
ops.DateAdd: (sge.Add, True),
ops.DateSub: (sge.Sub, False),
ops.DateDiff: (sge.Sub, False),
# Time
ops.TimeAdd: (sge.Add, True),
ops.TimeSub: (sge.Sub, False),
ops.TimeDiff: (sge.Sub, False),
# Timestamp
ops.TimestampAdd: (sge.Add, True),
ops.TimestampSub: (sge.Sub, False),
ops.TimestampDiff: (sge.Sub, False),
# Interval
ops.IntervalAdd: (sge.Add, True),
ops.IntervalMultiply: (sge.Mul, True),
ops.IntervalSubtract: (sge.Sub, False),
}

NEEDS_PARENS = BINARY_INFIX_OPS + (ops.IsNull,)
NEEDS_PARENS = tuple(BINARY_INFIX_OPS) + (ops.IsNull,)

# Constructed dynamically in `__init_subclass__` from their respective
# UPPERCASE values to handle inheritance, do not modify directly here.
Expand Down Expand Up @@ -469,6 +453,19 @@ def impl(self, _, *, _name: str = target_name, **kw):
for op, target_name in cls.SIMPLE_OPS.items():
setattr(cls, methodname(op), make_impl(op, target_name))

# Define binary op methods, only if BINARY_INFIX_OPS is set on the
# compiler class.
if binops := cls.__dict__.get("BINARY_INFIX_OPS", {}):

def make_binop(sge_cls, associative):
def impl(self, op, *, left, right):
return self.binop(sge_cls, op, left, right, associative=associative)

return impl

for op, (sge_cls, associative) in binops.items():
setattr(cls, methodname(op), make_binop(sge_cls, associative))

# unconditionally raise an exception for unsupported operations
#
# these *must* be defined after SIMPLE_OPS to handle compilers that
Expand Down Expand Up @@ -685,9 +682,6 @@ def visit_Cast(self, op, *, arg, to):
def visit_ScalarSubquery(self, op, *, rel):
return rel.this.subquery(copy=False)

def visit_Alias(self, op, *, arg, name):
return arg

def visit_Literal(self, op, *, value, dtype):
"""Compile a literal value.
Expand Down Expand Up @@ -1505,93 +1499,16 @@ def visit_SQLQueryResult(self, op, *, query, schema, source):
def visit_RegexExtract(self, op, *, arg, pattern, index):
return self.f.regexp_extract(arg, pattern, index, dialect=self.dialect)

@parenthesize_inputs
def visit_Add(self, op, *, left, right):
return sge.Add(this=left, expression=right)

visit_DateAdd = visit_TimestampAdd = visit_IntervalAdd = visit_Add

@parenthesize_inputs
def visit_Subtract(self, op, *, left, right):
return sge.Sub(this=left, expression=right)

visit_DateSub = visit_DateDiff = visit_TimestampSub = visit_TimestampDiff = (
visit_IntervalSubtract
) = visit_Subtract

@parenthesize_inputs
def visit_Multiply(self, op, *, left, right):
return sge.Mul(this=left, expression=right)

visit_IntervalMultiply = visit_Multiply

@parenthesize_inputs
def visit_Divide(self, op, *, left, right):
return sge.Div(this=left, expression=right)

@parenthesize_inputs
def visit_Modulus(self, op, *, left, right):
return sge.Mod(this=left, expression=right)

@parenthesize_inputs
def visit_Power(self, op, *, left, right):
return sge.Pow(this=left, expression=right)

@parenthesize_inputs
def visit_GreaterEqual(self, op, *, left, right):
return sge.GTE(this=left, expression=right)

@parenthesize_inputs
def visit_Greater(self, op, *, left, right):
return sge.GT(this=left, expression=right)

@parenthesize_inputs
def visit_LessEqual(self, op, *, left, right):
return sge.LTE(this=left, expression=right)

@parenthesize_inputs
def visit_Less(self, op, *, left, right):
return sge.LT(this=left, expression=right)

@parenthesize_inputs
def visit_Equals(self, op, *, left, right):
return sge.EQ(this=left, expression=right)

@parenthesize_inputs
def visit_NotEquals(self, op, *, left, right):
return sge.NEQ(this=left, expression=right)

@parenthesize_inputs
def visit_And(self, op, *, left, right):
return sge.And(this=left, expression=right)

@parenthesize_inputs
def visit_Or(self, op, *, left, right):
return sge.Or(this=left, expression=right)

@parenthesize_inputs
def visit_Xor(self, op, *, left, right):
return sge.Xor(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseLeftShift(self, op, *, left, right):
return sge.BitwiseLeftShift(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseRightShift(self, op, *, left, right):
return sge.BitwiseRightShift(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseAnd(self, op, *, left, right):
return sge.BitwiseAnd(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseOr(self, op, *, left, right):
return sge.BitwiseOr(this=left, expression=right)

@parenthesize_inputs
def visit_BitwiseXor(self, op, *, left, right):
return sge.BitwiseXor(this=left, expression=right)
def binop(self, sg_expr, op, left, right, *, associative=False):
# If the op is associative we can skip parenthesizing ops of the same
# type if they're on the left, since they would evaluate the same.
# SQLGlot has an optimizer for generating long sql chains of the same
# op of this form without recursion, by avoiding parenthesis in this
# common case we can make use of this optimization to handle large
# operator chains.
if not associative or type(op) is not type(op.left):
left = self._add_parens(op.left, left)
return sg_expr(this=left, expression=self._add_parens(op.right, right))

def visit_Undefined(self, op, **_):
raise com.OperationNotDefinedError(
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,14 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
return sg.select(column).from_(parent)

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
self,
op,
*,
parent,
column,
column_name: str,
offset: str | None,
keep_empty: bool,
):
quoted = self.quoted

Expand All @@ -1029,9 +1036,8 @@ def visit_TableUnnest(

table = sg.to_identifier(parent.alias_or_name, quoted=quoted)

opname = op.column.name
overlaps_with_parent = opname in op.parent.schema
computed_column = column_alias.as_(opname, quoted=quoted)
overlaps_with_parent = column_name in op.parent.schema
computed_column = column_alias.as_(column_name, quoted=quoted)

# replace the existing column if the unnested column hasn't been
# renamed
Expand Down
23 changes: 16 additions & 7 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import calendar
import math
from string import whitespace
from typing import TYPE_CHECKING, Any

import sqlglot as sg
Expand Down Expand Up @@ -96,7 +97,6 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.IsInf: "isInfinite",
ops.IsNan: "isNaN",
ops.IsNull: "isNull",
ops.LStrip: "trimLeft",
ops.Ln: "log",
ops.Log10: "log10",
ops.MapKeys: "mapKeys",
Expand All @@ -106,15 +106,13 @@ class ClickHouseCompiler(SQLGlotCompiler):
ops.Median: "quantileExactExclusive",
ops.NotNull: "isNotNull",
ops.NullIf: "nullIf",
ops.RStrip: "trimRight",
ops.RegexReplace: "replaceRegexpAll",
ops.RowNumber: "row_number",
ops.StartsWith: "startsWith",
ops.StrRight: "right",
ops.Strftime: "formatDateTime",
ops.StringLength: "length",
ops.StringReplace: "replaceAll",
ops.Strip: "trimBoth",
ops.TimestampNow: "now",
ops.TypeOf: "toTypeName",
ops.Unnest: "arrayJoin",
Expand Down Expand Up @@ -477,6 +475,11 @@ def visit_Repeat(self, op, *, arg, times):
def visit_StringContains(self, op, haystack, needle):
return self.f.position(haystack, needle) > 0

def visit_Strip(self, op, *, arg):
return sge.Trim(
this=arg, position="BOTH", expression=sge.Literal.string(whitespace)
)

def visit_DayOfWeekIndex(self, op, *, arg):
weekdays = len(calendar.day_name)
return (((self.f.toDayOfWeek(arg) - 1) % weekdays) + weekdays) % weekdays
Expand Down Expand Up @@ -685,7 +688,14 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
return sg.select(column).from_(parent)

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
self,
op,
*,
parent,
column,
column_name: str,
offset: str | None,
keep_empty: bool,
):
quoted = self.quoted

Expand All @@ -697,9 +707,8 @@ def visit_TableUnnest(

selcols = []

opname = op.column.name
overlaps_with_parent = opname in op.parent.schema
computed_column = column_alias.as_(opname, quoted=quoted)
overlaps_with_parent = column_name in op.parent.schema
computed_column = column_alias.as_(column_name, quoted=quoted)

if offset is not None:
if overlaps_with_parent:
Expand Down
Loading