Skip to content

Commit

Permalink
feat(duckdb): warn when querying an already consumed RecordBatchReader
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed May 3, 2023
1 parent e8a065c commit 5a013ff
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 16 deletions.
10 changes: 6 additions & 4 deletions ibis/backends/base/sql/__init__.py
Expand Up @@ -148,7 +148,7 @@ def _cursor_batches(
limit: int | str | None = None,
chunk_size: int = 1_000_000,
) -> Iterable[list]:
self._register_in_memory_tables(expr)
self._run_pre_execute_hooks(expr)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()

Expand Down Expand Up @@ -246,9 +246,7 @@ def execute(

schema = self.ast_schema(query_ast, **kwargs)

# register all in memory tables if the backend supports cheap access
# to them
self._register_in_memory_tables(expr)
self._run_pre_execute_hooks(expr)

with self._safe_raw_sql(sql, **kwargs) as cursor:
result = self.fetch_from_cursor(cursor, schema)
Expand All @@ -266,6 +264,10 @@ def _register_in_memory_tables(self, expr: ir.Expr) -> None:
for memtable in an.find_memtables(expr.op()):
self._register_in_memory_table(memtable)

def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:
"""Backend-specific hooks to run before an expression is executed."""
self._register_in_memory_tables(expr)

@abc.abstractmethod
def fetch_from_cursor(self, cursor, schema):
"""Fetch data from cursor."""
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/alchemy/__init__.py
Expand Up @@ -278,7 +278,7 @@ def create_table(
if has_expr := obj is not None:
# this has to happen outside the `begin` block, so that in-memory
# tables are visible inside the transaction created by it
self._register_in_memory_tables(obj)
self._run_pre_execute_hooks(obj)

table = self._table_from_schema(
name, schema, database=database or self.current_database, temp=temp
Expand Down
44 changes: 35 additions & 9 deletions ibis/backends/duckdb/__init__.py
Expand Up @@ -125,6 +125,7 @@ def configure_connection(dbapi_connection, connection_record):
# the progress bar causes kernel crashes in jupyterlab ¯\_(ツ)_/¯
dbapi_connection.execute("SET enable_progress_bar = false")

self._record_batch_readers_consumed = {}
super().do_connect(engine)

def _load_extensions(self, extensions):
Expand Down Expand Up @@ -395,13 +396,15 @@ def _read_parquet_pyarrow_dataset(
con.connection.register(table_name, dataset)

def read_in_memory(
self, dataframe: pd.DataFrame | pa.Table, table_name: str | None = None
self,
source: pd.DataFrame | pa.Table | pa.RecordBatchReader,
table_name: str | None = None,
) -> ir.Table:
"""Register a Pandas DataFrame or pyarrow Table as a table in the current database.
"""Register a Pandas DataFrame or pyarrow object as a table in the current database.
Parameters
----------
dataframe
source
The data source.
table_name
An optional name to use for the created table. This defaults to
Expand All @@ -414,7 +417,12 @@ def read_in_memory(
"""
table_name = table_name or util.gen_name("read_in_memory")
with self.begin() as con:
con.connection.register(table_name, dataframe)
con.connection.register(table_name, source)

if isinstance(source, pa.RecordBatchReader):
# Ensure the reader isn't marked as started, in case the name is
# being overwritten.
self._record_batch_readers_consumed[table_name] = False

return self.table(table_name)

Expand Down Expand Up @@ -534,6 +542,26 @@ def attach_sqlite(
sa.text(f"CALL sqlite_attach('{str(path)}', overwrite={overwrite})")
)

def _run_pre_execute_hooks(self, expr: ir.Expr) -> None:
from ibis.expr.analysis import find_physical_tables

# Warn for any tables depending on RecordBatchReaders that have already
# started being consumed.
for t in find_physical_tables(expr.op()):
started = self._record_batch_readers_consumed.get(t.name)
if started is True:
warnings.warn(
f"Table {t.name!r} is backed by a `pyarrow.RecordBatchReader` "
"that has already been partially consumed. This may lead to "
"unexpected results. Either recreate the table from a new "
"`pyarrow.RecordBatchReader`, or use `Table.cache()`/"
"`con.create_table()` to consume and store the results in "
"the backend to reuse later."
)
elif started is False:
self._record_batch_readers_consumed[t.name] = True
super()._run_pre_execute_hooks(expr)

def to_pyarrow_batches(
self,
expr: ir.Expr,
Expand All @@ -542,7 +570,7 @@ def to_pyarrow_batches(
limit: int | str | None = None,
chunk_size: int = 1_000_000,
**_: Any,
) -> pa.ipc.RecordBatchReader:
) -> pa.RecordBatchReader:
"""Return a stream of record batches.
The returned `RecordBatchReader` contains a cursor with an unbounded lifetime.
Expand All @@ -561,7 +589,7 @@ def to_pyarrow_batches(
chunk_size
!!! warning "DuckDB returns 1024 size batches regardless of what argument is passed."
"""
self._register_in_memory_tables(expr)
self._run_pre_execute_hooks(expr)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()

Expand All @@ -584,9 +612,7 @@ def to_pyarrow(
limit: int | str | None = None,
**_: Any,
) -> pa.Table:
pa = self._import_pyarrow()

self._register_in_memory_tables(expr)
self._run_pre_execute_hooks(expr)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()

Expand Down
26 changes: 26 additions & 0 deletions ibis/backends/duckdb/tests/test_register.py
Expand Up @@ -288,3 +288,29 @@ def test_register_numpy_str(con):
data = pd.DataFrame({"a": [np.str_("xyz"), None]})
result = con.read_in_memory(data)
tm.assert_frame_equal(result.execute(), data)


def test_register_recordbatchreader_warns(con):
table = pa.Table.from_batches(
[
pa.RecordBatch.from_pydict({"x": [1, 2]}),
pa.RecordBatch.from_pydict({"x": [3, 4]}),
]
)
reader = table.to_reader()
sol = table.to_pandas()
t = con.read_in_memory(reader)

# First execute is fine
res = t.execute()
tm.assert_frame_equal(res, sol)

# Later executes warn
with pytest.warns(UserWarning, match="RecordBatchReader"):
t.limit(2).execute()

# Re-registering over the name with a new reader is fine
reader = table.to_reader()
t = con.read_in_memory(reader, table_name=t.get_name())
res = t.execute()
tm.assert_frame_equal(res, sol)
4 changes: 2 additions & 2 deletions ibis/backends/snowflake/__init__.py
Expand Up @@ -236,7 +236,7 @@ def to_pyarrow(

import pyarrow as pa

self._register_in_memory_tables(expr)
self._run_pre_execute_hooks(expr)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
with self.begin() as con:
Expand Down Expand Up @@ -282,7 +282,7 @@ def to_pyarrow_batches(

import pyarrow as pa

self._register_in_memory_tables(expr)
self._run_pre_execute_hooks(expr)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
target_schema = expr.as_table().schema().to_pyarrow()
Expand Down
12 changes: 12 additions & 0 deletions ibis/expr/analysis.py
Expand Up @@ -118,6 +118,18 @@ def reduction_to_aggregation(node):
return agg


def find_physical_tables(node):
"""Find every first occurrence of a `ir.PhysicalTable` object in `node`."""

def finder(node):
if isinstance(node, ops.PhysicalTable):
return g.halt, node
else:
return g.proceed, None

return list(toolz.unique(g.traverse(finder, node)))


def find_immediate_parent_tables(input_node, keep_input=True):
"""Find every first occurrence of a `ir.Table` object in `input_node`.
Expand Down

0 comments on commit 5a013ff

Please sign in to comment.