Permalink
Browse files

Merge pull request #918 from ksaito7/round_func

Add round function
  • Loading branch information...
hzxa21 committed Dec 7, 2017
2 parents 0afeb62 + 4a1cbf4 commit 6de9e1533cfa32be651733afdff502e7e36c4d11
View
@@ -1016,6 +1016,12 @@ void Catalog::InitializeFunctions() {
function::BuiltInFuncType{OperatorId::Floor,
function::DecimalFunctions::_Floor},
txn);
AddBuiltinFunction(
"round", {type::TypeId::DECIMAL}, type::TypeId::DECIMAL,
internal_lang, "Round",
function::BuiltInFuncType{OperatorId::Round,
function::DecimalFunctions::_Round},
txn);
/**
* date functions
@@ -18,7 +18,9 @@
namespace peloton {
namespace codegen {
DEFINE_METHOD(peloton::function, DecimalFunctions, Floor);
DEFINE_METHOD(peloton::function, DecimalFunctions, Floor);
DEFINE_METHOD(peloton::function, DecimalFunctions, Round);
} // namespace codegen
} // namespace peloton
@@ -14,6 +14,7 @@
#include "codegen/lang/if.h"
#include "codegen/value.h"
#include "codegen/proxy/decimal_functions_proxy.h"
#include "codegen/proxy/values_runtime_proxy.h"
#include "codegen/proxy/decimal_functions_proxy.h"
#include "codegen/type/boolean_type.h"
@@ -170,11 +171,10 @@ struct Floor : public TypeSystem::UnaryOperator {
}
Value DoWork(CodeGen &codegen, const Value &val) const override {
llvm::Value *raw_ret = codegen.Call(DecimalFunctionsProxy::Floor,
{val.GetValue()});
llvm::Value *raw_ret =
codegen.Call(DecimalFunctionsProxy::Floor, {val.GetValue()});
return Value{Integer::Instance(), raw_ret};
}
};
// Addition
@@ -354,6 +354,23 @@ struct Modulo : public TypeSystem::BinaryOperator {
}
};
// Round
struct Round : public TypeSystem::UnaryOperator {
bool SupportsType(const Type &type) const override {
return type.GetSqlType() == Decimal::Instance();
}
Type ResultType(UNUSED_ATTRIBUTE const Type &val_type) const override {
return Decimal::Instance();
}
Value DoWork(CodeGen &codegen, const Value &val) const override {
llvm::Value *raw_ret =
codegen.Call(DecimalFunctionsProxy::Round, {val.GetValue()});
return Value{Decimal::Instance(), raw_ret};
}
};
//===----------------------------------------------------------------------===//
// TYPE SYSTEM CONSTRUCTION
//===----------------------------------------------------------------------===//
@@ -386,8 +403,11 @@ static std::vector<TypeSystem::ComparisonInfo> kComparisonTable = {
// Unary operators
static Negate kNegOp;
static Floor kFloorOp;
static Round kRound;
static std::vector<TypeSystem::UnaryOpInfo> kUnaryOperatorTable = {
{OperatorId::Negation, kNegOp}, {OperatorId::Floor, kFloorOp}};
{OperatorId::Negation, kNegOp},
{OperatorId::Floor, kFloorOp},
{OperatorId::Round, kRound}};
// Binary operations
static Add kAddOp;
@@ -415,8 +435,8 @@ static std::vector<TypeSystem::NaryOpInfo> kNaryOperatorTable = {};
Decimal::Decimal()
: SqlType(peloton::type::TypeId::DECIMAL),
type_system_(kImplicitCastingTable, kExplicitCastingTable,
kComparisonTable, kUnaryOperatorTable,
kBinaryOperatorTable, kNaryOperatorTable) {}
kComparisonTable, kUnaryOperatorTable, kBinaryOperatorTable,
kNaryOperatorTable) {}
Value Decimal::GetMinValue(CodeGen &codegen) const {
auto *raw_val = codegen.ConstDouble(peloton::type::PELOTON_DECIMAL_MIN);
@@ -447,4 +467,4 @@ llvm::Function *Decimal::GetOutputFunction(
} // namespace type
} // namespace codegen
} // namespace peloton
} // namespace peloton
@@ -32,7 +32,7 @@ type::Value DecimalFunctions::_Floor(const std::vector<type::Value> &args) {
return type::ValueFactory::GetNullValueByType(type::TypeId::DECIMAL);
}
double res;
switch(args[0].GetElementType()) {
switch (args[0].GetElementType()) {
case type::TypeId::DECIMAL:
res = Floor(args[0].GetAs<double>());
break;
@@ -54,8 +54,16 @@ type::Value DecimalFunctions::_Floor(const std::vector<type::Value> &args) {
return type::ValueFactory::GetDecimalValue(res);
}
double DecimalFunctions::Floor(const double val) {
return floor(val);
double DecimalFunctions::Floor(const double val) { return floor(val); }
// Round to nearest integer
double DecimalFunctions::Round(double arg) { return round(arg); }
type::Value DecimalFunctions::_Round(const std::vector<type::Value> &args) {
PL_ASSERT(args.size() == 1);
if (args[0].IsNull()) {
return type::ValueFactory::GetNullValueByType(type::TypeId::DECIMAL);
}
return type::ValueFactory::GetDecimalValue(Round(args[0].GetAs<double>()));
}
} // namespace function
@@ -19,7 +19,9 @@ namespace codegen {
PROXY(DecimalFunctions) {
// Proxy everything in function::DecimalFunctions
DECLARE_METHOD(Floor);
DECLARE_METHOD(Round);
};
} // namespace codegen
@@ -21,9 +21,16 @@ namespace function {
class DecimalFunctions {
public:
// Sqrt
static type::Value Sqrt(const std::vector<type::Value>& args);
static type::Value _Floor(const std::vector<type::Value>& args);
// Floor
static double Floor(const double val);
static type::Value _Floor(const std::vector<type::Value>& args);
// Round
static double Round(double arg);
static type::Value _Round(const std::vector<type::Value>& args);
};
} // namespace function
View
@@ -1040,6 +1040,7 @@ enum class OperatorId : uint32_t {
RTrim,
BTrim,
Sqrt,
Round,
Extract,
Floor,
DateTrunc,
@@ -51,7 +51,7 @@ TEST_F(DecimalFunctionsTests, FloorTest) {
// Testing Floor with DecimalTypes
std::vector<double> inputs = {9.5, 3.3, -4.4, 0.0};
std::vector<type::Value> args;
for(double in: inputs) {
for (double in : inputs) {
args = {type::ValueFactory::GetDecimalValue(in)};
auto result = function::DecimalFunctions::_Floor(args);
EXPECT_FALSE(result.IsNull());
@@ -89,5 +89,21 @@ TEST_F(DecimalFunctionsTests, FloorTest) {
EXPECT_TRUE(result.IsNull());
}
TEST_F(DecimalFunctionsTests, RoundTest) {
std::vector<double> column_vals = {9.5, 3.3, -4.4, -5.5, 0.0};
std::vector<type::Value> args;
for (double val : column_vals) {
args = {type::ValueFactory::GetDecimalValue(val)};
auto result = function::DecimalFunctions::_Round(args);
EXPECT_FALSE(result.IsNull());
EXPECT_EQ(round(val), result.GetAs<double>());
}
// NULL CHECK
args = {type::ValueFactory::GetNullValueByType(type::TypeId::DECIMAL)};
auto result = function::DecimalFunctions::_Round(args);
EXPECT_TRUE(result.IsNull());
}
} // namespace test
} // namespace peloton

0 comments on commit 6de9e15

Please sign in to comment.