Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyArrow] Fix bug in timestamp pushdown #9377

Merged
merged 1 commit into from Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/function/table/arrow.cpp
Expand Up @@ -266,6 +266,7 @@ unique_ptr<ArrowArrayStreamWrapper> ProduceArrowScan(const ArrowScanFunctionData
auto &schema = *function.schema_root.arrow_schema.children[col_idx];
parameters.projected_columns.projection_map[idx] = schema.name;
parameters.projected_columns.columns.emplace_back(schema.name);
parameters.projected_columns.filter_to_col[idx] = col_idx;
}
}
parameters.filters = filters;
Expand Down
2 changes: 2 additions & 0 deletions src/include/duckdb/function/table/arrow.hpp
Expand Up @@ -33,6 +33,8 @@ struct ArrowInterval {
struct ArrowProjectedColumns {
unordered_map<idx_t, string> projection_map;
vector<string> columns;
// Map from filter index to column index
unordered_map<idx_t, idx_t> filter_to_col;
};

struct ArrowStreamParameters {
Expand Down
10 changes: 6 additions & 4 deletions tools/pythonpkg/src/arrow/arrow_array_stream.cpp
Expand Up @@ -66,11 +66,12 @@ py::object PythonTableArrowArrayStreamFactory::ProduceScanner(py::object &arrow_

auto filters = parameters.filters;
auto &column_list = parameters.projected_columns.columns;
auto &filter_to_col = parameters.projected_columns.filter_to_col;
bool has_filter = filters && !filters->filters.empty();
py::list projection_list = py::cast(column_list);
if (has_filter) {
auto filter =
TransformFilter(*filters, parameters.projected_columns.projection_map, client_properties, arrow_table);
auto filter = TransformFilter(*filters, parameters.projected_columns.projection_map, filter_to_col,
client_properties, arrow_table);
if (column_list.empty()) {
return arrow_scanner(arrow_obj_handle, py::arg("filter") = filter);
} else {
Expand Down Expand Up @@ -176,7 +177,7 @@ string ConvertTimestampUnit(ArrowDateTimeType unit) {
case ArrowDateTimeType::SECONDS:
return "s";
default:
throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit");
throw NotImplementedException("DatetimeType not recognized in ConvertTimestampUnit: %d", (int)unit);
}
}

Expand Down Expand Up @@ -365,12 +366,13 @@ py::object TransformFilterRecursive(TableFilter *filter, const string &column_na

py::object PythonTableArrowArrayStreamFactory::TransformFilter(TableFilterSet &filter_collection,
std::unordered_map<idx_t, string> &columns,
unordered_map<idx_t, idx_t> filter_to_col,
const ClientProperties &config,
const ArrowTableType &arrow_table) {
auto filters_map = &filter_collection.filters;
auto it = filters_map->begin();
D_ASSERT(columns.find(it->first) != columns.end());
auto &arrow_type = *arrow_table.GetColumns().at(it->first);
auto &arrow_type = *arrow_table.GetColumns().at(filter_to_col.at(it->first));
py::object expression =
TransformFilterRecursive(it->second.get(), columns[it->first], config.time_zone, arrow_type);
while (it != filters_map->end()) {
Expand Down
Expand Up @@ -76,6 +76,7 @@ class PythonTableArrowArrayStreamFactory {
private:
//! We transform a TableFilterSet to an Arrow Expression Object
static py::object TransformFilter(TableFilterSet &filters, std::unordered_map<idx_t, string> &columns,
unordered_map<idx_t, idx_t> filter_to_col,
const ClientProperties &client_properties, const ArrowTableType &arrow_table);

static py::object ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle,
Expand Down
26 changes: 26 additions & 0 deletions tools/pythonpkg/tests/fast/arrow/test_filter_pushdown.py
Expand Up @@ -504,6 +504,32 @@ def test_filter_pushdown_integers(self, duckdb_cursor, data_type, value, create_
actual = duckdb_cursor.execute("select * from arrow_table where i = ?", (value,)).fetchall()
assert expected == actual

def test_9371(self, duckdb_cursor, tmp_path):
import datetime
import pathlib

# connect to an in-memory database
duckdb_cursor.execute("SET TimeZone='UTC';")
base_path = tmp_path / "parquet_folder"
base_path.mkdir(exist_ok=True)
file_path = base_path / "test.parquet"

duckdb_cursor.execute("SET TimeZone='UTC';")

# Example data
dt = datetime.datetime(2023, 8, 29, 1, tzinfo=datetime.timezone.utc)

my_arrow_table = pa.Table.from_pydict({'ts': [dt, dt, dt], 'value': [1, 2, 3]})
df = my_arrow_table.to_pandas()
df = df.set_index("ts") # SET INDEX! (It all works correctly when the index is not set)
df.to_parquet(str(file_path))

my_arrow_dataset = ds.dataset(str(file_path))
res = duckdb_cursor.execute("SELECT * FROM my_arrow_dataset WHERE ts = ?", parameters=[dt]).arrow()
output = duckdb_cursor.sql("select * from res").fetchall()
expected = [(1, dt), (2, dt), (3, dt)]
assert output == expected

@pytest.mark.parametrize('create_table', [create_pyarrow_pandas, create_pyarrow_table])
def test_filter_pushdown_date(self, duckdb_cursor, create_table):
duckdb_cursor.execute(
Expand Down