Skip to content

Commit

Permalink
Merge pull request #3067 from kuzudb/table-function-copy
Browse files Browse the repository at this point in the history
Copy table function instead of passing raw pointer
  • Loading branch information
andyfengHKU authored Mar 16, 2024
2 parents c3556e2 + 28bd03b commit a612c0f
Show file tree
Hide file tree
Showing 15 changed files with 39 additions and 36 deletions.
2 changes: 1 addition & 1 deletion src/binder/bind/bind_file_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ std::unique_ptr<BoundBaseScanSource> Binder::bindScanSource(BaseScanSource* sour
auto func = getScanFunction(config->fileType, *config);
auto bindInput = std::make_unique<ScanTableFuncBindInput>(config->copy(),
std::move(expectedColumnNames), std::move(expectedColumnTypes), clientContext);
auto bindData = func->bindFunc(clientContext, bindInput.get());
auto bindData = func.bindFunc(clientContext, bindInput.get());
// Bind input columns
expression_vector inputColumns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
Expand Down
8 changes: 4 additions & 4 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
auto offset = expressionBinder.createVariableExpression(
*LogicalType::INT64(), std::string(InternalKeyword::ROW_OFFSET));
auto boundInQueryCall = std::make_unique<BoundInQueryCall>(
tableFunc, std::move(bindData), std::move(columns), offset);
*tableFunc, std::move(bindData), std::move(columns), offset);
if (call.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*call.getWherePredicate());
boundInQueryCall->setPredicate(std::move(wherePredicate));
Expand All @@ -153,7 +153,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&

std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(const ReadingClause& readingClause) {
auto& loadFrom = ku_dynamic_cast<const ReadingClause&, const LoadFrom&>(readingClause);
function::TableFunction* scanFunction;
function::TableFunction scanFunction;
std::unique_ptr<TableFuncBindInput> bindInput;
auto source = loadFrom.getSource();
switch (source->type) {
Expand All @@ -167,7 +167,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(const ReadingClause& re
auto functions = clientContext->getCatalog()->getFunctions(clientContext->getTx());
auto func = BuiltInFunctionsUtils::matchFunction(READ_PANDAS_FUNC_NAME,
std::vector<LogicalType>{objectExpr->getDataType()}, functions);
scanFunction = ku_dynamic_cast<Function*, TableFunction*>(func);
scanFunction = *ku_dynamic_cast<Function*, TableFunction*>(func);
bindInput = std::make_unique<function::TableFuncBindInput>();
bindInput->inputs.push_back(*literalExpr->getValue());
} else {
Expand Down Expand Up @@ -219,7 +219,7 @@ std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(const ReadingClause& re
default:
throw BinderException(stringFormat("LOAD FROM subquery is not supported."));
}
auto bindData = scanFunction->bindFunc(clientContext, bindInput.get());
auto bindData = scanFunction.bindFunc(clientContext, bindInput.get());
expression_vector columns;
for (auto i = 0u; i < bindData->columnTypes.size(); i++) {
columns.push_back(createVariable(bindData->columnNames[i], bindData->columnTypes[i]));
Expand Down
10 changes: 5 additions & 5 deletions src/binder/bind/copy/bind_copy_rdf_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(
BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_FUNC_NAME, functions);
auto rScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rColumns = expression_vector{r};
auto rFileScanInfo = BoundFileScanInfo(rScanFunc, bindData->copy(), std::move(rColumns));
auto rFileScanInfo = BoundFileScanInfo(*rScanFunc, bindData->copy(), std::move(rColumns));
auto rSource = std::make_unique<BoundFileScanSource>(std::move(rFileScanInfo));
auto rTableID = rdfGraphEntry->getResourceTableID();
auto rEntry = catalog->getTableCatalogEntry(clientContext->getTx(), rTableID);
Expand All @@ -68,7 +68,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(
BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_FUNC_NAME, functions);
auto lScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto lColumns = expression_vector{l, lang};
auto lFileScanInfo = BoundFileScanInfo(lScanFunc, bindData->copy(), std::move(lColumns));
auto lFileScanInfo = BoundFileScanInfo(*lScanFunc, bindData->copy(), std::move(lColumns));
auto lSource = std::make_unique<BoundFileScanSource>(std::move(lFileScanInfo));
auto lTableID = rdfGraphEntry->getLiteralTableID();
auto lEntry = catalog->getTableCatalogEntry(clientContext->getTx(), lTableID);
Expand All @@ -80,7 +80,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(
BuiltInFunctionsUtils::matchFunction(READ_RDF_RESOURCE_TRIPLE_FUNC_NAME, functions);
auto rrrScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rrrColumns = expression_vector{s, p, o};
auto rrrFileScanInfo = BoundFileScanInfo(rrrScanFunc, bindData->copy(), rrrColumns);
auto rrrFileScanInfo = BoundFileScanInfo(*rrrScanFunc, bindData->copy(), rrrColumns);
auto rrrSource = std::make_unique<BoundFileScanSource>(std::move(rrrFileScanInfo));
auto rrrTableID = rdfGraphEntry->getResourceTripleTableID();
auto rrrEntry = catalog->getTableCatalogEntry(clientContext->getTx(), rrrTableID);
Expand All @@ -102,7 +102,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(
BuiltInFunctionsUtils::matchFunction(READ_RDF_LITERAL_TRIPLE_FUNC_NAME, functions);
auto rrlScanFunc = ku_dynamic_cast<Function*, TableFunction*>(func);
auto rrlColumns = expression_vector{s, p, oOffset};
auto rrlFileScanInfo = BoundFileScanInfo(rrlScanFunc, bindData->copy(), rrlColumns);
auto rrlFileScanInfo = BoundFileScanInfo(*rrlScanFunc, bindData->copy(), rrlColumns);
auto rrlSource = std::make_unique<BoundFileScanSource>(std::move(rrlFileScanInfo));
auto rrlTableID = rdfGraphEntry->getLiteralTripleTableID();
auto rrlEntry = catalog->getTableCatalogEntry(clientContext->getTx(), rrlTableID);
Expand All @@ -119,7 +119,7 @@ std::unique_ptr<BoundStatement> Binder::bindCopyRdfFrom(
std::move(rCopyInfo), std::move(lCopyInfo), std::move(rrrCopyInfo), std::move(rrLCopyInfo));
std::unique_ptr<BoundBaseScanSource> source;
if (inMemory) {
auto fileScanInfo = BoundFileScanInfo(scanFunc, bindData->copy(), expression_vector{});
auto fileScanInfo = BoundFileScanInfo(*scanFunc, bindData->copy(), expression_vector{});
source = std::make_unique<BoundFileScanSource>(std::move(fileScanInfo));
} else {
source = std::make_unique<BoundEmptyScanSource>();
Expand Down
4 changes: 2 additions & 2 deletions src/binder/binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ void Binder::restoreScope(std::unique_ptr<BinderScope> prevVariableScope) {
scope = std::move(prevVariableScope);
}

function::TableFunction* Binder::getScanFunction(FileType fileType, const ReaderConfig& config) {
function::TableFunction Binder::getScanFunction(FileType fileType, const ReaderConfig& config) {
function::Function* func;
auto stringType = LogicalType(LogicalTypeID::STRING);
std::vector<LogicalType> inputTypes;
Expand All @@ -240,7 +240,7 @@ function::TableFunction* Binder::getScanFunction(FileType fileType, const Reader
default:
KU_UNREACHABLE;
}
return ku_dynamic_cast<function::Function*, function::TableFunction*>(func);
return *ku_dynamic_cast<function::Function*, function::TableFunction*>(func);
}

} // namespace binder
Expand Down
2 changes: 1 addition & 1 deletion src/include/binder/binder.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ class Binder {
std::unique_ptr<BinderScope> saveScope();
void restoreScope(std::unique_ptr<BinderScope> prevVariableScope);

function::TableFunction* getScanFunction(
function::TableFunction getScanFunction(
common::FileType fileType, const common::ReaderConfig& config);

private:
Expand Down
8 changes: 4 additions & 4 deletions src/include/binder/copy/bound_file_scan_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ namespace kuzu {
namespace binder {

struct BoundFileScanInfo {
function::TableFunction* copyFunc;
function::TableFunction func;
std::unique_ptr<function::TableFuncBindData> bindData;
binder::expression_vector columns;

BoundFileScanInfo(function::TableFunction* copyFunc,
BoundFileScanInfo(function::TableFunction func,
std::unique_ptr<function::TableFuncBindData> bindData, binder::expression_vector columns)
: copyFunc{copyFunc}, bindData{std::move(bindData)}, columns{std::move(columns)} {}
: func{func}, bindData{std::move(bindData)}, columns{std::move(columns)} {}
EXPLICIT_COPY_DEFAULT_MOVE(BoundFileScanInfo);

private:
BoundFileScanInfo(const BoundFileScanInfo& other)
: copyFunc{other.copyFunc}, bindData{other.bindData->copy()}, columns{other.columns} {}
: func{other.func}, bindData{other.bindData->copy()}, columns{other.columns} {}
};

} // namespace binder
Expand Down
12 changes: 6 additions & 6 deletions src/include/binder/query/reading_clause/bound_in_query_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ namespace binder {

class BoundInQueryCall : public BoundReadingClause {
public:
BoundInQueryCall(function::TableFunction* tableFunc,
BoundInQueryCall(function::TableFunction tableFunc,
std::unique_ptr<function::TableFuncBindData> bindData, expression_vector outExprs,
std::shared_ptr<Expression> rowIdxExpr)
: BoundReadingClause{common::ClauseType::IN_QUERY_CALL}, tableFunc{tableFunc},
bindData{std::move(bindData)}, outExprs{std::move(outExprs)}, rowIdxExpr{std::move(
rowIdxExpr)} {}

inline function::TableFunction* getTableFunc() const { return tableFunc; }
function::TableFunction getTableFunc() const { return tableFunc; }

inline const function::TableFuncBindData* getBindData() const { return bindData.get(); }
const function::TableFuncBindData* getBindData() const { return bindData.get(); }

inline expression_vector getOutExprs() const { return outExprs; }
expression_vector getOutExprs() const { return outExprs; }

inline std::shared_ptr<Expression> getRowIdxExpr() const { return rowIdxExpr; }
std::shared_ptr<Expression> getRowIdxExpr() const { return rowIdxExpr; }

private:
function::TableFunction* tableFunc;
function::TableFunction tableFunc;
std::unique_ptr<function::TableFuncBindData> bindData;
expression_vector outExprs;
std::shared_ptr<Expression> rowIdxExpr;
Expand Down
3 changes: 2 additions & 1 deletion src/include/function/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ struct Function;
using scalar_bind_func = std::function<std::unique_ptr<FunctionBindData>(
const binder::expression_vector&, Function* definition)>;

enum class FunctionType : uint8_t { SCALAR, AGGREGATE, TABLE };
enum class FunctionType : uint8_t { UNKNOWN = 0, SCALAR = 1, AGGREGATE = 2, TABLE = 3 };

struct Function {
Function() : type{FunctionType::UNKNOWN} {};
Function(
FunctionType type, std::string name, std::vector<common::LogicalTypeID> parameterTypeIDs)
: type{type}, name{std::move(name)}, parameterTypeIDs{std::move(parameterTypeIDs)} {}
Expand Down
3 changes: 3 additions & 0 deletions src/include/function/table_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ struct TableFunction : public Function {
table_func_init_local_t initLocalStateFunc;
table_func_can_parallel_t canParallelFunc = [] { return true; };

TableFunction()
: Function{}, tableFunc{nullptr}, bindFunc{nullptr}, initSharedStateFunc{nullptr},
initLocalStateFunc{nullptr} {};
TableFunction(std::string name, table_func_t tableFunc, table_func_bind_t bindFunc,
table_func_init_shared_t initSharedFunc, table_func_init_local_t initLocalFunc,
std::vector<common::LogicalTypeID> inputTypes)
Expand Down
3 changes: 1 addition & 2 deletions src/include/main/attached_database.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ namespace kuzu {
namespace main {

class AttachedDatabase {

public:
AttachedDatabase(std::string dbName, function::TableFunction scanFunction)
: dbName{std::move(dbName)}, scanFunction{std::move(scanFunction)} {}

std::string getDBName() { return dbName; }

function::TableFunction* getScanFunction() { return &scanFunction; }
function::TableFunction getScanFunction() { return scanFunction; }

private:
std::string dbName;
Expand Down
6 changes: 3 additions & 3 deletions src/include/planner/operator/logical_in_query_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ namespace planner {

class LogicalInQueryCall : public LogicalOperator {
public:
LogicalInQueryCall(function::TableFunction* tableFunc,
LogicalInQueryCall(function::TableFunction tableFunc,
std::unique_ptr<function::TableFuncBindData> bindData,
binder::expression_vector outputExpressions,
std::shared_ptr<binder::Expression> rowIDExpression)
: LogicalOperator{LogicalOperatorType::IN_QUERY_CALL}, tableFunc{tableFunc},
bindData{std::move(bindData)}, outputExpressions{std::move(outputExpressions)},
rowIDExpression{std::move(rowIDExpression)} {}

inline function::TableFunction* getTableFunc() const { return tableFunc; }
inline function::TableFunction getTableFunc() const { return tableFunc; }

inline function::TableFuncBindData* getBindData() const { return bindData.get(); }

Expand All @@ -37,7 +37,7 @@ class LogicalInQueryCall : public LogicalOperator {
}

private:
function::TableFunction* tableFunc;
function::TableFunction tableFunc;
std::unique_ptr<function::TableFuncBindData> bindData;
binder::expression_vector outputExpressions;
std::shared_ptr<binder::Expression> rowIDExpression;
Expand Down
4 changes: 2 additions & 2 deletions src/include/processor/operator/call/in_query_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ enum class TableScanOutputType : uint8_t {
};

struct InQueryCallInfo {
function::TableFunction* function;
function::TableFunction function;
std::unique_ptr<function::TableFuncBindData> bindData;
std::vector<DataPos> outPosV;
DataPos rowOffsetPos;
Expand Down Expand Up @@ -66,7 +66,7 @@ class InQueryCall : public PhysicalOperator {

bool isSource() const override { return true; }

bool canParallel() const override { return info.function->canParallelFunc(); }
bool canParallel() const override { return info.function.canParallelFunc(); }

void initLocalStateInternal(ResultSet* resultSet, ExecutionContext* context) override;

Expand Down
2 changes: 1 addition & 1 deletion src/processor/map/create_factorized_table_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::createFTableScan(const expression_
auto function = function::BuiltInFunctionsUtils::matchFunction(
READ_FTABLE_FUNC_NAME, catalog->getFunctions(clientContext->getTx()));
auto info = InQueryCallInfo();
info.function = ku_dynamic_cast<Function*, TableFunction*>(function);
info.function = *ku_dynamic_cast<Function*, TableFunction*>(function);
info.bindData = std::move(bindData);
info.outPosV = std::move(outPosV);
if (offset != nullptr) {
Expand Down
2 changes: 1 addition & 1 deletion src/processor/map/map_scan_file.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ std::unique_ptr<PhysicalOperator> PlanMapper::mapScanFile(LogicalOperator* logic
outPosV.emplace_back(getDataPos(*expr, *outSchema));
}
auto info = InQueryCallInfo();
info.function = scanFileInfo->copyFunc;
info.function = scanFileInfo->func;
info.bindData = scanFileInfo->bindData->copy();
info.outPosV = outPosV;
if (scanFile->hasOffset()) {
Expand Down
6 changes: 3 additions & 3 deletions src/processor/operator/call/in_query_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ void InQueryCall::initLocalStateInternal(ResultSet* resultSet, ExecutionContext*
}
// Init table function input.
function::TableFunctionInitInput tableFunctionInitInput{info.bindData.get()};
localState.funcState = info.function->initLocalStateFunc(tableFunctionInitInput,
localState.funcState = info.function.initLocalStateFunc(tableFunctionInitInput,
sharedState->funcState.get(), context->clientContext->getMemoryManager());
localState.funcInput = function::TableFuncInput{
info.bindData.get(), localState.funcState.get(), sharedState->funcState.get()};
}

void InQueryCall::initGlobalStateInternal(ExecutionContext*) {
function::TableFunctionInitInput tableFunctionInitInput{info.bindData.get()};
sharedState->funcState = info.function->initSharedStateFunc(tableFunctionInitInput);
sharedState->funcState = info.function.initSharedStateFunc(tableFunctionInitInput);
}

bool InQueryCall::getNextTuplesInternal(ExecutionContext*) {
localState.funcOutput.dataChunk.state->selVector->selectedSize = 0;
localState.funcOutput.dataChunk.resetAuxiliaryBuffer();
auto numTuplesScanned = info.function->tableFunc(localState.funcInput, localState.funcOutput);
auto numTuplesScanned = info.function.tableFunc(localState.funcInput, localState.funcOutput);
localState.funcOutput.dataChunk.state->selVector->selectedSize = numTuplesScanned;
if (localState.rowOffsetVector != nullptr) {
auto rowIdx = sharedState->getAndIncreaseRowIdx(numTuplesScanned);
Expand Down

0 comments on commit a612c0f

Please sign in to comment.