Skip to content

Commit

Permalink
Internal duckdb#330: Quantile Merge Sort Trees
Browse files Browse the repository at this point in the history
Toggle 32/64 bit versions.
  • Loading branch information
hawkfish committed Oct 6, 2023
1 parent 5bc021c commit 8c915da
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 137 deletions.
315 changes: 195 additions & 120 deletions src/core_functions/aggregate/holistic/quantile.cpp
Expand Up @@ -56,38 +56,6 @@ struct QuantileIncluded {
const idx_t bias;
};

template <typename SAVE_TYPE>
struct QuantileState {
using SaveType = SAVE_TYPE;

// Regular aggregation
vector<SaveType> v;

// Windowed Quantile indirection
FrameBounds prev;
vector<idx_t> w;
idx_t pos;

// Windowed Quantile merge sort tree
unique_ptr<QuantileSortTree> qst;

// Windowed MAD indirection
vector<idx_t> m;

QuantileState() : pos(0) {
}

~QuantileState() {
}

inline void SetPos(size_t pos_p) {
pos = pos_p;
if (pos >= w.size()) {
w.resize(pos);
}
}
};

void ReuseIndexes(idx_t *index, const FrameBounds &frame, const FrameBounds &prev) {
idx_t j = 0;

Expand Down Expand Up @@ -547,6 +515,170 @@ struct QuantileBindData : public FunctionData {
bool desc;
};

template <typename IDX>
struct QuantileSortTree : public MergeSortTree<IDX, IDX> {

using BaseTree = MergeSortTree<IDX, IDX>;
using Elements = typename BaseTree::Elements;

explicit QuantileSortTree(Elements &&lowest_level) : BaseTree(std::move(lowest_level)) {
}

template <class INPUT_TYPE>
static unique_ptr<QuantileSortTree> WindowInit(const INPUT_TYPE *data, AggregateInputData &aggr_input_data,
const ValidityMask &data_mask, const ValidityMask &filter_mask,
idx_t count) {
// Build the indirection array
using ElementType = typename QuantileSortTree::ElementType;
vector<ElementType> sorted(count);
if (filter_mask.AllValid() && data_mask.AllValid()) {
std::iota(sorted.begin(), sorted.end(), 0);
} else {
size_t valid = 0;
QuantileIncluded included(filter_mask, data_mask, 0);
for (ElementType i = 0; i < count; ++i) {
if (included(i)) {
sorted[valid++] = i;
}
}
sorted.resize(valid);
}

// Sort it
auto &bind_data = aggr_input_data.bind_data->Cast<QuantileBindData>();
using Accessor = QuantileIndirect<INPUT_TYPE>;
Accessor indirect(data);
QuantileCompare<Accessor> cmp(indirect, bind_data.desc);
std::sort(sorted.begin(), sorted.end(), cmp);

return make_uniq<QuantileSortTree>(std::move(sorted));
}

template <typename INPUT_TYPE, typename RESULT_TYPE, bool DISCRETE>
void WindowScalar(const INPUT_TYPE *data, const QuantileIncluded &included, const FrameBounds &frame,
Vector &result, const idx_t ridx, const QuantileBindData &bind_data) {
auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
auto &rmask = FlatVector::Validity(result);

// Count the number of valid values
auto n = frame.end - frame.start;
if (!included.AllValid()) {
// NULLs or FILTERed values,
n = 0;
for (auto i = frame.start; i < frame.end; ++i) {
n += included(i);
}
}

const auto &q = bind_data.quantiles[0];
if (n) {
// Find the interpolated indicies within the frame
Interpolator<DISCRETE> interp(q, n, false);
const auto lo_idx = BaseTree::SelectNth(frame, interp.FRN);
const auto lo_data = BaseTree::NthElement(lo_idx);
auto hi_idx = lo_idx;
auto hi_data = lo_data;
if (interp.CRN != interp.FRN) {
hi_idx = BaseTree::SelectNth(frame, interp.CRN);
hi_data = BaseTree::NthElement(hi_idx);
}

// Interpolate indirectly
using ID = QuantileIndirect<INPUT_TYPE>;
ID indirect(data);
rdata[ridx] = interp.template Interpolate<size_t, RESULT_TYPE, ID>(lo_data, hi_data, result, indirect);
} else {
rmask.Set(ridx, false);
}
}

template <typename INPUT_TYPE, typename CHILD_TYPE, bool DISCRETE>
void WindowList(const INPUT_TYPE *data, const QuantileIncluded &included, const FrameBounds &frame, Vector &list,
const idx_t lidx, const QuantileBindData &bind_data) {

// Result is a constant LIST<CHILD_TYPE> with a fixed length
auto ldata = FlatVector::GetData<list_entry_t>(list);
auto &lmask = FlatVector::Validity(list);
auto &lentry = ldata[lidx];
lentry.offset = ListVector::GetListSize(list);
lentry.length = bind_data.quantiles.size();

ListVector::Reserve(list, lentry.offset + lentry.length);
ListVector::SetListSize(list, lentry.offset + lentry.length);
auto &result = ListVector::GetEntry(list);
auto rdata = FlatVector::GetData<CHILD_TYPE>(result);

// Count the number of valid values
auto n = frame.end - frame.start;
if (!included.AllValid()) {
// NULLs or FILTERed values,
n = 0;
for (auto i = frame.start; i < frame.end; ++i) {
n += included(i);
}
}

if (n) {
using ID = QuantileIndirect<INPUT_TYPE>;
ID indirect(data);
for (const auto &q : bind_data.order) {
const auto &quantile = bind_data.quantiles[q];
Interpolator<DISCRETE> interp(quantile, n, false);

const auto lo_idx = BaseTree::SelectNth(frame, interp.FRN);
const auto lo_data = BaseTree::NthElement(lo_idx);
auto hi_idx = lo_idx;
auto hi_data = lo_data;
if (interp.CRN != interp.FRN) {
hi_idx = BaseTree::SelectNth(frame, interp.CRN);
hi_data = BaseTree::NthElement(hi_idx);
}

// Interpolate indirectly
rdata[lentry.offset + q] =
interp.template Interpolate<idx_t, CHILD_TYPE, ID>(lo_data, hi_data, result, indirect);
}
} else {
lmask.Set(lidx, false);
}
}
};

template <typename SAVE_TYPE>
struct QuantileState {
using SaveType = SAVE_TYPE;
using QuantileSortTree32 = QuantileSortTree<uint32_t>;
using QuantileSortTree64 = QuantileSortTree<uint64_t>;

// Regular aggregation
vector<SaveType> v;

// Windowed Quantile indirection
FrameBounds prev;
vector<idx_t> w;
idx_t pos;

// Windowed Quantile merge sort trees
unique_ptr<QuantileSortTree32> qst32;
unique_ptr<QuantileSortTree64> qst64;

// Windowed MAD indirection
vector<idx_t> m;

QuantileState() : pos(0) {
}

~QuantileState() {
}

inline void SetPos(size_t pos_p) {
pos = pos_p;
if (pos >= w.size()) {
w.resize(pos);
}
}
};

struct QuantileOperation {
template <class STATE>
static void Initialize(STATE &state) {
Expand Down Expand Up @@ -590,32 +722,15 @@ struct QuantileOperation {
const auto data = FlatVector::GetData<const INPUT_TYPE>(inputs[0]);
const auto &data_mask = FlatVector::Validity(inputs[0]);

// Build the indirection array
using ElementType = typename QuantileSortTree::ElementType;
vector<ElementType> sorted(count);
if (filter_mask.AllValid() && data_mask.AllValid()) {
std::iota(sorted.begin(), sorted.end(), 0);
} else {
size_t valid = 0;
QuantileIncluded included(filter_mask, data_mask, 0);
for (ElementType i = 0; i < count; ++i) {
if (included(i)) {
sorted[valid++] = i;
}
}
sorted.resize(valid);
}

// Sort it
auto &bind_data = aggr_input_data.bind_data->Cast<QuantileBindData>();
using Accessor = QuantileIndirect<INPUT_TYPE>;
Accessor indirect(data);
QuantileCompare<Accessor> cmp(indirect, bind_data.desc);
std::sort(sorted.begin(), sorted.end(), cmp);

// Build the tree
auto &state = *reinterpret_cast<STATE *>(state_p);
state.qst = make_uniq<QuantileSortTree>(std::move(sorted));
if (count < std::numeric_limits<uint32_t>::max()) {
state.qst32 = QuantileSortTree<uint32_t>::WindowInit<INPUT_TYPE>(data, aggr_input_data, data_mask,
filter_mask, count);
} else {
state.qst64 = QuantileSortTree<uint64_t>::WindowInit<INPUT_TYPE>(data, aggr_input_data, data_mask,
filter_mask, count);
}
}
};

Expand Down Expand Up @@ -649,47 +764,29 @@ struct QuantileScalarOperation : public QuantileOperation {
static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask,
AggregateInputData &aggr_input_data, STATE &state, const FrameBounds &frame, Vector &result,
idx_t ridx, const STATE *gstate) {
auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
auto &rmask = FlatVector::Validity(result);

QuantileIncluded included(fmask, dmask, 0);

D_ASSERT(aggr_input_data.bind_data);
auto &bind_data = aggr_input_data.bind_data->Cast<QuantileBindData>();

// Find the two positions needed
const auto &q = bind_data.quantiles[0];
#if 1
// Count the number of valid values
auto n = frame.end - frame.start;
if (!included.AllValid()) {
// NULLs or FILTERed values,
n = 0;
for (auto i = frame.start; i < frame.end; ++i) {
n += included(i);
}
}

if (n) {
// Find the interpolated indicies within the frame
Interpolator<DISCRETE> interp(q, n, false);
const auto lo_idx = gstate->qst->SelectNth(frame, interp.FRN);
const auto lo_data = gstate->qst->NthElement(lo_idx);
auto hi_idx = lo_idx;
auto hi_data = lo_data;
if (interp.CRN != interp.FRN) {
hi_idx = gstate->qst->SelectNth(frame, interp.CRN);
hi_data = gstate->qst->NthElement(hi_idx);
}

// Interpolate indirectly
using ID = QuantileIndirect<INPUT_TYPE>;
ID indirect(data);
rdata[ridx] = interp.template Interpolate<idx_t, RESULT_TYPE, ID>(lo_data, hi_data, result, indirect);
if (gstate->qst32) {
gstate->qst32->template WindowScalar<INPUT_TYPE, RESULT_TYPE, DISCRETE>(data, included, frame, result, ridx,
bind_data);
} else if (gstate->qst64) {
gstate->qst64->template WindowScalar<INPUT_TYPE, RESULT_TYPE, DISCRETE>(data, included, frame, result, ridx,
bind_data);
} else {
auto &rmask = FlatVector::Validity(result);
rmask.Set(ridx, false);
}
#else
auto rdata = FlatVector::GetData<RESULT_TYPE>(result);
auto &rmask = FlatVector::Validity(result);
const auto &q = bind_data.quantiles[0];
// Lazily initialise frame state
auto prev_pos = state.pos;
state.SetPos(frame.end - frame.start);
Expand Down Expand Up @@ -837,6 +934,18 @@ struct QuantileListOperation : public QuantileOperation {

QuantileIncluded included(fmask, dmask, 0);

#if 1
if (gstate->qst32) {
gstate->qst32->template WindowList<INPUT_TYPE, CHILD_TYPE, DISCRETE>(data, included, frame, list, lidx,
bind_data);
} else if (gstate->qst64) {
gstate->qst64->template WindowList<INPUT_TYPE, CHILD_TYPE, DISCRETE>(data, included, frame, list, lidx,
bind_data);
} else {
auto &lmask = FlatVector::Validity(list);
lmask.Set(lidx, false);
}
#else
// Result is a constant LIST<RESULT_TYPE> with a fixed length
auto ldata = FlatVector::GetData<RESULT_TYPE>(list);
auto &lmask = FlatVector::Validity(list);
Expand All @@ -848,41 +957,7 @@ struct QuantileListOperation : public QuantileOperation {
ListVector::SetListSize(list, lentry.offset + lentry.length);
auto &result = ListVector::GetEntry(list);
auto rdata = FlatVector::GetData<CHILD_TYPE>(result);
#if 1
// Count the number of valid values
auto n = frame.end - frame.start;
if (!included.AllValid()) {
// NULLs or FILTERed values,
n = 0;
for (auto i = frame.start; i < frame.end; ++i) {
n += included(i);
}
}
if (n) {
using ID = QuantileIndirect<INPUT_TYPE>;
ID indirect(data);
for (const auto &q : bind_data.order) {
const auto &quantile = bind_data.quantiles[q];
Interpolator<DISCRETE> interp(quantile, n, false);

const auto lo_idx = gstate->qst->SelectNth(frame, interp.FRN);
const auto lo_data = gstate->qst->NthElement(lo_idx);
auto hi_idx = lo_idx;
auto hi_data = lo_data;
if (interp.CRN != interp.FRN) {
hi_idx = gstate->qst->SelectNth(frame, interp.CRN);
hi_data = gstate->qst->NthElement(hi_idx);
}

// Interpolate indirectly
rdata[lentry.offset + q] =
interp.template Interpolate<idx_t, CHILD_TYPE, ID>(lo_data, hi_data, result, indirect);
}
} else {
lmask.Set(lidx, false);
}
#else
// Lazily initialise frame state
auto prev_pos = state.pos;
state.SetPos(frame.end - frame.start);
Expand Down

0 comments on commit 8c915da

Please sign in to comment.