diff --git a/src/codegen/codegen.cpp b/src/codegen/codegen.cpp index b4bf416ddd..a91a208f71 100644 --- a/src/codegen/codegen.cpp +++ b/src/codegen/codegen.cpp @@ -106,6 +106,21 @@ llvm::Value *CodeGen::CallPrintf(const std::string &format, return CallFunc(printf_fn, printf_args); } +llvm::Value *CodeGen::CallStrlen(llvm::Value *str) { + auto *strlen_fn = LookupBuiltin("strlen"); + if (strlen_fn == nullptr) { + strlen_fn = RegisterBuiltin( + "strlen", llvm::TypeBuilder::get(GetContext()), + reinterpret_cast(strlen)); + } + + // Collect all the arguments into a vector + std::vector strlen_args{str}; + + // Call the function + return CallFunc(strlen_fn, strlen_args); +} + llvm::Value *CodeGen::CallAddWithOverflow(llvm::Value *left, llvm::Value *right, llvm::Value *&overflow_bit) { PL_ASSERT(left->getType() == right->getType()); diff --git a/src/codegen/expression/function_translator.cpp b/src/codegen/expression/function_translator.cpp index 6a977cefaa..0e2833a3c8 100644 --- a/src/codegen/expression/function_translator.cpp +++ b/src/codegen/expression/function_translator.cpp @@ -1,6 +1,7 @@ #include "codegen/expression/function_translator.h" - +#include "codegen/type/sql_type.h" #include "expression/function_expression.h" +#include "codegen/proxy/function_wrapper_proxy.h" namespace peloton { namespace codegen { @@ -13,7 +14,76 @@ FunctionTranslator::FunctionTranslator( codegen::Value FunctionTranslator::DeriveValue( CodeGen &codegen, RowBatch::Row &row) const { const auto &func_expr = GetExpressionAs(); - func_expr-> + + // get the number of arguments + size_t child_num = func_expr.GetChildrenSize(); + std::vector args; + + // store function pointer + args.push_back(codegen.Const64((int64_t)func_expr.func_ptr_)); + // store argument number + args.push_back(codegen.Const32((int32_t) child_num)); + + // store arguments + for (size_t i = 0; i < child_num; ++i) { + args.push_back(codegen.Const32( + static_cast(func_expr.func_arg_types_[i]))); + args.push_back(row.DeriveValue(codegen, *func_expr.GetChild(i)) + .GetValue()); + } + + return CallWrapperFunction(func_expr.GetValueType(), args, codegen); +} + +codegen::Value FunctionTranslator::CallWrapperFunction( + peloton::type::TypeId ret_type, + std::vector &args, + CodeGen &codegen) const { + llvm::Function *wrapper = nullptr; + std::vector arg_types{codegen.Int32Type()}; + + // get the wrapper function with certain return type + switch (ret_type) { + case peloton::type::TypeId::TINYINT: + wrapper = FunctionWrapperProxy::TinyIntWrapper.GetFunction(codegen); + break; + case peloton::type::TypeId::SMALLINT: + wrapper = FunctionWrapperProxy::SmallIntWrapper.GetFunction(codegen); + break; + case peloton::type::TypeId::INTEGER: + wrapper = FunctionWrapperProxy::IntegerWrapper.GetFunction(codegen); + break; + case peloton::type::TypeId::BIGINT: + wrapper = FunctionWrapperProxy::BigIntWrapper.GetFunction(codegen); + break; + case peloton::type::TypeId::DECIMAL: + wrapper = FunctionWrapperProxy::DecimalWrapper.GetFunction(codegen); + break; + case peloton::type::TypeId::DATE: + wrapper = FunctionWrapperProxy::DateWrapper.GetFunction(codegen); + break; + case peloton::type::TypeId::TIMESTAMP: + wrapper = FunctionWrapperProxy::TimestampWrapper.GetFunction(codegen); + break; + case peloton::type::TypeId::VARCHAR: + wrapper = FunctionWrapperProxy::VarcharWrapper.GetFunction(codegen); + break; + default: + break; + } + if (wrapper != nullptr) { + // call the function + llvm::Value *ret_val = codegen.CallFunc(wrapper, args); + // call strlen to get the length of a varchar + llvm::Value *ret_len = nullptr; + if (ret_type == peloton::type::TypeId::VARCHAR) { + ret_len = codegen.CallStrlen(ret_val); + } + return codegen::Value(type::SqlType::LookupType(ret_type), ret_val, ret_len); + } + else { + return codegen::Value(type::SqlType::LookupType(ret_type)); + } } } // namespace codegen diff --git a/src/codegen/function_wrapper.cpp b/src/codegen/function_wrapper.cpp new file mode 100644 index 0000000000..0f65450ea0 --- /dev/null +++ b/src/codegen/function_wrapper.cpp @@ -0,0 +1,132 @@ +#include +#include "codegen/function_wrapper.h" +#include "type/value_factory.h" +#include "type/ephemeral_pool.h" + +namespace peloton { +namespace codegen { + +inline std::vector GetArguments( + int n_args, va_list &ap) { + std::vector args; + for (int i = 0; i < n_args; ++i) { + peloton::type::TypeId type_id = peloton::type::TypeId(va_arg(ap, int32_t)); + switch (type_id) { + case peloton::type::TypeId::TINYINT: + args.push_back(peloton::type::ValueFactory::GetTinyIntValue( + va_arg(ap, int32_t) + )); + break; + case peloton::type::TypeId::SMALLINT: + args.push_back(peloton::type::ValueFactory::GetSmallIntValue( + va_arg(ap, int32_t) + )); + break; + case peloton::type::TypeId::INTEGER: + args.push_back(peloton::type::ValueFactory::GetIntegerValue( + va_arg(ap, int32_t) + )); + break; + case peloton::type::TypeId::BIGINT: + args.push_back(peloton::type::ValueFactory::GetBigIntValue( + va_arg(ap, int64_t) + )); + break; + case peloton::type::TypeId::DECIMAL: + args.push_back(peloton::type::ValueFactory::GetDecimalValue( + va_arg(ap, double) + )); + break; + case peloton::type::TypeId::DATE: + args.push_back(peloton::type::ValueFactory::GetDateValue( + va_arg(ap, int32_t) + )); + break; + case peloton::type::TypeId::TIMESTAMP: + args.push_back(peloton::type::ValueFactory::GetTimestampValue( + va_arg(ap, int64_t) + )); + break; + case peloton::type::TypeId::VARCHAR: + args.push_back(peloton::type::ValueFactory::GetVarcharValue( + va_arg(ap, const char*) + )); + break; + default: + args.push_back(peloton::type::ValueFactory::GetNullValueByType(type_id)); + break; + } + } + return args; +} + +int8_t FunctionWrapper::TinyIntWrapper(int64_t func, int n_args, ...) { + va_list ap; + va_start(ap, n_args); + std::vector args = GetArguments(n_args, ap); + peloton::type::Value ret = ((BuiltInFuncType)func)(args); + return ret.GetAs(); +} + +int16_t FunctionWrapper::SmallIntWrapper(int64_t func, int n_args, ...) { + va_list ap; + va_start(ap, n_args); + std::vector args = GetArguments(n_args, ap); + peloton::type::Value ret = ((BuiltInFuncType)func)(args); + return ret.GetAs(); +} + +int32_t FunctionWrapper::IntegerWrapper(int64_t func, int n_args, ...) { + va_list ap; + va_start(ap, n_args); + std::vector args = GetArguments(n_args, ap); + peloton::type::Value ret = ((BuiltInFuncType)func)(args); + return ret.GetAs(); +} + +int64_t FunctionWrapper::BigIntWrapper(int64_t func, int n_args, ...) { + va_list ap; + va_start(ap, n_args); + std::vector args = GetArguments(n_args, ap); + peloton::type::Value ret = ((BuiltInFuncType)func)(args); + return ret.GetAs(); +} + +double FunctionWrapper::DecimalWrapper(int64_t func, int n_args, ...) { + va_list ap; + va_start(ap, n_args); + std::vector args = GetArguments(n_args, ap); + peloton::type::Value ret = ((BuiltInFuncType)func)(args); + return ret.GetAs(); +} + +int32_t FunctionWrapper::DateWrapper(int64_t func, int n_args, ...) { + va_list ap; + va_start(ap, n_args); + std::vector args = GetArguments(n_args, ap); + peloton::type::Value ret = ((BuiltInFuncType)func)(args); + return ret.GetAs(); +} + +int64_t FunctionWrapper::TimestampWrapper(int64_t func, int n_args, ...) { + va_list ap; + va_start(ap, n_args); + std::vector args = GetArguments(n_args, ap); + peloton::type::Value ret = ((BuiltInFuncType)func)(args); + return ret.GetAs(); +} + +const char* FunctionWrapper::VarcharWrapper(int64_t func, int n_args, ...) { + // TODO: put this pool to a proper place + static type::EphemeralPool pool; + va_list ap; + va_start(ap, n_args); + std::vector args = GetArguments(n_args, ap); + peloton::type::Value ret = ((BuiltInFuncType)func)(args); + char* str = static_cast(pool.Allocate(ret.GetLength())); + strcpy(str, ret.GetData()); + return str; +} + +} +} diff --git a/src/codegen/proxy/function_wrapper_proxy.cpp b/src/codegen/proxy/function_wrapper_proxy.cpp new file mode 100644 index 0000000000..49b9f248d2 --- /dev/null +++ b/src/codegen/proxy/function_wrapper_proxy.cpp @@ -0,0 +1,16 @@ +#include "codegen/proxy/function_wrapper_proxy.h" + +namespace peloton { +namespace codegen { + +DEFINE_METHOD(peloton::codegen, FunctionWrapper, TinyIntWrapper); +DEFINE_METHOD(peloton::codegen, FunctionWrapper, SmallIntWrapper); +DEFINE_METHOD(peloton::codegen, FunctionWrapper, IntegerWrapper); +DEFINE_METHOD(peloton::codegen, FunctionWrapper, BigIntWrapper); +DEFINE_METHOD(peloton::codegen, FunctionWrapper, DecimalWrapper); +DEFINE_METHOD(peloton::codegen, FunctionWrapper, DateWrapper); +DEFINE_METHOD(peloton::codegen, FunctionWrapper, TimestampWrapper); +DEFINE_METHOD(peloton::codegen, FunctionWrapper, VarcharWrapper); + +} // namespace codegen +} // namespace peloton diff --git a/src/codegen/translator_factory.cpp b/src/codegen/translator_factory.cpp index 59fd2c99d0..63e521f3e6 100644 --- a/src/codegen/translator_factory.cpp +++ b/src/codegen/translator_factory.cpp @@ -28,13 +28,12 @@ #include "codegen/expression/tuple_value_translator.h" #include "codegen/expression/function_translator.h" #include "expression/case_expression.h" -#include "expression/comparison_expression.h" #include "expression/conjunction_expression.h" #include "expression/constant_value_expression.h" #include "expression/operator_expression.h" #include "expression/tuple_value_expression.h" #include "expression/aggregate_expression.h" -#include "planner/aggregate_plan.h" +#include "expression/function_expression.h" #include "planner/delete_plan.h" #include "planner/hash_join_plan.h" #include "planner/order_by_plan.h" diff --git a/src/include/codegen/codegen.h b/src/include/codegen/codegen.h index a10659965d..d0dba99422 100644 --- a/src/include/codegen/codegen.h +++ b/src/include/codegen/codegen.h @@ -98,6 +98,8 @@ class CodeGen { llvm::Value *CallPrintf(const std::string &format, const std::vector &args); + llvm::Value *CallStrlen(llvm::Value *str); + //===--------------------------------------------------------------------===// // Arithmetic with overflow logic - These methods perform the desired math op, // on the provided left and right argument and return the result of the op diff --git a/src/include/codegen/expression/function_translator.h b/src/include/codegen/expression/function_translator.h index 7ce1460ae8..9a2ea55258 100644 --- a/src/include/codegen/expression/function_translator.h +++ b/src/include/codegen/expression/function_translator.h @@ -1,6 +1,8 @@ #pragma once #include "codegen/expression/expression_translator.h" +#include "codegen/function_wrapper.h" +#include namespace peloton { @@ -22,6 +24,11 @@ class FunctionTranslator : public ExpressionTranslator { // Return the result of the function call codegen::Value DeriveValue(CodeGen &codegen, RowBatch::Row &row) const override; + private: + codegen::Value CallWrapperFunction( + peloton::type::TypeId ret_type, + std::vector &args, + CodeGen &codegen) const; }; } // namespace codegen diff --git a/src/include/codegen/function_wrapper.h b/src/include/codegen/function_wrapper.h new file mode 100644 index 0000000000..7e247b3538 --- /dev/null +++ b/src/include/codegen/function_wrapper.h @@ -0,0 +1,23 @@ +#pragma once + +#include "type/types.h" +#include "type/value.h" + +namespace peloton { +namespace codegen { + +class FunctionWrapper { + typedef peloton::type::Value (*BuiltInFuncType)(const std::vector &); + public: + static int8_t TinyIntWrapper(int64_t func, int n_args, ...); + static int16_t SmallIntWrapper(int64_t func, int n_args, ...); + static int32_t IntegerWrapper(int64_t func, int n_args, ...); + static int64_t BigIntWrapper(int64_t func, int n_args, ...); + static double DecimalWrapper(int64_t func, int n_args, ...); + static int32_t DateWrapper(int64_t func, int n_args, ...); + static int64_t TimestampWrapper(int64_t func, int n_args, ...); + static const char *VarcharWrapper(int64_t func, int n_args, ...); +}; + +} +} diff --git a/src/include/codegen/proxy/function_wrapper_proxy.h b/src/include/codegen/proxy/function_wrapper_proxy.h new file mode 100644 index 0000000000..2e9883f1d9 --- /dev/null +++ b/src/include/codegen/proxy/function_wrapper_proxy.h @@ -0,0 +1,22 @@ +#pragma once + +#include "codegen/function_wrapper.h" +#include "codegen/proxy/proxy.h" +#include "codegen/proxy/type_builder.h" + +namespace peloton { +namespace codegen { + +PROXY(FunctionWrapper) { + DECLARE_METHOD(TinyIntWrapper); + DECLARE_METHOD(SmallIntWrapper); + DECLARE_METHOD(IntegerWrapper); + DECLARE_METHOD(BigIntWrapper); + DECLARE_METHOD(DecimalWrapper); + DECLARE_METHOD(DateWrapper); + DECLARE_METHOD(TimestampWrapper); + DECLARE_METHOD(VarcharWrapper); +}; + +} // namespace codegen +} // namespace peloton diff --git a/src/include/codegen/proxy/proxy.h b/src/include/codegen/proxy/proxy.h index cff95ce218..021bf33c20 100644 --- a/src/include/codegen/proxy/proxy.h +++ b/src/include/codegen/proxy/proxy.h @@ -182,6 +182,11 @@ struct MemFn { static void *Get() { return reinterpret_cast(F); } }; +template +struct MemFn { + static void *Get() { return reinterpret_cast(F); } +}; + } // namespace detail } // namespace proxy diff --git a/src/include/codegen/proxy/type_builder.h b/src/include/codegen/proxy/type_builder.h index e71ee7527d..77ed416b50 100644 --- a/src/include/codegen/proxy/type_builder.h +++ b/src/include/codegen/proxy/type_builder.h @@ -103,6 +103,28 @@ struct TypeBuilder { } }; +/// Regular C-style functions with variable arguments +template +struct TypeBuilder { + static llvm::Type *GetType(CodeGen &codegen) ALWAYS_INLINE { + llvm::Type *ret_type = TypeBuilder::GetType(codegen); + std::vector arg_types = { + TypeBuilder::GetType(codegen)...}; + return llvm::FunctionType::get(ret_type, arg_types, true); + } +}; + +/// C-style function pointer with variable arguments +template +struct TypeBuilder { + static llvm::Type *GetType(CodeGen &codegen) ALWAYS_INLINE { + llvm::Type *ret_type = TypeBuilder::GetType(codegen); + std::vector arg_types = { + TypeBuilder::GetType(codegen)...}; + return llvm::FunctionType::get(ret_type, arg_types, true)->getPointerTo(); + } +}; + /// Member functions template struct TypeBuilder { diff --git a/src/include/expression/function_expression.h b/src/include/expression/function_expression.h index 6e1cf1eb1f..66a21dab52 100644 --- a/src/include/expression/function_expression.h +++ b/src/include/expression/function_expression.h @@ -40,11 +40,12 @@ class FunctionExpression : public AbstractExpression { const std::vector& arg_types, const std::vector& children) : AbstractExpression(ExpressionType::FUNCTION, return_type), - func_ptr_(func_ptr) { + func_ptr_(func_ptr), + func_arg_types_(arg_types) { for (auto& child : children) { children_.push_back(std::unique_ptr(child)); } - CheckChildrenTypes(arg_types, children_, func_name_); + CheckChildrenTypes(children_, func_name_); } void SetFunctionExpressionParameters( @@ -53,7 +54,8 @@ class FunctionExpression : public AbstractExpression { const std::vector& arg_types) { func_ptr_ = func_ptr; return_value_type_ = val_type; - CheckChildrenTypes(arg_types, children_, func_name_); + func_arg_types_ = arg_types; + CheckChildrenTypes(children_, func_name_); } type::Value Evaluate( @@ -81,37 +83,40 @@ class FunctionExpression : public AbstractExpression { std::string func_name_; + type::Value (*func_ptr_)(const std::vector&) = nullptr; + + std::vector func_arg_types_; + virtual void Accept(SqlNodeVisitor* v) override { v->Visit(this); } protected: FunctionExpression(const FunctionExpression& other) : AbstractExpression(other), func_name_(other.func_name_), - func_ptr_(other.func_ptr_) {} + func_ptr_(other.func_ptr_), + func_arg_types_(other.func_arg_types_) {} private: - type::Value (*func_ptr_)(const std::vector&) = nullptr; // throws an exception if children return unexpected types - static void CheckChildrenTypes( - const std::vector& arg_types, + void CheckChildrenTypes( const std::vector>& children, const std::string& func_name) { - if (arg_types.size() != children.size()) { + if (func_arg_types_.size() != children.size()) { throw Exception(EXCEPTION_TYPE_EXPRESSION, "Unexpected number of arguments to function: " + func_name + ". Expected: " + - std::to_string(arg_types.size()) + " Actual: " + + std::to_string(func_arg_types_.size()) + " Actual: " + std::to_string(children.size())); } // check that the types are correct - for (size_t i = 0; i < arg_types.size(); i++) { - if (children[i]->GetValueType() != arg_types[i]) { + for (size_t i = 0; i < func_arg_types_.size(); i++) { + if (children[i]->GetValueType() != func_arg_types_[i]) { throw Exception(EXCEPTION_TYPE_EXPRESSION, "Incorrect argument type to fucntion: " + func_name + ". Argument " + std::to_string(i) + " expected type " + - type::Type::GetInstance(arg_types[i])->ToString() + + type::Type::GetInstance(func_arg_types_[i])->ToString() + " but found " + type::Type::GetInstance(children[i]->GetValueType()) ->ToString() +