Skip to content

Commit

Permalink
Merge pull request #712 from tiagokepe/master
Browse files Browse the repository at this point in the history
C++ UDF API
  • Loading branch information
Mytherin committed Jul 10, 2020
2 parents 0ae3c1c + cbcac68 commit 0c69daf
Show file tree
Hide file tree
Showing 19 changed files with 1,079 additions and 35 deletions.
4 changes: 4 additions & 0 deletions src/common/exception.cpp
Expand Up @@ -184,3 +184,7 @@ FatalException::FatalException(string msg, ...) : Exception(ExceptionType::FATAL
InternalException::InternalException(string msg, ...) : Exception(ExceptionType::INTERNAL, msg) {
FORMAT_CONSTRUCTOR(msg);
}

InvalidInputException::InvalidInputException(string msg, ...) : Exception(ExceptionType::INVALID_INPUT, msg) {
FORMAT_CONSTRUCTOR(msg);
}
1 change: 1 addition & 0 deletions src/execution/expression_executor/execute_function.cpp
Expand Up @@ -43,6 +43,7 @@ void ExpressionExecutor::Execute(BoundFunctionExpression &expr, ExpressionState
arguments.Verify();
}
expr.function.function(arguments, *state, result);

if (result.type != expr.return_type) {
throw TypeMismatchException(expr.return_type, result.type,
"expected function to return the former "
Expand Down
3 changes: 2 additions & 1 deletion src/function/CMakeLists.txt
Expand Up @@ -5,7 +5,8 @@ add_subdirectory(table)
add_library_unity(duckdb_function
OBJECT
cast_rules.cpp
function.cpp)
function.cpp
udf_function.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:duckdb_function>
PARENT_SCOPE)
14 changes: 10 additions & 4 deletions src/function/scalar/math/numeric.cpp
Expand Up @@ -10,18 +10,24 @@ using namespace std;
namespace duckdb {

template <class TR, class OP> static scalar_function_t GetScalarIntegerUnaryFunctionFixedReturn(SQLType type) {
scalar_function_t function;
switch (type.id) {
case SQLTypeId::TINYINT:
return ScalarFunction::UnaryFunction<int8_t, TR, OP>;
function = &ScalarFunction::UnaryFunction<int8_t, TR, OP>;
break;
case SQLTypeId::SMALLINT:
return ScalarFunction::UnaryFunction<int16_t, TR, OP>;
function = &ScalarFunction::UnaryFunction<int16_t, TR, OP>;
break;
case SQLTypeId::INTEGER:
return ScalarFunction::UnaryFunction<int32_t, TR, OP>;
function = &ScalarFunction::UnaryFunction<int32_t, TR, OP>;
break;
case SQLTypeId::BIGINT:
return ScalarFunction::UnaryFunction<int64_t, TR, OP>;
function = &ScalarFunction::UnaryFunction<int64_t, TR, OP>;
break;
default:
throw NotImplementedException("Unimplemented type for GetScalarIntegerUnaryFunctionFixedReturn");
}
return function;
}

struct UnaryDoubleWrapper {
Expand Down
23 changes: 16 additions & 7 deletions src/function/scalar/operators/arithmetic.cpp
Expand Up @@ -7,24 +7,33 @@ using namespace std;
namespace duckdb {

template <class OP> static scalar_function_t GetScalarBinaryFunction(SQLType type) {
scalar_function_t function;
switch (type.id) {
case SQLTypeId::TINYINT:
return ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>;
function = &ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>;
break;
case SQLTypeId::SMALLINT:
return ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>;
function = &ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>;
break;
case SQLTypeId::INTEGER:
return ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>;
function = &ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>;
break;
case SQLTypeId::BIGINT:
return ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>;
function = &ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>;
break;
case SQLTypeId::FLOAT:
return ScalarFunction::BinaryFunction<float, float, float, OP, true>;
function = &ScalarFunction::BinaryFunction<float, float, float, OP, true>;
break;
case SQLTypeId::DOUBLE:
return ScalarFunction::BinaryFunction<double, double, double, OP, true>;
function = &ScalarFunction::BinaryFunction<double, double, double, OP, true>;
break;
case SQLTypeId::DECIMAL:
return ScalarFunction::BinaryFunction<double, double, double, OP, true>;
function = &ScalarFunction::BinaryFunction<double, double, double, OP, true>;
break;
default:
throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction");
}
return function;
}

//===--------------------------------------------------------------------===//
Expand Down
16 changes: 16 additions & 0 deletions src/function/udf_function.cpp
@@ -0,0 +1,16 @@
#include "duckdb/function/udf_function.hpp"

#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp"
#include "duckdb/main/client_context.hpp"

namespace duckdb {

void UDFWrapper::RegisterFunction(string name, vector<SQLType> args, SQLType ret_type,
scalar_function_t udf_function, ClientContext &context) {

ScalarFunction scalar_function = ScalarFunction(name, args, ret_type, udf_function);
CreateScalarFunctionInfo info(scalar_function);
context.RegisterFunction(&info);
}

} // namespace duckdb
8 changes: 7 additions & 1 deletion src/include/duckdb/common/exception.hpp
Expand Up @@ -66,7 +66,8 @@ enum class ExceptionType {
INTERRUPT = 29, // interrupt
FATAL = 30, // Fatal exception: fatal exceptions are non-recoverable, and render the entire DB in an unusable state
INTERNAL =
31 // Internal exception: exception that indicates something went wrong internally (i.e. bug in the code base)
31, // Internal exception: exception that indicates something went wrong internally (i.e. bug in the code base)
INVALID_INPUT = 32 // Input or arguments error
};

class Exception : public std::exception {
Expand Down Expand Up @@ -196,4 +197,9 @@ class InternalException : public Exception {
InternalException(string msg, ...);
};

class InvalidInputException : public Exception {
public:
InvalidInputException(string msg, ...);
};

} // namespace duckdb
79 changes: 59 additions & 20 deletions src/include/duckdb/function/scalar_function.hpp
Expand Up @@ -20,7 +20,7 @@ class BoundFunctionExpression;
class ScalarFunctionCatalogEntry;

//! The type used for scalar functions
typedef void (*scalar_function_t)(DataChunk &input, ExpressionState &state, Vector &result);
typedef std::function<void(DataChunk &, ExpressionState &, Vector &)> scalar_function_t;
//! Binds the scalar function and creates the function data
typedef unique_ptr<FunctionData> (*bind_scalar_function_t)(BoundFunctionExpression &expr, ClientContext &context);
//! Adds the dependencies of this BoundFunctionExpression to the set of dependencies
Expand Down Expand Up @@ -57,12 +57,27 @@ class ScalarFunction : public SimpleFunction {
vector<unique_ptr<Expression>> children, bool is_operator = false);

bool operator==(const ScalarFunction &rhs) const {
return function == rhs.function && bind == rhs.bind && dependency == rhs.dependency;
return CompareScalarFunctionT(rhs.function) && bind == rhs.bind && dependency == rhs.dependency;
}
bool operator!=(const ScalarFunction &rhs) const {
return !(*this == rhs);
}

private:
bool CompareScalarFunctionT(const scalar_function_t other) const {
typedef void(funcTypeT)(DataChunk &, ExpressionState &, Vector &);

funcTypeT **func_ptr = (funcTypeT**) function.template target<funcTypeT*>();
funcTypeT **other_ptr = (funcTypeT**) other.template target<funcTypeT*>();

//Case the functions were created from lambdas the target will return a nullptr
if(func_ptr == nullptr || other_ptr == nullptr) {
//scalar_function_t (std::functions) from lambdas cannot be compared
return false;
}
return ((size_t)*func_ptr == (size_t)*other_ptr);
}

public:
static void NopFunction(DataChunk &input, ExpressionState &state, Vector &result) {
assert(input.column_count() >= 1);
Expand All @@ -84,60 +99,84 @@ class ScalarFunction : public SimpleFunction {

public:
template <class OP> static scalar_function_t GetScalarUnaryFunction(SQLType type) {
scalar_function_t function;
switch (type.id) {
case SQLTypeId::TINYINT:
return ScalarFunction::UnaryFunction<int8_t, int8_t, OP>;
function = &ScalarFunction::UnaryFunction<int8_t, int8_t, OP>;
break;
case SQLTypeId::SMALLINT:
return ScalarFunction::UnaryFunction<int16_t, int16_t, OP>;
function = &ScalarFunction::UnaryFunction<int16_t, int16_t, OP>;
break;
case SQLTypeId::INTEGER:
return ScalarFunction::UnaryFunction<int32_t, int32_t, OP>;
function = &ScalarFunction::UnaryFunction<int32_t, int32_t, OP>;
break;
case SQLTypeId::BIGINT:
return ScalarFunction::UnaryFunction<int64_t, int64_t, OP>;
function = &ScalarFunction::UnaryFunction<int64_t, int64_t, OP>;
break;
case SQLTypeId::FLOAT:
return ScalarFunction::UnaryFunction<float, float, OP>;
function = &ScalarFunction::UnaryFunction<float, float, OP>;
break;
case SQLTypeId::DOUBLE:
return ScalarFunction::UnaryFunction<double, double, OP>;
function = &ScalarFunction::UnaryFunction<double, double, OP>;
break;
case SQLTypeId::DECIMAL:
return ScalarFunction::UnaryFunction<double, double, OP>;
function = &ScalarFunction::UnaryFunction<double, double, OP>;
break;
default:
throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction");
}
return function;
}

template <class TR, class OP> static scalar_function_t GetScalarUnaryFunctionFixedReturn(SQLType type) {
scalar_function_t function;
switch (type.id) {
case SQLTypeId::TINYINT:
return ScalarFunction::UnaryFunction<int8_t, TR, OP>;
function = &ScalarFunction::UnaryFunction<int8_t, TR, OP>;
break;
case SQLTypeId::SMALLINT:
return ScalarFunction::UnaryFunction<int16_t, TR, OP>;
function = &ScalarFunction::UnaryFunction<int16_t, TR, OP>;
break;
case SQLTypeId::INTEGER:
return ScalarFunction::UnaryFunction<int32_t, TR, OP>;
function = &ScalarFunction::UnaryFunction<int32_t, TR, OP>;
break;
case SQLTypeId::BIGINT:
return ScalarFunction::UnaryFunction<int64_t, TR, OP>;
function = &ScalarFunction::UnaryFunction<int64_t, TR, OP>;
break;
case SQLTypeId::FLOAT:
return ScalarFunction::UnaryFunction<float, TR, OP>;
function = &ScalarFunction::UnaryFunction<float, TR, OP>;
break;
case SQLTypeId::DOUBLE:
return ScalarFunction::UnaryFunction<double, TR, OP>;
function = &ScalarFunction::UnaryFunction<double, TR, OP>;
break;
case SQLTypeId::DECIMAL:
return ScalarFunction::UnaryFunction<double, TR, OP>;
function = &ScalarFunction::UnaryFunction<double, TR, OP>;
break;
default:
throw NotImplementedException("Unimplemented type for GetScalarUnaryFunctionFixedReturn");
}
return function;
}

template <class OP> static scalar_function_t GetScalarIntegerBinaryFunction(SQLType type) {
scalar_function_t function;
switch (type.id) {
case SQLTypeId::TINYINT:
return ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>;
function = &ScalarFunction::BinaryFunction<int8_t, int8_t, int8_t, OP>;
break;
case SQLTypeId::SMALLINT:
return ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>;
function = &ScalarFunction::BinaryFunction<int16_t, int16_t, int16_t, OP>;
break;
case SQLTypeId::INTEGER:
return ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>;
function = &ScalarFunction::BinaryFunction<int32_t, int32_t, int32_t, OP>;
break;
case SQLTypeId::BIGINT:
return ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>;
function = &ScalarFunction::BinaryFunction<int64_t, int64_t, int64_t, OP>;
break;
default:
throw NotImplementedException("Unimplemented type for GetScalarIntegerBinaryFunction");
}
return function;
}
};

Expand Down

0 comments on commit 0c69daf

Please sign in to comment.