Skip to content

Commit

Permalink
Fix min/max(x, n) (facebookincubator#8311)
Browse files Browse the repository at this point in the history
Summary:

In Presto, accumulators of min/max(x, n) do not clear the heap when values are extracted from 
accumulator. But in Velox they do. Fix this bug to make Velox behavior align with Presto.

This diff fixes facebookincubator#8138.

Differential Revision: D52638334
  • Loading branch information
kagamiori authored and facebook-github-bot committed Jan 9, 2024
1 parent 4fd8186 commit 087fd23
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
34 changes: 18 additions & 16 deletions velox/functions/prestosql/aggregates/MinMaxAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,17 +545,17 @@ std::pair<vector_size_t*, vector_size_t*> rawOffsetAndSizes(
template <typename T, typename Compare>
struct MinMaxNAccumulator {
int64_t n{0};
std::priority_queue<T, std::vector<T, StlAllocator<T>>, Compare> topValues;
std::vector<T, StlAllocator<T>> heapValues;

explicit MinMaxNAccumulator(HashStringAllocator* allocator)
: topValues{Compare{}, StlAllocator<T>(allocator)} {}
: heapValues{StlAllocator<T>(allocator)} {}

int64_t getN() const {
return n;
}

size_t size() const {
return topValues.size();
return heapValues.size();
}

void checkAndSetN(DecodedVector& decodedN, vector_size_t row) {
Expand Down Expand Up @@ -584,25 +584,27 @@ struct MinMaxNAccumulator {
}

void compareAndAdd(T value, Compare& comparator) {
if (topValues.size() < n) {
topValues.push(value);
if (heapValues.size() < n) {
heapValues.push_back(value);
std::push_heap(heapValues.begin(), heapValues.end(), comparator);
} else {
const auto& topValue = topValues.top();
const auto& topValue = heapValues.front();
if (comparator(value, topValue)) {
topValues.pop();
topValues.push(value);
std::pop_heap(heapValues.begin(), heapValues.end(), comparator);
heapValues.back() = value;
std::push_heap(heapValues.begin(), heapValues.end(), comparator);
}
}
}

/// Moves all values from 'topValues' into 'rawValues' buffer. The queue of
/// 'topValues' will be empty after this call.
void extractValues(T* rawValues, vector_size_t offset) {
const vector_size_t size = topValues.size();
for (auto i = size - 1; i >= 0; --i) {
rawValues[offset + i] = topValues.top();
topValues.pop();
/// Copy all values from 'topValues' into 'rawValues' buffer. The heap remains
/// unchanged after the call.
void extractValues(T* rawValues, vector_size_t offset, Compare& comparator) {
std::sort_heap(heapValues.begin(), heapValues.end(), comparator);
for (int64_t i = heapValues.size() - 1; i >= 0; --i) {
rawValues[offset + i] = heapValues[i];
}
std::make_heap(heapValues.begin(), heapValues.end(), comparator);
}
};

Expand Down Expand Up @@ -775,7 +777,7 @@ class MinMaxNAggregateBase : public exec::Aggregate {
if (rawNs != nullptr) {
rawNs[i] = accumulator->n;
}
accumulator->extractValues(rawValues, offset);
accumulator->extractValues(rawValues, offset, comparator_);

offset += size;
}
Expand Down
35 changes: 35 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -766,4 +766,39 @@ TEST_F(MinMaxNTest, double) {
testNumericGroupBy<double>();
}

TEST_F(MinMaxNTest, incrementalWindow) {
// SELECT
// c0, c1, c2, c3,
// max(c0, c1) over (partition by c2 order by c3 asc)
// FROM (
// VALUES
// (1, 10, false, 0),
// (2, 10, false, 1)
// ) AS t(c0, c1, c2, c3)
auto data = makeRowVector({
makeFlatVector<int64_t>({1, 2}),
makeFlatVector<int64_t>({10, 10}),
makeFlatVector<bool>({false, false}),
makeFlatVector<int64_t>({0, 1}),
});

auto plan =
PlanBuilder()
.values({data})
.window({"max(c0, c1) over (partition by c2 order by c3 asc)"})
.planNode();

auto result = AssertQueryBuilder(plan).copyResults(pool());

// Expected result: {1, 10, false, 0, [1]}, {2, 10, false, 1, [2, 1]}.
auto expected = makeRowVector({
makeFlatVector<int64_t>({1, 2}),
makeFlatVector<int64_t>({10, 10}),
makeFlatVector<bool>({false, false}),
makeFlatVector<int64_t>({0, 1}),
makeArrayVector<int64_t>({{1}, {2, 1}}),
});
facebook::velox::test::assertEqualVectors(expected, result);
}

} // namespace

0 comments on commit 087fd23

Please sign in to comment.