Skip to content

Commit

Permalink
Implement scan pandas
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Nov 14, 2023
1 parent 28c99e7 commit 9f99c26
Show file tree
Hide file tree
Showing 46 changed files with 934 additions and 35 deletions.
15 changes: 13 additions & 2 deletions src/binder/bind/bind_reading_clause.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
#include "common/string_format.h"
#include "common/string_utils.h"
#include "function/table_functions/bind_input.h"
#include "main/client_context.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/expression/parsed_literal_expression.h"
#include "parser/parsed_expression_visitor.h"
#include "parser/query/reading_clause/in_query_call_clause.h"
#include "parser/query/reading_clause/load_from.h"
#include "parser/query/reading_clause/match_clause.h"
Expand Down Expand Up @@ -98,6 +100,17 @@ std::unique_ptr<BoundReadingClause> Binder::bindUnwindClause(const ReadingClause
std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause& readingClause) {
auto& call = reinterpret_cast<const InQueryCallClause&>(readingClause);
auto funcExpr = reinterpret_cast<ParsedFunctionExpression*>(call.getFunctionExpression());
auto funcName = funcExpr->getFunctionName();
StringUtils::toUpper(funcName);
if (funcName == common::READ_PANDAS_FUNC_NAME && clientContext->replaceFunc) {
auto replacedValue = clientContext->replaceFunc(
reinterpret_cast<ParsedLiteralExpression*>(funcExpr->getChild(0))->getValue());
auto parameterExpression =
std::make_unique<ParsedLiteralExpression>(std::move(replacedValue), "pd");
auto inQueryCallParameterReplacer = std::make_unique<InQueryCallParameterReplacer>(
std::make_pair(funcName, parameterExpression.get()));
funcExpr = inQueryCallParameterReplacer->visit(funcExpr);
}
std::vector<std::unique_ptr<Value>> inputValues;
std::vector<LogicalType*> inputTypes;
for (auto i = 0u; i < funcExpr->getNumChildren(); i++) {
Expand All @@ -109,8 +122,6 @@ std::unique_ptr<BoundReadingClause> Binder::bindInQueryCall(const ReadingClause&
inputTypes.push_back(expressionValue->getDataType());
inputValues.push_back(expressionValue->copy());
}
auto funcName = funcExpr->getFunctionName();
StringUtils::toUpper(funcName);
// TODO: this is dangerous because we could match to a scan function.
auto tableFunction = reinterpret_cast<function::TableFunction*>(
catalog.getBuiltInFunctions()->matchScalarFunction(std::move(funcName), inputTypes));
Expand Down
7 changes: 7 additions & 0 deletions src/common/string_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,12 @@ std::string StringUtils::removeEscapedCharacters(const std::string& input) {
return resultStr;
}

bool StringUtils::startsWith(std::string str, std::string prefix) {
if (prefix.size() > str.size()) {
return false;
}
return equal(prefix.begin(), prefix.end(), str.begin());
}

} // namespace common
} // namespace kuzu
11 changes: 11 additions & 0 deletions src/common/types/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace kuzu {
namespace common {

std::string PhysicalTypeUtils::physicalTypeToString(PhysicalTypeID physicalType) {
// LCOV_EXCL_START
switch (physicalType) {
case PhysicalTypeID::BOOL:
return "BOOL";
Expand Down Expand Up @@ -55,9 +56,12 @@ std::string PhysicalTypeUtils::physicalTypeToString(PhysicalTypeID physicalType)
return "STRUCT";
case PhysicalTypeID::VAR_LIST:
return "VAR_LIST";
case PhysicalTypeID::POINTER:
return "POINTER";
default:
KU_UNREACHABLE;
}
// LCOV_EXCL_STOP
}

uint32_t PhysicalTypeUtils::getFixedTypeSize(PhysicalTypeID physicalType) {
Expand Down Expand Up @@ -511,6 +515,9 @@ void LogicalType::setPhysicalType() {
case LogicalTypeID::RDF_VARIANT: {
physicalType = PhysicalTypeID::STRUCT;
} break;
case LogicalTypeID::POINTER: {
physicalType = PhysicalTypeID::POINTER;
} break;
default:
KU_UNREACHABLE;
}
Expand Down Expand Up @@ -583,6 +590,7 @@ LogicalTypeID LogicalTypeUtils::dataTypeIDFromString(const std::string& dataType
}

std::string LogicalTypeUtils::toString(LogicalTypeID dataTypeID) {
// LCOV_EXCL_START
switch (dataTypeID) {
case LogicalTypeID::ANY:
return "ANY";
Expand Down Expand Up @@ -642,9 +650,12 @@ std::string LogicalTypeUtils::toString(LogicalTypeID dataTypeID) {
return "MAP";
case LogicalTypeID::UNION:
return "UNION";
case LogicalTypeID::POINTER:
return "POINTER";
default:
KU_UNREACHABLE;
}
// LCOV_EXCL_STOP
}

std::string LogicalTypeUtils::toString(const std::vector<LogicalType*>& dataTypes) {
Expand Down
16 changes: 11 additions & 5 deletions src/common/types/value/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ Value::Value(const char* val_) : isNull_{false} {
strVal = std::string(val_);
}

Value::Value(uint8_t* val_) : isNull_{false} {
dataType = std::make_unique<LogicalType>(LogicalTypeID::POINTER);
val.pointer = val_;
}

Value::Value(LogicalType type, const std::string& val_) : isNull_{false} {
dataType = type.copy();
strVal = val_;
Expand All @@ -212,11 +217,6 @@ Value::Value(LogicalType dataType_, std::vector<std::unique_ptr<Value>> children
childrenSize = this->children.size();
}

Value::Value(LogicalType dataType_, const uint8_t* val_) : isNull_{false} {
dataType = dataType_.copy();
copyValueFrom(val_);
}

Value::Value(const Value& other) : isNull_{other.isNull_} {
dataType = other.dataType->copy();
copyValueFrom(other);
Expand Down Expand Up @@ -293,6 +293,9 @@ void Value::copyValueFrom(const uint8_t* value) {
case LogicalTypeID::RDF_VARIANT: {
copyFromStruct(value);
} break;
case LogicalTypeID::POINTER: {
val.pointer = *((uint8_t**)value);
} break;

Check warning on line 298 in src/common/types/value/value.cpp

View check run for this annotation

Codecov / codecov/patch

src/common/types/value/value.cpp#L296-L298

Added lines #L296 - L298 were not covered by tests
default:
KU_UNREACHABLE;
}
Expand Down Expand Up @@ -358,6 +361,9 @@ void Value::copyValueFrom(const Value& other) {
children.push_back(child->copy());
}
} break;
case PhysicalTypeID::POINTER: {
val.pointer = other.val.pointer;
} break;
default:
KU_UNREACHABLE;
}
Expand Down
2 changes: 1 addition & 1 deletion src/function/table_functions/scan_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace kuzu {
namespace function {

std::pair<uint64_t, uint64_t> ScanSharedTableFuncState::getNext() {
std::pair<uint64_t, uint64_t> ScanSharedState::getNext() {
std::lock_guard<std::mutex> guard{lock};
if (fileIdx >= readerConfig.getNumFiles()) {
return {UINT64_MAX, UINT64_MAX};
Expand Down
1 change: 1 addition & 0 deletions src/include/common/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ struct CopyConstants {
static constexpr char DEFAULT_CSV_LIST_END_CHAR = ']';
static constexpr char DEFAULT_CSV_LINE_BREAK = '\n';
static constexpr const char* ROW_IDX_COLUMN_NAME = "ROW_IDX";
static constexpr uint64_t PANDAS_PARTITION_COUNT = 50 * DEFAULT_VECTOR_CAPACITY;
};

struct LoggerConstants {
Expand Down
1 change: 1 addition & 0 deletions src/include/common/enums/expression_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ const std::string READ_NPY_FUNC_NAME = "READ_NPY";
const std::string READ_CSV_SERIAL_FUNC_NAME = "READ_CSV_SERIAL";
const std::string READ_CSV_PARALLEL_FUNC_NAME = "READ_CSV_PARALLEL";
const std::string READ_RDF_FUNC_NAME = "READ_RDF";
const std::string READ_PANDAS_FUNC_NAME = "READ_PANDAS";

enum class ExpressionType : uint8_t {

Expand Down
2 changes: 2 additions & 0 deletions src/include/common/string_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class StringUtils {
char delimiterEnd, bool includeDelimiter = false);

static std::string removeEscapedCharacters(const std::string& input);

static bool startsWith(std::string str, std::string prefix);
};

} // namespace common
Expand Down
2 changes: 2 additions & 0 deletions src/include/common/types/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ enum class KUZU_API LogicalTypeID : uint8_t {
MAP = 54,
UNION = 55,
RDF_VARIANT = 56,
POINTER = 57,
};

enum class PhysicalTypeID : uint8_t {
Expand All @@ -132,6 +133,7 @@ enum class PhysicalTypeID : uint8_t {
FIXED_LIST = 21,
VAR_LIST = 22,
STRUCT = 23,
POINTER = 24,
};

class LogicalType;
Expand Down
34 changes: 34 additions & 0 deletions src/include/common/types/value/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ class Value {
* @return a Value with STRING type and val_ value.
*/
KUZU_API explicit Value(const char* val_);
/**
* @param val_ the uint8_t* value to set.
* @return a Value with POINTER type and val_ value.
*/
KUZU_API explicit Value(uint8_t* val_);
/**
* @param val_ the string value to set.
* @return a Value with type and val_ value.
Expand Down Expand Up @@ -249,6 +254,8 @@ class Value {
uint8_t uint8Val;
double doubleVal;
float floatVal;
// TODO(Ziyi): Should we remove the val suffix from all values in Val? Looks redundant.
uint8_t* pointer;
interval_t intervalVal;
internalID_t internalIDVal;
} val;
Expand Down Expand Up @@ -418,6 +425,15 @@ KUZU_API inline std::string Value::getValue() const {
return strVal;
}

/**
* @return uint8_t* value.
*/
template<>
KUZU_API inline uint8_t* Value::getValue() const {
KU_ASSERT(dataType->getLogicalTypeID() == LogicalTypeID::POINTER);
return val.pointer;
}

/**
* @return the reference to the boolean value.
*/
Expand Down Expand Up @@ -571,6 +587,15 @@ KUZU_API inline std::string& Value::getValueReference() {
return strVal;
}

/**
* @return the reference to the uint8_t* value.
*/
template<>
KUZU_API inline uint8_t*& Value::getValueReference() {
KU_ASSERT(dataType->getLogicalTypeID() == LogicalTypeID::POINTER);
return val.pointer;
}

/**
* @param val the boolean value
* @return a Value with BOOL type and val value.
Expand Down Expand Up @@ -720,5 +745,14 @@ KUZU_API inline Value Value::createValue(const char* value) {
return Value(LogicalType{LogicalTypeID::STRING}, std::string(value));
}

/**
* @param val the uint8_t* val
* @return a Value with POINTER type and val val.
*/
template<>
KUZU_API inline Value Value::createValue(uint8_t* val) {
return Value(val);
}

} // namespace common
} // namespace kuzu
13 changes: 9 additions & 4 deletions src/include/function/table_functions/scan_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,20 @@
namespace kuzu {
namespace function {

struct ScanSharedTableFuncState : public TableFuncSharedState {
struct BaseScanSharedState : public TableFuncSharedState {
std::mutex lock;
uint64_t fileIdx;
uint64_t blockIdx;
const common::ReaderConfig readerConfig;
uint64_t numRows;

ScanSharedTableFuncState(const common::ReaderConfig readerConfig, uint64_t numRows)
: fileIdx{0}, blockIdx{0}, readerConfig{std::move(readerConfig)}, numRows{numRows} {}
BaseScanSharedState(uint64_t numRows) : fileIdx{0}, blockIdx{0}, numRows{numRows} {}
};

struct ScanSharedState : public BaseScanSharedState {

Check warning on line 17 in src/include/function/table_functions/scan_functions.h

View check run for this annotation

Codecov / codecov/patch

src/include/function/table_functions/scan_functions.h#L17

Added line #L17 was not covered by tests
const common::ReaderConfig readerConfig;

ScanSharedState(const common::ReaderConfig readerConfig, uint64_t numRows)
: BaseScanSharedState{numRows}, readerConfig{std::move(readerConfig)} {}

std::pair<uint64_t, uint64_t> getNext();
};
Expand Down
8 changes: 8 additions & 0 deletions src/include/main/client_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <atomic>
#include <cstdint>
#include <functional>
#include <memory>

#include "common/timer.h"
Expand All @@ -24,6 +25,8 @@ struct ActiveQuery {
void reset();
};

using replace_func_t = std::function<std::unique_ptr<common::Value>(common::Value*)>;

/**
* @brief Contain client side configuration. We make profiler associated per query, so profiler is
* not maintained in client context.
Expand Down Expand Up @@ -62,6 +65,10 @@ class ClientContext {
transaction::Transaction* getActiveTransaction() const;
transaction::TransactionContext* getTransactionContext() const;

inline void setReplaceFunc(replace_func_t replaceFunc) {
this->replaceFunc = std::move(replaceFunc);
}

private:
inline void resetActiveQuery() { activeQuery.reset(); }

Expand All @@ -71,6 +78,7 @@ class ClientContext {
uint32_t varLengthExtendMaxDepth;
std::unique_ptr<transaction::TransactionContext> transactionContext;
bool enableSemiMask;
replace_func_t replaceFunc;
};

} // namespace main
Expand Down
4 changes: 4 additions & 0 deletions src/include/main/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ class Connection {
std::move(parameterTypes), returnType));
}

inline void setReplaceFunc(replace_func_t replaceFunc) {
clientContext->setReplaceFunc(std::move(replaceFunc));
}

private:
std::unique_ptr<QueryResult> query(const std::string& query, const std::string& encodedJoin);

Expand Down
9 changes: 9 additions & 0 deletions src/include/main/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

#include <memory>
#include <string>
#include <vector>

#include "common/api.h"
#include "kuzu_fwd.h"

namespace kuzu {
namespace function {
struct Function;
} // namespace function

namespace main {

/**
Expand Down Expand Up @@ -69,6 +74,10 @@ class Database {
*/
KUZU_API static void setLoggingLevel(std::string loggingLevel);

// TODO(Ziyi): Instead of exposing a dedicated API for adding a new function, we should consider
// add function through the extension module.
void addFunction(std::string name, std::vector<std::unique_ptr<function::Function>> tableFunc);

private:
void openLockFile();
void initDBDirAndCoreFilesIfNecessary();
Expand Down
14 changes: 14 additions & 0 deletions src/include/parser/parsed_expression_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <unordered_map>

#include "parser/expression/parsed_expression.h"
#include "parser/expression/parsed_function_expression.h"
#include "parser/expression/parsed_literal_expression.h"

namespace kuzu {
namespace parser {
Expand Down Expand Up @@ -34,5 +36,17 @@ class MacroParameterReplacer {
const std::unordered_map<std::string, ParsedExpression*>& expressionNamesToReplace;
};

class InQueryCallParameterReplacer {
public:
explicit InQueryCallParameterReplacer(
std::pair<std::string, ParsedLiteralExpression*> literalExpressionToReplace)
: literalExpressionToReplace{std::move(literalExpressionToReplace)} {}

ParsedFunctionExpression* visit(ParsedFunctionExpression* tableFuncExpression) const;

private:
std::pair<std::string, ParsedLiteralExpression*> literalExpressionToReplace;
};

} // namespace parser
} // namespace kuzu
Loading

0 comments on commit 9f99c26

Please sign in to comment.