Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more MultiFilereader features/hooks #11984

Merged
merged 7 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extension/json/json_functions/read_json.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ static void ReadJSONFunction(ClientContext &context, TableFunctionInput &data_p,
}

if (output.size() != 0) {
MultiFileReader().FinalizeChunk(context, gstate.bind_data.reader_bind, lstate.GetReaderData(), output);
MultiFileReader().FinalizeChunk(context, gstate.bind_data.reader_bind, lstate.GetReaderData(), output, nullptr);
}
}

Expand Down
2 changes: 1 addition & 1 deletion extension/json/json_functions/read_json_objects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ static void ReadJSONObjectsFunction(ClientContext &context, TableFunctionInput &
output.SetCardinality(count);

if (output.size() != 0) {
MultiFileReader().FinalizeChunk(context, gstate.bind_data.reader_bind, lstate.GetReaderData(), output);
MultiFileReader().FinalizeChunk(context, gstate.bind_data.reader_bind, lstate.GetReaderData(), output, nullptr);
}
}

Expand Down
2 changes: 1 addition & 1 deletion extension/json/json_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ unique_ptr<GlobalTableFunctionState> JSONGlobalTableFunctionState::Init(ClientCo
for (auto &reader : gstate.json_readers) {
MultiFileReader().FinalizeBind(reader->GetOptions().file_options, gstate.bind_data.reader_bind,
reader->GetFileName(), gstate.names, dummy_types, bind_data.names,
input.column_ids, reader->reader_data, context);
input.column_ids, reader->reader_data, context, nullptr);
}

return std::move(result);
Expand Down
67 changes: 46 additions & 21 deletions extension/parquet/parquet_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct ParquetReadLocalState : public LocalTableFunctionState {
bool is_parallel;
idx_t batch_index;
idx_t file_index;
//! The DataChunk containing all read columns (even filter columns that are immediately removed)
//! The DataChunk containing all read columns (even columns that are immediately removed)
DataChunk all_columns;
};

Expand Down Expand Up @@ -110,6 +110,8 @@ struct ParquetReadGlobalState : public GlobalTableFunctionState {
//! The scan over the file_list
MultiFileListScanData file_list_scan;

unique_ptr<MultiFileReaderGlobalState> multi_file_reader_state;

mutex lock;

//! The current set of parquet readers
Expand All @@ -135,7 +137,7 @@ struct ParquetReadGlobalState : public GlobalTableFunctionState {
return max_threads;
}

bool CanRemoveFilterColumns() const {
bool CanRemoveColumns() const {
return !projection_ids.empty();
}
};
Expand Down Expand Up @@ -211,6 +213,9 @@ static void ParseFileRowNumberOption(MultiFileReaderBindData &bind_data, Parquet
static MultiFileReaderBindData BindSchema(ClientContext &context, vector<LogicalType> &return_types,
vector<string> &names, ParquetReadBindData &result, ParquetOptions &options) {
D_ASSERT(!options.schema.empty());

options.file_options.AutoDetectHivePartitioning(*result.file_list, context);

auto &file_options = options.file_options;
if (file_options.union_by_name || file_options.hive_partitioning) {
throw BinderException("Parquet schema cannot be combined with union_by_name=true or hive_partitioning=true");
Expand Down Expand Up @@ -241,13 +246,18 @@ static MultiFileReaderBindData BindSchema(ClientContext &context, vector<Logical

static void InitializeParquetReader(ParquetReader &reader, const ParquetReadBindData &bind_data,
const vector<column_t> &global_column_ids,
optional_ptr<TableFilterSet> table_filters, ClientContext &context) {
optional_ptr<TableFilterSet> table_filters, ClientContext &context, idx_t file_idx,
optional_ptr<MultiFileReaderGlobalState> reader_state) {
auto &parquet_options = bind_data.parquet_options;
auto &reader_data = reader.reader_data;

// Mark the file in the file list we are scanning here
reader_data.file_list_idx = file_idx;

if (bind_data.parquet_options.schema.empty()) {
bind_data.multi_file_reader->InitializeReader(reader, parquet_options.file_options, bind_data.reader_bind,
bind_data.types, bind_data.names, global_column_ids,
table_filters, bind_data.file_list->GetFirstFile(), context);
bind_data.multi_file_reader->InitializeReader(
reader, parquet_options.file_options, bind_data.reader_bind, bind_data.types, bind_data.names,
global_column_ids, table_filters, bind_data.file_list->GetFirstFile(), context, reader_state);
return;
}

Expand All @@ -256,7 +266,7 @@ static void InitializeParquetReader(ParquetReader &reader, const ParquetReadBind
// this deals with hive partitioning and filename=true
bind_data.multi_file_reader->FinalizeBind(parquet_options.file_options, bind_data.reader_bind, reader.GetFileName(),
reader.GetNames(), bind_data.types, bind_data.names, global_column_ids,
reader_data, context);
reader_data, context, reader_state);

// create a mapping from field id to column index in file
unordered_map<uint32_t, idx_t> field_id_to_column_index;
Expand Down Expand Up @@ -315,7 +325,7 @@ static void InitializeParquetReader(ParquetReader &reader, const ParquetReadBind
reader_data.empty_columns = reader_data.column_ids.empty();

// Finally, initialize the filters
bind_data.multi_file_reader->CreateFilterMap(bind_data.types, table_filters, reader_data);
bind_data.multi_file_reader->CreateFilterMap(bind_data.types, table_filters, reader_data, reader_state);
reader_data.filters = table_filters;
}

Expand Down Expand Up @@ -475,15 +485,18 @@ class ParquetScanFunction {
bool bound_on_first_file = true;
if (result->multi_file_reader->Bind(parquet_options.file_options, *result->file_list, result->types,
result->names, result->reader_bind)) {
// The MultiFileReader has performed a full bind
ParseFileRowNumberOption(result->reader_bind, parquet_options, result->types, result->names);
result->multi_file_reader->BindOptions(parquet_options.file_options, *result->file_list, result->types,
result->names, result->reader_bind);
// Enable the parquet file_row_number on the parquet options if the file_row_number_idx was set
if (result->reader_bind.file_row_number_idx != DConstants::INVALID_INDEX) {
parquet_options.file_row_number = true;
}
bound_on_first_file = false;
} else if (!parquet_options.schema.empty()) {
// A schema was suppliedParquetProgress
// A schema was supplied: use the schema for binding
result->reader_bind = BindSchema(context, result->types, result->names, *result, parquet_options);
} else {
parquet_options.file_options.AutoDetectHivePartitioning(*result->file_list, context);
// Default bind
result->reader_bind = result->multi_file_reader->BindReader<ParquetReader>(
context, result->types, result->names, *result->file_list, *result, parquet_options);
Expand Down Expand Up @@ -543,8 +556,6 @@ class ParquetScanFunction {
}

auto file_list = multi_file_reader->CreateFileList(context, input.inputs[0]);
parquet_options.file_options.AutoDetectHivePartitioning(*file_list, context);

return ParquetScanBindInternal(context, std::move(multi_file_reader), std::move(file_list), return_types, names,
parquet_options);
}
Expand Down Expand Up @@ -575,8 +586,7 @@ class ParquetScanFunction {
result->is_parallel = true;
result->batch_index = 0;

// TODO: needs lock?
if (input.CanRemoveFilterColumns()) {
if (gstate.CanRemoveColumns()) {
result->all_columns.Initialize(context.client, gstate.scanned_types);
}
if (!ParquetParallelStateNext(context.client, bind_data, *result, gstate)) {
Expand All @@ -591,6 +601,9 @@ class ParquetScanFunction {
auto result = make_uniq<ParquetReadGlobalState>();
bind_data.file_list->InitializeScan(result->file_list_scan);

result->multi_file_reader_state = bind_data.multi_file_reader->InitializeGlobalState(
context, bind_data.parquet_options.file_options, bind_data.reader_bind, *bind_data.file_list,
bind_data.types, bind_data.names, input.column_ids);
if (bind_data.file_list->IsEmpty()) {
result->readers = {};
} else if (!bind_data.union_readers.empty()) {
Expand Down Expand Up @@ -619,11 +632,13 @@ class ParquetScanFunction {
// Ensure all readers are initialized and FileListScan is sync with readers list
for (auto &reader_data : result->readers) {
string file_name;
idx_t file_idx = result->file_list_scan.current_file_idx;
bind_data.file_list->Scan(result->file_list_scan, file_name);
if (file_name != reader_data.reader->file_name) {
throw InternalException("Mismatch in filename order and reader order in parquet scan");
}
InitializeParquetReader(*reader_data.reader, bind_data, input.column_ids, input.filters, context);
InitializeParquetReader(*reader_data.reader, bind_data, input.column_ids, input.filters, context, file_idx,
result->multi_file_reader_state);
}

result->column_ids = input.column_ids;
Expand All @@ -632,7 +647,10 @@ class ParquetScanFunction {
result->file_index = 0;
result->batch_index = 0;
result->max_threads = ParquetScanMaxThreads(context, input.bind_data.get());
if (input.CanRemoveFilterColumns()) {

bool require_extra_columns =
result->multi_file_reader_state && result->multi_file_reader_state->RequiresExtraColumns();
if (input.CanRemoveFilterColumns() || require_extra_columns) {
result->projection_ids = input.projection_ids;
const auto table_types = bind_data.types;
for (const auto &col_idx : input.column_ids) {
Expand All @@ -643,6 +661,13 @@ class ParquetScanFunction {
}
}
}

if (require_extra_columns) {
for (const auto &column_type : result->multi_file_reader_state->extra_columns) {
result->scanned_types.push_back(column_type);
}
}

return std::move(result);
}

Expand Down Expand Up @@ -691,16 +716,16 @@ class ParquetScanFunction {
auto &bind_data = data_p.bind_data->CastNoConst<ParquetReadBindData>();

do {
if (gstate.CanRemoveFilterColumns()) {
if (gstate.CanRemoveColumns()) {
data.all_columns.Reset();
data.reader->Scan(data.scan_state, data.all_columns);
bind_data.multi_file_reader->FinalizeChunk(context, bind_data.reader_bind, data.reader->reader_data,
data.all_columns);
data.all_columns, gstate.multi_file_reader_state);
output.ReferenceColumns(data.all_columns, gstate.projection_ids);
} else {
data.reader->Scan(data.scan_state, output);
bind_data.multi_file_reader->FinalizeChunk(context, bind_data.reader_bind, data.reader->reader_data,
output);
output, gstate.multi_file_reader_state);
}

bind_data.chunk_count++;
Expand Down Expand Up @@ -858,7 +883,7 @@ class ParquetScanFunction {
try {
reader = make_shared_ptr<ParquetReader>(context, current_reader_data.file_to_be_opened, pq_options);
InitializeParquetReader(*reader, bind_data, parallel_state.column_ids, parallel_state.filters,
context);
context, i, parallel_state.multi_file_reader_state);
} catch (...) {
parallel_lock.lock();
parallel_state.error_opening_file = true;
Expand Down
35 changes: 27 additions & 8 deletions src/common/multi_file_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ void MultiFileReader::FinalizeBind(const MultiFileReaderOptions &file_options, c
const string &filename, const vector<string> &local_names,
const vector<LogicalType> &global_types, const vector<string> &global_names,
const vector<column_t> &global_column_ids, MultiFileReaderData &reader_data,
ClientContext &context) {
ClientContext &context, optional_ptr<MultiFileReaderGlobalState> global_state) {

// create a map of name -> column index
case_insensitive_map_t<idx_t> name_map;
Expand Down Expand Up @@ -258,10 +258,20 @@ void MultiFileReader::FinalizeBind(const MultiFileReaderOptions &file_options, c
}
}

unique_ptr<MultiFileReaderGlobalState>
MultiFileReader::InitializeGlobalState(ClientContext &context, const MultiFileReaderOptions &file_options,
const MultiFileReaderBindData &bind_data, const MultiFileList &file_list,
const vector<LogicalType> &global_types, const vector<string> &global_names,
const vector<column_t> &global_column_ids) {
// By default, the multifilereader does not require any global state
return nullptr;
}

void MultiFileReader::CreateNameMapping(const string &file_name, const vector<LogicalType> &local_types,
const vector<string> &local_names, const vector<LogicalType> &global_types,
const vector<string> &global_names, const vector<column_t> &global_column_ids,
MultiFileReaderData &reader_data, const string &initial_file) {
MultiFileReaderData &reader_data, const string &initial_file,
optional_ptr<MultiFileReaderGlobalState> global_state) {
D_ASSERT(global_types.size() == global_names.size());
D_ASSERT(local_types.size() == local_names.size());
// we have expected types: create a map of name -> column index
Expand Down Expand Up @@ -318,23 +328,31 @@ void MultiFileReader::CreateNameMapping(const string &file_name, const vector<Lo
reader_data.column_mapping.push_back(i);
reader_data.column_ids.push_back(local_id);
}

reader_data.empty_columns = reader_data.column_ids.empty();
}

void MultiFileReader::CreateMapping(const string &file_name, const vector<LogicalType> &local_types,
const vector<string> &local_names, const vector<LogicalType> &global_types,
const vector<string> &global_names, const vector<column_t> &global_column_ids,
optional_ptr<TableFilterSet> filters, MultiFileReaderData &reader_data,
const string &initial_file) {
const string &initial_file, const MultiFileReaderBindData &options,
optional_ptr<MultiFileReaderGlobalState> global_state) {
CreateNameMapping(file_name, local_types, local_names, global_types, global_names, global_column_ids, reader_data,
initial_file);
CreateFilterMap(global_types, filters, reader_data);
initial_file, global_state);
CreateFilterMap(global_types, filters, reader_data, global_state);
}

void MultiFileReader::CreateFilterMap(const vector<LogicalType> &global_types, optional_ptr<TableFilterSet> filters,
MultiFileReaderData &reader_data) {
MultiFileReaderData &reader_data,
optional_ptr<MultiFileReaderGlobalState> global_state) {
if (filters) {
reader_data.filter_map.resize(global_types.size());
auto filter_map_size = global_types.size();
if (global_state) {
filter_map_size += global_state->extra_columns.size();
}
reader_data.filter_map.resize(filter_map_size);

for (idx_t c = 0; c < reader_data.column_mapping.size(); c++) {
auto map_index = reader_data.column_mapping[c];
reader_data.filter_map[map_index].index = c;
Expand All @@ -349,7 +367,8 @@ void MultiFileReader::CreateFilterMap(const vector<LogicalType> &global_types, o
}

void MultiFileReader::FinalizeChunk(ClientContext &context, const MultiFileReaderBindData &bind_data,
const MultiFileReaderData &reader_data, DataChunk &chunk) {
const MultiFileReaderData &reader_data, DataChunk &chunk,
optional_ptr<MultiFileReaderGlobalState> global_state) {
// reference all the constants set up in MultiFileReader::FinalizeBind
for (auto &entry : reader_data.constant_map) {
chunk.data[entry.column_id].Reference(entry.value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ CSVFileScan::CSVFileScan(ClientContext &context, shared_ptr<CSVBufferManager> bu
options = union_reader.options;
types = union_reader.GetTypes();
multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types,
bind_data.return_names, column_ids, nullptr, file_path, context);
bind_data.return_names, column_ids, nullptr, file_path, context, nullptr);
InitializeFileNamesTypes();
return;
} else if (!bind_data.column_info.empty()) {
// Serialized Union By name
names = bind_data.column_info[0].names;
types = bind_data.column_info[0].types;
multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types,
bind_data.return_names, column_ids, nullptr, file_path, context);
bind_data.return_names, column_ids, nullptr, file_path, context, nullptr);
InitializeFileNamesTypes();
return;
}
names = bind_data.return_names;
types = bind_data.return_types;
file_schema = bind_data.return_types;
multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types,
bind_data.return_names, column_ids, nullptr, file_path, context);
bind_data.return_names, column_ids, nullptr, file_path, context, nullptr);

InitializeFileNamesTypes();
}
Expand Down Expand Up @@ -68,7 +68,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons
state_machine = union_reader.state_machine;
multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind,
bind_data.return_types, bind_data.return_names, column_ids, nullptr,
file_path, context);
file_path, context, nullptr);

InitializeFileNamesTypes();
return;
Expand Down Expand Up @@ -96,7 +96,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons
state_machine_cache.Get(options.dialect_options.state_machine_options), options);

multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types,
bind_data.return_names, column_ids, nullptr, file_path, context);
bind_data.return_names, column_ids, nullptr, file_path, context, nullptr);
InitializeFileNamesTypes();
return;
}
Expand Down Expand Up @@ -127,7 +127,7 @@ CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, cons
state_machine_cache.Get(options.dialect_options.state_machine_options), options);

multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types,
bind_data.return_names, column_ids, nullptr, file_path, context);
bind_data.return_names, column_ids, nullptr, file_path, context, nullptr);
InitializeFileNamesTypes();
}

Expand Down
2 changes: 1 addition & 1 deletion src/function/table/read_csv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ static void ReadCSVFunction(ClientContext &context, TableFunctionInput &data_p,
do {
if (output.size() != 0) {
MultiFileReader().FinalizeChunk(context, bind_data.reader_bind,
csv_local_state.csv_reader->csv_file_scan->reader_data, output);
csv_local_state.csv_reader->csv_file_scan->reader_data, output, nullptr);
break;
}
if (csv_local_state.csv_reader->FinishedIterator()) {
Expand Down
Loading
Loading