Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overflow in sparkbar function #48121

Merged
merged 5 commits into from
Apr 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
65 changes: 53 additions & 12 deletions src/AggregateFunctions/AggregateFunctionSparkbar.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <base/arithmeticOverflow.h>

#include <array>
#include <string_view>
#include <DataTypes/DataTypeString.h>
Expand Down Expand Up @@ -43,7 +45,19 @@ struct AggregateFunctionSparkbarData

auto [it, inserted] = points.insert({x, y});
if (!inserted)
it->getMapped() += y;
{
if constexpr (std::is_floating_point_v<Y>)
{
it->getMapped() += y;
return it->getMapped();
}
else
{
Y res;
bool has_overfllow = common::addOverflow(it->getMapped(), y, res);
it->getMapped() = has_overfllow ? std::numeric_limits<Y>::max() : res;
}
}
return it->getMapped();
}

Expand Down Expand Up @@ -117,6 +131,7 @@ class AggregateFunctionSparkbar final
{

private:
static constexpr size_t BAR_LEVELS = 8;
const size_t width = 0;

/// Range for x specified in parameters.
Expand All @@ -126,8 +141,8 @@ class AggregateFunctionSparkbar final

size_t updateFrame(ColumnString::Chars & frame, Y value) const
{
static constexpr std::array<std::string_view, 9> bars{" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"};
const auto & bar = (isNaN(value) || value < 1 || 8 < value) ? bars[0] : bars[static_cast<UInt8>(value)];
static constexpr std::array<std::string_view, BAR_LEVELS + 1> bars{" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"};
const auto & bar = (isNaN(value) || value < 1 || static_cast<Y>(BAR_LEVELS) < value) ? bars[0] : bars[static_cast<UInt8>(value)];
frame.insert(bar.begin(), bar.end());
return bar.size();
}
Expand Down Expand Up @@ -161,7 +176,7 @@ class AggregateFunctionSparkbar final
}

PaddedPODArray<Y> histogram(width, 0);
PaddedPODArray<UInt64> fhistogram(width, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

PaddedPODArray<UInt64> count_histogram(width, 0); /// The number of points in each bucket

for (const auto & point : data.points)
{
Expand All @@ -176,22 +191,30 @@ class AggregateFunctionSparkbar final
Float64 w = histogram.size();
size_t index = std::min<size_t>(static_cast<size_t>(w / delta * value), histogram.size() - 1);

if (std::numeric_limits<Y>::max() - histogram[index] > point.getMapped())
{
histogram[index] += point.getMapped();
fhistogram[index] += 1;
}
Y res;
bool has_overfllow = false;
if constexpr (std::is_floating_point_v<Y>)
res = histogram[index] + point.getMapped();
else
has_overfllow = common::addOverflow(histogram[index], point.getMapped(), res);

if (unlikely(has_overfllow))
{
/// In case of overflow, just saturate
/// Do not count new values, because we do not know how many of them were added
histogram[index] = std::numeric_limits<Y>::max();
}
else
{
histogram[index] = res;
count_histogram[index] += 1;
}
}

for (size_t i = 0; i < histogram.size(); ++i)
{
if (fhistogram[i] > 0)
histogram[i] /= fhistogram[i];
if (count_histogram[i] > 0)
histogram[i] /= count_histogram[i];
}

Y y_max = 0;
Expand All @@ -209,12 +232,30 @@ class AggregateFunctionSparkbar final
return;
}

/// Scale the histogram to the range [0, BAR_LEVELS]
for (auto & y : histogram)
{
if (isNaN(y) || y <= 0)
{
y = 0;
continue;
}

constexpr auto levels_num = static_cast<Y>(BAR_LEVELS - 1);
if constexpr (std::is_floating_point_v<Y>)
{
y = y / (y_max / levels_num) + 1;
}
else
y = y * 7 / y_max + 1;
{
Y scaled;
bool has_overfllow = common::mulOverflow<Y>(y, levels_num, scaled);

if (has_overfllow)
y = y / (y_max / levels_num) + 1;
else
y = scaled / y_max + 1;
}
}

size_t sz = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ SELECT sparkbar(5,toDate('2020-01-01'),toDate('2020-01-10'))(event_date,cnt) FRO
▃▄▆█
SELECT sparkbar(9,toDate('2020-01-01'),toDate('2020-01-10'))(event_date,cnt) FROM spark_bar_test;
▂▅▂▃▇▆█
WITH number DIV 50 AS k, number % 50 AS value SELECT k, sparkbar(50, 0, 99)(number, value) FROM numbers(100) GROUP BY k ORDER BY k;
WITH number DIV 50 AS k, toUInt32(number % 50) AS value SELECT k, sparkbar(50, 0, 99)(number, value) FROM numbers(100) GROUP BY k ORDER BY k;
0 ▁▁▁▁▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇█
1 ▁▁▁▁▂▂▂▃▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇█
SELECT sparkbar(128, 0, 9223372036854775806)(toUInt64(9223372036854775806), number % 65535) FROM numbers(100);
SELECT sparkbar(128)(toUInt64(9223372036854775806), number % 65535) FROM numbers(100);
SELECT sparkbar(9)(x, y) FROM (SELECT * FROM Values('x UInt64, y UInt8', (18446744073709551615,255), (0,0), (0,0), (4036797895307271799,254)));
SELECT sparkbar(8, 0, 7)((number + 1) % 8, 1), sparkbar(8, 0, 7)((number + 2) % 8, 1), sparkbar(8, 0, 7)((number + 3) % 8, 1) FROM numbers(7);
███████ █ ██████ ██ █████
SELECT sparkbar(2)(number, -number) FROM numbers(10);
Expand Down
8 changes: 7 additions & 1 deletion tests/queries/0_stateless/02016_aggregation_spark_bar.sql
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ SELECT sparkbar(4,toDate('2020-01-01'),toDate('2020-01-08'))(event_date,cnt) FRO
SELECT sparkbar(5,toDate('2020-01-01'),toDate('2020-01-10'))(event_date,cnt) FROM spark_bar_test;
SELECT sparkbar(9,toDate('2020-01-01'),toDate('2020-01-10'))(event_date,cnt) FROM spark_bar_test;

WITH number DIV 50 AS k, number % 50 AS value SELECT k, sparkbar(50, 0, 99)(number, value) FROM numbers(100) GROUP BY k ORDER BY k;
WITH number DIV 50 AS k, toUInt32(number % 50) AS value SELECT k, sparkbar(50, 0, 99)(number, value) FROM numbers(100) GROUP BY k ORDER BY k;

SELECT sparkbar(128, 0, 9223372036854775806)(toUInt64(9223372036854775806), number % 65535) FROM numbers(100);
SELECT sparkbar(128)(toUInt64(9223372036854775806), number % 65535) FROM numbers(100);
Expand All @@ -59,4 +59,10 @@ SELECT sparkbar(2)(toInt32(number), number) FROM numbers(10); -- { serverError
SELECT sparkbar(2, 0)(number, number) FROM numbers(10); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }
SELECT sparkbar(2, 0, 5, 8)(number, number) FROM numbers(10); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATCH }

-- it causes overflow, just check that it doesn't crash under UBSan, do not check the result it's not really reasonable
SELECT sparkbar(10)(number, toInt64(number)) FROM numbers(toUInt64(9223372036854775807), 20) FORMAT Null;
SELECT sparkbar(10)(number, -number) FROM numbers(toUInt64(9223372036854775807), 7) FORMAT Null;
SELECT sparkbar(10)(number, number) FROM numbers(18446744073709551615, 7) FORMAT Null;
SELECT sparkbar(16)(number, number) FROM numbers(18446744073709551600, 16) FORMAT Null;

DROP TABLE IF EXISTS spark_bar_test;