Skip to content

Commit

Permalink
Unify csv and rdf reader config
Browse files Browse the repository at this point in the history
  • Loading branch information
andyfengHKU committed Nov 30, 2023
1 parent ff5e450 commit c2acb9c
Show file tree
Hide file tree
Showing 32 changed files with 340 additions and 343 deletions.
12 changes: 6 additions & 6 deletions src/binder/bind/bind_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ std::unique_ptr<BoundStatement> Binder::bindCopyToClause(const Statement& statem
if (fileType != FileType::CSV && copyToStatement.getParsingOptionsRef().size() != 0) {
throw BinderException{"Only copy to csv can have options."};
}
auto csvOption = bindParsingOptions(copyToStatement.getParsingOptionsRef());
auto csvConfig = bindParsingOptions(copyToStatement.getParsingOptionsRef());
return std::make_unique<BoundCopyTo>(boundFilePath, fileType, std::move(columnNames),
std::move(columnTypes), std::move(query), std::move(csvOption));
std::move(columnTypes), std::move(query), std::move(csvConfig->option.copy()));
}

// As a temporary constraint, we require npy files loaded with COPY FROM BY COLUMN keyword.
Expand Down Expand Up @@ -77,11 +77,11 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
default:
break;
}
auto csvReaderConfig = bindParsingOptions(copyStatement.getParsingOptionsRef());
auto csvConfig = bindParsingOptions(copyStatement.getParsingOptionsRef());
auto filePaths = bindFilePaths(copyStatement.getFilePaths());
auto fileType = bindFileType(filePaths);
auto readerConfig =
std::make_unique<ReaderConfig>(fileType, std::move(filePaths), std::move(csvReaderConfig));
std::make_unique<ReaderConfig>(fileType, std::move(filePaths), std::move(csvConfig));
validateByColumnKeyword(readerConfig->fileType, copyStatement.byColumn());
if (readerConfig->fileType == FileType::NPY) {
validateCopyNpyNotForRelTables(tableSchema);
Expand Down Expand Up @@ -111,7 +111,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyFromClause(const Statement& stat
std::unique_ptr<BoundStatement> Binder::bindCopyNodeFrom(const Statement& statement,
std::unique_ptr<common::ReaderConfig> config, TableSchema* tableSchema) {
auto& copyStatement = reinterpret_cast<const CopyFrom&>(statement);
auto func = getScanFunction(config->fileType, config->csvReaderConfig->parallel);
auto func = getScanFunction(config->fileType, *config);
// For table with SERIAL columns, we need to read in serial from files.
auto containsSerial = tableSchema->containsColumnType(LogicalType(LogicalTypeID::SERIAL));
std::vector<std::string> expectedColumnNames;
Expand All @@ -137,7 +137,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyNodeFrom(const Statement& statem
std::unique_ptr<BoundStatement> Binder::bindCopyRelFrom(const parser::Statement& statement,
std::unique_ptr<common::ReaderConfig> config, TableSchema* tableSchema) {
auto& copyStatement = reinterpret_cast<const CopyFrom&>(statement);
auto func = getScanFunction(config->fileType, config->csvReaderConfig->parallel);
auto func = getScanFunction(config->fileType, *config);
// For table with SERIAL columns, we need to read in serial from files.
auto containsSerial = tableSchema->containsColumnType(LogicalType(LogicalTypeID::SERIAL));
KU_ASSERT(containsSerial == false);
Expand Down
15 changes: 8 additions & 7 deletions src/binder/bind/bind_file_scan.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "binder/binder.h"
#include "binder/expression/literal_expression.h"
#include "common/copier_config/csv_reader_config.h"
#include "common/exception/binder.h"
#include "common/exception/copy.h"
#include "common/string_format.h"
Expand Down Expand Up @@ -63,23 +64,23 @@ static char bindParsingOptionValue(std::string value) {
}

static void bindBoolParsingOption(
CSVReaderConfig& csvReaderConfig, const std::string& optionName, bool optionValue) {
CSVReaderConfig& config, const std::string& optionName, bool optionValue) {
if (optionName == "HEADER") {
csvReaderConfig.hasHeader = optionValue;
config.option.hasHeader = optionValue;
} else if (optionName == "PARALLEL") {
csvReaderConfig.parallel = optionValue;
config.parallel = optionValue;
}
}

static void bindStringParsingOption(
CSVReaderConfig& csvReaderConfig, const std::string& optionName, std::string& optionValue) {
CSVReaderConfig& config, const std::string& optionName, std::string& optionValue) {
auto parsingOptionValue = bindParsingOptionValue(optionValue);
if (optionName == "ESCAPE") {
csvReaderConfig.escapeChar = parsingOptionValue;
config.option.escapeChar = parsingOptionValue;
} else if (optionName == "DELIM") {
csvReaderConfig.delimiter = parsingOptionValue;
config.option.delimiter = parsingOptionValue;
} else if (optionName == "QUOTE") {
csvReaderConfig.quoteChar = parsingOptionValue;
config.option.quoteChar = parsingOptionValue;
}
}

Expand Down
3 changes: 1 addition & 2 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(
expectedColumnNames.push_back(name);
expectedColumnTypes.push_back(bindDataType(type));
}
auto scanFunction =
getScanFunction(readerConfig->fileType, readerConfig->csvReaderConfig->parallel);
auto scanFunction = getScanFunction(readerConfig->fileType, *readerConfig);
auto bindInput = std::make_unique<function::ScanTableFuncBindInput>(memoryManager,
*readerConfig, std::move(expectedColumnNames), std::move(expectedColumnTypes));
auto bindData = scanFunction->bindFunc(clientContext, bindInput.get(), (Catalog*)&catalog);
Expand Down
12 changes: 6 additions & 6 deletions src/binder/bind/copy/bind_copy_rdf_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace binder {

std::unique_ptr<BoundStatement> Binder::bindCopyRdfNodeFrom(const Statement& /*statement*/,
std::unique_ptr<ReaderConfig> config, TableSchema* tableSchema) {
auto func = getScanFunction(config->fileType, config->csvReaderConfig->parallel);
auto func = getScanFunction(config->fileType, *config);
bool containsSerial;
auto stringType = LogicalType{LogicalTypeID::STRING};
std::vector<std::string> columnNames;
Expand All @@ -25,13 +25,13 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfNodeFrom(const Statement& /*s
if (tableSchema->tableName.ends_with(rdf::RESOURCE_TABLE_SUFFIX)) {
containsSerial = false;
columnTypes.push_back(stringType.copy());
config->rdfReaderConfig =
config->extraConfig =
std::make_unique<RdfReaderConfig>(RdfReaderMode::RESOURCE, nullptr /* index */);
} else {
KU_ASSERT(tableSchema->tableName.ends_with(rdf::LITERAL_TABLE_SUFFIX));
containsSerial = true;
columnTypes.push_back(RdfVariantType::getType());
config->rdfReaderConfig =
config->extraConfig =
std::make_unique<RdfReaderConfig>(RdfReaderMode::LITERAL, nullptr /* index */);
}
auto bindInput = std::make_unique<function::ScanTableFuncBindInput>(
Expand All @@ -52,7 +52,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfNodeFrom(const Statement& /*s

std::unique_ptr<BoundStatement> Binder::bindCopyRdfRelFrom(const Statement& /*statement*/,
std::unique_ptr<ReaderConfig> config, TableSchema* tableSchema) {
auto func = getScanFunction(config->fileType, config->csvReaderConfig->parallel);
auto func = getScanFunction(config->fileType, *config);
auto containsSerial = false;
std::vector<std::string> columnNames;
columnNames.emplace_back(InternalKeyword::SRC_OFFSET);
Expand All @@ -67,10 +67,10 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfRelFrom(const Statement& /*st
auto resourceTableID = relTableSchema->getSrcTableID();
auto index = storageManager->getPKIndex(resourceTableID);
if (tableSchema->tableName.ends_with(rdf::RESOURCE_TRIPLE_TABLE_SUFFIX)) {
config->rdfReaderConfig =
config->extraConfig =
std::make_unique<RdfReaderConfig>(RdfReaderMode::RESOURCE_TRIPLE, index);
} else {
config->rdfReaderConfig =
config->extraConfig =
std::make_unique<RdfReaderConfig>(RdfReaderMode::LITERAL_TRIPLE, index);
}
auto bindInput = std::make_unique<function::ScanTableFuncBindInput>(
Expand Down
17 changes: 10 additions & 7 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/binder.h"

#include "binder/bound_statement_rewriter.h"
#include "common/copier_config/csv_reader_config.h"
#include "common/exception/binder.h"
#include "common/string_format.h"
#include "function/table_functions.h"
Expand Down Expand Up @@ -209,25 +210,27 @@ void Binder::restoreScope(std::unique_ptr<BinderScope> prevVariableScope) {
scope = std::move(prevVariableScope);
}

function::TableFunction* Binder::getScanFunction(common::FileType fileType, bool isParallel) {
function::TableFunction* Binder::getScanFunction(FileType fileType, const ReaderConfig& config) {
function::Function* func;
auto stringType = LogicalType(LogicalTypeID::STRING);
std::vector<LogicalType*> inputTypes;
inputTypes.push_back(&stringType);
auto functions = catalog.getBuiltInFunctions();
switch (fileType) {
case common::FileType::PARQUET: {
func =
catalog.getBuiltInFunctions()->matchScalarFunction(READ_PARQUET_FUNC_NAME, inputTypes);
func = functions->matchScalarFunction(READ_PARQUET_FUNC_NAME, inputTypes);
} break;
case common::FileType::NPY: {
func = catalog.getBuiltInFunctions()->matchScalarFunction(READ_NPY_FUNC_NAME, inputTypes);
func = functions->matchScalarFunction(READ_NPY_FUNC_NAME, inputTypes);
} break;
case common::FileType::CSV: {
func = catalog.getBuiltInFunctions()->matchScalarFunction(
isParallel ? READ_CSV_PARALLEL_FUNC_NAME : READ_CSV_SERIAL_FUNC_NAME, inputTypes);
auto csvConfig = reinterpret_cast<CSVReaderConfig*>(config.extraConfig.get());
func = functions->matchScalarFunction(
csvConfig->parallel ? READ_CSV_PARALLEL_FUNC_NAME : READ_CSV_SERIAL_FUNC_NAME,
inputTypes);
} break;
case common::FileType::TURTLE: {
func = catalog.getBuiltInFunctions()->matchScalarFunction(READ_RDF_FUNC_NAME, inputTypes);
func = functions->matchScalarFunction(READ_RDF_FUNC_NAME, inputTypes);
} break;
default:
KU_UNREACHABLE;
Expand Down
2 changes: 1 addition & 1 deletion src/common/copier_config/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
add_library(kuzu_common_copier_config
OBJECT
copier_config.cpp)
reader_config.cpp)

set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:kuzu_common_copier_config>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "common/copier_config/copier_config.h"
#include "common/copier_config/reader_config.h"

#include "common/assert.h"
#include "common/exception/copy.h"
Expand Down
15 changes: 6 additions & 9 deletions src/function/cast/cast_fixed_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,20 @@ void CastFixedList::stringtoFixedListCastExecFunction<UnaryFunctionExecutor>(
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
const auto& param = params[0];
auto csvReaderConfig = &reinterpret_cast<CastFunctionBindData*>(dataPtr)->csvConfig;
auto option = &reinterpret_cast<CastFunctionBindData*>(dataPtr)->csvConfig.option;
if (param->state->isFlat()) {
auto inputPos = param->state->selVector->selectedPositions[0];
auto resultPos = result.state->selVector->selectedPositions[0];
result.setNull(resultPos, param->isNull(inputPos));
if (!result.isNull(inputPos)) {
CastString::castToFixedList(
param->getValue<ku_string_t>(inputPos), &result, resultPos, csvReaderConfig);
param->getValue<ku_string_t>(inputPos), &result, resultPos, option);
}
} else if (param->state->selVector->isUnfiltered()) {
for (auto i = 0u; i < param->state->selVector->selectedSize; i++) {
result.setNull(i, param->isNull(i));
if (!result.isNull(i)) {
CastString::castToFixedList(
param->getValue<ku_string_t>(i), &result, i, csvReaderConfig);
CastString::castToFixedList(param->getValue<ku_string_t>(i), &result, i, option);
}
}
} else {
Expand All @@ -189,7 +188,7 @@ void CastFixedList::stringtoFixedListCastExecFunction<UnaryFunctionExecutor>(
result.setNull(pos, param->isNull(pos));
if (!result.isNull(pos)) {
CastString::castToFixedList(
param->getValue<ku_string_t>(pos), &result, pos, csvReaderConfig);
param->getValue<ku_string_t>(pos), &result, pos, option);
}
}
}
Expand All @@ -209,14 +208,12 @@ void CastFixedList::stringtoFixedListCastExecFunction<CastChildFunctionExecutor>
const std::vector<std::shared_ptr<ValueVector>>& params, ValueVector& result, void* dataPtr) {
KU_ASSERT(params.size() == 1);
auto numOfEntries = reinterpret_cast<CastFunctionBindData*>(dataPtr)->numOfEntries;
auto csvReaderConfig = &reinterpret_cast<CastFunctionBindData*>(dataPtr)->csvConfig;

auto option = &reinterpret_cast<CastFunctionBindData*>(dataPtr)->csvConfig.option;
auto inputVector = params[0].get();
for (auto i = 0u; i < numOfEntries; i++) {
result.setNull(i, inputVector->isNull(i));
if (!result.isNull(i)) {
CastString::castToFixedList(
inputVector->getValue<ku_string_t>(i), &result, i, csvReaderConfig);
CastString::castToFixedList(inputVector->getValue<ku_string_t>(i), &result, i, option);
}
}
}
Expand Down
Loading

0 comments on commit c2acb9c

Please sign in to comment.