Skip to content

Commit

Permalink
Internal duckdb#330: Merge Sort Tree EXCLUDE
Browse files Browse the repository at this point in the history
Finish implementing and testing merge sort trees
by supporting EXCLUDE/multiple frames.
  • Loading branch information
hawkfish committed Oct 17, 2023
1 parent b20a841 commit a9ef0e9
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 28 deletions.
70 changes: 46 additions & 24 deletions src/core_functions/aggregate/holistic/quantile.cpp
Expand Up @@ -558,20 +558,45 @@ struct QuantileSortTree : public MergeSortTree<IDX, IDX> {
return make_uniq<QuantileSortTree>(std::move(sorted));
}

IDX SelectNth(const IncludedFrames &frames, size_t n) const {
size_t result = 0;

// To find the element at position n we need to find n+1 elements.
auto needed = n + 1;
for (const auto &frame : frames) {
// Skip empty frames
if (frame.end <= frame.start) {
continue;
}

// The frame can't supply more elements than it contains
// Don't look for more than we need
const auto available = MinValue<size_t>(needed, frame.end - frame.start);
auto frame_n = available - 1;
result = BaseTree::SelectNth(frame, frame_n);

// Reduce the count by the number we found
needed -= (available - frame_n);
if (!needed) {
break;
}
}
D_ASSERT(!needed);

return BaseTree::NthElement(result);
}

template <typename INPUT_TYPE, typename RESULT_TYPE, bool DISCRETE>
RESULT_TYPE WindowScalar(const INPUT_TYPE *data, const FrameBounds &frame, const idx_t n, Vector &result,
RESULT_TYPE WindowScalar(const INPUT_TYPE *data, const IncludedFrames &frames, const idx_t n, Vector &result,
const QuantileValue &q) const {
D_ASSERT(n > 0);

// Find the interpolated indicies within the frame
Interpolator<DISCRETE> interp(q, n, false);
const auto lo_idx = BaseTree::SelectNth(frame, interp.FRN);
const auto lo_data = BaseTree::NthElement(lo_idx);
auto hi_idx = lo_idx;
const auto lo_data = SelectNth(frames, interp.FRN);
auto hi_data = lo_data;
if (interp.CRN != interp.FRN) {
hi_idx = BaseTree::SelectNth(frame, interp.CRN);
hi_data = BaseTree::NthElement(hi_idx);
hi_data = SelectNth(frames, interp.CRN);
}

// Interpolate indirectly
Expand All @@ -581,7 +606,7 @@ struct QuantileSortTree : public MergeSortTree<IDX, IDX> {
}

template <typename INPUT_TYPE, typename CHILD_TYPE, bool DISCRETE>
void WindowList(const INPUT_TYPE *data, const FrameBounds &frame, const idx_t n, Vector &list, const idx_t lidx,
void WindowList(const INPUT_TYPE *data, const IncludedFrames &frames, const idx_t n, Vector &list, const idx_t lidx,
const QuantileBindData &bind_data) const {
D_ASSERT(n > 0);

Expand All @@ -602,13 +627,10 @@ struct QuantileSortTree : public MergeSortTree<IDX, IDX> {
const auto &quantile = bind_data.quantiles[q];
Interpolator<DISCRETE> interp(quantile, n, false);

const auto lo_idx = BaseTree::SelectNth(frame, interp.FRN);
const auto lo_data = BaseTree::NthElement(lo_idx);
auto hi_idx = lo_idx;
const auto lo_data = SelectNth(frames, interp.FRN);
auto hi_data = lo_data;
if (interp.CRN != interp.FRN) {
hi_idx = BaseTree::SelectNth(frame, interp.CRN);
hi_data = BaseTree::NthElement(hi_idx);
hi_data = SelectNth(frames, interp.CRN);
}

// Interpolate indirectly
Expand Down Expand Up @@ -730,9 +752,9 @@ struct QuantileState {
const QuantileValue &q) const {
D_ASSERT(n > 0);
if (qst32) {
return qst32->WindowScalar<INPUT_TYPE, RESULT_TYPE, DISCRETE>(data, frames[0], n, result, q);
return qst32->WindowScalar<INPUT_TYPE, RESULT_TYPE, DISCRETE>(data, frames, n, result, q);
} else if (qst64) {
return qst64->WindowScalar<INPUT_TYPE, RESULT_TYPE, DISCRETE>(data, frames[0], n, result, q);
return qst64->WindowScalar<INPUT_TYPE, RESULT_TYPE, DISCRETE>(data, frames, n, result, q);
} else if (s) {
// Find the position(s) needed
Interpolator<DISCRETE> interp(q, s->size(), false);
Expand Down Expand Up @@ -806,7 +828,7 @@ struct QuantileOperation {
const ValidityMask &filter_mask, data_ptr_t state_p, idx_t count, const FrameStats *stats) {
D_ASSERT(input_count == 1);

// If frames overlap significantly, then use local skip lists.
// If frames overlap significantly, then use local skip lists.
D_ASSERT(stats);
if (stats[0].end <= stats[1].begin) {
// Frames can overlap
Expand All @@ -818,7 +840,7 @@ struct QuantileOperation {
}
}

const auto data = FlatVector::GetData<const INPUT_TYPE>(inputs[0]);
const auto data = FlatVector::GetData<const INPUT_TYPE>(inputs[0]);
const auto &data_mask = FlatVector::Validity(inputs[0]);

// Build the tree
Expand All @@ -837,7 +859,7 @@ struct QuantileOperation {
idx_t n = 0;
if (included.AllValid()) {
for (const auto &frame : frames) {
n += frame.end - frame.start;
n += frame.end - frame.start;
}
} else {
// NULLs or FILTERed values,
Expand Down Expand Up @@ -880,8 +902,8 @@ struct QuantileScalarOperation : public QuantileOperation {

template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask,
AggregateInputData &aggr_input_data, STATE &state, const IncludedFrames &frames,
Vector &result, idx_t ridx, const STATE *gstate) {
AggregateInputData &aggr_input_data, STATE &state, const IncludedFrames &frames, Vector &result,
idx_t ridx, const STATE *gstate) {
QuantileIncluded included(fmask, dmask);
const auto n = FrameSize(included, frames);

Expand All @@ -897,7 +919,7 @@ struct QuantileScalarOperation : public QuantileOperation {
}

const auto &quantile = bind_data.quantiles[0];
if (gstate && gstate->HasTrees() && frames.size() == 1) {
if (gstate && gstate->HasTrees()) {
rdata[ridx] = gstate->template WindowScalar<RESULT_TYPE, DISCRETE>(data, frames, n, result, quantile);
} else {
// Update the skip list
Expand Down Expand Up @@ -1008,8 +1030,8 @@ struct QuantileListOperation : public QuantileOperation {

template <class STATE, class INPUT_TYPE, class RESULT_TYPE>
static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask,
AggregateInputData &aggr_input_data, STATE &state, const IncludedFrames &frames,
Vector &list, idx_t lidx, const STATE *gstate) {
AggregateInputData &aggr_input_data, STATE &state, const IncludedFrames &frames, Vector &list,
idx_t lidx, const STATE *gstate) {
D_ASSERT(aggr_input_data.bind_data);
auto &bind_data = aggr_input_data.bind_data->Cast<QuantileBindData>();

Expand All @@ -1023,7 +1045,7 @@ struct QuantileListOperation : public QuantileOperation {
return;
}

if (gstate && gstate->HasTrees() && frames.size() == 1) {
if (gstate && gstate->HasTrees()) {
gstate->template WindowList<CHILD_TYPE, DISCRETE>(data, frames, n, list, lidx, bind_data);
} else {
//
Expand Down Expand Up @@ -1322,7 +1344,7 @@ struct MedianAbsoluteDeviationOperation : public QuantileOperation {
D_ASSERT(bind_data.quantiles.size() == 1);
const auto &quantile = bind_data.quantiles[0];
MEDIAN_TYPE med;
if (gstate && gstate->HasTrees() && frames.size() == 1) {
if (gstate && gstate->HasTrees()) {
med = gstate->template WindowScalar<MEDIAN_TYPE, false>(data, frames, n, result, quantile);
} else {
state.UpdateSkip(data, frames, included);
Expand Down
12 changes: 8 additions & 4 deletions src/include/duckdb/execution/merge_sort_tree.hpp
Expand Up @@ -78,7 +78,7 @@ struct MergeSortTree {
}
explicit MergeSortTree(Elements &&lowest_level, const CMP &cmp = CMP());

size_t SelectNth(const FrameBounds &frame, size_t n) const;
size_t SelectNth(const FrameBounds &frame, size_t &requested) const;

inline ElementType NthElement(size_t i) const {
return tree.front().first[i];
Expand Down Expand Up @@ -265,7 +265,7 @@ MergeSortTree<E, O, CMP, F, C>::MergeSortTree(Elements &&lowest_level, const CMP
}

template <typename E, typename O, typename CMP, uint64_t F, uint64_t C>
size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const FrameBounds &frame, size_t n) const {
size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const FrameBounds &frame, size_t &requested) const {
// Empty frames should be handled by the caller
D_ASSERT(frame.start < frame.end);

Expand All @@ -284,6 +284,7 @@ size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const FrameBounds &frame, size_

// Find Nth element in a top-down traversal
size_t result = 0;
auto n = requested;

// First, handle levels with cascading pointers
const auto min_cascaded = LowestCascadingLevel();
Expand Down Expand Up @@ -360,15 +361,18 @@ size_t MergeSortTree<E, O, CMP, F, C>::SelectNth(const FrameBounds &frame, size_
// The last level
const auto *level_data = tree[level_no].first.data();
++n;
while (true) {

const auto count = tree[level_no].first.size();
for (const auto limit = MinValue<size_t>(result + FANOUT, count); result < limit; ++result) {
const auto v = level_data[result];
n -= (v >= frame.start) && (v < frame.end);
if (!n) {
break;
}
++result;
}

requested = n;

return result;
}

Expand Down

0 comments on commit a9ef0e9

Please sign in to comment.