Skip to content
This repository has been archived by the owner on Jul 25, 2022. It is now read-only.

Commit

Permalink
#45 Add register_table/deregister_table and expose some public mod (#46)
Browse files Browse the repository at this point in the history
* Add register_table and deregister_table

* expose public module and method for PyTable inheritant
  • Loading branch information
jychen7 committed Apr 4, 2022
1 parent bfd67e3 commit 0fe08fe
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 54 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ uuid = { version = "0.8", features = ["v4"] }
mimalloc = { version = "*", default-features = false }

[lib]
name = "_internal"
crate-type = ["cdylib"]
name = "datafusion_python"
crate-type = ["cdylib", "rlib"]

[package.metadata.maturin]
name = "datafusion._internal"
Expand Down
33 changes: 33 additions & 0 deletions datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from datafusion import ExecutionContext
import pyarrow as pa


@pytest.fixture
def ctx():
return ExecutionContext()


@pytest.fixture
def database(ctx, tmp_path):
path = tmp_path / "test.csv"

table = pa.Table.from_arrays(
[
[1, 2, 3, 4],
["a", "b", "c", "d"],
[1.1, 2.2, 3.3, 4.4],
],
names=["int", "str", "float"],
)
pa.csv.write_csv(table, path)

ctx.register_csv("csv", path)
ctx.register_csv("csv1", str(path))
ctx.register_csv(
"csv2",
path,
has_header=True,
delimiter=",",
schema_infer_max_records=10,
)
32 changes: 0 additions & 32 deletions datafusion/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,6 @@
import pyarrow as pa
import pytest

from datafusion import ExecutionContext


@pytest.fixture
def ctx():
return ExecutionContext()


@pytest.fixture
def database(ctx, tmp_path):
path = tmp_path / "test.csv"

table = pa.Table.from_arrays(
[
[1, 2, 3, 4],
["a", "b", "c", "d"],
[1.1, 2.2, 3.3, 4.4],
],
names=["int", "str", "float"],
)
pa.csv.write_csv(table, path)

ctx.register_csv("csv", path)
ctx.register_csv("csv1", str(path))
ctx.register_csv(
"csv2",
path,
has_header=True,
delimiter=",",
schema_infer_max_records=10,
)


def test_basic(ctx, database):
with pytest.raises(KeyError):
Expand Down
27 changes: 19 additions & 8 deletions datafusion/tests/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@
# under the License.

import pyarrow as pa
import pytest

from datafusion import ExecutionContext


@pytest.fixture
def ctx():
return ExecutionContext()


def test_register_record_batches(ctx):
Expand Down Expand Up @@ -61,3 +53,22 @@ def test_create_dataframe_registers_unique_table_name(ctx):
# only hexadecimal numbers
for c in tables[0][1:]:
assert c in "0123456789abcdef"


def test_register_table(ctx, database):
default = ctx.catalog()
public = default.database("public")
assert public.names() == {"csv", "csv1", "csv2"}
table = public.table("csv")

ctx.register_table("csv3", table)
assert public.names() == {"csv", "csv1", "csv2", "csv3"}


def test_deregister_table(ctx, database):
default = ctx.catalog()
public = default.database("public")
assert public.names() == {"csv", "csv1", "csv2"}

ctx.deregister_table("csv")
assert public.names() == {"csv1", "csv2"}
7 changes: 1 addition & 6 deletions datafusion/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,11 @@
import pyarrow as pa
import pytest

from datafusion import ExecutionContext, udf
from datafusion import udf

from . import generic as helpers


@pytest.fixture
def ctx():
return ExecutionContext()


def test_no_table(ctx):
with pytest.raises(Exception, match="DataFusion error"):
ctx.sql("SELECT a FROM b").collect()
Expand Down
6 changes: 5 additions & 1 deletion src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub(crate) struct PyDatabase {
}

#[pyclass(name = "Table", module = "datafusion", subclass)]
pub(crate) struct PyTable {
pub struct PyTable {
table: Arc<dyn TableProvider>,
}

Expand All @@ -58,6 +58,10 @@ impl PyTable {
pub fn new(table: Arc<dyn TableProvider>) -> Self {
Self { table }
}

pub fn table(&self) -> Arc<dyn TableProvider> {
self.table.clone()
}
}

#[pymethods]
Expand Down
16 changes: 15 additions & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion::datasource::MemTable;
use datafusion::execution::context::ExecutionContext;
use datafusion::prelude::CsvReadOptions;

use crate::catalog::PyCatalog;
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
use crate::errors::DataFusionError;
use crate::udf::PyScalarUDF;
Expand Down Expand Up @@ -80,6 +80,20 @@ impl PyExecutionContext {
Ok(df)
}

fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> {
self.ctx
.register_table(name, table.table())
.map_err(DataFusionError::from)?;
Ok(())
}

fn deregister_table(&mut self, name: &str) -> PyResult<()> {
self.ctx
.deregister_table(name)
.map_err(DataFusionError::from)?;
Ok(())
}

fn register_record_batches(
&mut self,
name: &str,
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
use mimalloc::MiMalloc;
use pyo3::prelude::*;

mod catalog;
pub mod catalog;
mod context;
mod dataframe;
mod errors;
pub mod errors;
mod expression;
mod functions;
mod udaf;
mod udf;
mod utils;
pub mod utils;

#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;
Expand Down
2 changes: 1 addition & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion::physical_plan::functions::Volatility;
use crate::errors::DataFusionError;

/// Utility to collect rust futures with GIL released
pub(crate) fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
pub fn wait_for_future<F: Future>(py: Python, f: F) -> F::Output
where
F: Send,
F::Output: Send,
Expand Down

0 comments on commit 0fe08fe

Please sign in to comment.