Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function bankerRound #8112

Merged
merged 7 commits into from
Dec 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dbms/src/Functions/FunctionsRound.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ namespace DB
void registerFunctionsRound(FunctionFactory & factory)
{
factory.registerFunction<FunctionRound>("round", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionRoundBankers>("roundBankers", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionFloor>("floor", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionCeil>("ceil", FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionTrunc>("trunc", FunctionFactory::CaseInsensitive);
Expand Down
56 changes: 40 additions & 16 deletions dbms/src/Functions/FunctionsRound.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace ErrorCodes

/** Rounding Functions:
* round(x, N) - rounding to nearest (N = 0 by default). Use banker's rounding for floating point numbers.
* roundBankers(x, N) - rounding to nearest (N = 0 by default). Use banker's rounding for all numbers.
* floor(x, N) is the largest number <= x (N = 0 by default).
* ceil(x, N) is the smallest number >= x (N = 0 by default).
* trunc(x, N) - is the largest by absolute value number that is not greater than x by absolute value (N = 0 by default).
Expand Down Expand Up @@ -76,10 +77,16 @@ enum class RoundingMode
#endif
};

enum class TieBreakingMode
{
Auto, // use banker's rounding for floating point numbers, round up otherwise
Bankers, // use banker's rounding
};


/** Rounding functions for integer values.
*/
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, TieBreakingMode tie_breaking_mode>
struct IntegerRoundingComputation
{
static const size_t data_count = 1;
Expand Down Expand Up @@ -114,7 +121,21 @@ struct IntegerRoundingComputation
bool negative = x < 0;
if (negative)
x = -x;
x = (x + scale / 2) / scale * scale;
switch (tie_breaking_mode)
{
case TieBreakingMode::Auto:
x = (x + scale / 2) / scale * scale;
break;
case TieBreakingMode::Bankers:
{
T quotient = (x + scale / 2) / scale;
if (quotient * scale == x + scale / 2)
x = (quotient & ~1) * scale;
else
x = quotient * scale;
break;
}
}
if (negative)
x = -x;
return x;
Expand Down Expand Up @@ -323,11 +344,11 @@ struct FloatRoundingImpl
}
};

template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, TieBreakingMode tie_breaking_mode>
struct IntegerRoundingImpl
{
private:
using Op = IntegerRoundingComputation<T, rounding_mode, scale_mode>;
using Op = IntegerRoundingComputation<T, rounding_mode, scale_mode, tie_breaking_mode>;

public:
template <size_t scale>
Expand Down Expand Up @@ -379,11 +400,12 @@ struct IntegerRoundingImpl
};


template <typename T, RoundingMode rounding_mode>
class DecimalRounding
template <typename T, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode>
class DecimalRoundingImpl
{
private:
using NativeType = typename T::NativeType;
using Op = IntegerRoundingComputation<NativeType, rounding_mode, ScaleMode::Negative>;
using Op = IntegerRoundingComputation<NativeType, rounding_mode, ScaleMode::Negative, tie_breaking_mode>;
using Container = typename ColumnDecimal<T>::Container;

public:
Expand Down Expand Up @@ -413,13 +435,13 @@ class DecimalRounding

/** Select the appropriate processing algorithm depending on the scale.
*/
template <typename T, RoundingMode rounding_mode>
template <typename T, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode>
class Dispatcher
{
template <ScaleMode scale_mode>
using FunctionRoundingImpl = std::conditional_t<std::is_floating_point_v<T>,
FloatRoundingImpl<T, rounding_mode, scale_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode>>;
IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>;

static void apply(Block & block, const ColumnVector<T> * col, Int64 scale_arg, size_t result)
{
Expand Down Expand Up @@ -458,7 +480,7 @@ class Dispatcher
auto & vec_res = col_res->getData();

if (!vec_res.empty())
DecimalRounding<T, rounding_mode>::apply(col->getData(), vec_res, scale_arg);
DecimalRoundingImpl<T, rounding_mode, tie_breaking_mode>::apply(col->getData(), vec_res, scale_arg);

block.getByPosition(result).column = std::move(col_res);
}
Expand All @@ -476,7 +498,7 @@ class Dispatcher
/** A template for functions that round the value of an input parameter of type
* (U)Int8/16/32/64, Float32/64 or Decimal32/64/128, and accept an additional optional parameter (default is 0).
*/
template <typename Name, RoundingMode rounding_mode>
template <typename Name, RoundingMode rounding_mode, TieBreakingMode tie_breaking_mode>
class FunctionRounding : public IFunction
{
public:
Expand Down Expand Up @@ -542,7 +564,7 @@ class FunctionRounding : public IFunction
if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>)
{
using FieldType = typename DataType::FieldType;
Dispatcher<FieldType, rounding_mode>::apply(block, column.column.get(), scale_arg, result);
Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::apply(block, column.column.get(), scale_arg, result);
return true;
}
return false;
Expand Down Expand Up @@ -716,13 +738,15 @@ class FunctionRoundDown : public IFunction


struct NameRound { static constexpr auto name = "round"; };
struct NameRoundBankers { static constexpr auto name = "roundBankers"; };
struct NameCeil { static constexpr auto name = "ceil"; };
struct NameFloor { static constexpr auto name = "floor"; };
struct NameTrunc { static constexpr auto name = "trunc"; };

using FunctionRound = FunctionRounding<NameRound, RoundingMode::Round>;
using FunctionFloor = FunctionRounding<NameFloor, RoundingMode::Floor>;
using FunctionCeil = FunctionRounding<NameCeil, RoundingMode::Ceil>;
using FunctionTrunc = FunctionRounding<NameTrunc, RoundingMode::Trunc>;
using FunctionRound = FunctionRounding<NameRound, RoundingMode::Round, TieBreakingMode::Auto>;
using FunctionRoundBankers = FunctionRounding<NameRoundBankers, RoundingMode::Round, TieBreakingMode::Bankers>;
using FunctionFloor = FunctionRounding<NameFloor, RoundingMode::Floor, TieBreakingMode::Auto>;
using FunctionCeil = FunctionRounding<NameCeil, RoundingMode::Ceil, TieBreakingMode::Auto>;
using FunctionTrunc = FunctionRounding<NameTrunc, RoundingMode::Trunc, TieBreakingMode::Auto>;

}
35 changes: 35 additions & 0 deletions dbms/tests/performance/round_methods.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
<test>
<type>once</type>

<stop_conditions>
<all_of>
<total_time_ms>10000</total_time_ms>
</all_of>
<any_of>
<average_speed_not_changing_for_ms>5000</average_speed_not_changing_for_ms>
<total_time_ms>20000</total_time_ms>
</any_of>
</stop_conditions>

<main_metric>
<avg_rows_per_second/>
</main_metric>

<query>SELECT count() FROM system.numbers WHERE NOT ignore(round(toInt64(number), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(roundBankers(toInt64(number), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(floor(toInt64(number), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(ceil(toInt64(number), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(trunc(toInt64(number), -2))</query>

<query>SELECT count() FROM system.numbers WHERE NOT ignore(round(toFloat64(number), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(roundBankers(toFloat64(number), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(floor(toFloat64(number), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(ceil(toFloat64(number), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(trunc(toFloat64(number), -2))</query>

<query>SELECT count() FROM system.numbers WHERE NOT ignore(round(toDecimal128(number, 0), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(roundBankers(toDecimal128(number, 0), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(floor(toDecimal128(number, 0), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(ceil(toDecimal128(number, 0), -2))</query>
<query>SELECT count() FROM system.numbers WHERE NOT ignore(trunc(toDecimal128(number, 0), -2))</query>
</test>