diff --git a/Cargo.toml b/Cargo.toml index a3abe04..9eeabb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,8 +38,8 @@ uuid = { version = "0.8", features = ["v4"] } mimalloc = { version = "*", default-features = false } [lib] -name = "_internal" -crate-type = ["cdylib"] +name = "datafusion_python" +crate-type = ["cdylib", "rlib"] [package.metadata.maturin] name = "datafusion._internal" diff --git a/datafusion/tests/conftest.py b/datafusion/tests/conftest.py new file mode 100644 index 0000000..ab25508 --- /dev/null +++ b/datafusion/tests/conftest.py @@ -0,0 +1,33 @@ +import pytest +from datafusion import ExecutionContext +import pyarrow as pa + + +@pytest.fixture +def ctx(): + return ExecutionContext() + + +@pytest.fixture +def database(ctx, tmp_path): + path = tmp_path / "test.csv" + + table = pa.Table.from_arrays( + [ + [1, 2, 3, 4], + ["a", "b", "c", "d"], + [1.1, 2.2, 3.3, 4.4], + ], + names=["int", "str", "float"], + ) + pa.csv.write_csv(table, path) + + ctx.register_csv("csv", path) + ctx.register_csv("csv1", str(path)) + ctx.register_csv( + "csv2", + path, + has_header=True, + delimiter=",", + schema_infer_max_records=10, + ) diff --git a/datafusion/tests/test_catalog.py b/datafusion/tests/test_catalog.py index 2e64a81..a9bdf72 100644 --- a/datafusion/tests/test_catalog.py +++ b/datafusion/tests/test_catalog.py @@ -18,38 +18,6 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext - - -@pytest.fixture -def ctx(): - return ExecutionContext() - - -@pytest.fixture -def database(ctx, tmp_path): - path = tmp_path / "test.csv" - - table = pa.Table.from_arrays( - [ - [1, 2, 3, 4], - ["a", "b", "c", "d"], - [1.1, 2.2, 3.3, 4.4], - ], - names=["int", "str", "float"], - ) - pa.csv.write_csv(table, path) - - ctx.register_csv("csv", path) - ctx.register_csv("csv1", str(path)) - ctx.register_csv( - "csv2", - path, - has_header=True, - delimiter=",", - schema_infer_max_records=10, - ) - def test_basic(ctx, database): with pytest.raises(KeyError): diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py index 60beea4..4d4a38c 100644 --- a/datafusion/tests/test_context.py +++ b/datafusion/tests/test_context.py @@ -16,14 +16,6 @@ # under the License. import pyarrow as pa -import pytest - -from datafusion import ExecutionContext - - -@pytest.fixture -def ctx(): - return ExecutionContext() def test_register_record_batches(ctx): @@ -61,3 +53,22 @@ def test_create_dataframe_registers_unique_table_name(ctx): # only hexadecimal numbers for c in tables[0][1:]: assert c in "0123456789abcdef" + + +def test_register_table(ctx, database): + default = ctx.catalog() + public = default.database("public") + assert public.names() == {"csv", "csv1", "csv2"} + table = public.table("csv") + + ctx.register_table("csv3", table) + assert public.names() == {"csv", "csv1", "csv2", "csv3"} + + +def test_deregister_table(ctx, database): + default = ctx.catalog() + public = default.database("public") + assert public.names() == {"csv", "csv1", "csv2"} + + ctx.deregister_table("csv") + assert public.names() == {"csv1", "csv2"} diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py index 23f2007..cde5425 100644 --- a/datafusion/tests/test_sql.py +++ b/datafusion/tests/test_sql.py @@ -19,16 +19,11 @@ import pyarrow as pa import pytest -from datafusion import ExecutionContext, udf +from datafusion import udf from . import generic as helpers -@pytest.fixture -def ctx(): - return ExecutionContext() - - def test_no_table(ctx): with pytest.raises(Exception, match="DataFusion error"): ctx.sql("SELECT a FROM b").collect() diff --git a/src/catalog.rs b/src/catalog.rs index f93c795..d7a6b8a 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -38,7 +38,7 @@ pub(crate) struct PyDatabase { } #[pyclass(name = "Table", module = "datafusion", subclass)] -pub(crate) struct PyTable { +pub struct PyTable { table: Arc, } @@ -58,6 +58,10 @@ impl PyTable { pub fn new(table: Arc) -> Self { Self { table } } + + pub fn table(&self) -> Arc { + self.table.clone() + } } #[pymethods] diff --git a/src/context.rs b/src/context.rs index ebd893c..274005e 100644 --- a/src/context.rs +++ b/src/context.rs @@ -29,7 +29,7 @@ use datafusion::datasource::MemTable; use datafusion::execution::context::ExecutionContext; use datafusion::prelude::CsvReadOptions; -use crate::catalog::PyCatalog; +use crate::catalog::{PyCatalog, PyTable}; use crate::dataframe::PyDataFrame; use crate::errors::DataFusionError; use crate::udf::PyScalarUDF; @@ -80,6 +80,20 @@ impl PyExecutionContext { Ok(df) } + fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> { + self.ctx + .register_table(name, table.table()) + .map_err(DataFusionError::from)?; + Ok(()) + } + + fn deregister_table(&mut self, name: &str) -> PyResult<()> { + self.ctx + .deregister_table(name) + .map_err(DataFusionError::from)?; + Ok(()) + } + fn register_record_batches( &mut self, name: &str, diff --git a/src/lib.rs b/src/lib.rs index ab528a1..977d9e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,15 +18,15 @@ use mimalloc::MiMalloc; use pyo3::prelude::*; -mod catalog; +pub mod catalog; mod context; mod dataframe; -mod errors; +pub mod errors; mod expression; mod functions; mod udaf; mod udf; -mod utils; +pub mod utils; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; diff --git a/src/utils.rs b/src/utils.rs index c8e1c63..7e7d0a1 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -25,7 +25,7 @@ use datafusion::physical_plan::functions::Volatility; use crate::errors::DataFusionError; /// Utility to collect rust futures with GIL released -pub(crate) fn wait_for_future(py: Python, f: F) -> F::Output +pub fn wait_for_future(py: Python, f: F) -> F::Output where F: Send, F::Output: Send,