Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement scan pandas #2403

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/antlr4/Cypher.g4
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ LOAD : ( 'L' | 'l' ) ( 'O' | 'o' ) ( 'A' | 'a' ) ( 'D' | 'd' ) ;
HEADERS : ( 'H' | 'h' ) ( 'E' | 'e' ) ( 'A' | 'a' ) ( 'D' | 'd' ) ( 'E' | 'e' ) ( 'R' | 'r' ) ( 'S' | 's' ) ;

kU_InQueryCall
: CALL SP oC_FunctionInvocation ;
: CALL SP oC_FunctionInvocation (SP? oC_Where)? ;

oC_Match
: ( OPTIONAL SP )? MATCH SP? oC_Pattern (SP? oC_Where)? ;
Expand Down
22 changes: 19 additions & 3 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
#include "common/string_format.h"
#include "common/string_utils.h"
#include "function/table_functions/bind_input.h"
#include "main/client_context.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/expression/parsed_literal_expression.h"
#include "parser/parsed_expression_visitor.h"
#include "parser/query/reading_clause/in_query_call_clause.h"
#include "parser/query/reading_clause/load_from.h"
#include "parser/query/reading_clause/match_clause.h"
Expand Down Expand Up @@ -98,6 +100,17 @@ std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause
std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause& readingClause) {
auto& call = reinterpret_cast<const InQueryCallClause&>(readingClause);
auto funcExpr = reinterpret_cast<ParsedFunctionExpression*>(call.getFunctionExpression());
auto funcName = funcExpr->getFunctionName();
StringUtils::toUpper(funcName);
if (funcName == common::READ_PANDAS_FUNC_NAME && clientContext->replaceFunc) {
auto replacedValue = clientContext->replaceFunc(
reinterpret_cast<ParsedLiteralExpression*>(funcExpr->getChild(0))->getValue());
auto parameterExpression =
std::make_unique<ParsedLiteralExpression>(std::move(replacedValue), "pd");
auto inQueryCallParameterReplacer = std::make_unique<InQueryCallParameterReplacer>(
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
std::make_pair(funcName, parameterExpression.get()));
funcExpr = inQueryCallParameterReplacer->visit(funcExpr);
}
std::vector<std::unique_ptr<Value>> inputValues;
std::vector<LogicalType*> inputTypes;
for (auto i = 0u; i < funcExpr->getNumChildren(); i++) {
Expand All @@ -109,8 +122,6 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
inputTypes.push_back(expressionValue->getDataType());
inputValues.push_back(expressionValue->copy());
}
auto funcName = funcExpr->getFunctionName();
StringUtils::toUpper(funcName);
// TODO: this is dangerous because we could match to a scan function.
auto tableFunction = reinterpret_cast<function::TableFunction*>(
catalog.getBuiltInFunctions()->matchScalarFunction(std::move(funcName), inputTypes));
Expand All @@ -124,8 +135,13 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
}
auto offset = expressionBinder.createVariableExpression(
LogicalType(LogicalTypeID::INT64), common::InternalKeyword::ANONYMOUS);
return std::make_unique<BoundInQueryCall>(
auto boundInQueryCall = std::make_unique<BoundInQueryCall>(
std::move(tableFunction), std::move(bindData), std::move(columns), offset);
if (call.hasWherePredicate()) {
auto wherePredicate = expressionBinder.bindExpression(*call.getWherePredicate());
boundInQueryCall->setWherePredicate(std::move(wherePredicate));
}
return boundInQueryCall;
}

std::unique_ptr<BoundReadingClause> Binder::bindLoadFrom(
Expand Down
8 changes: 8 additions & 0 deletions src/binder/visitor/property_collector.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "binder/visitor/property_collector.h"

#include "binder/expression_visitor.h"
#include "binder/query/reading_clause/bound_in_query_call.h"
#include "binder/query/reading_clause/bound_load_from.h"
#include "binder/query/reading_clause/bound_match_clause.h"
#include "binder/query/reading_clause/bound_unwind_clause.h"
Expand Down Expand Up @@ -46,6 +47,13 @@ void PropertyCollector::visitLoadFrom(const BoundReadingClause& readingClause) {
}
}

void PropertyCollector::visitInQueryCall(const BoundReadingClause& readingClause) {
auto& inQueryCallClause = reinterpret_cast<const BoundInQueryCall&>(readingClause);
if (inQueryCallClause.hasWherePredicate()) {
collectPropertyExpressions(inQueryCallClause.getWherePredicate());
}
}

void PropertyCollector::visitSet(const BoundUpdatingClause& updatingClause) {
auto& boundSetClause = (BoundSetClause&)updatingClause;
for (auto& info : boundSetClause.getInfosRef()) {
Expand Down
11 changes: 11 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace kuzu {
namespace common {

std::string PhysicalTypeUtils::physicalTypeToString(PhysicalTypeID physicalType) {
// LCOV_EXCL_START
switch (physicalType) {
case PhysicalTypeID::BOOL:
return "BOOL";
Expand Down Expand Up @@ -55,9 +56,12 @@ std::string PhysicalTypeUtils::physicalTypeToString(PhysicalTypeID physicalType)
return "STRUCT";
case PhysicalTypeID::VAR_LIST:
return "VAR_LIST";
case PhysicalTypeID::POINTER:
return "POINTER";
default:
KU_UNREACHABLE;
}
// LCOV_EXCL_STOP
}

uint32_t PhysicalTypeUtils::getFixedTypeSize(PhysicalTypeID physicalType) {
Expand Down Expand Up @@ -511,6 +515,9 @@ void LogicalType::setPhysicalType() {
case LogicalTypeID::RDF_VARIANT: {
physicalType = PhysicalTypeID::STRUCT;
} break;
case LogicalTypeID::POINTER: {
physicalType = PhysicalTypeID::POINTER;
} break;
default:
KU_UNREACHABLE;
}
Expand Down Expand Up @@ -583,6 +590,7 @@ LogicalTypeID LogicalTypeUtils::dataTypeIDFromString(const std::string& dataType
}

std::string LogicalTypeUtils::toString(LogicalTypeID dataTypeID) {
// LCOV_EXCL_START
switch (dataTypeID) {
case LogicalTypeID::ANY:
return "ANY";
Expand Down Expand Up @@ -642,9 +650,12 @@ std::string LogicalTypeUtils::toString(LogicalTypeID dataTypeID) {
return "MAP";
case LogicalTypeID::UNION:
return "UNION";
case LogicalTypeID::POINTER:
return "POINTER";
default:
KU_UNREACHABLE;
}
// LCOV_EXCL_STOP
}

std::string LogicalTypeUtils::toString(const std::vector<LogicalType*>& dataTypes) {
Expand Down
16 changes: 11 additions & 5 deletions src/common/types/value/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@
strVal = std::string(val_);
}

Value::Value(uint8_t* val_) : isNull_{false} {
dataType = std::make_unique<LogicalType>(LogicalTypeID::POINTER);
val.pointer = val_;
}

Value::Value(LogicalType type, const std::string& val_) : isNull_{false} {
dataType = type.copy();
strVal = val_;
Expand All @@ -212,11 +217,6 @@
childrenSize = this->children.size();
}

Value::Value(LogicalType dataType_, const uint8_t* val_) : isNull_{false} {
dataType = dataType_.copy();
copyValueFrom(val_);
}

Value::Value(const Value& other) : isNull_{other.isNull_} {
dataType = other.dataType->copy();
copyValueFrom(other);
Expand Down Expand Up @@ -293,6 +293,9 @@
case LogicalTypeID::RDF_VARIANT: {
copyFromStruct(value);
} break;
case LogicalTypeID::POINTER: {
val.pointer = *((uint8_t**)value);
} break;

Check warning on line 298 in src/common/types/value/value.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/value/value.cpp#L296-L298

Added lines #L296 - L298 were not covered by tests
default:
KU_UNREACHABLE;
}
Expand Down Expand Up @@ -358,6 +361,9 @@
children.push_back(child->copy());
}
} break;
case PhysicalTypeID::POINTER: {
val.pointer = other.val.pointer;
} break;
default:
KU_UNREACHABLE;
}
Expand Down
2 changes: 1 addition & 1 deletion src/function/table_functions/scan_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace kuzu {
namespace function {

std::pair<uint64_t, uint64_t> ScanSharedTableFuncState::getNext() {
std::pair<uint64_t, uint64_t> ScanSharedState::getNext() {
std::lock_guard<std::mutex> guard{lock};
if (fileIdx >= readerConfig.getNumFiles()) {
return {UINT64_MAX, UINT64_MAX};
Expand Down
10 changes: 10 additions & 0 deletions src/include/binder/query/reading_clause/bound_in_query_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ class BoundInQueryCall : public BoundReadingClause {

inline std::shared_ptr<Expression> getRowIdxExpression() const { return rowIdxExpression; }

inline void setWherePredicate(std::shared_ptr<Expression> expression) {
wherePredicate = std::move(expression);
}
inline bool hasWherePredicate() const { return wherePredicate != nullptr; }
inline std::shared_ptr<Expression> getWherePredicate() const { return wherePredicate; }
inline expression_vector getPredicatesSplitOnAnd() const {
return hasWherePredicate() ? wherePredicate->splitOnAND() : expression_vector{};
}

inline std::unique_ptr<BoundReadingClause> copy() override {
return std::make_unique<BoundInQueryCall>(
tableFunc, bindData->copy(), outputExpressions, rowIdxExpression);
Expand All @@ -35,6 +44,7 @@ class BoundInQueryCall : public BoundReadingClause {
std::unique_ptr<function::TableFuncBindData> bindData;
expression_vector outputExpressions;
std::shared_ptr<Expression> rowIdxExpression;
std::shared_ptr<Expression> wherePredicate;
};

} // namespace binder
Expand Down
1 change: 1 addition & 0 deletions src/include/binder/visitor/property_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class PropertyCollector : public BoundStatementVisitor {
void visitMatch(const BoundReadingClause& readingClause) final;
void visitUnwind(const BoundReadingClause& readingClause) final;
void visitLoadFrom(const BoundReadingClause& readingClause) final;
void visitInQueryCall(const BoundReadingClause& readingClause) final;

void visitSet(const BoundUpdatingClause& updatingClause) final;
void visitDelete(const BoundUpdatingClause& updatingClause) final;
Expand Down
1 change: 1 addition & 0 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct CopyConstants {
static constexpr char DEFAULT_CSV_LIST_END_CHAR = ']';
static constexpr char DEFAULT_CSV_LINE_BREAK = '\n';
static constexpr const char* ROW_IDX_COLUMN_NAME = "ROW_IDX";
static constexpr uint64_t PANDAS_PARTITION_COUNT = 50 * DEFAULT_VECTOR_CAPACITY;
};

struct LoggerConstants {
Expand Down
1 change: 1 addition & 0 deletions src/include/common/enums/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ const std::string READ_NPY_FUNC_NAME = "READ_NPY";
const std::string READ_CSV_SERIAL_FUNC_NAME = "READ_CSV_SERIAL";
const std::string READ_CSV_PARALLEL_FUNC_NAME = "READ_CSV_PARALLEL";
const std::string READ_RDF_FUNC_NAME = "READ_RDF";
const std::string READ_PANDAS_FUNC_NAME = "READ_PANDAS";

enum class ExpressionType : uint8_t {

Expand Down
2 changes: 2 additions & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ enum class KUZU_API LogicalTypeID : uint8_t {
MAP = 54,
UNION = 55,
RDF_VARIANT = 56,
POINTER = 57,
};

enum class PhysicalTypeID : uint8_t {
Expand All @@ -132,6 +133,7 @@ enum class PhysicalTypeID : uint8_t {
FIXED_LIST = 21,
VAR_LIST = 22,
STRUCT = 23,
POINTER = 24,
};

class LogicalType;
Expand Down
34 changes: 34 additions & 0 deletions src/include/common/types/value/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ class Value {
* @return a Value with STRING type and val_ value.
*/
KUZU_API explicit Value(const char* val_);
/**
* @param val_ the uint8_t* value to set.
* @return a Value with POINTER type and val_ value.
*/
KUZU_API explicit Value(uint8_t* val_);
/**
* @param val_ the string value to set.
* @return a Value with type and val_ value.
Expand Down Expand Up @@ -249,6 +254,8 @@ class Value {
uint8_t uint8Val;
double doubleVal;
float floatVal;
// TODO(Ziyi): Should we remove the val suffix from all values in Val? Looks redundant.
uint8_t* pointer;
interval_t intervalVal;
internalID_t internalIDVal;
} val;
Expand Down Expand Up @@ -418,6 +425,15 @@ KUZU_API inline std::string Value::getValue() const {
return strVal;
}

/**
* @return uint8_t* value.
*/
template<>
KUZU_API inline uint8_t* Value::getValue() const {
KU_ASSERT(dataType->getLogicalTypeID() == LogicalTypeID::POINTER);
return val.pointer;
}

/**
* @return the reference to the boolean value.
*/
Expand Down Expand Up @@ -571,6 +587,15 @@ KUZU_API inline std::string& Value::getValueReference() {
return strVal;
}

/**
* @return the reference to the uint8_t* value.
*/
template<>
KUZU_API inline uint8_t*& Value::getValueReference() {
KU_ASSERT(dataType->getLogicalTypeID() == LogicalTypeID::POINTER);
return val.pointer;
}

/**
* @param val the boolean value
* @return a Value with BOOL type and val value.
Expand Down Expand Up @@ -720,5 +745,14 @@ KUZU_API inline Value Value::createValue(const char* value) {
return Value(LogicalType{LogicalTypeID::STRING}, std::string(value));
}

/**
* @param val the uint8_t* val
* @return a Value with POINTER type and val val.
*/
template<>
KUZU_API inline Value Value::createValue(uint8_t* val) {
return Value(val);
}

} // namespace common
} // namespace kuzu
13 changes: 9 additions & 4 deletions src/include/function/table_functions/scan_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
namespace kuzu {
namespace function {

struct ScanSharedTableFuncState : public TableFuncSharedState {
struct BaseScanSharedState : public TableFuncSharedState {
std::mutex lock;
uint64_t fileIdx;
uint64_t blockIdx;
const common::ReaderConfig readerConfig;
uint64_t numRows;

ScanSharedTableFuncState(const common::ReaderConfig readerConfig, uint64_t numRows)
: fileIdx{0}, blockIdx{0}, readerConfig{std::move(readerConfig)}, numRows{numRows} {}
BaseScanSharedState(uint64_t numRows) : fileIdx{0}, blockIdx{0}, numRows{numRows} {}
};

struct ScanSharedState : public BaseScanSharedState {

Check warning on line 17 in src/include/function/table_functions/scan_functions.h

View check run for this annotation

Codecov / codecov/patch

src/include/function/table_functions/scan_functions.h#L17

Added line #L17 was not covered by tests
const common::ReaderConfig readerConfig;

ScanSharedState(const common::ReaderConfig readerConfig, uint64_t numRows)
: BaseScanSharedState{numRows}, readerConfig{std::move(readerConfig)} {}

std::pair<uint64_t, uint64_t> getNext();
};
Expand Down
8 changes: 8 additions & 0 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <atomic>
#include <cstdint>
#include <functional>
#include <memory>

#include "common/timer.h"
Expand All @@ -24,6 +25,8 @@ struct ActiveQuery {
void reset();
};

using replace_func_t = std::function<std::unique_ptr<common::Value>(common::Value*)>;

/**
* @brief Contain client side configuration. We make profiler associated per query, so profiler is
* not maintained in client context.
Expand Down Expand Up @@ -62,6 +65,10 @@ class ClientContext {
transaction::Transaction* getActiveTransaction() const;
transaction::TransactionContext* getTransactionContext() const;

inline void setReplaceFunc(replace_func_t replaceFunc) {
this->replaceFunc = std::move(replaceFunc);
}

private:
inline void resetActiveQuery() { activeQuery.reset(); }

Expand All @@ -71,6 +78,7 @@ class ClientContext {
uint32_t varLengthExtendMaxDepth;
std::unique_ptr<transaction::TransactionContext> transactionContext;
bool enableSemiMask;
replace_func_t replaceFunc;
};

} // namespace main
Expand Down
4 changes: 4 additions & 0 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ class Connection {
std::move(parameterTypes), returnType));
}

inline void setReplaceFunc(replace_func_t replaceFunc) {
clientContext->setReplaceFunc(std::move(replaceFunc));
}

private:
std::unique_ptr<QueryResult> query(const std::string& query, const std::string& encodedJoin);

Expand Down
Loading
Loading