Skip to content

Commit

Permalink
Internal duckdb#330: Quantile EXCLUDE
Browse files Browse the repository at this point in the history
Push subframes into the Merge Sort Tree.
  • Loading branch information
hawkfish committed Oct 21, 2023
1 parent eb396b3 commit 37f2ea0
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 56 deletions.
28 changes: 2 additions & 26 deletions src/core_functions/aggregate/holistic/quantile.cpp
Expand Up @@ -558,32 +558,8 @@ struct QuantileSortTree : public MergeSortTree<IDX, IDX> {
return make_uniq<QuantileSortTree>(std::move(sorted));
}

IDX SelectNth(const SubFrames &frames, size_t n) const {
size_t result = 0;

// To find the element at position n we need to find n+1 elements.
auto needed = n + 1;
for (const auto &frame : frames) {
// Skip empty frames
if (frame.end <= frame.start) {
continue;
}

// The frame can't supply more elements than it contains
// Don't look for more than we need
const auto available = MinValue<size_t>(needed, frame.end - frame.start);
auto frame_n = available - 1;
result = BaseTree::SelectNth(frame, frame_n);

// Reduce the count by the number we found
needed -= (available - frame_n);
if (!needed) {
break;
}
}
D_ASSERT(!needed);

return BaseTree::NthElement(result);
inline IDX SelectNth(const SubFrames &frames, size_t n) const {
return BaseTree::NthElement(BaseTree::SelectNth(frames, n));
}

template <typename INPUT_TYPE, typename RESULT_TYPE, bool DISCRETE>
Expand Down
78 changes: 48 additions & 30 deletions src/include/duckdb/execution/merge_sort_tree.hpp
Expand Up @@ -78,7 +78,7 @@ struct MergeSortTree {
}
explicit MergeSortTree(Elements &&lowest_level, const CMP &cmp = CMP());

size_t SelectNth(const FrameBounds &frame, size_t &requested) const;
size_t SelectNth(const SubFrames &frames, size_t n) const;

inline ElementType NthElement(size_t i) const {
return tree.front().first[i];
Expand Down Expand Up @@ -265,10 +265,7 @@ MergeSortTree<E, O, CMP, F, C>::MergeSortTree(Elements &&lowest_level, const CMP
}

template <typename E, typename O, typename CMP, uint64_t F, uint64_t C>
size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const FrameBounds &frame, size_t &requested) const {
// Empty frames should be handled by the caller
D_ASSERT(frame.start < frame.end);

size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const SubFrames &frames, size_t n) const {
// Handle special case of a one-element tree
if (tree.size() < 2) {
return 0;
Expand All @@ -284,15 +281,17 @@ size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const FrameBounds &frame, size_

// Find Nth element in a top-down traversal
size_t result = 0;
auto n = requested;

// First, handle levels with cascading pointers
const auto min_cascaded = LowestCascadingLevel();
if (level_no > min_cascaded) {
// Initialise the cascade indicies from the previous level
pair<size_t, size_t> cascade_idx;
{
const auto &level = tree[level_no + 1].first;
using CascadeRange = pair<size_t, size_t>;
std::array<CascadeRange, 3> cascades;
const auto &level = tree[level_no + 1].first;
for (size_t f = 0; f < frames.size(); ++f) {
const auto &frame = frames[f];
auto &cascade_idx = cascades[f];
const auto lower_idx = std::lower_bound(level.begin(), level.end(), frame.start) - level.begin();
cascade_idx.first = lower_idx / CASCADING * FANOUT;
const auto upper_idx = std::lower_bound(level.begin(), level.end(), frame.end) - level.begin();
Expand All @@ -307,25 +306,40 @@ size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const FrameBounds &frame, size_
// Go over all children until we found enough in range
const auto *level_data = tree[level_no].first.data();
while (true) {
const auto lower_begin = level_data + level_cascades[cascade_idx.first];
const auto lower_end = level_data + level_cascades[cascade_idx.first + FANOUT];
const auto lower_idx = std::lower_bound(lower_begin, lower_end, frame.start) - level_data;

const auto upper_begin = level_data + level_cascades[cascade_idx.second];
const auto upper_end = level_data + level_cascades[cascade_idx.second + FANOUT];
const auto upper_idx = std::lower_bound(upper_begin, upper_end, frame.end) - level_data;

const auto matched = size_t(upper_idx - lower_idx);
size_t matched = 0;
std::array<CascadeRange, 3> matches;
for (size_t f = 0; f < frames.size(); ++f) {
const auto &frame = frames[f];
auto &cascade_idx = cascades[f];
auto &match = matches[f];

const auto lower_begin = level_data + level_cascades[cascade_idx.first];
const auto lower_end = level_data + level_cascades[cascade_idx.first + FANOUT];
match.first = std::lower_bound(lower_begin, lower_end, frame.start) - level_data;

const auto upper_begin = level_data + level_cascades[cascade_idx.second];
const auto upper_end = level_data + level_cascades[cascade_idx.second + FANOUT];
match.second = std::lower_bound(upper_begin, upper_end, frame.end) - level_data;

matched += size_t(match.second - match.first);
}
if (matched > n) {
// Enough in this level, so move down to leftmost child candidate within the cascade range
cascade_idx.first = (lower_idx / CASCADING + 2 * result) * FANOUT;
cascade_idx.second = (upper_idx / CASCADING + 2 * result) * FANOUT;
// Too much in this level, so move down to leftmost child candidate within the cascade range
for (size_t f = 0; f < frames.size(); ++f) {
auto &cascade_idx = cascades[f];
auto &match = matches[f];
cascade_idx.first = (match.first / CASCADING + 2 * result) * FANOUT;
cascade_idx.second = (match.second / CASCADING + 2 * result) * FANOUT;
}
break;
}

// Not enough in this child, so move right
++cascade_idx.first;
++cascade_idx.second;
for (size_t f = 0; f < frames.size(); ++f) {
auto &cascade_idx = cascades[f];
++cascade_idx.first;
++cascade_idx.second;
}
++result;
n -= matched;
}
Expand All @@ -340,11 +354,15 @@ size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const FrameBounds &frame, size_
auto range_begin = level.begin() + result * level_width;
auto range_end = range_begin + level_width;
while (range_end < level.end()) {
const auto lower_match = std::lower_bound(range_begin, range_end, frame.start);
const auto upper_match = std::lower_bound(lower_match, range_end, frame.end);
const auto matched = size_t(upper_match - lower_match);
size_t matched = 0;
for (size_t f = 0; f < frames.size(); ++f) {
const auto &frame = frames[f];
const auto lower_match = std::lower_bound(range_begin, range_end, frame.start);
const auto upper_match = std::lower_bound(lower_match, range_end, frame.end);
matched += size_t(upper_match - lower_match);
}
if (matched > n) {
// Enough in this level, so move down to leftmost child candidate
// Too much in this level, so move down to leftmost child candidate
// Since we have no cascade pointers left, this is just the start of the next level.
break;
}
Expand All @@ -365,14 +383,14 @@ size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const FrameBounds &frame, size_
const auto count = tree[level_no].first.size();
for (const auto limit = MinValue<size_t>(result + FANOUT, count); result < limit; ++result) {
const auto v = level_data[result];
n -= (v >= frame.start) && (v < frame.end);
for (const auto &frame : frames) {
n -= (v >= frame.start) && (v < frame.end);
}
if (!n) {
break;
}
}

requested = n;

return result;
}

Expand Down
32 changes: 32 additions & 0 deletions test/sql/window/test_window_exclude.test
Expand Up @@ -704,6 +704,38 @@ ORDER BY i;
5 4.5
5 5.0

# Test Merge Sort Trees with exclusions
query III
WITH t1(x, y) AS (VALUES
( 1, 3 ),
( 2, 2 ),
( 3, 1 )
)
SELECT x, y, QUANTILE_DISC(y, 0) OVER (
ORDER BY x
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
EXCLUDE CURRENT ROW)
FROM t1;
----
1 3 NULL
2 2 3
3 1 2

query III
WITH t1(x, y) AS (VALUES
( 1, 3 ),
( 2, 2 ),
( 3, 1 )
)
SELECT x, y, QUANTILE_DISC(y, 0) OVER (
ORDER BY x
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
EXCLUDE CURRENT ROW)
FROM t1;
----
1 3 1
2 2 1
3 1 2


# PG test
Expand Down

0 comments on commit 37f2ea0

Please sign in to comment.