Permalink
Browse files

rewrite the llvm function expression

  • Loading branch information...
lixupeng authored and apavlo committed Sep 7, 2017
1 parent 2fe07ec commit 51919db1bd91699d7b5f4e1c2f0891492bc9b9df
View
@@ -28,8 +28,6 @@
#include "catalog/trigger_catalog.h"
#include "catalog/proc_catalog.h"
#include "catalog/language_catalog.h"
#include "concurrency/transaction_manager_factory.h"
#include "function/functions.h"
#include "index/index_factory.h"
#include "storage/storage_manager.h"
#include "storage/table_factory.h"
@@ -812,13 +810,15 @@ void Catalog::AddFunction(const std::string &name,
oid_t prolang,
const std::string &func_name,
function::BuiltInFuncType func_ptr,
codegen::function::BuiltInFuncType codegen_func_ptr_,
concurrency::Transaction *txn) {
if (!ProcCatalog::GetInstance()->
InsertProc(name, return_type, argument_types,
prolang, func_name, pool_.get(), txn)) {
throw CatalogException("Failed to add function " + func_name);
}
function::BuiltInFunctions::AddFunction(func_name, func_ptr);
codegen::function::BuiltInFunctions::AddFunction(func_name, codegen_func_ptr_);
}
const FunctionData Catalog::GetFunction(
@@ -836,7 +836,8 @@ const FunctionData Catalog::GetFunction(
result.func_name_ = ProcCatalog::GetInstance()->GetProSrc(name, argument_types, txn);
result.return_type_ = ProcCatalog::GetInstance()->GetProRetType(name, argument_types, txn);
result.func_ptr_ = function::BuiltInFunctions::GetFuncByName(result.func_name_);
if (result.func_ptr_ != nullptr) {
result.codegen_func_ptr_ = codegen::function::BuiltInFunctions::GetFuncByName(result.func_name_);
if (result.func_ptr_ != nullptr || result.codegen_func_ptr_ != nullptr) {
txn_manager.CommitTransaction(txn);
return result;
}
@@ -877,47 +878,64 @@ void Catalog::InitializeFunctions() {
* string functions
*/
AddFunction("ascii", {type::TypeId::VARCHAR}, type::TypeId::INTEGER, prolang,
"Ascii", function::StringFunctions::Ascii, txn);
"Ascii", function::StringFunctions::Ascii,
codegen::function::StringFunctions::Ascii, txn);
AddFunction("chr", {type::TypeId::INTEGER}, type::TypeId::VARCHAR, prolang,
"Chr", function::StringFunctions::Chr, txn);
"Chr", function::StringFunctions::Chr,
codegen::function::StringFunctions::Chr, txn);
AddFunction("concat", {type::TypeId::VARCHAR, type::TypeId::VARCHAR},
type::TypeId::VARCHAR, prolang,
"Concat", function::StringFunctions::Concat, txn);
"Concat", function::StringFunctions::Concat,
codegen::function::StringFunctions::Concat, txn);
AddFunction("substr", {type::TypeId::VARCHAR, type::TypeId::INTEGER,
type::TypeId::INTEGER},
type::TypeId::VARCHAR, prolang,
"Substr", function::StringFunctions::Substr, txn);
"Substr", function::StringFunctions::Substr,
codegen::function::StringFunctions::Substr, txn);
AddFunction("char_length", {type::TypeId::VARCHAR}, type::TypeId::INTEGER,
prolang,
"CharLength", function::StringFunctions::CharLength, txn);
"CharLength", function::StringFunctions::CharLength,
codegen::function::StringFunctions::CharLength, txn);
AddFunction("octet_length", {type::TypeId::VARCHAR}, type::TypeId::INTEGER,
prolang,
"OctetLength", function::StringFunctions::OctetLength, txn);
"OctetLength", function::StringFunctions::OctetLength,
codegen::function::StringFunctions::OctetLength, txn);
AddFunction("repeat", {type::TypeId::VARCHAR, type::TypeId::INTEGER},
type::TypeId::VARCHAR, prolang,
"Repeat", function::StringFunctions::Repeat, txn);
"Repeat", function::StringFunctions::Repeat,
codegen::function::StringFunctions::Repeat, txn);
AddFunction("replace", {type::TypeId::VARCHAR, type::TypeId::VARCHAR,
type::TypeId::VARCHAR},
type::TypeId::VARCHAR, prolang,
"Replace", function::StringFunctions::Replace, txn);
"Replace", function::StringFunctions::Replace,
codegen::function::StringFunctions::Replace, txn);
AddFunction("ltrim", {type::TypeId::VARCHAR, type::TypeId::VARCHAR},
type::TypeId::VARCHAR, prolang,
"LTrim", function::StringFunctions::LTrim, txn);
"LTrim", function::StringFunctions::LTrim,
codegen::function::StringFunctions::LTrim, txn);
AddFunction("rtrim", {type::TypeId::VARCHAR, type::TypeId::VARCHAR},
type::TypeId::VARCHAR, prolang,
"RTrim", function::StringFunctions::RTrim, txn);
"RTrim", function::StringFunctions::RTrim,
codegen::function::StringFunctions::RTrim, txn);
AddFunction("btrim", {type::TypeId::VARCHAR, type::TypeId::VARCHAR},
type::TypeId::VARCHAR, prolang,
"btrim", function::StringFunctions::BTrim, txn);
"btrim", function::StringFunctions::BTrim,
codegen::function::StringFunctions::BTrim, txn);
/**
* decimal functions
*/
AddFunction("sqrt", {type::TypeId::DECIMAL}, type::TypeId::DECIMAL,
prolang, "Sqrt", function::DecimalFunctions::Sqrt, txn);
prolang, "Sqrt", function::DecimalFunctions::Sqrt,
codegen::function::DecimalFunctions::Sqrt, txn);
/**
* date functions
*/
AddFunction("extract", {type::TypeId::INTEGER, type::TypeId::TIMESTAMP},
type::TypeId::DECIMAL, prolang,
"Extract", function::DateFunctions::Extract, txn);
"Extract", function::DateFunctions::Extract,
codegen::function::DateFunctions::Extract, txn);
}
catch (CatalogException e) {
txn_manager.AbortTransaction(txn);
@@ -14,89 +14,30 @@
#include "codegen/expression/function_translator.h"
#include "codegen/type/sql_type.h"
#include "expression/function_expression.h"
#include "codegen/proxy/function_wrapper_proxy.h"
namespace peloton {
namespace codegen {
// Constructor
FunctionTranslator::FunctionTranslator(
const expression::FunctionExpression &exp, CompilationContext &ctx)
: ExpressionTranslator(exp, ctx) {}
: ExpressionTranslator(exp, ctx), context_(ctx) {}
codegen::Value FunctionTranslator::DeriveValue(
CodeGen &codegen, RowBatch::Row &row) const {
const auto &func_expr = GetExpressionAs<expression::FunctionExpression>();
// get the number of arguments
size_t child_num = func_expr.GetChildrenSize();
std::vector<llvm::Value*> args;
// store function pointer
args.push_back(codegen.Const64((int64_t)func_expr.func_ptr_));
// store argument number
args.push_back(codegen.Const32((int32_t) child_num));
std::vector<codegen::Value> args;
// store arguments
for (size_t i = 0; i < child_num; ++i) {
args.push_back(codegen.Const32(
static_cast<int32_t>(func_expr.func_arg_types_[i])));
args.push_back(row.DeriveValue(codegen, *func_expr.GetChild(i))
.GetValue());
args.push_back(row.DeriveValue(codegen, *func_expr.GetChild(i)));
}
return CallWrapperFunction(func_expr.GetValueType(), args, codegen);
}
codegen::Value FunctionTranslator::CallWrapperFunction(
peloton::type::TypeId ret_type,
std::vector<llvm::Value*> &args,
CodeGen &codegen) const {
llvm::Function *wrapper = nullptr;
std::vector<llvm::Type *> arg_types{codegen.Int32Type()};
// get the wrapper function with certain return type
switch (ret_type) {
case peloton::type::TypeId::TINYINT:
wrapper = FunctionWrapperProxy::TinyIntWrapper.GetFunction(codegen);
break;
case peloton::type::TypeId::SMALLINT:
wrapper = FunctionWrapperProxy::SmallIntWrapper.GetFunction(codegen);
break;
case peloton::type::TypeId::INTEGER:
wrapper = FunctionWrapperProxy::IntegerWrapper.GetFunction(codegen);
break;
case peloton::type::TypeId::BIGINT:
wrapper = FunctionWrapperProxy::BigIntWrapper.GetFunction(codegen);
break;
case peloton::type::TypeId::DECIMAL:
wrapper = FunctionWrapperProxy::DecimalWrapper.GetFunction(codegen);
break;
case peloton::type::TypeId::DATE:
wrapper = FunctionWrapperProxy::DateWrapper.GetFunction(codegen);
break;
case peloton::type::TypeId::TIMESTAMP:
wrapper = FunctionWrapperProxy::TimestampWrapper.GetFunction(codegen);
break;
case peloton::type::TypeId::VARCHAR:
wrapper = FunctionWrapperProxy::VarcharWrapper.GetFunction(codegen);
break;
default:
break;
}
if (wrapper != nullptr) {
// call the function
llvm::Value *ret_val = codegen.CallFunc(wrapper, args);
// call strlen to get the length of a varchar
llvm::Value *ret_len = nullptr;
if (ret_type == peloton::type::TypeId::VARCHAR) {
ret_len = codegen.CallStrlen(ret_val);
}
return codegen::Value(type::SqlType::LookupType(ret_type), ret_val, ret_len);
}
else {
return codegen::Value(type::SqlType::LookupType(ret_type));
}
return func_expr.codegen_func_ptr_(codegen, context_, args);
}
} // namespace codegen
@@ -0,0 +1,136 @@
#include "codegen/function/functions.h"
#include "codegen/proxy/builtin_function_proxy.h"
#include "codegen/type/sql_type.h"
#include <date/date.h>
#include <date/iso_week.h>
namespace peloton {
namespace codegen {
namespace function {
codegen::Value DateFunctions::Extract(CodeGen &codegen, CompilationContext &ctx,
const std::vector<codegen::Value> &args) {
llvm::Function *func = DateFunctionsProxy::Extract_.GetFunction(codegen);
llvm::Value *ret_val = BuiltInFunctions::CallFunction(codegen, ctx, func, args);
return codegen::Value(type::SqlType::LookupType(peloton::type::TypeId::DECIMAL),
ret_val, nullptr);
}
double DateFunctions::Extract_(
UNUSED_ATTRIBUTE executor::ExecutorContext *executor_context,
int32_t date_part, uint64_t timestamp) {
if (timestamp == 0) {
return 0.0;
}
uint32_t micro = timestamp % 1000000;
timestamp /= 1000000;
uint32_t hour_min_sec = timestamp % 100000;
uint16_t sec = hour_min_sec % 60;
hour_min_sec /= 60;
uint16_t min = hour_min_sec % 60;
hour_min_sec /= 60;
uint16_t hour = hour_min_sec % 24;
timestamp /= 100000;
uint16_t year = timestamp % 10000;
timestamp /= 10000;
timestamp /= 27; // skip time zone
uint16_t day = timestamp % 32;
timestamp /= 32;
uint16_t month = timestamp;
uint16_t millennium = (year - 1) / 1000 + 1;
uint16_t century = (year - 1) / 100 + 1;
uint16_t decade = year / 10;
uint8_t quarter = (month - 1) / 3 + 1;
double microsecond = sec * 1000000 + micro;
double millisecond = sec * 1000 + micro / 1000.0;
double second = sec + micro / 1000000.0;
date::year_month_day ymd = date::year_month_day{
date::year{year}, date::month{month}, date::day{day}};
iso_week::year_weeknum_weekday yww = iso_week::year_weeknum_weekday{ymd};
date::year_month_day year_begin =
date::year_month_day{date::year{year}, date::month{1}, date::day{1}};
date::days duration = date::sys_days{ymd} - date::sys_days{year_begin};
uint16_t dow = ((unsigned) yww.weekday()) == 7 ? 0 : (unsigned) yww.weekday();
uint16_t doy = duration.count() + 1;
uint16_t week = (unsigned) yww.weeknum();
double result;
switch (static_cast<DatePartType>(date_part)) {
case DatePartType::CENTURY: {
result = century;
break;
}
case DatePartType::DAY: {
result = day;
break;
}
case DatePartType::DECADE: {
result = decade;
break;
}
case DatePartType::DOW: {
result = dow;
break;
}
case DatePartType::DOY: {
result = doy;
break;
}
case DatePartType::HOUR: {
result = hour;
break;
}
case DatePartType::MICROSECOND: {
result = microsecond;
break;
}
case DatePartType::MILLENNIUM: {
result = millennium;
break;
}
case DatePartType::MILLISECOND: {
result = millisecond;
break;
}
case DatePartType::MINUTE: {
result = min;
break;
}
case DatePartType::MONTH: {
result = month;
break;
}
case DatePartType::QUARTER: {
result = quarter;
break;
}
case DatePartType::SECOND: {
result = second;
break;
}
case DatePartType::WEEK: {
result = week;
break;
}
case DatePartType::YEAR: {
result = year;
break;
}
default: {
result = 0.0;
}
};
return result;
}
} // namespace function
} // namespace expression
} // namespace peloton
@@ -0,0 +1,29 @@
#include <cmath>
#include "codegen/type/sql_type.h"
#include "codegen/function/decimal_functions.h"
#include "codegen/proxy/builtin_function_proxy.h"
namespace peloton {
namespace codegen {
namespace function {
codegen::Value DecimalFunctions::Sqrt(CodeGen &codegen, CompilationContext &ctx,
const std::vector<codegen::Value> &args) {
llvm::Function *func = DecimalFunctionsProxy::Sqrt_.GetFunction(codegen);
llvm::Value *ret_val = BuiltInFunctions::CallFunction(codegen, ctx, func, args);
return codegen::Value(type::SqlType::LookupType(peloton::type::TypeId::DECIMAL),
ret_val, nullptr);
}
double DecimalFunctions::Sqrt_(
UNUSED_ATTRIBUTE executor::ExecutorContext *executor_context,
double val) {
if (val < 0) {
return 0.0;
}
return sqrt(val);
}
} // namespace function
} // namespace expression
} // namespace peloton
@@ -0,0 +1,24 @@
#include "codegen/function/functions.h"
namespace peloton {
namespace codegen {
namespace function {
std::unordered_map<std::string, BuiltInFuncType>
BuiltInFunctions::func_map;
void BuiltInFunctions::AddFunction(const std::string &func_name,
BuiltInFuncType func) {
func_map.emplace(func_name, func);
}
BuiltInFuncType BuiltInFunctions::GetFuncByName(const std::string &func_name) {
auto func = func_map.find(func_name);
if (func == func_map.end())
return nullptr;
return func->second;
}
} // namespace function
} // namespace codegen
} // namespace peloton
Oops, something went wrong.

0 comments on commit 51919db

Please sign in to comment.