diff --git a/CMakeLists.txt b/CMakeLists.txt index 67444dbc8cc..0a67b80373c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -483,6 +483,7 @@ include_directories(third_party/fsst) include_directories(third_party/fmt/include) include_directories(third_party/hyperloglog) include_directories(third_party/fastpforlib) +include_directories(third_party/skiplist) include_directories(third_party/fast_float) include_directories(third_party/re2) include_directories(third_party/miniz) diff --git a/scripts/package_build.py b/scripts/package_build.py index f98eeeeb377..a60a5d2d5f8 100644 --- a/scripts/package_build.py +++ b/scripts/package_build.py @@ -17,6 +17,7 @@ def third_party_includes(): includes += [os.path.join('third_party', 'utf8proc', 'include')] includes += [os.path.join('third_party', 'utf8proc')] includes += [os.path.join('third_party', 'hyperloglog')] + includes += [os.path.join('third_party', 'skiplist')] includes += [os.path.join('third_party', 'fastpforlib')] includes += [os.path.join('third_party', 'tdigest')] includes += [os.path.join('third_party', 'libpg_query', 'include')] @@ -40,6 +41,7 @@ def third_party_sources(): sources += [os.path.join('third_party', 'miniz')] sources += [os.path.join('third_party', 're2')] sources += [os.path.join('third_party', 'hyperloglog')] + sources += [os.path.join('third_party', 'skiplist')] sources += [os.path.join('third_party', 'fastpforlib')] sources += [os.path.join('third_party', 'utf8proc')] sources += [os.path.join('third_party', 'libpg_query')] diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cda2d86cf88..071f584a041 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -75,6 +75,7 @@ else() duckdb_utf8proc duckdb_hyperloglog duckdb_fastpforlib + duckdb_skiplistlib duckdb_mbedtls) add_library(duckdb SHARED ${ALL_OBJECT_FILES}) diff --git a/src/common/types/validity_mask.cpp b/src/common/types/validity_mask.cpp index 2cf28ce3b98..3faa811e44c 100644 --- a/src/common/types/validity_mask.cpp +++ b/src/common/types/validity_mask.cpp @@ -67,9 +67,6 @@ void ValidityMask::Resize(idx_t old_size, idx_t new_size) { } validity_data = std::move(new_validity_data); validity_mask = validity_data->owned_data.get(); - } else { - // TODO: We shouldn't have to initialize here, just update the target count - Initialize(new_size); } } diff --git a/src/common/vector_operations/vector_copy.cpp b/src/common/vector_operations/vector_copy.cpp index a2e0f523d48..52016757fe9 100644 --- a/src/common/vector_operations/vector_copy.cpp +++ b/src/common/vector_operations/vector_copy.cpp @@ -114,8 +114,7 @@ void VectorOperations::Copy(const Vector &source_p, Vector &target, const Select } else { // set invalid if (tmask.AllValid()) { - auto init_size = MaxValue(STANDARD_VECTOR_SIZE, target_offset + copy_count); - tmask.Initialize(init_size); + tmask.Initialize(); } tmask.SetInvalidUnsafe(target_offset + i); } diff --git a/src/core_functions/aggregate/holistic/mode.cpp b/src/core_functions/aggregate/holistic/mode.cpp index 7eaf323fc1b..4174dcc91c1 100644 --- a/src/core_functions/aggregate/holistic/mode.cpp +++ b/src/core_functions/aggregate/holistic/mode.cpp @@ -44,7 +44,7 @@ struct ModeState { ModeState() { } - vector prevs; + SubFrames prevs; Counts *frequency_map = nullptr; KEY_TYPE *mode = nullptr; size_t nonzero = 0; @@ -237,13 +237,11 @@ struct ModeFunction { template static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const vector &frames, - Vector &result, idx_t rid) { - + AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result, + idx_t rid, const STATE *gstate) { auto rdata = FlatVector::GetData(result); auto &rmask = FlatVector::Validity(result); auto &prevs = state.prevs; - // TODO: Hack around PerfectAggregateHashTable memory leak if (prevs.empty()) { prevs.resize(1); } @@ -254,7 +252,8 @@ struct ModeFunction { state.frequency_map = new typename STATE::Counts; } const double tau = .25; - if (state.nonzero <= tau * state.frequency_map->size()) { + if (state.nonzero <= tau * state.frequency_map->size() || prevs.back().end <= frames.front().start || + frames.back().end <= prevs.front().start) { state.Reset(); // for f ∈ F do for (const auto &frame : frames) { diff --git a/src/core_functions/aggregate/holistic/quantile.cpp b/src/core_functions/aggregate/holistic/quantile.cpp index f34087325a6..bfb9bfd4e48 100644 --- a/src/core_functions/aggregate/holistic/quantile.cpp +++ b/src/core_functions/aggregate/holistic/quantile.cpp @@ -1,5 +1,6 @@ #include "duckdb/execution/expression_executor.hpp" #include "duckdb/core_functions/aggregate/holistic_functions.hpp" +#include "duckdb/execution/merge_sort_tree.hpp" #include "duckdb/planner/expression.hpp" #include "duckdb/common/operator/cast_operators.hpp" #include "duckdb/common/operator/abs.hpp" @@ -10,7 +11,10 @@ #include "duckdb/common/serializer/serializer.hpp" #include "duckdb/common/serializer/deserializer.hpp" +#include "SkipList.h" + #include +#include #include #include @@ -37,9 +41,7 @@ inline interval_t operator-(const interval_t &lhs, const interval_t &rhs) { } struct FrameSet { - using Frames = vector; - - inline explicit FrameSet(const Frames &frames_p) : frames(frames_p) { + inline explicit FrameSet(const SubFrames &frames_p) : frames(frames_p) { } inline idx_t Size() const { @@ -60,42 +62,7 @@ struct FrameSet { } return false; } - const Frames &frames; -}; - -template -struct QuantileState { - using SaveType = SAVE_TYPE; - - // Regular aggregation - vector v; - - // Windowing state - vector prevs; - - // Windowed Quantile indirection - vector w; - idx_t count; - - // Windowed MAD indirection - vector m; - - QuantileState() : count(0) { - } - - ~QuantileState() { - } - - inline void SetCount(const vector &frames) { - // TODO: Hack around PerfectAggregateHashTable memory leak - if (prevs.empty()) { - prevs.resize(1); - } - count = FrameSet(frames).Size(); - if (count >= w.size()) { - w.resize(count); - } - } + const SubFrames &frames; }; struct QuantileIncluded { @@ -138,7 +105,7 @@ struct QuantileReuseUpdater { } }; -void ReuseIndexes(idx_t *index, const vector &currs, const vector &prevs) { +void ReuseIndexes(idx_t *index, const SubFrames &currs, const SubFrames &prevs) { // Copy overlapping indices by scanning the previous set and copying down into holes. // We copy instead of leaving gaps in case there are fewer values in the current frame. @@ -174,50 +141,6 @@ void ReuseIndexes(idx_t *index, const vector &currs, const vector -static inline int CanReplace(const idx_t *index, const INPUT_TYPE *fdata, const idx_t j, const idx_t k0, const idx_t k1, - const QuantileIncluded &validity) { - D_ASSERT(index); - - // NULLs sort to the end, so if we have inserted a NULL, - // it must be past the end of the quantile to be replaceable. - // Note that the quantile values are never NULL. - const auto ij = index[j]; - if (!validity(ij)) { - return k1 < j ? 1 : 0; - } - - auto curr = fdata[ij]; - if (k1 < j) { - auto hi = fdata[index[k0]]; - return hi < curr ? 1 : 0; - } else if (j < k0) { - auto lo = fdata[index[k1]]; - return curr < lo ? -1 : 0; - } - - return 0; -} - template struct IndirectLess { inline explicit IndirectLess(const INPUT_TYPE *inputs_p) : inputs(inputs_p) { @@ -381,6 +304,18 @@ struct Interpolator { : desc(desc_p), RN((double)(n_p - 1) * q.dbl), FRN(floor(RN)), CRN(ceil(RN)), begin(0), end(n_p) { } + template > + TARGET_TYPE Interpolate(INPUT_TYPE lidx, INPUT_TYPE hidx, Vector &result, const ACCESSOR &accessor) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + if (lidx == hidx) { + return CastInterpolation::Cast(accessor(lidx), result); + } else { + auto lo = CastInterpolation::Cast(accessor(lidx), result); + auto hi = CastInterpolation::Cast(accessor(hidx), result); + return CastInterpolation::Interpolate(lo, RN - FRN, hi); + } + } + template > TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; @@ -397,14 +332,13 @@ struct Interpolator { } } - template > - TARGET_TYPE Replace(const INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + template + inline TARGET_TYPE Extract(const INPUT_TYPE **dest, Vector &result) const { if (CRN == FRN) { - return CastInterpolation::Cast(accessor(v_t[FRN]), result); + return CastInterpolation::Cast(*dest[0], result); } else { - auto lo = CastInterpolation::Cast(accessor(v_t[FRN]), result); - auto hi = CastInterpolation::Cast(accessor(v_t[CRN]), result); + auto lo = CastInterpolation::Cast(*dest[0], result); + auto hi = CastInterpolation::Cast(*dest[1], result); return CastInterpolation::Interpolate(lo, RN - FRN, hi); } } @@ -446,6 +380,12 @@ struct Interpolator { : desc(desc_p), FRN(Index(q, n_p)), CRN(FRN), begin(0), end(n_p) { } + template > + TARGET_TYPE Interpolate(INPUT_TYPE lidx, INPUT_TYPE hidx, Vector &result, const ACCESSOR &accessor) const { + using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; + return CastInterpolation::Cast(accessor(lidx), result); + } + template > TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; @@ -454,10 +394,9 @@ struct Interpolator { return CastInterpolation::Cast(accessor(v_t[FRN]), result); } - template > - TARGET_TYPE Replace(const INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - return CastInterpolation::Cast(accessor(v_t[FRN]), result); + template + TARGET_TYPE Extract(const INPUT_TYPE **dest, Vector &result) const { + return CastInterpolation::Cast(*dest[0], result); } const bool desc; @@ -580,6 +519,254 @@ struct QuantileBindData : public FunctionData { bool desc; }; +template +struct QuantileSortTree : public MergeSortTree { + + using BaseTree = MergeSortTree; + using Elements = typename BaseTree::Elements; + + explicit QuantileSortTree(Elements &&lowest_level) : BaseTree(std::move(lowest_level)) { + } + + template + static unique_ptr WindowInit(const INPUT_TYPE *data, AggregateInputData &aggr_input_data, + const ValidityMask &data_mask, const ValidityMask &filter_mask, + idx_t count) { + // Build the indirection array + using ElementType = typename QuantileSortTree::ElementType; + vector sorted(count); + if (filter_mask.AllValid() && data_mask.AllValid()) { + std::iota(sorted.begin(), sorted.end(), 0); + } else { + size_t valid = 0; + QuantileIncluded included(filter_mask, data_mask); + for (ElementType i = 0; i < count; ++i) { + if (included(i)) { + sorted[valid++] = i; + } + } + sorted.resize(valid); + } + + // Sort it + auto &bind_data = aggr_input_data.bind_data->Cast(); + using Accessor = QuantileIndirect; + Accessor indirect(data); + QuantileCompare cmp(indirect, bind_data.desc); + std::sort(sorted.begin(), sorted.end(), cmp); + + return make_uniq(std::move(sorted)); + } + + inline IDX SelectNth(const SubFrames &frames, size_t n) const { + return BaseTree::NthElement(BaseTree::SelectNth(frames, n)); + } + + template + RESULT_TYPE WindowScalar(const INPUT_TYPE *data, const SubFrames &frames, const idx_t n, Vector &result, + const QuantileValue &q) const { + D_ASSERT(n > 0); + + // Find the interpolated indicies within the frame + Interpolator interp(q, n, false); + const auto lo_data = SelectNth(frames, interp.FRN); + auto hi_data = lo_data; + if (interp.CRN != interp.FRN) { + hi_data = SelectNth(frames, interp.CRN); + } + + // Interpolate indirectly + using ID = QuantileIndirect; + ID indirect(data); + return interp.template Interpolate(lo_data, hi_data, result, indirect); + } + + template + void WindowList(const INPUT_TYPE *data, const SubFrames &frames, const idx_t n, Vector &list, const idx_t lidx, + const QuantileBindData &bind_data) const { + D_ASSERT(n > 0); + + // Result is a constant LIST with a fixed length + auto ldata = FlatVector::GetData(list); + auto &lentry = ldata[lidx]; + lentry.offset = ListVector::GetListSize(list); + lentry.length = bind_data.quantiles.size(); + + ListVector::Reserve(list, lentry.offset + lentry.length); + ListVector::SetListSize(list, lentry.offset + lentry.length); + auto &result = ListVector::GetEntry(list); + auto rdata = FlatVector::GetData(result); + + using ID = QuantileIndirect; + ID indirect(data); + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + Interpolator interp(quantile, n, false); + + const auto lo_data = SelectNth(frames, interp.FRN); + auto hi_data = lo_data; + if (interp.CRN != interp.FRN) { + hi_data = SelectNth(frames, interp.CRN); + } + + // Interpolate indirectly + rdata[lentry.offset + q] = + interp.template Interpolate(lo_data, hi_data, result, indirect); + } + } +}; + +template +struct PointerLess { + inline bool operator()(const T &lhi, const T &rhi) const { + return *lhi < *rhi; + } +}; + +template +struct QuantileState { + using SaveType = SAVE_TYPE; + using InputType = INPUT_TYPE; + + // Regular aggregation + vector v; + + // Windowed Quantile merge sort trees + using QuantileSortTree32 = QuantileSortTree; + using QuantileSortTree64 = QuantileSortTree; + unique_ptr qst32; + unique_ptr qst64; + + // Windowed Quantile skip lists + using PointerType = const InputType *; + using SkipListType = duckdb_skiplistlib::skip_list::HeadNode>; + SubFrames prevs; + unique_ptr s; + mutable vector dest; + + // Windowed MAD indirection + idx_t count; + vector m; + + QuantileState() : count(0) { + } + + ~QuantileState() { + } + + inline void SetCount(size_t count_p) { + count = count_p; + if (count >= m.size()) { + m.resize(count); + } + } + + inline SkipListType &GetSkipList(bool reset = false) { + if (reset || !s) { + s.reset(); + s = make_uniq(); + } + return *s; + } + + struct SkipListUpdater { + SkipListType &skip; + const INPUT_TYPE *data; + const QuantileIncluded &included; + + inline SkipListUpdater(SkipListType &skip, const INPUT_TYPE *data, const QuantileIncluded &included) + : skip(skip), data(data), included(included) { + } + + inline void Neither(idx_t begin, idx_t end) { + } + + inline void Left(idx_t begin, idx_t end) { + for (; begin < end; ++begin) { + if (included(begin)) { + skip.remove(data + begin); + } + } + } + + inline void Right(idx_t begin, idx_t end) { + for (; begin < end; ++begin) { + if (included(begin)) { + skip.insert(data + begin); + } + } + } + + inline void Both(idx_t begin, idx_t end) { + } + }; + + void UpdateSkip(const INPUT_TYPE *data, const SubFrames &frames, const QuantileIncluded &included) { + // No overlap, or no data + if (!s || prevs.back().end <= frames.front().start || frames.back().end <= prevs.front().start) { + auto &skip = GetSkipList(true); + for (const auto &frame : frames) { + for (auto i = frame.start; i < frame.end; ++i) { + if (included(i)) { + skip.insert(data + i); + } + } + } + } else { + auto &skip = GetSkipList(); + SkipListUpdater updater(skip, data, included); + AggregateExecutor::IntersectFrames(prevs, frames, updater); + } + } + + bool HasTrees() const { + return qst32 || qst64; + } + + template + RESULT_TYPE WindowScalar(const INPUT_TYPE *data, const SubFrames &frames, const idx_t n, Vector &result, + const QuantileValue &q) const { + D_ASSERT(n > 0); + if (qst32) { + return qst32->WindowScalar(data, frames, n, result, q); + } else if (qst64) { + return qst64->WindowScalar(data, frames, n, result, q); + } else if (s) { + // Find the position(s) needed + try { + Interpolator interp(q, s->size(), false); + s->at(interp.FRN, interp.CRN - interp.FRN + 1, dest); + return interp.template Extract(dest.data(), result); + } catch (const duckdb_skiplistlib::skip_list::IndexError &idx_err) { + throw InternalException(idx_err.message()); + } + } else { + throw InternalException("No accelerator for scalar QUANTILE"); + } + } + + template + void WindowList(const INPUT_TYPE *data, const SubFrames &frames, const idx_t n, Vector &list, const idx_t lidx, + const QuantileBindData &bind_data) const { + D_ASSERT(n > 0); + // Result is a constant LIST with a fixed length + auto ldata = FlatVector::GetData(list); + auto &lentry = ldata[lidx]; + lentry.offset = ListVector::GetListSize(list); + lentry.length = bind_data.quantiles.size(); + + ListVector::Reserve(list, lentry.offset + lentry.length); + ListVector::SetListSize(list, lentry.offset + lentry.length); + auto &result = ListVector::GetEntry(list); + auto rdata = FlatVector::GetData(result); + + for (const auto &q : bind_data.order) { + const auto &quantile = bind_data.quantiles[q]; + rdata[lentry.offset + q] = WindowScalar(data, frames, n, result, quantile); + } + } +}; + struct QuantileOperation { template static void Initialize(STATE &state) { @@ -615,6 +802,60 @@ struct QuantileOperation { static bool IgnoreNull() { return true; } + + template + static void WindowInit(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + data_ptr_t g_state) { + D_ASSERT(partition.input_count == 1); + + auto inputs = partition.inputs; + const auto count = partition.count; + const auto &filter_mask = partition.filter_mask; + const auto &stats = partition.stats; + + // If frames overlap significantly, then use local skip lists. + if (stats[0].end <= stats[1].begin) { + // Frames can overlap + const auto overlap = double(stats[1].begin - stats[0].end); + const auto cover = double(stats[1].end - stats[0].begin); + const auto ratio = overlap / cover; + if (ratio > .75) { + return; + } + } + + const auto data = FlatVector::GetData(inputs[0]); + const auto &data_mask = FlatVector::Validity(inputs[0]); + + // Build the tree + auto &state = *reinterpret_cast(g_state); + if (count < std::numeric_limits::max()) { + state.qst32 = QuantileSortTree::WindowInit(data, aggr_input_data, data_mask, + filter_mask, count); + } else { + state.qst64 = QuantileSortTree::WindowInit(data, aggr_input_data, data_mask, + filter_mask, count); + } + } + + static idx_t FrameSize(const QuantileIncluded &included, const SubFrames &frames) { + // Count the number of valid values + idx_t n = 0; + if (included.AllValid()) { + for (const auto &frame : frames) { + n += frame.end - frame.start; + } + } else { + // NULLs or FILTERed values, + for (const auto &frame : frames) { + for (auto i = frame.start; i < frame.end; ++i) { + n += included(i); + } + } + } + + return n; + } }; template @@ -645,68 +886,45 @@ struct QuantileScalarOperation : public QuantileOperation { template static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const vector &frames, - Vector &result, idx_t ridx) { - auto rdata = FlatVector::GetData(result); - auto &rmask = FlatVector::Validity(result); - + AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result, + idx_t ridx, const STATE *gstate) { QuantileIncluded included(fmask, dmask); - - // Lazily initialise frame state - const auto prev_count = state.count; - auto &prevs = state.prevs; - state.SetCount(frames); - - auto index = state.w.data(); - D_ASSERT(index); + const auto n = FrameSize(included, frames); D_ASSERT(aggr_input_data.bind_data); auto &bind_data = aggr_input_data.bind_data->Cast(); - // Find the two positions needed - const auto &q = bind_data.quantiles[0]; - - bool replace = false; - if (frames.size() == 1 && frames[0].start == prevs[0].start + 1 && frames[0].end == prevs[0].end + 1) { - // Fixed frame size - const auto j = ReplaceIndex(index, frames[0], prevs[0]); - // We can only replace if the number of NULLs has not changed - if (included.AllValid() || included(prevs[0].start) == included(prevs[0].end)) { - Interpolator interp(q, prev_count, false); - replace = CanReplace(index, data, j, interp.FRN, interp.CRN, included); - if (replace) { - state.count = prev_count; - } - } - } else { - ReuseIndexes(index, frames, prevs); - } + auto rdata = FlatVector::GetData(result); + auto &rmask = FlatVector::Validity(result); - if (!replace && !included.AllValid()) { - // Remove the NULLs - state.count = std::partition(index, index + state.count, included) - index; + if (!n) { + rmask.Set(ridx, false); + return; } - if (state.count) { - Interpolator interp(q, state.count, false); - using ID = QuantileIndirect; - ID indirect(data); - rdata[ridx] = replace ? interp.template Replace(index, result, indirect) - : interp.template Operation(index, result, indirect); + const auto &quantile = bind_data.quantiles[0]; + if (gstate && gstate->HasTrees()) { + rdata[ridx] = gstate->template WindowScalar(data, frames, n, result, quantile); } else { - rmask.Set(ridx, false); - } + // Update the skip list + state.UpdateSkip(data, frames, included); - prevs = frames; + // Find the position(s) needed + rdata[ridx] = state.template WindowScalar(data, frames, n, result, quantile); + + // Save the previous state for next time + state.prevs = frames; + } } }; template AggregateFunction GetTypedDiscreteQuantileAggregateFunction(const LogicalType &type) { - using STATE = QuantileState; + using STATE = QuantileState; using OP = QuantileScalarOperation; auto fun = AggregateFunction::UnaryAggregateDestructor(type, type); fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::WindowInit; return fun; } @@ -796,109 +1014,40 @@ struct QuantileListOperation : public QuantileOperation { template static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const vector &frames, - Vector &list, idx_t lidx) { + AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &list, + idx_t lidx, const STATE *gstate) { D_ASSERT(aggr_input_data.bind_data); auto &bind_data = aggr_input_data.bind_data->Cast(); QuantileIncluded included(fmask, dmask); + const auto n = FrameSize(included, frames); // Result is a constant LIST with a fixed length - auto ldata = FlatVector::GetData(list); - auto &lmask = FlatVector::Validity(list); - auto &lentry = ldata[lidx]; - lentry.offset = ListVector::GetListSize(list); - lentry.length = bind_data.quantiles.size(); - - ListVector::Reserve(list, lentry.offset + lentry.length); - ListVector::SetListSize(list, lentry.offset + lentry.length); - auto &result = ListVector::GetEntry(list); - auto rdata = FlatVector::GetData(result); - - // Lazily initialise frame state - const auto prev_count = state.count; - auto &prevs = state.prevs; - state.SetCount(frames); - - auto index = state.w.data(); - - // We can generalise replacement for quantile lists by observing that when a replacement is - // valid for a single quantile, it is valid for all quantiles greater/less than that quantile - // based on whether the insertion is below/above the quantile location. - // So if a replaced index in an IQR is located between Q25 and Q50, but has a value below Q25, - // then Q25 must be recomputed, but Q50 and Q75 are unaffected. - // For a single element list, this reduces to the scalar case. - std::pair replaceable {state.count, 0}; - if (frames.size() == 1 && frames[0].start == prevs[0].start + 1 && frames[0].end == prevs[0].end + 1) { - // Fixed frame size - const auto j = ReplaceIndex(index, frames[0], prevs[0]); - // We can only replace if the number of NULLs has not changed - if (included.AllValid() || included(prevs[0].start) == included(prevs[0].end)) { - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, prev_count, false); - const auto replace = CanReplace(index, data, j, interp.FRN, interp.CRN, included); - if (replace < 0) { - // Replacement is before this quantile, so the rest will be replaceable too. - replaceable.first = MinValue(replaceable.first, interp.FRN); - replaceable.second = prev_count; - break; - } else if (replace > 0) { - // Replacement is after this quantile, so everything before it is replaceable too. - replaceable.first = 0; - replaceable.second = MaxValue(replaceable.second, interp.CRN); - } - } - if (replaceable.first < replaceable.second) { - state.count = prev_count; - } - } - } else { - ReuseIndexes(index, frames, prevs); - } - - if (replaceable.first >= replaceable.second && !included.AllValid()) { - // Remove the NULLs - state.count = std::partition(index, index + state.count, included) - index; + if (!n) { + auto &lmask = FlatVector::Validity(list); + lmask.Set(lidx, false); + return; } - if (state.count) { - using ID = QuantileIndirect; - ID indirect(data); - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, state.count, false); - if (replaceable.first <= interp.FRN && interp.CRN <= replaceable.second) { - rdata[lentry.offset + q] = interp.template Replace(index, result, indirect); - } else { - // Make sure we don't disturb any replacements - if (replaceable.first < replaceable.second) { - if (interp.FRN < replaceable.first) { - interp.end = replaceable.first; - } - if (replaceable.second < interp.CRN) { - interp.begin = replaceable.second; - } - } - rdata[lentry.offset + q] = - interp.template Operation(index, result, indirect); - } - } + if (gstate && gstate->HasTrees()) { + gstate->template WindowList(data, frames, n, list, lidx, bind_data); } else { - lmask.Set(lidx, false); + // + state.UpdateSkip(data, frames, included); + state.template WindowList(data, frames, n, list, lidx, bind_data); + state.prevs = frames; } - - prevs = frames; } }; template AggregateFunction GetTypedDiscreteQuantileListAggregateFunction(const LogicalType &type) { - using STATE = QuantileState; + using STATE = QuantileState; using OP = QuantileListOperation; auto fun = QuantileListAggregate(type, type); fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::template WindowInit; return fun; } @@ -951,11 +1100,12 @@ AggregateFunction GetDiscreteQuantileListAggregateFunction(const LogicalType &ty template AggregateFunction GetTypedContinuousQuantileAggregateFunction(const LogicalType &input_type, const LogicalType &target_type) { - using STATE = QuantileState; + using STATE = QuantileState; using OP = QuantileScalarOperation; auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::template WindowInit; return fun; } @@ -1005,11 +1155,12 @@ AggregateFunction GetContinuousQuantileAggregateFunction(const LogicalType &type template AggregateFunction GetTypedContinuousQuantileListAggregateFunction(const LogicalType &input_type, const LogicalType &result_type) { - using STATE = QuantileState; + using STATE = QuantileState; using OP = QuantileListOperation; auto fun = QuantileListAggregate(input_type, result_type); fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::template WindowInit; return fun; } @@ -1157,82 +1308,59 @@ struct MedianAbsoluteDeviationOperation : public QuantileOperation { template static void Window(const INPUT_TYPE *data, const ValidityMask &fmask, const ValidityMask &dmask, - AggregateInputData &aggr_input_data, STATE &state, const vector &frames, - Vector &result, idx_t ridx) { + AggregateInputData &aggr_input_data, STATE &state, const SubFrames &frames, Vector &result, + idx_t ridx, const STATE *gstate) { auto rdata = FlatVector::GetData(result); - auto &rmask = FlatVector::Validity(result); QuantileIncluded included(fmask, dmask); + const auto n = FrameSize(included, frames); - // Lazily initialise frame state - auto prev_count = state.count; - auto &prevs = state.prevs; - state.SetCount(frames); + if (!n) { + auto &rmask = FlatVector::Validity(result); + rmask.Set(ridx, false); + return; + } - auto index = state.w.data(); - D_ASSERT(index); + // Compute the median + D_ASSERT(aggr_input_data.bind_data); + auto &bind_data = aggr_input_data.bind_data->Cast(); - // We need a second index for the second pass. - if (state.count > state.m.size()) { - state.m.resize(state.count); + D_ASSERT(bind_data.quantiles.size() == 1); + const auto &quantile = bind_data.quantiles[0]; + MEDIAN_TYPE med; + if (gstate && gstate->HasTrees()) { + med = gstate->template WindowScalar(data, frames, n, result, quantile); + } else { + state.UpdateSkip(data, frames, included); + med = state.template WindowScalar(data, frames, n, result, quantile); } + // Lazily initialise frame state + state.SetCount(frames.back().end - frames.front().start); auto index2 = state.m.data(); D_ASSERT(index2); // The replacement trick does not work on the second index because if // the median has changed, the previous order is not correct. // It is probably close, however, and so reuse is helpful. + auto &prevs = state.prevs; ReuseIndexes(index2, frames, prevs); std::partition(index2, index2 + state.count, included); - // Find the two positions needed for the median - D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - const auto &q = bind_data.quantiles[0]; - - bool replace = false; - if (frames.size() == 1 && frames[0].start == prevs[0].start + 1 && frames[0].end == prevs[0].end + 1) { - // Fixed frame size - const auto j = ReplaceIndex(index, frames[0], prevs[0]); - // We can only replace if the number of NULLs has not changed - if (included.AllValid() || included(prevs[0].start) == included(prevs[0].end)) { - Interpolator interp(q, prev_count, false); - replace = CanReplace(index, data, j, interp.FRN, interp.CRN, included); - if (replace) { - state.count = prev_count; - } - } - } else { - ReuseIndexes(index, frames, prevs); - } - - if (!replace && !included.AllValid()) { - // Remove the NULLs - state.count = std::partition(index, index + state.count, included) - index; - } - - if (state.count) { - Interpolator interp(q, state.count, false); + Interpolator interp(quantile, n, false); - // Compute or replace median from the first index - using ID = QuantileIndirect; - ID indirect(data); - const auto med = replace ? interp.template Replace(index, result, indirect) - : interp.template Operation(index, result, indirect); + // Compute mad from the second index + using ID = QuantileIndirect; + ID indirect(data); - // Compute mad from the second index - using MAD = MadAccessor; - MAD mad(med); + using MAD = MadAccessor; + MAD mad(med); - using MadIndirect = QuantileComposed; - MadIndirect mad_indirect(mad, indirect); - rdata[ridx] = interp.template Operation(index2, result, mad_indirect); - } else { - rmask.Set(ridx, false); - } + using MadIndirect = QuantileComposed; + MadIndirect mad_indirect(mad, indirect); + rdata[ridx] = interp.template Operation(index2, result, mad_indirect); + // Prev is used by both skip lists and increments prevs = frames; } }; @@ -1245,12 +1373,13 @@ unique_ptr BindMedian(ClientContext &context, AggregateFunction &f template AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const LogicalType &input_type, const LogicalType &target_type) { - using STATE = QuantileState; + using STATE = QuantileState; using OP = MedianAbsoluteDeviationOperation; auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); fun.bind = BindMedian; fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; fun.window = AggregateFunction::UnaryWindow; + fun.window_init = OP::template WindowInit; return fun; } diff --git a/src/core_functions/aggregate/nested/list.cpp b/src/core_functions/aggregate/nested/list.cpp index 20f167a75bd..22feb8f9b4c 100644 --- a/src/core_functions/aggregate/nested/list.cpp +++ b/src/core_functions/aggregate/nested/list.cpp @@ -147,8 +147,8 @@ static void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_d ListVector::SetListSize(result, total_len); } -static void ListWindow(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, - idx_t input_count, data_ptr_t state, const vector &frames, Vector &result, +static void ListWindow(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result, idx_t rid) { auto &list_bind_data = aggr_input_data.bind_data->Cast(); @@ -156,8 +156,9 @@ static void ListWindow(Vector inputs[], const ValidityMask &filter_mask, Aggrega // UPDATE step - D_ASSERT(input_count == 1); - auto &input = inputs[0]; + D_ASSERT(partition.input_count == 1); + // FIXME: We are modifying the window operator's data here + auto &input = const_cast(partition.inputs[0]); // FIXME: we unify more values than necessary (count is frame.end) const auto count = frames.back().end; diff --git a/src/execution/window_executor.cpp b/src/execution/window_executor.cpp index 53841b72794..9406200e09d 100644 --- a/src/execution/window_executor.cpp +++ b/src/execution/window_executor.cpp @@ -3,6 +3,8 @@ #include "duckdb/common/operator/add.hpp" #include "duckdb/common/operator/subtract.hpp" +#include "duckdb/common/array.hpp" + namespace duckdb { static idx_t FindNextStart(const ValidityMask &mask, idx_t l, const idx_t r, idx_t &n) { @@ -545,6 +547,7 @@ WindowBoundariesState::WindowBoundariesState(BoundWindowExpression &wexpr, const partition_count(wexpr.partitions.size()), order_count(wexpr.orders.size()), range_sense(wexpr.orders.empty() ? OrderType::INVALID : wexpr.orders[0].type), has_preceding_range(HasPrecedingRange(wexpr)), has_following_range(HasFollowingRange(wexpr)), + // if we have EXCLUDE GROUP / TIES, we also need peer boundaries needs_peer(BoundaryNeedsPeer(wexpr.end) || ExpressionNeedsPeer(wexpr.type) || wexpr.exclude_clause >= WindowExcludeMode::GROUP) { } @@ -874,7 +877,6 @@ WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, C const idx_t count, const ValidityMask &partition_mask, const ValidityMask &order_mask, WindowAggregationMode mode) : WindowExecutor(wexpr, context, count, partition_mask, order_mask), mode(mode), filter_executor(context) { - // TODO we could evaluate those expressions in parallel // Check for constant aggregate if (IsConstantAggregate()) { @@ -898,6 +900,7 @@ WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, C } void WindowAggregateExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx, const idx_t total_count) { + // TODO we could evaluate those expressions in parallel idx_t filtered = 0; SelectionVector *filtering = nullptr; if (wexpr.filter_expr) { @@ -922,7 +925,77 @@ void WindowAggregateExecutor::Sink(DataChunk &input_chunk, const idx_t input_idx void WindowAggregateExecutor::Finalize() { D_ASSERT(aggregator); - aggregator->Finalize(); + + // Estimate the frame statistics + // Default to the entire partition if we don't know anything + FrameStats stats; + const int64_t count = aggregator->GetInputs().size(); + + // First entry is the frame start + stats[0] = FrameDelta(-count, count); + auto base = wexpr.expr_stats.empty() ? nullptr : wexpr.expr_stats[0].get(); + switch (wexpr.start) { + case WindowBoundary::UNBOUNDED_PRECEDING: + stats[0].end = 0; + break; + case WindowBoundary::CURRENT_ROW_ROWS: + stats[0].begin = stats[0].end = 0; + break; + case WindowBoundary::EXPR_PRECEDING_ROWS: + if (base && base->GetStatsType() == StatisticsType::NUMERIC_STATS && NumericStats::HasMinMax(*base)) { + // Preceding so negative offset from current row + stats[0].begin = -NumericStats::GetMax(*base); + stats[0].end = -NumericStats::GetMin(*base) + 1; + } + break; + case WindowBoundary::EXPR_FOLLOWING_ROWS: + if (base && base->GetStatsType() == StatisticsType::NUMERIC_STATS && NumericStats::HasMinMax(*base)) { + stats[0].begin = NumericStats::GetMin(*base); + stats[0].end = NumericStats::GetMax(*base) + 1; + } + break; + + case WindowBoundary::CURRENT_ROW_RANGE: + case WindowBoundary::EXPR_PRECEDING_RANGE: + case WindowBoundary::EXPR_FOLLOWING_RANGE: + break; + default: + throw InternalException("Unsupported window start boundary"); + } + + // Second entry is the frame end + stats[1] = FrameDelta(-count, count); + base = wexpr.expr_stats.empty() ? nullptr : wexpr.expr_stats[1].get(); + switch (wexpr.end) { + case WindowBoundary::UNBOUNDED_FOLLOWING: + stats[1].begin = 0; + break; + case WindowBoundary::CURRENT_ROW_ROWS: + stats[1].begin = stats[1].end = 0; + break; + case WindowBoundary::EXPR_PRECEDING_ROWS: + if (base && base->GetStatsType() == StatisticsType::NUMERIC_STATS && NumericStats::HasMinMax(*base)) { + // Preceding so negative offset from current row + stats[1].begin = -NumericStats::GetMax(*base); + stats[1].end = -NumericStats::GetMin(*base) + 1; + } + break; + case WindowBoundary::EXPR_FOLLOWING_ROWS: + if (base && base->GetStatsType() == StatisticsType::NUMERIC_STATS && NumericStats::HasMinMax(*base)) { + stats[1].begin = NumericStats::GetMin(*base); + stats[1].end = NumericStats::GetMax(*base) + 1; + } + break; + + case WindowBoundary::CURRENT_ROW_RANGE: + case WindowBoundary::EXPR_PRECEDING_RANGE: + case WindowBoundary::EXPR_FOLLOWING_RANGE: + break; + default: + throw InternalException("Unsupported window end boundary"); + } + + aggregator->Finalize(stats); } class WindowAggregateState : public WindowExecutorBoundsState { diff --git a/src/execution/window_segment_tree.cpp b/src/execution/window_segment_tree.cpp index b783e4ef079..adb182e4d51 100644 --- a/src/execution/window_segment_tree.cpp +++ b/src/execution/window_segment_tree.cpp @@ -45,7 +45,7 @@ void WindowAggregator::Sink(DataChunk &payload_chunk, SelectionVector *filter_se } } -void WindowAggregator::Finalize() { +void WindowAggregator::Finalize(const FrameStats &stats) { } //===--------------------------------------------------------------------===// @@ -183,7 +183,7 @@ void WindowConstantAggregator::Sink(DataChunk &payload_chunk, SelectionVector *f } } -void WindowConstantAggregator::Finalize() { +void WindowConstantAggregator::Finalize(const FrameStats &stats) { AggegateFinal(*results, partition++); } @@ -248,26 +248,24 @@ WindowCustomAggregator::~WindowCustomAggregator() { class WindowCustomAggregatorState : public WindowAggregatorState { public: - WindowCustomAggregatorState(const AggregateObject &aggr, DataChunk &inputs, const WindowExcludeMode exclude_mode); + WindowCustomAggregatorState(const AggregateObject &aggr, const WindowExcludeMode exclude_mode); ~WindowCustomAggregatorState() override; public: //! The aggregate function const AggregateObject &aggr; - //! The aggregate function - DataChunk &inputs; //! Data pointer that contains a single state, shared by all the custom evaluators vector state; //! Reused result state container for the window functions Vector statef; //! The frame boundaries, used for the window functions - vector frames; + SubFrames frames; }; -WindowCustomAggregatorState::WindowCustomAggregatorState(const AggregateObject &aggr, DataChunk &inputs, +WindowCustomAggregatorState::WindowCustomAggregatorState(const AggregateObject &aggr, const WindowExcludeMode exclude_mode) - : aggr(aggr), inputs(inputs), state(aggr.function.state_size()), - statef(Value::POINTER(CastPointerToValue(state.data()))), frames(3, {0, 0}) { + : aggr(aggr), state(aggr.function.state_size()), statef(Value::POINTER(CastPointerToValue(state.data()))), + frames(3, {0, 0}) { // if we have a frame-by-frame method, share the single state aggr.function.initialize(state.data()); @@ -294,8 +292,22 @@ WindowCustomAggregatorState::~WindowCustomAggregatorState() { } } +void WindowCustomAggregator::Finalize(const FrameStats &stats) { + WindowAggregator::Finalize(stats); + partition_input = + make_uniq(inputs.data.data(), inputs.ColumnCount(), inputs.size(), filter_mask, stats); + + if (aggr.function.window_init) { + gstate = GetLocalState(); + auto &gcstate = gstate->Cast(); + + AggregateInputData aggr_input_data(aggr.GetFunctionData(), gcstate.allocator); + aggr.function.window_init(aggr_input_data, *partition_input, gcstate.state.data()); + } +} + unique_ptr WindowCustomAggregator::GetLocalState() const { - return make_uniq(aggr, const_cast(inputs), exclude_mode); + return make_uniq(aggr, exclude_mode); } void WindowCustomAggregator::Evaluate(WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, @@ -305,11 +317,14 @@ void WindowCustomAggregator::Evaluate(WindowAggregatorState &lstate, const DataC auto peer_begin = FlatVector::GetData(bounds.data[PEER_BEGIN]); auto peer_end = FlatVector::GetData(bounds.data[PEER_END]); - // TODO: window should take a const Vector* auto &lcstate = lstate.Cast(); auto &frames = lcstate.frames; - auto params = lcstate.inputs.data.data(); auto &rmask = FlatVector::Validity(result); + const_data_ptr_t gstate_p = nullptr; + if (gstate) { + auto &gcstate = gstate->Cast(); + gstate_p = gcstate.state.data(); + } for (idx_t i = 0, cur_row = row_idx; i < count; ++i, ++cur_row) { idx_t nframes = 0; idx_t non_empty = 0; @@ -368,8 +383,7 @@ void WindowCustomAggregator::Evaluate(WindowAggregatorState &lstate, const DataC // Extract the range AggregateInputData aggr_input_data(aggr.GetFunctionData(), lstate.allocator); - aggr.function.window(params, filter_mask, aggr_input_data, inputs.ColumnCount(), lcstate.state.data(), frames, - result, i); + aggr.function.window(aggr_input_data, *partition_input, gstate_p, lcstate.state.data(), frames, result, i); } } @@ -381,7 +395,9 @@ WindowSegmentTree::WindowSegmentTree(AggregateObject aggr, const LogicalType &re : WindowAggregator(std::move(aggr), result_type, exclude_mode_p, count), internal_nodes(0), mode(mode_p) { } -void WindowSegmentTree::Finalize() { +void WindowSegmentTree::Finalize(const FrameStats &stats) { + WindowAggregator::Finalize(stats); + gstate = GetLocalState(); if (inputs.ColumnCount() > 0) { if (aggr.function.combine && UseCombineAPI()) { diff --git a/src/function/aggregate/distributive/count.cpp b/src/function/aggregate/distributive/count.cpp index 9631665c9d1..99f02017883 100644 --- a/src/function/aggregate/distributive/count.cpp +++ b/src/function/aggregate/distributive/count.cpp @@ -34,10 +34,9 @@ struct CountStarFunction : public BaseCountFunction { } template - static void Window(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, - idx_t input_count, data_ptr_t state, const vector &frames, Vector &result, - idx_t rid) { - D_ASSERT(input_count == 0); + static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, const_data_ptr_t, + data_ptr_t l_state, const SubFrames &frames, Vector &result, idx_t rid) { + D_ASSERT(partition.input_count == 0); auto data = FlatVector::GetData(result); RESULT_TYPE total = 0; @@ -46,12 +45,12 @@ struct CountStarFunction : public BaseCountFunction { const auto end = frame.end; // Slice to any filtered rows - if (filter_mask.AllValid()) { + if (partition.filter_mask.AllValid()) { total += end - begin; continue; } for (auto i = begin; i < end; ++i) { - total += filter_mask.RowIsValid(i); + total += partition.filter_mask.RowIsValid(i); } } data[rid] = total; diff --git a/src/function/aggregate/sorted_aggregate_function.cpp b/src/function/aggregate/sorted_aggregate_function.cpp index 55e20470fb5..a0c4951d58e 100644 --- a/src/function/aggregate/sorted_aggregate_function.cpp +++ b/src/function/aggregate/sorted_aggregate_function.cpp @@ -351,8 +351,8 @@ struct SortedAggregateFunction { target.Combine(order_bind, other); } - static void Window(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, - idx_t input_count, data_ptr_t state, const vector &frames, Vector &result, + static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &subframes, Vector &result, idx_t rid) { throw InternalException("Sorted aggregates should not be generated for window clauses"); } diff --git a/src/include/duckdb/common/vector_operations/aggregate_executor.hpp b/src/include/duckdb/common/vector_operations/aggregate_executor.hpp index 60627608044..312dd065580 100644 --- a/src/include/duckdb/common/vector_operations/aggregate_executor.hpp +++ b/src/include/duckdb/common/vector_operations/aggregate_executor.hpp @@ -17,6 +17,8 @@ namespace duckdb { // structs struct AggregateInputData; + +// The bounds of a window frame struct FrameBounds { FrameBounds() : start(0), end(0) {}; FrameBounds(idx_t start, idx_t end) : start(start), end(end) {}; @@ -24,6 +26,9 @@ struct FrameBounds { idx_t end = 0; }; +// A set of window subframes for windowed EXCLUDE +using SubFrames = vector; + class AggregateExecutor { private: template @@ -382,21 +387,21 @@ class AggregateExecutor { } } - using Frames = vector; - template - static void UnaryWindow(Vector &input, const ValidityMask &ifilter, AggregateInputData &aggr_input_data, - data_ptr_t state_p, const Frames &frames, Vector &result, idx_t rid) { + static void UnaryWindow(const Vector &input, const ValidityMask &ifilter, AggregateInputData &aggr_input_data, + data_ptr_t state_p, const SubFrames &frames, Vector &result, idx_t ridx, + const_data_ptr_t gstate_p) { auto idata = FlatVector::GetData(input); const auto &ivalid = FlatVector::Validity(input); auto &state = *reinterpret_cast(state_p); + auto gstate = reinterpret_cast(gstate_p); OP::template Window(idata, ifilter, ivalid, aggr_input_data, state, frames, - result, rid); + result, ridx, gstate); } template - static void IntersectFrames(const Frames &lefts, const Frames &rights, OP &op) { + static void IntersectFrames(const SubFrames &lefts, const SubFrames &rights, OP &op) { const auto cover_start = MinValue(rights[0].start, lefts[0].start); const auto cover_end = MaxValue(rights.back().end, lefts.back().end); const FrameBounds last(cover_end, cover_end); diff --git a/src/include/duckdb/execution/merge_sort_tree.hpp b/src/include/duckdb/execution/merge_sort_tree.hpp new file mode 100644 index 00000000000..c568b1563fe --- /dev/null +++ b/src/include/duckdb/execution/merge_sort_tree.hpp @@ -0,0 +1,397 @@ +//===----------------------------------------------------------------------===// +// DuckDB +// +// duckdb/execution/merge_sort_tree.hpp +// +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include "duckdb/common/array.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/pair.hpp" +#include "duckdb/common/typedefs.hpp" +#include "duckdb/common/vector.hpp" +#include "duckdb/common/vector_operations/aggregate_executor.hpp" + +namespace duckdb { + +// MIT License Text: +// +// Copyright 2022 salesforce.com, inc. +// +// Permission is hereby granted, free of charge, to any person obtaining +// a copy of this software and associated documentation files (the +// "Software"), to deal in the Software without restriction, including +// without limitation the rights to use, copy, modify, merge, publish, +// distribute, sublicense, and/or sell copies of the Software, and to +// permit persons to whom the Software is furnished to do so, subject to +// the following conditions: +// +// The above copyright notice and this permission notice shall be +// included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS +// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Implementation of a generic merge-sort-tree +// Rewrite of the original, which was in C++17 and targeted for research, +// instead of deployment. +template , uint64_t F = 32, uint64_t C = 32> +struct MergeSortTree { + using ElementType = E; + using OffsetType = O; + using Elements = vector; + using Offsets = vector; + using Level = pair; + using Tree = vector; + + using RunElement = pair; + using RunElements = array; + using Games = array; + + struct CompareElements { + explicit CompareElements(const CMP &cmp) : cmp(cmp) { + } + + bool operator()(const RunElement &lhs, const RunElement &rhs) { + if (cmp(lhs.first, rhs.first)) { + return true; + } + if (cmp(rhs.first, lhs.first)) { + return false; + } + return lhs.second < rhs.second; + } + + CMP cmp; + }; + + MergeSortTree() { + } + explicit MergeSortTree(Elements &&lowest_level, const CMP &cmp = CMP()); + + idx_t SelectNth(const SubFrames &frames, idx_t n) const; + + inline ElementType NthElement(idx_t i) const { + return tree.front().first[i]; + } + +protected: + RunElement StartGames(Games &losers, const RunElements &elements, const RunElement &sentinel) { + const auto elem_nodes = elements.size(); + const auto game_nodes = losers.size(); + Games winners; + + // Play the first round of games, + // placing the losers at the bottom of the game + const auto base_offset = game_nodes / 2; + auto losers_base = losers.data() + base_offset; + auto winners_base = winners.data() + base_offset; + + const auto base_count = elem_nodes / 2; + for (idx_t i = 0; i < base_count; ++i) { + const auto &e0 = elements[i * 2 + 0]; + const auto &e1 = elements[i * 2 + 1]; + if (cmp(e0, e1)) { + losers_base[i] = e1; + winners_base[i] = e0; + } else { + losers_base[i] = e0; + winners_base[i] = e1; + } + } + + // Fill in any byes + if (elem_nodes % 2) { + winners_base[base_count] = elements.back(); + losers_base[base_count] = sentinel; + } + + // Pad to a power of 2 + const auto base_size = (game_nodes + 1) / 2; + for (idx_t i = (elem_nodes + 1) / 2; i < base_size; ++i) { + winners_base[i] = sentinel; + losers_base[i] = sentinel; + } + + // Play the winners against each other + // and stick the losers in the upper levels of the tournament tree + for (idx_t i = base_offset; i-- > 0;) { + // Indexing backwards + const auto &e0 = winners[i * 2 + 1]; + const auto &e1 = winners[i * 2 + 2]; + if (cmp(e0, e1)) { + losers[i] = e1; + winners[i] = e0; + } else { + losers[i] = e0; + winners[i] = e1; + } + } + + // Return the final winner + return winners[0]; + } + + RunElement ReplayGames(Games &losers, idx_t slot_idx, const RunElement &insert_val) { + RunElement smallest = insert_val; + // Start at a fake level below + auto idx = slot_idx + losers.size(); + do { + // Parent index + idx = (idx - 1) / 2; + // swap if out of order + if (cmp(losers[idx], smallest)) { + std::swap(losers[idx], smallest); + } + } while (idx); + + return smallest; + } + + Tree tree; + CompareElements cmp; + + static constexpr auto FANOUT = F; + static constexpr auto CASCADING = C; + + static idx_t LowestCascadingLevel() { + idx_t level = 0; + idx_t level_width = 1; + while (level_width <= CASCADING) { + ++level; + level_width *= FANOUT; + } + return level; + } +}; + +template +MergeSortTree::MergeSortTree(Elements &&lowest_level, const CMP &cmp) : cmp(cmp) { + const auto fanout = F; + const auto cascading = C; + const auto count = lowest_level.size(); + tree.emplace_back(Level(lowest_level, Offsets())); + + const RunElement SENTINEL(std::numeric_limits::max(), std::numeric_limits::max()); + + // Fan in parent levels until we are at the top + // Note that we don't build the top layer as that would just be all the data. + for (idx_t child_run_length = 1; child_run_length < count;) { + const auto run_length = child_run_length * fanout; + const auto num_runs = (count + run_length - 1) / run_length; + + Elements elements; + elements.reserve(count); + + // Allocate cascading pointers only if there is room + Offsets cascades; + if (cascading > 0 && run_length > cascading) { + const auto num_cascades = fanout * num_runs * (run_length / cascading + 2); + cascades.reserve(num_cascades); + } + + // Create each parent run by merging the child runs using a tournament tree + // https://en.wikipedia.org/wiki/K-way_merge_algorithm + // TODO: Because the runs are independent, they can be parallelised with parallel_for + const auto &child_level = tree.back(); + for (idx_t run_idx = 0; run_idx < num_runs; ++run_idx) { + // Position markers for scanning the children. + using Bounds = pair; + array bounds; + // Start with first element of each (sorted) child run + RunElements players; + const auto child_base = run_idx * run_length; + for (idx_t child_run = 0; child_run < fanout; ++child_run) { + const auto child_idx = child_base + child_run * child_run_length; + bounds[child_run] = {MinValue(child_idx, count), + MinValue(child_idx + child_run_length, count)}; + if (bounds[child_run].first != bounds[child_run].second) { + players[child_run] = {child_level.first[child_idx], child_run}; + } else { + // Empty child + players[child_run] = SENTINEL; + } + } + + // Play the first round and extract the winner + Games games; + auto winner = StartGames(games, players, SENTINEL); + while (winner != SENTINEL) { + // Add fractional cascading pointers + // if we are on a fraction boundary + if (cascading > 0 && run_length > cascading && elements.size() % cascading == 0) { + for (idx_t i = 0; i < fanout; ++i) { + cascades.emplace_back(bounds[i].first); + } + } + + // Insert new winner element into the current run + elements.emplace_back(winner.first); + const auto child_run = winner.second; + auto &child_idx = bounds[child_run].first; + ++child_idx; + + // Move to the next entry in the child run (if any) + if (child_idx < bounds[child_run].second) { + winner = ReplayGames(games, child_run, {child_level.first[child_idx], child_run}); + } else { + winner = ReplayGames(games, child_run, SENTINEL); + } + } + + // Add terminal cascade pointers to the end + if (cascading > 0 && run_length > cascading) { + for (idx_t j = 0; j < 2; ++j) { + for (idx_t i = 0; i < fanout; ++i) { + cascades.emplace_back(bounds[i].first); + } + } + } + } + + // Insert completed level and move up to the next one + tree.emplace_back(std::move(elements), std::move(cascades)); + child_run_length = run_length; + } +} + +template +idx_t MergeSortTree::SelectNth(const SubFrames &frames, idx_t n) const { + // Handle special case of a one-element tree + if (tree.size() < 2) { + return 0; + } + + // The first level contains a single run, + // so the only thing we need is any cascading pointers + auto level_no = tree.size() - 2; + auto level_width = 1; + for (idx_t i = 0; i < level_no; ++i) { + level_width *= FANOUT; + } + + // Find Nth element in a top-down traversal + idx_t result = 0; + + // First, handle levels with cascading pointers + const auto min_cascaded = LowestCascadingLevel(); + if (level_no > min_cascaded) { + // Initialise the cascade indicies from the previous level + using CascadeRange = pair; + std::array cascades; + const auto &level = tree[level_no + 1].first; + for (idx_t f = 0; f < frames.size(); ++f) { + const auto &frame = frames[f]; + auto &cascade_idx = cascades[f]; + const auto lower_idx = std::lower_bound(level.begin(), level.end(), frame.start) - level.begin(); + cascade_idx.first = lower_idx / CASCADING * FANOUT; + const auto upper_idx = std::lower_bound(level.begin(), level.end(), frame.end) - level.begin(); + cascade_idx.second = upper_idx / CASCADING * FANOUT; + } + + // Walk the cascaded levels + for (; level_no >= min_cascaded; --level_no) { + // The cascade indicies into this level are in the previous level + const auto &level_cascades = tree[level_no + 1].second; + + // Go over all children until we found enough in range + const auto *level_data = tree[level_no].first.data(); + while (true) { + idx_t matched = 0; + std::array matches; + for (idx_t f = 0; f < frames.size(); ++f) { + const auto &frame = frames[f]; + auto &cascade_idx = cascades[f]; + auto &match = matches[f]; + + const auto lower_begin = level_data + level_cascades[cascade_idx.first]; + const auto lower_end = level_data + level_cascades[cascade_idx.first + FANOUT]; + match.first = std::lower_bound(lower_begin, lower_end, frame.start) - level_data; + + const auto upper_begin = level_data + level_cascades[cascade_idx.second]; + const auto upper_end = level_data + level_cascades[cascade_idx.second + FANOUT]; + match.second = std::lower_bound(upper_begin, upper_end, frame.end) - level_data; + + matched += idx_t(match.second - match.first); + } + if (matched > n) { + // Too much in this level, so move down to leftmost child candidate within the cascade range + for (idx_t f = 0; f < frames.size(); ++f) { + auto &cascade_idx = cascades[f]; + auto &match = matches[f]; + cascade_idx.first = (match.first / CASCADING + 2 * result) * FANOUT; + cascade_idx.second = (match.second / CASCADING + 2 * result) * FANOUT; + } + break; + } + + // Not enough in this child, so move right + for (idx_t f = 0; f < frames.size(); ++f) { + auto &cascade_idx = cascades[f]; + ++cascade_idx.first; + ++cascade_idx.second; + } + ++result; + n -= matched; + } + result *= FANOUT; + level_width /= FANOUT; + } + } + + // Continue with the uncascaded levels (except the first) + for (; level_no > 0; --level_no) { + const auto &level = tree[level_no].first; + auto range_begin = level.begin() + result * level_width; + auto range_end = range_begin + level_width; + while (range_end < level.end()) { + idx_t matched = 0; + for (idx_t f = 0; f < frames.size(); ++f) { + const auto &frame = frames[f]; + const auto lower_match = std::lower_bound(range_begin, range_end, frame.start); + const auto upper_match = std::lower_bound(lower_match, range_end, frame.end); + matched += idx_t(upper_match - lower_match); + } + if (matched > n) { + // Too much in this level, so move down to leftmost child candidate + // Since we have no cascade pointers left, this is just the start of the next level. + break; + } + // Not enough in this child, so move right + range_begin = range_end; + range_end += level_width; + ++result; + n -= matched; + } + result *= FANOUT; + level_width /= FANOUT; + } + + // The last level + const auto *level_data = tree[level_no].first.data(); + ++n; + + const auto count = tree[level_no].first.size(); + for (const auto limit = MinValue(result + FANOUT, count); result < limit; ++result) { + const auto v = level_data[result]; + for (const auto &frame : frames) { + n -= (v >= frame.start) && (v < frame.end); + } + if (!n) { + break; + } + } + + return result; +} + +} // namespace duckdb diff --git a/src/include/duckdb/execution/window_segment_tree.hpp b/src/include/duckdb/execution/window_segment_tree.hpp index 0cfce8af568..28609e193a7 100644 --- a/src/include/duckdb/execution/window_segment_tree.hpp +++ b/src/include/duckdb/execution/window_segment_tree.hpp @@ -44,9 +44,14 @@ class WindowAggregator { idx_t partition_count); virtual ~WindowAggregator(); + // Access + const DataChunk &GetInputs() const { + return inputs; + } + // Build virtual void Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered); - virtual void Finalize(); + virtual void Finalize(const FrameStats &stats); // Probe virtual unique_ptr GetLocalState() const = 0; @@ -85,7 +90,7 @@ class WindowConstantAggregator : public WindowAggregator { } void Sink(DataChunk &payload_chunk, SelectionVector *filter_sel, idx_t filtered) override; - void Finalize() override; + void Finalize(const FrameStats &stats) override; unique_ptr GetLocalState() const override; void Evaluate(WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, idx_t count, @@ -117,9 +122,17 @@ class WindowCustomAggregator : public WindowAggregator { const WindowExcludeMode exclude_mode_p, idx_t partition_count); ~WindowCustomAggregator() override; + void Finalize(const FrameStats &stats) override; + unique_ptr GetLocalState() const override; void Evaluate(WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const override; + + //! Partition description + unique_ptr partition_input; + + //! Data pointer that contains a single state, used for global custom window state + unique_ptr gstate; }; class WindowSegmentTree : public WindowAggregator { @@ -129,7 +142,7 @@ class WindowSegmentTree : public WindowAggregator { const WindowExcludeMode exclude_mode_p, idx_t count); ~WindowSegmentTree() override; - void Finalize() override; + void Finalize(const FrameStats &stats) override; unique_ptr GetLocalState() const override; void Evaluate(WindowAggregatorState &lstate, const DataChunk &bounds, Vector &result, idx_t count, diff --git a/src/include/duckdb/function/aggregate_function.hpp b/src/include/duckdb/function/aggregate_function.hpp index 7b23a1ab900..aff5a71bff6 100644 --- a/src/include/duckdb/function/aggregate_function.hpp +++ b/src/include/duckdb/function/aggregate_function.hpp @@ -8,6 +8,7 @@ #pragma once +#include "duckdb/common/array.hpp" #include "duckdb/common/vector_operations/aggregate_executor.hpp" #include "duckdb/function/aggregate_state.hpp" #include "duckdb/planner/bound_result_modifier.hpp" @@ -15,6 +16,31 @@ namespace duckdb { +//! A half-open range of frame boundary values _relative to the current row_ +//! This is why they are signed values. +struct FrameDelta { + FrameDelta() : begin(0), end(0) {}; + FrameDelta(int64_t begin, int64_t end) : begin(begin), end(end) {}; + int64_t begin = 0; + int64_t end = 0; +}; + +//! The half-open ranges of frame boundary values relative to the current row +using FrameStats = array; + +//! The partition data for custom window functions +struct WindowPartitionInput { + WindowPartitionInput(const Vector inputs[], idx_t input_count, idx_t count, const ValidityMask &filter_mask, + const FrameStats &stats) + : inputs(inputs), input_count(input_count), count(count), filter_mask(filter_mask), stats(stats) { + } + const Vector *inputs; + idx_t input_count; + idx_t count; + const ValidityMask &filter_mask; + const FrameStats stats; +}; + //! The type used for sizing hashed aggregate function states typedef idx_t (*aggregate_size_t)(); //! The type used for initializing hashed aggregate function states @@ -40,10 +66,14 @@ typedef void (*aggregate_destructor_t)(Vector &state, AggregateInputData &aggr_i typedef void (*aggregate_simple_update_t)(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, idx_t count); -//! The type used for updating complex windowed aggregate functions (optional) -typedef void (*aggregate_window_t)(Vector inputs[], const ValidityMask &filter_mask, - AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, - const vector &frames, Vector &result, idx_t rid); +//! The type used for computing complex/custom windowed aggregate functions (optional) +typedef void (*aggregate_window_t)(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &subframes, + Vector &result, idx_t rid); + +//! The type used for initializing shared complex/custom windowed aggregate state (optional) +typedef void (*aggregate_wininit_t)(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + data_ptr_t g_state); typedef void (*aggregate_serialize_t)(Serializer &serializer, const optional_ptr bind_data, const AggregateFunction &function); @@ -115,8 +145,10 @@ class AggregateFunction : public BaseScalarFunction { aggregate_finalize_t finalize; //! The simple aggregate update function (may be null) aggregate_simple_update_t simple_update; - //! The windowed aggregate frame update function (may be null) + //! The windowed aggregate custom function (may be null) aggregate_window_t window; + //! The windowed aggregate custom initialization function (may be null) + aggregate_wininit_t window_init = nullptr; //! The bind function (may be null) bind_aggregate_function_t bind; @@ -217,12 +249,13 @@ class AggregateFunction : public BaseScalarFunction { } template - static void UnaryWindow(Vector inputs[], const ValidityMask &filter_mask, AggregateInputData &aggr_input_data, - idx_t input_count, data_ptr_t state, const vector &frames, Vector &result, + static void UnaryWindow(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, + const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &subframes, Vector &result, idx_t rid) { - D_ASSERT(input_count == 1); - AggregateExecutor::UnaryWindow(inputs[0], filter_mask, aggr_input_data, - state, frames, result, rid); + + D_ASSERT(partition.input_count == 1); + AggregateExecutor::UnaryWindow( + partition.inputs[0], partition.filter_mask, aggr_input_data, l_state, subframes, result, rid, g_state); } template diff --git a/src/include/duckdb/planner/expression/bound_window_expression.hpp b/src/include/duckdb/planner/expression/bound_window_expression.hpp index 4bc17cac7db..8fbfcb89dd2 100644 --- a/src/include/duckdb/planner/expression/bound_window_expression.hpp +++ b/src/include/duckdb/planner/expression/bound_window_expression.hpp @@ -52,6 +52,9 @@ class BoundWindowExpression : public Expression { unique_ptr offset_expr; unique_ptr default_expr; + //! Statistics belonging to the other expressions (start, end, offset, default) + vector> expr_stats; + public: bool IsWindow() const override { return true; diff --git a/src/optimizer/statistics/operator/propagate_window.cpp b/src/optimizer/statistics/operator/propagate_window.cpp index 1bfd3fc0610..9b6dcc493df 100644 --- a/src/optimizer/statistics/operator/propagate_window.cpp +++ b/src/optimizer/statistics/operator/propagate_window.cpp @@ -11,13 +11,37 @@ unique_ptr StatisticsPropagator::PropagateStatistics(LogicalWind // then propagate to each of the order expressions for (auto &window_expr : window.expressions) { - auto over_expr = reinterpret_cast(window_expr.get()); - for (auto &expr : over_expr->partitions) { - over_expr->partitions_stats.push_back(PropagateExpression(expr)); + auto &over_expr = window_expr->Cast(); + for (auto &expr : over_expr.partitions) { + over_expr.partitions_stats.push_back(PropagateExpression(expr)); } - for (auto &bound_order : over_expr->orders) { + for (auto &bound_order : over_expr.orders) { bound_order.stats = PropagateExpression(bound_order.expression); } + + if (over_expr.start_expr) { + over_expr.expr_stats.push_back(PropagateExpression(over_expr.start_expr)); + } else { + over_expr.expr_stats.push_back(nullptr); + } + + if (over_expr.end_expr) { + over_expr.expr_stats.push_back(PropagateExpression(over_expr.end_expr)); + } else { + over_expr.expr_stats.push_back(nullptr); + } + + if (over_expr.offset_expr) { + over_expr.expr_stats.push_back(PropagateExpression(over_expr.offset_expr)); + } else { + over_expr.expr_stats.push_back(nullptr); + } + + if (over_expr.default_expr) { + over_expr.expr_stats.push_back(PropagateExpression(over_expr.default_expr)); + } else { + over_expr.expr_stats.push_back(nullptr); + } } return std::move(node_stats); } diff --git a/src/planner/expression/bound_window_expression.cpp b/src/planner/expression/bound_window_expression.cpp index 58826f587bb..5b52b3be6c6 100644 --- a/src/planner/expression/bound_window_expression.cpp +++ b/src/planner/expression/bound_window_expression.cpp @@ -109,6 +109,13 @@ unique_ptr BoundWindowExpression::Copy() { new_window->default_expr = default_expr ? default_expr->Copy() : nullptr; new_window->ignore_nulls = ignore_nulls; + for (auto &es : expr_stats) { + if (es) { + new_window->expr_stats.push_back(es->ToUnique()); + } else { + new_window->expr_stats.push_back(nullptr); + } + } return std::move(new_window); } diff --git a/src/storage/compression/validity_uncompressed.cpp b/src/storage/compression/validity_uncompressed.cpp index 66b8f4f857c..feb46b7a6aa 100644 --- a/src/storage/compression/validity_uncompressed.cpp +++ b/src/storage/compression/validity_uncompressed.cpp @@ -242,7 +242,7 @@ void ValidityScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t s for (idx_t i = 0; i < scan_count; i++) { if (!source_mask.RowIsValid(start + i)) { if (result_mask.AllValid()) { - result_mask.Initialize(MaxValue(STANDARD_VECTOR_SIZE, result_offset + scan_count)); + result_mask.Initialize(); } result_mask.SetInvalid(result_offset + i); } @@ -323,7 +323,7 @@ void ValidityScanPartial(ColumnSegment &segment, ColumnScanState &state, idx_t s // now finally we can merge the input mask with the result mask if (input_mask != ValidityMask::ValidityBuffer::MAX_ENTRY) { if (!result_data) { - result_mask.Initialize(MaxValue(STANDARD_VECTOR_SIZE, result_offset + scan_count)); + result_mask.Initialize(); result_data = (validity_t *)result_mask.GetData(); } result_data[current_result_idx] &= input_mask; @@ -363,7 +363,7 @@ void ValidityScan(ColumnSegment &segment, ColumnScanState &state, idx_t scan_cou continue; } if (!result_data) { - result_mask.Initialize(MaxValue(STANDARD_VECTOR_SIZE, scan_count)); + result_mask.Initialize(); result_data = result_mask.GetData(); } result_data[i] = input_entry; diff --git a/test/sql/window/test_quantile_window.test_coverage b/test/sql/window/test_quantile_window.test_coverage new file mode 100644 index 00000000000..054a0d052a2 --- /dev/null +++ b/test/sql/window/test_quantile_window.test_coverage @@ -0,0 +1,80 @@ +# name: test/sql/window/test_quantile_window.test_coverage +# description: Moving QUANTILE coverage, fixed or variable 100 element frame for MEDIAN, IQR, and MAD +# group: [window] + +# Common table +statement ok +create table rank100 as + select b % 100 as a, b from range(10000000) tbl(b) + +# window_median_fixed_100 +query I +select sum(m) +from ( + select median(a) over ( + order by b asc + rows between 100 preceding and current row) as m + from rank100 + ) q; +---- +494997500 + +# window_median_variable_100 +query I +select sum(m) +from ( + select median(a) over ( + order by b asc + rows between mod(b * 47, 521) preceding and 100 - mod(b * 47, 521) following) as m + from rank100 + ) q; +---- +494989867 + +# window_iqr_fixed_100 +query II +select min(iqr), max(iqr) +from ( + select quantile_cont(a, [0.25, 0.5, 0.75]) over ( + order by b asc + rows between 100 preceding and current row) as iqr + from rank100 + ) q; +---- +[0.0, 0.0, 0.0] [25.0, 50.0, 75.0] + +# window_iqr_variable_100 +query II +select min(iqr), max(iqr) +from ( + select quantile_cont(a, [0.25, 0.5, 0.75]) over ( + order by b asc + rows between mod(b * 47, 521) preceding and 100 - mod(b * 47, 521) following) as iqr + from rank100 + ) q; +---- +[0.0, 0.0, 0.0] [76.5, 84.0, 91.5] + +# window_mad_fixed_100 +query I +select sum(m) +from ( + select mad(a) over ( + order by b asc + rows between 100 preceding and current row) as m + from rank100 + ) q; +---- +249998762.5 + +# +query I +select sum(m) +from ( + select mad(a) over ( + order by b asc + rows between mod(b * 47, 521) preceding and 100 - mod(b * 47, 521) following) as m + from rank100 + ) q; +---- +249994596.000000 diff --git a/test/sql/window/test_window_exclude.test b/test/sql/window/test_window_exclude.test index c2e8b39c5bd..90417317d38 100644 --- a/test/sql/window/test_window_exclude.test +++ b/test/sql/window/test_window_exclude.test @@ -704,6 +704,38 @@ ORDER BY i; 5 4.5 5 5.0 +# Test Merge Sort Trees with exclusions +query III +WITH t1(x, y) AS (VALUES + ( 1, 3 ), + ( 2, 2 ), + ( 3, 1 ) +) +SELECT x, y, QUANTILE_DISC(y, 0) OVER ( + ORDER BY x + ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW + EXCLUDE CURRENT ROW) +FROM t1; +---- +1 3 NULL +2 2 3 +3 1 2 + +query III +WITH t1(x, y) AS (VALUES + ( 1, 3 ), + ( 2, 2 ), + ( 3, 1 ) +) +SELECT x, y, QUANTILE_DISC(y, 0) OVER ( + ORDER BY x + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + EXCLUDE CURRENT ROW) +FROM t1; +---- +1 3 1 +2 2 1 +3 1 2 # PG test diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index e2248f8ac24..7452dc78bac 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -8,6 +8,7 @@ if(NOT AMALGAMATION_BUILD) add_subdirectory(miniz) add_subdirectory(utf8proc) add_subdirectory(hyperloglog) + add_subdirectory(skiplist) add_subdirectory(fastpforlib) add_subdirectory(mbedtls) add_subdirectory(fsst) diff --git a/third_party/skiplist/CMakeLists.txt b/third_party/skiplist/CMakeLists.txt new file mode 100644 index 00000000000..1fa1e839122 --- /dev/null +++ b/third_party/skiplist/CMakeLists.txt @@ -0,0 +1,17 @@ +if(POLICY CMP0063) + cmake_policy(SET CMP0063 NEW) +endif() + +add_library(duckdb_skiplistlib STATIC SkipList.cpp) + +target_include_directories( + duckdb_skiplistlib + PUBLIC $) +set_target_properties(duckdb_skiplistlib PROPERTIES EXPORT_NAME duckdb_skiplistlib) + +install(TARGETS duckdb_skiplistlib + EXPORT "${DUCKDB_EXPORT_SET}" + LIBRARY DESTINATION "${INSTALL_LIB_DIR}" + ARCHIVE DESTINATION "${INSTALL_LIB_DIR}") + +disable_target_warnings(duckdb_skiplistlib) diff --git a/third_party/skiplist/HeadNode.h b/third_party/skiplist/HeadNode.h new file mode 100755 index 00000000000..36cd866a48c --- /dev/null +++ b/third_party/skiplist/HeadNode.h @@ -0,0 +1,934 @@ +/** + * @file + * + * Project: skiplist + * + * Created by Paul Ross on 03/12/2015. + * + * Copyright (c) 2015-2023 Paul Ross. All rights reserved. + * + * @code + * MIT License + * + * Copyright (c) 2017-2023 Paul Ross + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * @endcode + */ + +#ifndef SkipList_HeadNode_h +#define SkipList_HeadNode_h + +#include +//#ifdef SKIPLIST_THREAD_SUPPORT +// #include +//#endif +#include + +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS +#include +#endif // INCLUDE_METHODS_THAT_USE_STREAMS + +#include "IntegrityEnums.h" + +/** HeadNode + * + * @brief A HeadNode is a skip list. This is the single node leading to all other content Nodes. + * + * Example: + * + * @code + * OrderedStructs::SkipList::HeadNode sl; + * for (int i = 0; i < 100; ++i) { + * sl.insert(i * 22.0 / 7.0); + * } + * sl.size(); // 100 + * sl.at(50); // Value of 50 pi + * sl.remove(sl.at(50)); // Remove 50 pi + * @endcode + * + * Created by Paul Ross on 03/12/2015. + * + * Copyright (c) 2015-2023 Paul Ross. All rights reserved. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + */ +template > +class HeadNode { +public: + /** + * Constructor for and Empty Skip List. + * + * @param cmp The comparison function for comparing Node values. + */ + HeadNode(_Compare cmp=_Compare()) : _count(0), _compare(cmp), _pool(cmp) { +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS + _dot_file_subgraph = 0; +#endif + } + // Const methods + // + // Returns true if the value is present in the skip list. + bool has(const T &value) const; + // Returns the value at the index in the skip list. + // Will throw an OrderedStructs::SkipList::IndexError if index out of range. + const T &at(size_t index) const; + // Find the value at index and write count values to dest. + // Will throw a SkipList::IndexError if any index out of range. + // This is useful for rolling median on even length lists where + // the caller might want to implement the mean of two values. + void at(size_t index, size_t count, std::vector &dest) const; + // Computes index of the first occurrence of a value + // Will throw a ValueError if the value does not exist in the skip list + size_t index(const T& value) const; + // Number of values in the skip list. + size_t size() const; + // Non-const methods + // + // Insert a value. + void insert(const T &value); + // Remove a value and return it. + // Will throw a ValueError is value not present. + T remove(const T &value); + + // Const methods that are mostly used for debugging and visualisation. + // + // Number of linked lists that are in the skip list. + size_t height() const; + // Number of linked lists that the node at index has. + // Will throw a SkipList::IndexError if idx out of range. + size_t height(size_t idx) const; + // The skip width of the node at index has. + // May throw a SkipList::IndexError + size_t width(size_t idx, size_t level) const; + +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS + void dotFile(std::ostream &os) const; + void dotFileFinalise(std::ostream &os) const; +#endif // INCLUDE_METHODS_THAT_USE_STREAMS + + // Returns non-zero if the integrity of this data structure is compromised + // This is a thorough but expensive check! + IntegrityCheck lacksIntegrity() const; + // Estimate of the number of bytes used by the skip list + size_t size_of() const; + virtual ~HeadNode(); + +protected: + void _adjRemoveRefs(size_t level, Node *pNode); + const Node *_nodeAt(size_t idx) const; + +protected: + // Standardised way of throwing a ValueError + void _throwValueErrorNotFound(const T &value) const; + void _throwIfValueDoesNotCompare(const T &value) const; + // Internal integrity checks + IntegrityCheck _lacksIntegrityCyclicReferences() const; + IntegrityCheck _lacksIntegrityWidthAccumulation() const; + IntegrityCheck _lacksIntegrityNodeReferencesNotInList() const; + IntegrityCheck _lacksIntegrityOrder() const; +protected: + /// Number of nodes in the list. + size_t _count; + /// My node references, the size of this is the largest height in the list + SwappableNodeRefStack _nodeRefs; + /// Comparison function. + _Compare _compare; + typename Node::_Pool _pool; +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS + /// Used to count how many sub-graphs have been plotted + mutable size_t _dot_file_subgraph; +#endif + +private: + /// Prevent cctor and operator= + HeadNode(const HeadNode &that); + HeadNode &operator=(const HeadNode &that) const; +}; + +/** + * Returns true if the value is present in the skip list. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param value Value to check if it is in the Skip List. + * @return true if in the Skip List. + */ +template +bool HeadNode::has(const T &value) const { + _throwIfValueDoesNotCompare(value); +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + for (size_t l = _nodeRefs.height(); l-- > 0;) { + assert(_nodeRefs[l].pNode); + if (_nodeRefs[l].pNode->has(value)) { + return true; + } + } + return false; +} + +/** + * Returns the value at a particular index. + * Will throw an OrderedStructs::SkipList::IndexError if index out of range. + * + * If @ref SKIPLIST_THREAD_SUPPORT is defined this will block. + * + * See _throw_exceeds_size() that does the throw. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param index The index. + * @return The value at that index. + */ +template +const T &HeadNode::at(size_t index) const { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + const Node *pNode = _nodeAt(index); + assert(pNode); + return pNode->value(); +} + +/** + * Find the count number of value starting at index and write them to dest. + * + * Will throw a OrderedStructs::SkipList::IndexError if any index out of range. + * + * This is useful for rolling median on even length lists where the caller might want to implement the mean of two + * values. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param index The index. + * @param count The number of values to retrieve. + * @param dest The vector of values + */ +template +void HeadNode::at(size_t index, size_t count, + std::vector &dest) const { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + dest.clear(); + const Node *pNode = _nodeAt(index); + // _nodeAt will (should) throw an IndexError so this + // assert should always be true + assert(pNode); + while (count) { + if (! pNode) { + _throw_exceeds_size(_count); + } + dest.push_back(pNode->value()); + pNode = pNode->next(); + --count; + } +} + +/** + * Computes index of the first occurrence of a value + * Will throw a OrderedStructs::SkipList::ValueError if the value does not exist in the skip list + * Will throw a OrderedStructs::SkipList::FailedComparison if the value is not comparable. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param value The value to search for. + * @return + */ +template +size_t HeadNode::index(const T& value) const { + _throwIfValueDoesNotCompare(value); + size_t idx; + +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + for (size_t l = _nodeRefs.height(); l-- > 0;) { + assert(_nodeRefs[l].pNode); + if (_nodeRefs[l].pNode->index(value, idx, l)) { + idx += _nodeRefs[l].width; + assert(idx > 0); + return idx - 1; + } + } + _throwValueErrorNotFound(value); + return 0; +} + +/** + * Return the number of values in the Skip List. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @return The number of values in the Skip List. + */ +template +size_t HeadNode::size() const { + return _count; +} + +template +size_t HeadNode::height() const { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + size_t val = _nodeRefs.height(); + return val; +} + +/** + * Return the number of linked lists that the node at index has. + * + * Will throw a OrderedStructs::SkipList::IndexError if the index out of range. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param idx The index of the Skip List node. + * @return The number of linked lists that the node at the index has. + */ +template +size_t HeadNode::height(size_t idx) const { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + const Node *pNode = _nodeAt(idx); + assert(pNode); + return pNode->height(); +} + +/** + * The skip width of the Node at index has at the given level. + * Will throw an IndexError if the index is out of range. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param idx The index. + * @param level The level. + * @return Width of Node. + */ +template +size_t HeadNode::width(size_t idx, size_t level) const { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + // Will throw if out of range. + const Node *pNode = _nodeAt(idx); + assert(pNode); + if (level >= pNode->height()) { + _throw_exceeds_size(pNode->height()); + } + return pNode->nodeRefs()[level].width; +} + +/** + * Find the Node at the given index. + * Will throw an IndexError if the index is out of range. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param idx The index. + * @return The Node. + */ +template +const Node *HeadNode::_nodeAt(size_t idx) const { + if (idx < _count) { + for (size_t l = _nodeRefs.height(); l-- > 0;) { + if (_nodeRefs[l].pNode && _nodeRefs[l].width <= idx + 1) { + size_t new_index = idx + 1 - _nodeRefs[l].width; + const Node *pNode = _nodeRefs[l].pNode->at(new_index); + if (pNode) { + return pNode; + } + } + } + } + assert(idx >= _count); + _throw_exceeds_size(_count); + // Should not get here as _throw_exceeds_size() will always throw. + return NULL; +} + +/** + * Insert a value. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param value + */ +template +void HeadNode::insert(const T &value) { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#ifdef SKIPLIST_THREAD_SUPPORT_TRACE + std::cout << "HeadNode insert() thread: " << std::this_thread::get_id() << std::endl; +#endif +#endif + Node *pNode = nullptr; + size_t level = _nodeRefs.height(); + + _throwIfValueDoesNotCompare(value); + while (level-- > 0) { + assert(_nodeRefs[level].pNode); + pNode = _nodeRefs[level].pNode->insert(value); + if (pNode) { + break; + } + } + if (! pNode) { + pNode = _pool.Allocate(value); + level = 0; + } + assert(pNode); + SwappableNodeRefStack &thatRefs = pNode->nodeRefs(); + if (thatRefs.canSwap()) { + // Expand this to that + while (_nodeRefs.height() < thatRefs.height()) { + _nodeRefs.push_back(nullptr, _count + 1); + } + if (level < thatRefs.swapLevel()) { + // Happens when we were originally, say 3 high (max height of any + // previously seen node). Then a node is created + // say 5 high. In that case this will be at level 2 and + // thatRefs.swapLevel() will be 3 + assert(level + 1 == thatRefs.swapLevel()); + thatRefs[thatRefs.swapLevel()].width += _nodeRefs[level].width; + ++level; + } + // Now swap + while (level < _nodeRefs.height() && thatRefs.canSwap()) { + assert(thatRefs.canSwap()); + assert(level == thatRefs.swapLevel()); + _nodeRefs[level].width -= thatRefs[level].width - 1; + thatRefs.swap(_nodeRefs); + if (thatRefs.canSwap()) { + assert(thatRefs[thatRefs.swapLevel()].width == 0); + thatRefs[thatRefs.swapLevel()].width = _nodeRefs[level].width; + } + ++level; + } + // Check all references swapped + assert(! thatRefs.canSwap()); + // Check that all 'this' pointers created on construction have been moved + assert(thatRefs.noNodePointerMatches(pNode)); + } + if (level < thatRefs.swapLevel()) { + // Happens when we are, say 5 high then a node is created + // and consumed by the next node say 3 high. In that case this will be + // at level 2 and thatRefs.swapLevel() will be 3 + assert(level + 1 == thatRefs.swapLevel()); + ++level; + } + // Increment my widths as my references are now going over the top of + // pNode. + while (level < _nodeRefs.height() && level >= thatRefs.height()) { + _nodeRefs[level++].width += 1; + } + ++_count; +#ifdef SKIPLIST_THREAD_SUPPORT +#ifdef SKIPLIST_THREAD_SUPPORT_TRACE + std::cout << "HeadNode insert() thread: " << std::this_thread::get_id() << " DONE" << std::endl; +#endif +#endif +} + +/** + * Adjust references >= level for removal of the node pNode. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param level Current level. + * @param pNode Node to swap references with. + */ +template +void HeadNode::_adjRemoveRefs(size_t level, + Node *pNode) { + assert(pNode); + SwappableNodeRefStack &thatRefs = pNode->nodeRefs(); + + // Swap all remaining levels + // This assertion checks that if swapping can take place we must be at the + // same level. + assert(! thatRefs.canSwap() || level == thatRefs.swapLevel()); + while (level < _nodeRefs.height() && thatRefs.canSwap()) { + assert(level == thatRefs.swapLevel()); + // Compute the new width for the new node + thatRefs[level].width += _nodeRefs[level].width - 1; + thatRefs.swap(_nodeRefs); + ++level; + if (! thatRefs.canSwap()) { + break; + } + } + assert(! thatRefs.canSwap()); + // Decrement my widths as my references are now going over the top of + // pNode. + while (level < _nodeRefs.height()) { + _nodeRefs[level++].width -= 1; + } + // Decrement my stack while top has a NULL pointer. + while (_nodeRefs.height() && ! _nodeRefs[_nodeRefs.height() - 1].pNode) { + _nodeRefs.pop_back(); + } +} + +/** + * Remove a Node with a value. + * May throw a ValueError if the value is not found. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param value The value in the Node to remove. + * @return The value removed. + */ +template +T HeadNode::remove(const T &value) { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#ifdef SKIPLIST_THREAD_SUPPORT_TRACE + std::cout << "HeadNode remove() thread: " << std::this_thread::get_id() << std::endl; +#endif +#endif + Node *pNode = nullptr; + size_t level; + + _throwIfValueDoesNotCompare(value); + for (level = _nodeRefs.height(); level-- > 0;) { + assert(_nodeRefs[level].pNode); + pNode = _nodeRefs[level].pNode->remove(level, value); + if (pNode) { + break; + } + } + if (! pNode) { + _throwValueErrorNotFound(value); + } + // Take swap level as some swaps will have been dealt with by the remove() above. + _adjRemoveRefs(pNode->nodeRefs().swapLevel(), pNode); + --_count; + T ret_val = _pool.Release(pNode); +#ifdef SKIPLIST_THREAD_SUPPORT_TRACE + std::cout << "HeadNode remove() thread: " << std::this_thread::get_id() << " DONE" << std::endl; +#endif + return ret_val; +} + +/** + * Throw a ValueError in a consistent fashion. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param value The value to put into the ValueError. + */ +template +void HeadNode::_throwValueErrorNotFound(const T &value) const { +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS + std::ostringstream oss; + oss << "Value " << value << " not found."; + std::string err_msg = oss.str(); +#else + std::string err_msg = "Value not found."; +#endif + throw ValueError(err_msg); +} + +/** + * Checks that the value == value. + * This will throw a FailedComparison if that is not the case, for example NaN. + * + * @note + * The Node class is (should be) not directly accessible by the user so we can just assert(value == value) in Node. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param value + */ +template +void HeadNode::_throwIfValueDoesNotCompare(const T &value) const { + if (value != value) { + throw FailedComparison( + "Can not work with something that does not compare equal to itself."); + } +} + +/** + * This tests that at every level >= 0 the sequence of node pointers + * at that level does not contain a cyclic reference. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @return An IntegrityCheck enum. + */ +template +IntegrityCheck HeadNode::_lacksIntegrityCyclicReferences() const { + assert(_nodeRefs.height()); + // Check for cyclic references at each level + for (size_t level = 0; level < _nodeRefs.height(); ++level) { + Node *p1 = _nodeRefs[level].pNode; + Node *p2 = _nodeRefs[level].pNode; + while (p1 && p2) { + p1 = p1->nodeRefs()[level].pNode; + if (p2->nodeRefs()[level].pNode) { + p2 = p2->nodeRefs()[level].pNode->nodeRefs()[level].pNode; + } else { + p2 = nullptr; + } + if (p1 && p2 && p1 == p2) { + return HEADNODE_DETECTS_CYCLIC_REFERENCE; + } + } + } + return INTEGRITY_SUCCESS; +} + +/** + * This tests that at every level > 0 the node to node width is the same + * as the accumulated node to node widths at level - 1. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @return An IntegrityCheck enum. + */ +template +IntegrityCheck HeadNode::_lacksIntegrityWidthAccumulation() const { + assert(_nodeRefs.height()); + for (size_t level = 1; level < _nodeRefs.height(); ++level) { + const Node *pl = _nodeRefs[level].pNode; + const Node *pl_1 = _nodeRefs[level - 1].pNode; + assert(pl && pl_1); // No nulls allowed in HeadNode + size_t wl = _nodeRefs[level].width; + size_t wl_1 = _nodeRefs[level - 1].width; + while (true) { + while (pl != pl_1) { + assert(pl_1); // Could only happen if a lower reference was NULL and the higher non-NULL. + wl_1 += pl_1->width(level - 1); + pl_1 = pl_1->pNode(level - 1); + } + if (wl != wl_1) { + return HEADNODE_LEVEL_WIDTHS_MISMATCH; + } + if (pl == nullptr && pl_1 == nullptr) { + break; + } + wl = pl->width(level); + wl_1 = pl_1->width(level - 1); + pl = pl->pNode(level); + pl_1 = pl_1->pNode(level - 1); + } + } + return INTEGRITY_SUCCESS; +} + +/** + * This tests the integrity of each Node. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @return An IntegrityCheck enum. + */ +template +IntegrityCheck HeadNode::_lacksIntegrityNodeReferencesNotInList() const { + assert(_nodeRefs.height()); + + IntegrityCheck result; + std::set*> nodeSet; + const Node *pNode = _nodeRefs[0].pNode; + assert(pNode); + + // First gather all nodes, slightly awkward code here is so that + // NULL is always included. + nodeSet.insert(pNode); + do { + pNode = pNode->next(); + nodeSet.insert(pNode); + } while (pNode); + assert(nodeSet.size() == _count + 1); // All nodes plus NULL + // Then test each node does not have pointers that are not in nodeSet + pNode = _nodeRefs[0].pNode; + while (pNode) { + result = pNode->lacksIntegrityRefsInSet(nodeSet); + if (result) { + return result; + } + pNode = pNode->next(); + } + return INTEGRITY_SUCCESS; +} + +/** + * Integrity check. Traverse the lowest level and check that the ordering + * is correct according to the compare function. The HeadNode checks that the + * Node(s) have correctly applied the compare function. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @return An IntegrityCheck enum. + */ +template +IntegrityCheck HeadNode::_lacksIntegrityOrder() const { + if (_nodeRefs.height()) { + // Traverse the lowest level list iteratively deleting as we go + // Doing this recursivley could be expensive as we are at level 0. + const Node *node = _nodeRefs[0].pNode; + const Node *next; + while (node) { + next = node->next(); + if (next && _compare(next->value(), node->value())) { + return HEADNODE_DETECTS_OUT_OF_ORDER; + } + node = next; + } + } + return INTEGRITY_SUCCESS; +} + +/** + * Full integrity check. + * This calls the other integrity check functions. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @return An IntegrityCheck enum. + */ +template +IntegrityCheck HeadNode::lacksIntegrity() const { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + if (_nodeRefs.height()) { + IntegrityCheck result = _nodeRefs.lacksIntegrity(); + if (result) { + return result; + } + if (! _nodeRefs.noNodePointerMatches(nullptr)) { + return HEADNODE_CONTAINS_NULL; + } + // Check all nodes for integrity + const Node *pNode = _nodeRefs[0].pNode; + while (pNode) { + result = pNode->lacksIntegrity(_nodeRefs.height()); + if (result) { + return result; + } + pNode = pNode->next(); + } + // Check count against total number of nodes + pNode = _nodeRefs[0].pNode; + size_t total = 0; + while (pNode) { + total += pNode->nodeRefs()[0].width; + pNode = pNode->next(); + } + if (total != _count) { + return HEADNODE_COUNT_MISMATCH; + } + result = _lacksIntegrityWidthAccumulation(); + if (result) { + return result; + } + result = _lacksIntegrityCyclicReferences(); + if (result) { + return result; + } + result = _lacksIntegrityNodeReferencesNotInList(); + if (result) { + return result; + } + result = _lacksIntegrityOrder(); + if (result) { + return result; + } + } + return INTEGRITY_SUCCESS; +} + +/** + * Returns an estimate of the memory usage of an instance. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @return The size of the memory estimate. + */ +template +size_t HeadNode::size_of() const { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + // sizeof(*this) includes the size of _nodeRefs but _nodeRefs.size_of() + // includes sizeof(_nodeRefs) so we need to subtract to avoid double counting + size_t ret_val = sizeof(*this) + _nodeRefs.size_of() - sizeof(_nodeRefs); + if (_nodeRefs.height()) { + const Node *node = _nodeRefs[0].pNode; + while (node) { + ret_val += node->size_of(); + node = node->next(); + } + } + return ret_val; +} + +/** + * Destructor. + * This deletes all Nodes. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + */ +template +HeadNode::~HeadNode() { + // Hmm could this deadlock? +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + if (_nodeRefs.height()) { + // Traverse the lowest level list iteratively deleting as we go + // Doing this recursivley could be expensive as we are at level 0. + const Node *node = _nodeRefs[0].pNode; + const Node *next; + while (node) { + next = node->next(); + delete node; + --_count; + node = next; + } + } + assert(_count == 0); +} + +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS + +/** + * Create a DOT file of the internal representation. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param os Where to write the DOT file. + */ +template +void HeadNode::dotFile(std::ostream &os) const { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + if (_dot_file_subgraph == 0) { + os << "digraph SkipList {" << std::endl; + os << "label = \"SkipList.\"" << std::endl; + os << "graph [rankdir = \"LR\"];" << std::endl; + os << "node [fontsize = \"12\" shape = \"ellipse\"];" << std::endl; + os << "edge [];" << std::endl; + os << std::endl; + } + os << "subgraph cluster" << _dot_file_subgraph << " {" << std::endl; + os << "style=dashed" << std::endl; + os << "label=\"Skip list iteration " << _dot_file_subgraph << "\"" << std::endl; + os << std::endl; + os << "\"HeadNode" << _dot_file_subgraph; + os << "\" [" << std::endl; + os << "label = \""; + // Write out the fields + if (_nodeRefs.height()) { + for (size_t level = _nodeRefs.height(); level-- > 0;) { + os << "{ " << _nodeRefs[level].width << " | "; + os << " "; + os << std::hex << _nodeRefs[level].pNode << std::dec; + os << "}"; + if (level > 0) { + os << " | "; + } + } + } else { + os << "Empty HeadNode"; + } + os << "\"" << std::endl; + os << "shape = \"record\"" << std::endl; + os << "];" << std::endl; + // Edges for head node + for (size_t level = 0; level < _nodeRefs.height(); ++level) { + os << "\"HeadNode"; + os << _dot_file_subgraph; + os << "\":f" << level + 1 << " -> "; + _nodeRefs[level].pNode->writeNode(os, _dot_file_subgraph); + os << ":w" << level + 1 << " [];" << std::endl; + } + os << std::endl; + // Now all nodes via level 0, if non-empty + if (_nodeRefs.height()) { + Node *pNode = this->_nodeRefs[0].pNode; + pNode->dotFile(os, _dot_file_subgraph); + } + os << std::endl; + // NULL, the sentinal node + if (_nodeRefs.height()) { + os << "\"node"; + os << _dot_file_subgraph; + os << "0x0\" [label = \""; + for (size_t level = _nodeRefs.height(); level-- > 0;) { + os << " NULL"; + if (level) { + os << " | "; + } + } + os << "\" shape = \"record\"];" << std::endl; + } + // End: "subgraph cluster1 {" + os << "}" << std::endl; + os << std::endl; + _dot_file_subgraph += 1; +} + +/** + * Finalise the DOT file of the internal representation. + * + * @tparam T Type of the values in the Skip List. + * @tparam _Compare Compare function. + * @param os Where to write the DOT file. + */ +template +void HeadNode::dotFileFinalise(std::ostream &os) const { +#ifdef SKIPLIST_THREAD_SUPPORT + std::lock_guard lock(gSkipListMutex); +#endif + if (_dot_file_subgraph > 0) { + // Link the nodes together with an invisible node. + // node0 [shape=record, label = " | | | | | | | | | ", + // style=invis, + // width=0.01]; + os << "node0 [shape=record, label = \""; + for (size_t i = 0; i < _dot_file_subgraph; ++i) { + os << " | "; + } + os << "\", style=invis, width=0.01];" << std::endl; + // Now: + // node0:f0 -> HeadNode [style=invis]; + // node0:f1 -> HeadNode1 [style=invis]; + for (size_t i = 0; i < _dot_file_subgraph; ++i) { + os << "node0:f" << i << " -> HeadNode" << i; + os << " [style=invis];" << std::endl; + } + _dot_file_subgraph = 0; + } + os << "}" << std::endl; +} + +#endif // INCLUDE_METHODS_THAT_USE_STREAMS + +/************************** END: HeadNode *******************************/ + +#endif // SkipList_HeadNode_h diff --git a/third_party/skiplist/IntegrityEnums.h b/third_party/skiplist/IntegrityEnums.h new file mode 100755 index 00000000000..143d0a3e7f0 --- /dev/null +++ b/third_party/skiplist/IntegrityEnums.h @@ -0,0 +1,62 @@ +#ifndef SkipList_IntegrityEnums_h +#define SkipList_IntegrityEnums_h + +/** + * @file + * + * Project: skiplist + * + * Integrity codes for structures in this code. + * + * Created by Paul Ross on 11/12/2015. + * + * Copyright (c) 2015-2023 Paul Ross. All rights reserved. + * + * @code + * MIT License + * + * Copyright (c) 2015-2023 Paul Ross + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * @endcode + */ + +/** + * Various integrity codes for structures in this code. + */ +enum IntegrityCheck { + INTEGRITY_SUCCESS = 0, + // SwappableNodeRefStack integrity checks + NODEREFS_WIDTH_ZERO_NOT_UNITY = 100, + NODEREFS_WIDTH_DECREASING, + // Node integrity checks + NODE_HEIGHT_ZERO = 200, + NODE_HEIGHT_EXCEEDS_HEADNODE, + NODE_NON_NULL_AFTER_NULL, + NODE_SELF_REFERENCE, + NODE_REFERENCES_NOT_IN_GLOBAL_SET, + // HeadNode integrity checks + HEADNODE_CONTAINS_NULL = 300, + HEADNODE_COUNT_MISMATCH, + HEADNODE_LEVEL_WIDTHS_MISMATCH, + HEADNODE_DETECTS_CYCLIC_REFERENCE, + HEADNODE_DETECTS_OUT_OF_ORDER, +}; + +#endif diff --git a/third_party/skiplist/LICENSE b/third_party/skiplist/LICENSE new file mode 100644 index 00000000000..cee9bac95d5 --- /dev/null +++ b/third_party/skiplist/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017-2023 Paul Ross + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/skiplist/Node.h b/third_party/skiplist/Node.h new file mode 100755 index 00000000000..73c8024eca4 --- /dev/null +++ b/third_party/skiplist/Node.h @@ -0,0 +1,640 @@ +/** + * @file + * + * Project: skiplist + * + * Concurrency Tests. + * + * Created by Paul Ross on 03/12/2015. + * + * Copyright (c) 2015-2023 Paul Ross. All rights reserved. + * + * @code + * MIT License + * + * Copyright (c) 2015-2023 Paul Ross + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * @endcode + */ + +#ifndef SkipList_Node_h +#define SkipList_Node_h + +#include "IntegrityEnums.h" + +#if __cplusplus < 201103L +#define nullptr NULL +#endif + +/**************************** Node *********************************/ + +/** + * @brief A single node in a Skip List containing a value and references to other downstream Node objects. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + */ +template +class Node { +public: + struct _Pool { + explicit _Pool(_Compare _cmp) : _compare(_cmp), cache(nullptr) { + } + ~_Pool() { + delete cache; + } + Node *Allocate(const T &value) { + if (cache) { + Node *result = cache; + cache = nullptr; + result->Initialize(value); + return result; + } + + return new Node(value, _compare, *this); + } + + T Release(Node *pNode) { + T result = pNode->value(); + std::swap(pNode, cache); + delete pNode; + return result; + } + + _Compare _compare; + Node* cache; + }; + + Node(const T &value, _Compare _cmp, _Pool &pool); + // Const methods + // + /// Returns the node value + const T &value() const { return _value; } + // Returns true if the value is present in the skip list from this node onwards. + bool has(const T &value) const; + // Returns the value at the index in the skip list from this node onwards. + // Will return nullptr is not found. + const Node *at(size_t idx) const; + // Computes index of the first occurrence of a value + bool index(const T& value, size_t &idx, size_t level) const; + /// Number of linked lists that this node engages in, minimum 1. + size_t height() const { return _nodeRefs.height(); } + // Return the pointer to the next node at level 0 + const Node *next() const; + // Return the width at given level. + size_t width(size_t level) const; + // Return the node pointer at given level, only used for HeadNode + // integrity checks. + const Node *pNode(size_t level) const; + + // Non-const methods + /// Get a reference to the node references + SwappableNodeRefStack &nodeRefs() { return _nodeRefs; } + /// Get a reference to the node references + const SwappableNodeRefStack &nodeRefs() const { return _nodeRefs; } + // Insert a node + Node *insert(const T &value); + // Remove a node + Node *remove(size_t call_level, const T &value); + // An estimate of the number of bytes used by this node + size_t size_of() const; + +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS + void dotFile(std::ostream &os, size_t suffix = 0) const; + void writeNode(std::ostream &os, size_t suffix = 0) const; +#endif // INCLUDE_METHODS_THAT_USE_STREAMS + + // Integrity checks, returns non-zero on failure + IntegrityCheck lacksIntegrity(size_t headnode_height) const; + IntegrityCheck lacksIntegrityRefsInSet(const std::set*> &nodeSet) const; + +protected: + Node *_adjRemoveRefs(size_t level, Node *pNode); + + void Initialize(const T &value) { + _value = value; + _nodeRefs.clear(); + do { + _nodeRefs.push_back(this, _nodeRefs.height() ? 0 : 1); + } while (tossCoin()); + } + +protected: + T _value; + SwappableNodeRefStack _nodeRefs; + // Comparison function + _Compare _compare; + _Pool &_pool; +private: + // Prevent cctor and operator= + Node(const Node &that); + Node &operator=(const Node &that) const; +}; + +/** + * Constructor. + * This also creates a SwappableNodeRefStack of random height by tossing a virtual coin. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param value The value of the Node. + * @param _cmp The comparison function. + */ +template +Node::Node(const T &value, _Compare _cmp, _Pool &pool) : \ + _value(value), _compare(_cmp), _pool(pool) { + Initialize(value); +} + +/** + * Returns true if the value is present in the skip list from this node onwards. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param value The value to look for. + * @return true if the value is present in the skip list from this node onwards. + */ +template +bool Node::has(const T &value) const { + assert(_nodeRefs.height()); + assert(value == value); // value can not be NaN for example + // Effectively: if (value > _value) { + if (_compare(_value, value)) { + for (size_t l = _nodeRefs.height(); l-- > 0;) { + if (_nodeRefs[l].pNode && _nodeRefs[l].pNode->has(value)) { + return true; + } + } + return false; + } + // Effectively: return value == _value; // false if value smaller + return !_compare(value, _value) && !_compare(_value, value); +} + +/** + * Return a pointer to the n'th node. + * Start (or continue) from the highest level, drop down a level if not found. + * Return nullptr if not found at level 0. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param idx The index from hereon. If zero return this. + * @return Pointer to the Node or nullptr. + */ +template +const Node *Node::at(size_t idx) const { + assert(_nodeRefs.height()); + if (idx == 0) { + return this; + } + for (size_t l = _nodeRefs.height(); l-- > 0;) { + if (_nodeRefs[l].pNode && _nodeRefs[l].width <= idx) { + return _nodeRefs[l].pNode->at(idx - _nodeRefs[l].width); + } + } + return nullptr; +} + +/** + * Computes index of the first occurrence of a value. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param value The value to find. + * @param idx The current index, this will be updated. + * @param level The current level to search from. + * @return true if found, false otherwise. + */ +template +bool Node::index(const T& value, size_t &idx, size_t level) const { + assert(_nodeRefs.height()); + assert(value == value); // value can not be NaN for example + assert(level < _nodeRefs.height()); + // Search has overshot, try again at a lower level. + //if (_value > value) { + if (_compare(value, _value)) { + return false; + } + // First check if we match but we have been approached at a high level + // as there may be an earlier node of the same value but with fewer + // node references. In that case this search has to fail and try at a + // lower level. + // If however the level is 0 and we match then set the idx to 0 to mark us. + // Effectively: if (_value == value) { + if (!_compare(value, _value) && !_compare(_value, value)) { + if (level > 0) { + return false; + } + idx = 0; + return true; + } + // Now work our way down + // NOTE: We initialise l as level + 1 because l-- > 0 will decrement it to + // the correct initial value + for (size_t l = level + 1; l-- > 0;) { + assert(l < _nodeRefs.height()); + if (_nodeRefs[l].pNode && _nodeRefs[l].pNode->index(value, idx, l)) { + idx += _nodeRefs[l].width; + return true; + } + } + return false; +} + +/** + * Return the pointer to the next node at level 0. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @return The next node at level 0. + */ +template +const Node *Node::next() const { + assert(_nodeRefs.height()); + return _nodeRefs[0].pNode; +} + +/** + * Return the width at given level. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param level The requested level. + * @return The width. + */ +template +size_t Node::width(size_t level) const { + assert(level < _nodeRefs.height()); + return _nodeRefs[level].width; +} + +/** + * Return the node pointer at given level, only used for HeadNode integrity checks. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param level The requested level. + * @return The Node. + */ +template +const Node *Node::pNode(size_t level) const { + assert(level < _nodeRefs.height()); + return _nodeRefs[level].pNode; +} + +/** + * Insert a new node with a value. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param value The value of the Node to insert. + * @return Pointer to the new Node or nullptr on failure. + */ +template +Node *Node::insert(const T &value) { + assert(_nodeRefs.height()); + assert(_nodeRefs.noNodePointerMatches(this)); + assert(! _nodeRefs.canSwap()); + assert(value == value); // NaN check for double + + // Effectively: if (value < _value) { + if (_compare(value, _value)) { + return nullptr; + } + // Recursive search for where to put the node + Node *pNode = nullptr; + size_t level = _nodeRefs.height(); + // Effectively: if (value >= _value) { + if (! _compare(value, _value)) { + for (level = _nodeRefs.height(); level-- > 0;) { + if (_nodeRefs[level].pNode) { + pNode = _nodeRefs[level].pNode->insert(value); + if (pNode) { + break; + } + } + } + } + // Effectively: if (! pNode && value >= _value) { + if (! pNode && !_compare(value, _value)) { + // Insert new node here + pNode = _pool.Allocate(value); + level = 0; + } + assert(pNode); // Should never get here unless a NaN has slipped through + // Adjust references by marching up and recursing back. + SwappableNodeRefStack &thatRefs = pNode->_nodeRefs; + if (! thatRefs.canSwap()) { + // Have an existing node or new node that is all swapped. + // All I need to do is adjust my overshooting nodes and return + // this for the caller to do the same. + level = thatRefs.height(); + while (level < _nodeRefs.height()) { + _nodeRefs[level].width += 1; + ++level; + } + // The caller just has to increment its references that overshoot this + assert(! _nodeRefs.canSwap()); + return this; + } + // March upwards + if (level < thatRefs.swapLevel()) { + assert(level == thatRefs.swapLevel() - 1); + // This will happen when say a 3 high node, A, finds a 2 high + // node, B, that creates a new 2+ high node. A will be at + // level 1 and the new node will have swapLevel == 2 after + // B has swapped. + // Add the level to the accumulator at the next level + thatRefs[thatRefs.swapLevel()].width += _nodeRefs[level].width; + ++level; + } + size_t min_height = std::min(_nodeRefs.height(), thatRefs.height()); + while (level < min_height) { + assert(thatRefs.canSwap()); + assert(level == thatRefs.swapLevel()); + assert(level < thatRefs.height()); + assert(_nodeRefs[level].width > 0); + assert(thatRefs[level].width > 0); + _nodeRefs[level].width -= thatRefs[level].width - 1; + assert(_nodeRefs[level].width > 0); + thatRefs.swap(_nodeRefs); + if (thatRefs.canSwap()) { + assert(thatRefs[thatRefs.swapLevel()].width == 0); + thatRefs[thatRefs.swapLevel()].width = _nodeRefs[level].width; + } + ++level; + } + // Upwards march complete, now recurse back ('left'). + if (! thatRefs.canSwap()) { + // All done with pNode locally. + assert(level == thatRefs.height()); + assert(thatRefs.height() <= _nodeRefs.height()); + assert(level == thatRefs.swapLevel()); + // Adjust my overshooting nodes + while (level < _nodeRefs.height()) { + _nodeRefs[level].width += 1; + ++level; + } + // The caller just has to increment its references that overshoot this + assert(! _nodeRefs.canSwap()); + pNode = this; + } + return pNode; +} + +/** + * Adjust the Node references. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param level The level of the caller's node. + * @param pNode The Node to swap references with. + * @return The Node with swapped references. + */ +template +Node *Node::_adjRemoveRefs(size_t level, Node *pNode) { + assert(pNode); + SwappableNodeRefStack &thatRefs = pNode->_nodeRefs; + + assert(pNode != this); + if (level < thatRefs.swapLevel()) { + assert(level == thatRefs.swapLevel() - 1); + ++level; + } + if (thatRefs.canSwap()) { + assert(level == thatRefs.swapLevel()); + while (level < _nodeRefs.height() && thatRefs.canSwap()) { + assert(level == thatRefs.swapLevel()); + // Compute the new width for the new node + thatRefs[level].width += _nodeRefs[level].width - 1; + thatRefs.swap(_nodeRefs); + ++level; + } + assert(thatRefs.canSwap() || thatRefs.allNodePointerMatch(pNode)); + } + // Decrement my widths as my refs are over the top of the missing pNode. + while (level < _nodeRefs.height()) { + _nodeRefs[level].width -= 1; + ++level; + thatRefs.incSwapLevel(); + } + assert(! _nodeRefs.canSwap()); + return pNode; +} + +/** + * Remove a Node with the given value to be removed. + * The return value must be deleted, the other Nodes have been adjusted as required. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param call_level Level the caller Node is at. + * @param value Value of the detached Node to remove. + * @return A pointer to the Node to be free'd or nullptr on failure. + */ +template +Node *Node::remove(size_t call_level, + const T &value) { + assert(_nodeRefs.height()); + assert(_nodeRefs.noNodePointerMatches(this)); + + Node *pNode = nullptr; + // Effectively: if (value >= _value) { + if (!_compare(value, _value)) { + for (size_t level = call_level + 1; level-- > 0;) { + if (_nodeRefs[level].pNode) { + // Make progress to the right + pNode = _nodeRefs[level].pNode->remove(level, value); + if (pNode) { + return _adjRemoveRefs(level, pNode); + } + } + // Make progress down + } + } + if (! pNode) { // Base case + // We only admit to being the node to remove if the caller is + // approaching us from level 0. It is entirely likely that + // the same (or an other) caller can see us at a higher level + // but the recursion stack will not have been set up in the correct + // step wise fashion so that the lower level references will + // not be swapped. + // Effectively: if (call_level == 0 && value == _value) { + if (call_level == 0 && !_compare(value, _value) && !_compare(_value, value)) { + _nodeRefs.resetSwapLevel(); + return this; + } + } + assert(pNode == nullptr); + return nullptr; +} + +/* + * This checks the internal concistency of a Node. It returns 0 + * if succesful, non-zero on error. The tests are: + * + * - Height must be >= 1 + * - Height must not exceed HeadNode height. + * - NULL pointer must not have a non-NULL above them. + * - Node pointers must not be self-referential. + */ +/** + * This checks the internal concistency of a Node. It returns 0 + * if succesful, non-zero on error. The tests are: + * + * - Height must be >= 1 + * - Height must not exceed HeadNode height. + * - NULL pointer must not have a non-NULL above them. + * - Node pointers must not be self-referential. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param headnode_height Height of HeadNode. + * @return An IntegrityCheck enum. + */ +template +IntegrityCheck Node::lacksIntegrity(size_t headnode_height) const { + IntegrityCheck result = _nodeRefs.lacksIntegrity(); + if (result) { + return result; + } + if (_nodeRefs.height() == 0) { + return NODE_HEIGHT_ZERO; + } + if (_nodeRefs.height() > headnode_height) { + return NODE_HEIGHT_EXCEEDS_HEADNODE; + } + // Test: All nodes above a nullprt must be nullptr + size_t level = 0; + while (level < _nodeRefs.height()) { + if (! _nodeRefs[level].pNode) { + break; + } + ++level; + } + while (level < _nodeRefs.height()) { + if (_nodeRefs[level].pNode) { + return NODE_NON_NULL_AFTER_NULL; + } + ++level; + } + // No reference should be to self. + if (! _nodeRefs.noNodePointerMatches(this)) { + return NODE_SELF_REFERENCE; + } + return INTEGRITY_SUCCESS; +} + +/** + * Checks that this Node is in the set held by the HeadNode. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param nodeSet Set of Nodes held by the HeadNode. + * @return An IntegrityCheck enum. + */ +template +IntegrityCheck Node::lacksIntegrityRefsInSet(const std::set*> &nodeSet) const { + size_t level = 0; + while (level < _nodeRefs.height()) { + if (nodeSet.count(_nodeRefs[level].pNode) == 0) { + return NODE_REFERENCES_NOT_IN_GLOBAL_SET; + } + ++level; + } + return INTEGRITY_SUCCESS; +} + +/** + * Returns an estimate of the memory usage of an instance. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @return The memory estimate of this Node. + */ +template +size_t Node::size_of() const { + // sizeof(*this) includes the size of _nodeRefs but _nodeRefs.size_of() + // includes sizeof(_nodeRefs) so we need to subtract to avoid double counting + return sizeof(*this) + _nodeRefs.size_of() - sizeof(_nodeRefs); +} + + +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS + +/** + * Writes out this Node address. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param os Where to write. + * @param suffix The suffix (node number). + */ +template +void Node::writeNode(std::ostream &os, size_t suffix) const { + os << "\"node"; + os << suffix; + os << std::hex << this << std::dec << "\""; +} + +/** + * Writes out a fragment of a DOT file representing this Node. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param os Wheere to write. + * @param suffix The node number. + */ +template +void Node::dotFile(std::ostream &os, size_t suffix) const { + assert(_nodeRefs.height()); + writeNode(os, suffix); + os << " [" << std::endl; + os << "label = \""; + for (size_t level = _nodeRefs.height(); level-- > 0;) { + os << " { " << _nodeRefs[level].width; + os << " | "; + os << std::hex << _nodeRefs[level].pNode << std::dec; + os << " }"; + os << " |"; + } + os << " " << _value << "\"" << std::endl; + os << "shape = \"record\"" << std::endl; + os << "];" << std::endl; + // Now edges + for (size_t level = 0; level < _nodeRefs.height(); ++level) { + writeNode(os, suffix); + os << ":f" << level + 1 << " -> "; + _nodeRefs[level].pNode->writeNode(os, suffix); + // writeNode(os, suffix); + // os << ":f" << i + 1 << " [];" << std::endl; + os << ":w" << level + 1 << " [];" << std::endl; + } + assert(_nodeRefs.height()); + if (_nodeRefs[0].pNode) { + _nodeRefs[0].pNode->dotFile(os, suffix); + } +} + +#endif // INCLUDE_METHODS_THAT_USE_STREAMS + +/************************** END: Node *******************************/ + +#endif // SkipList_Node_h diff --git a/third_party/skiplist/NodeRefs.h b/third_party/skiplist/NodeRefs.h new file mode 100755 index 00000000000..839752e435d --- /dev/null +++ b/third_party/skiplist/NodeRefs.h @@ -0,0 +1,251 @@ +// +// NodeRefs.h +// SkipList +// +// Created by Paul Ross on 03/12/2015. +// Copyright (c) 2017 Paul Ross. All rights reserved. +// + +#ifndef SkipList_NodeRefs_h +#define SkipList_NodeRefs_h + +#include "IntegrityEnums.h" + +/// Forward reference +template +class Node; + +/** + * @brief A PoD struct that contains a pointer to a Node and a width that represents the coarser linked list span to the + * next Node. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + */ +template > +struct NodeRef { + Node *pNode; + size_t width; +}; + +/******************** SwappableNodeRefStack **********************/ + +/** + * @brief Class that represents a stack of references to other nodes. + * + * Each reference is a NodeRef so a pointer to a Node and a width. + * This just does simple bookkeeping on this stack. + * + * It also facilitates swapping references with another SwappableNodeRefStack when inserting or removing a Node. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + */ +template +class SwappableNodeRefStack { +public: + /** + * Constructor. Initialises the swap level to 0. + */ + SwappableNodeRefStack() : _swapLevel(0) {} + + // Const methods + // ------------- + // Subscript read/write + const NodeRef &operator[](size_t level) const; + + NodeRef &operator[](size_t level); + + /// Number of nodes referenced. + size_t height() const { + return _nodes.size(); + } + + /// The current swap level + size_t swapLevel() const { return _swapLevel; } + + /// true if a swap can take place _swapLevel < height() + bool canSwap() const { return _swapLevel < height(); } + + // Returns true if there is no record of p in my data that + // could lead to circular references + bool noNodePointerMatches(const Node *p) const; + + // Returns true if all pointers in my data are equal to p. + bool allNodePointerMatch(const Node *p) const; + + // Non-const methods + // ----------------- + /// Add a new reference + void push_back(Node *p, size_t w) { + struct NodeRef val = {p, w}; + _nodes.push_back(val); + } + + /// Remove top reference + void pop_back() { + _nodes.pop_back(); + } + + // Swap reference at current swap level with another SwappableNodeRefStack + void swap(SwappableNodeRefStack &val); + + /// Reset the swap level (for example before starting a remove). + void resetSwapLevel() { _swapLevel = 0; } + + /// Increment the swap level. + /// This is used when removing nodes where the parent node can record to what level it has made its adjustments + /// so the grand parent knows where to start. + /// + /// For this reason the _swapLevel can easily be >= _nodes.size(). + void incSwapLevel() { ++_swapLevel; } + + IntegrityCheck lacksIntegrity() const; + + // Returns an estimate of the memory usage of an instance + size_t size_of() const; + + // Resets to the construction state + void clear() { _swapLevel = 0; _nodes.clear(); } + +protected: + /// Stack of NodeRef node references. + std::vector > _nodes; + /// The current swap level. + size_t _swapLevel; + +private: + /// Prevent cctor + SwappableNodeRefStack(const SwappableNodeRefStack &that); + + /// Prevent operator= + SwappableNodeRefStack &operator=(const SwappableNodeRefStack &that) const; +}; + +/** + * The readable NodeRef at the given level. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param level The level. + * @return A reference to the Node. + */ +template +const NodeRef &SwappableNodeRefStack::operator[](size_t level) const { + // NOTE: No bounds checking on vector::operator[], so this assert will do + assert(level < _nodes.size()); + return _nodes[level]; +} + +/** + * The writeable NodeRef at the given level. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param level The level. + * @return A reference to the Node. + */ +template +NodeRef &SwappableNodeRefStack::operator[](size_t level) { + // NOTE: No bounds checking on vector::operator[], so this assert will do + assert(level < _nodes.size()); + return _nodes[level]; +} + +/** + * Whether all node references are swapped. + * Should be true after an insert operation. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param p The Node. + * @return true if all the Node references are swapped (none are referring to the given Node). + */ +template +bool SwappableNodeRefStack::noNodePointerMatches(const Node *p) const { + for (size_t level = height(); level-- > 0;) { + if (p == _nodes[level].pNode) { + return false; + } + } + return true; +} + +/** + * Returns true if all pointers in my data are equal to p. + * Should be true after a remove operation. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param p The Node. + * @return true if all the Node references are un-swapped (all are referring to the given Node). + */ +template +bool SwappableNodeRefStack::allNodePointerMatch(const Node *p) const { + for (size_t level = height(); level-- > 0;) { + if (p != _nodes[level].pNode) { + return false; + } + } + return true; +} + +/** + * Swap references with another SwappableNodeRefStack at the current swap level. + * This also increments the swap level. + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @param val The SwappableNodeRefStack. + */ +template +void SwappableNodeRefStack::swap(SwappableNodeRefStack &val) { + assert(_swapLevel < height()); + NodeRef temp = val[_swapLevel]; + val[_swapLevel] = _nodes[_swapLevel]; + _nodes[_swapLevel] = temp; + ++_swapLevel; +} + +/** + * This checks the internal consistency of the object. It returns + * INTEGRITY_SUCCESS [0] if successful or non-zero on error. + * The tests are: + * + * - Widths must all be >= 1 + * - Widths must be weakly increasing with increasing level + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @return An IntegrityCheck enum. + */ +template +IntegrityCheck SwappableNodeRefStack::lacksIntegrity() const { + if (height()) { + if (_nodes[0].width != 1) { + return NODEREFS_WIDTH_ZERO_NOT_UNITY; + } + for (size_t level = 1; level < height(); ++level) { + if (_nodes[level].width < _nodes[level - 1].width) { + return NODEREFS_WIDTH_DECREASING; + } + } + } + return INTEGRITY_SUCCESS; +} + +/** + * Returns an estimate of the memory usage of an instance + * + * @tparam T The type of the Skip List Node values. + * @tparam _Compare A comparison function for type T. + * @return The memory estimate. + */ +template +size_t SwappableNodeRefStack::size_of() const { + return sizeof(*this) + _nodes.capacity() * sizeof(struct NodeRef); +} + +/********************* END: SwappableNodeRefStack ****************************/ + +#endif // SkipList_NodeRefs_h diff --git a/third_party/skiplist/RollingMedian.h b/third_party/skiplist/RollingMedian.h new file mode 100755 index 00000000000..604a1f221c0 --- /dev/null +++ b/third_party/skiplist/RollingMedian.h @@ -0,0 +1,202 @@ +#ifndef __SkipList__RollingMedian__ +#define __SkipList__RollingMedian__ + +/** + * @file + * + * Project: skiplist + * + * Rolling Median. + * + * Created by Paul Ross on 18/12/2015. + * + * Copyright (c) 2015-2023 Paul Ross. All rights reserved. + * + * @code + * MIT License + * + * Copyright (c) 2015-2023 Paul Ross + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * @endcode + */ + +#include + +#include "SkipList.h" + +namespace OrderedStructs { + /** + * @brief Namespace for the C++ Rolling Median. + */ + namespace RollingMedian { + +/** + * Error codes. + */ + enum RollingMedianResult { + ROLLING_MEDIAN_SUCCESS = 0, + ROLLING_MEDIAN_SOURCE_STRIDE, + ROLLING_MEDIAN_DESTINATION_STRIDE, + ROLLING_MEDIAN_WIN_LENGTH, + }; + +/** + * Return an error code. + */ +#define ROLLING_MEDIAN_ERROR_CHECK \ + do { \ + if (src_stride == 0) { \ + return ROLLING_MEDIAN_SOURCE_STRIDE; \ + } \ + if (dest_stride == 0) { \ + return ROLLING_MEDIAN_DESTINATION_STRIDE; \ + } \ + if (win_length == 0) { \ + return ROLLING_MEDIAN_WIN_LENGTH; \ + } \ + } while (0) + +/* Helpers for the destination memory area. + * Iterating through the destination to see the replaced values is done thus: + * + * for (int i = 0; + * i < RollingMedian::dest_size(COUNT, WIN_LENGTH, DEST_STRIDE); + * i += DEST_STRIDE) { + * ... + * } + */ + +/** + * Returns the size of the destination array for a rolling median on an array + * of count values with a window of win_length and a destination stride. + * + * @param count Number of input values. + * @param win_length Window length. + * @return Number of destination values. + */ + size_t dest_count(size_t count, size_t win_length) { + return 1 + count - win_length; + } + +/** + * Returns the size of the destination array for a rolling median on an array + * of count values with a window of win_length and a destination stride. + * + * @param count Number of input values. + * @param win_length Window length. + * @param dest_stride The destination stride given a 2D array. + * @return Size of destination array. + */ + size_t dest_size(size_t count, + size_t win_length, + size_t dest_stride) { + return dest_count(count, win_length) * dest_stride; + } + +/** + * Rolling median where only the odd mid-index is considered. + * If the win_length is even then (win_length - 1) / 2 value is used. + * See even_odd_index() for a different treatment of even lengths. + * This is valid for all types T. + * It is up to the caller to ensure that there is enough space in dest for + * the results, use dest_size() for this. + * + * @tparam T Type of the value(s). + * @param src Source array of values. + * @param src_stride Source stride for 2D arrays. + * @param count Number of input values. + * @param win_length Window length. + * @param dest The destination array. + * @param dest_stride The destination stride given a 2D array. + * @return The result of the Rolling Median operation as a RollingMedianResult enum. + */ + template + RollingMedianResult odd_index(const T *src, size_t src_stride, + size_t count, size_t win_length, + T *dest, size_t dest_stride) { + SkipList::HeadNode sl; + const T *tail = src; + + ROLLING_MEDIAN_ERROR_CHECK; + for (size_t i = 0; i < count; ++i) { + sl.insert(*src); + if (i + 1 >= win_length) { + *dest = sl.at(win_length / 2); + dest += dest_stride; + sl.remove(*tail); + tail += src_stride; + } + src += src_stride; + } + return ROLLING_MEDIAN_SUCCESS; + } + +/* + */ +/** + * Rolling median where the mean of adjacent values is used + * when the window size is even length. + * This requires T / 2 to be meaningful. + * It is up to the caller to ensure that there is enough space in dest for + * the results, use dest_size() for this. + * + * @tparam T Type of the value(s). + * @param src Source array of values. + * @param src_stride Source stride for 2D arrays. + * @param count Number of input values. + * @param win_length Window length. + * @param dest The destination array. + * @param dest_stride The destination stride given a 2D array. + * @return The result of the Rolling Median operation as a RollingMedianResult enum. + */ + template + RollingMedianResult even_odd_index(const T *src, size_t src_stride, + size_t count, size_t win_length, + T *dest, size_t dest_stride) { + if (win_length % 2 == 1) { + return odd_index(src, src_stride, + count, win_length, + dest, dest_stride); + } else { + ROLLING_MEDIAN_ERROR_CHECK; + SkipList::HeadNode sl; + std::vector buffer; + + const T *tail = src; + for (size_t i = 0; i < count; ++i) { + sl.insert(*src); + if (i + 1 >= win_length) { + sl.at((win_length - 1) / 2, 2, buffer); + assert(buffer.size() == 2); + *dest = buffer[0] / 2 + buffer[1] / 2; + dest += dest_stride; + sl.remove(*tail); + tail += src_stride; + } + src += src_stride; + } + } + return ROLLING_MEDIAN_SUCCESS; + } + + } // namespace RollingMedian +} // namespace OrderedStructs + +#endif /* defined(__SkipList__RollingMedian__) */ diff --git a/third_party/skiplist/SkipList.cpp b/third_party/skiplist/SkipList.cpp new file mode 100755 index 00000000000..9635af30471 --- /dev/null +++ b/third_party/skiplist/SkipList.cpp @@ -0,0 +1,90 @@ +// +// SkipList.cpp +// SkipList +// +// Created by Paul Ross on 19/12/2015. +// Copyright (c) 2017 Paul Ross. All rights reserved. +// + +#include +#ifdef SKIPLIST_THREAD_SUPPORT +#include +#endif +#include + +#include "SkipList.h" + +namespace duckdb_skiplistlib { +namespace skip_list { + +/** Tosses a virtual coin, returns true if 'heads'. + * + * No heads, ever: + * @code + * return false; + * @endcode + * + * 6.25% heads: + * @code + * return rand() < RAND_MAX / 16; + * @endcode + * + * 12.5% heads: + * @code + * return rand() < RAND_MAX / 8; + * @endcode + * + * 25% heads: + * @code + * return rand() < RAND_MAX / 4; + * @endcode + * + * Fair coin: + * @code + * return rand() < RAND_MAX / 2; + * @endcode + * + * 75% heads: + * @code + * return rand() < RAND_MAX - RAND_MAX / 4; + * @endcode + * + * 87.5% heads: + * @code + * return rand() < RAND_MAX - RAND_MAX / 8; + * @endcode + * + * 93.75% heads: + * @code + * @return rand() < RAND_MAX - RAND_MAX / 16; + * @endcode + */ +bool tossCoin() { + return rand() < RAND_MAX / 2; +} + +void seedRand(unsigned seed) { + srand(seed); +} + +// This throws an IndexError when the index value >= size. +// If possible the error will have an informative message. +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS +void _throw_exceeds_size(size_t index) { + std::ostringstream oss; + oss << "Index out of range 0 <= index < " << index; + std::string err_msg = oss.str(); +#else +void _throw_exceeds_size(size_t /* index */) { + std::string err_msg = "Index out of range."; +#endif + throw IndexError(err_msg); +} + +#ifdef SKIPLIST_THREAD_SUPPORT + std::mutex gSkipListMutex; +#endif + + +} // namespace SkipList +} // namespace OrderedStructs diff --git a/third_party/skiplist/SkipList.h b/third_party/skiplist/SkipList.h new file mode 100755 index 00000000000..619208e7dfa --- /dev/null +++ b/third_party/skiplist/SkipList.h @@ -0,0 +1,553 @@ +#ifndef __SkipList__SkipList__ +#define __SkipList__SkipList__ + +/** + * @file + * + * Project: skiplist + * + * Created by Paul Ross on 15/11/2015. + * + * Copyright (c) 2015-2023 Paul Ross. All rights reserved. + * + * @code + * MIT License + * + * Copyright (c) 2017-2023 Paul Ross + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * @endcode + */ + +/** @mainpage + * + * General + * ======= + * This is a generic skip list implementation for any type T. + * There only restriction on the size of this skip list is the available memory. + * + * A skip list is a singly linked list of ordered nodes with a series of other, coarser, lists that reference a subset + * of nodes in order. + * 'Level' is an size_t that specifies the coarseness of the linked list, level 0 is the linked list to every node. + * + * Typically: + * - The list at level 1 links (ideally) to every other node. + * - The list at level 2 links (ideally) to every fourth node and so on. + * + * In general the list at level n links (ideally) to every 2**n node. + * + * These additional lists allow rapid location, insertion and removal of nodes. + * These lists are created and updated in a probabilistic manner and this is achieved at node creation time by tossing a + * virtual coin. + * These lists are not explicit, they are implied by the references between Nodes at a particular level. + * + * Skip lists are alternatives to balanced trees for operations such as a rolling median. + * The disadvantages of skip lists are: + - Less space efficient than balanced trees (see 'Space Complexity' below). + - performance is similar to balanced trees except finding the mid-point which is @c O(log(N)) for a skip list + compared with @c O(1) for a balanced tree. + * + * The advantages claimed for skip lists are: + - The insert() and remove() logic is simpler (I do not subscribe to this). + * + * Examples of Usage + * ================= + * + * C++ + * --- + * @code + * #include "SkipList.h" + * + * OrderedStructs::SkipList::HeadNode sl; + * + * sl.insert(42.0); + * sl.insert(21.0); + * sl.insert(84.0); + * sl.has(42.0) // true + * sl.size() // 3 + * sl.at(1) // 42.0 + * @endcode + * + * Python + * ------ + * @code + * import orderedstructs + * + * sl = orderedstructs.SkipList(float) + * sl.insert(42.0) + * sl.insert(21.0) + * sl.insert(84.0) + * sl.has(42.0) # True + * sl.size() # 3 + * sl.at(1) # 42.0 + * @endcode + * + * Design + * ====== + * + * This skip list design has the coarser lists implemented as optional additional links between the nodes themselves. + * The drawing below shows a well formed skip list with a head node ('HED') linked to the ordered nodes A to H. + * + * @code + * + | 5 E |------------------------------------->| 4 0 |---------------------------->| NULL | + | 1 A |->| 2 C |---------->| 2 E |---------->| 2 G |---------->| 2 0 |---------->| NULL | + | 1 A |->| 1 B |->| 1 C |->| 1 D |->| 1 E |->| 1 F |->| 1 G |->| 1 H |->| 1 0 |->| NULL | + | HED | | A | | B | | C | | D | | E | | F | | G | | H | + * @endcode + * + * Each node has a stack of values that consist of a 'width' and a reference to another node (or NULL). + * At the lowest level is a singly linked list and all widths are 1. + * At level 1 the links are (ideally) to every other node and at level 2 the links are (ideally) to every fourth node. + * The 'widths' at each node/level specify how many level 0 nodes the node reference skips over. + * The widths are used to rapidly index into the skip list starting from the highest level and working down. + * + * To understand how the skip list is maintained, consider insertion; before inserting node 'E' the skip list would look + * like this: + * + * @code + * + | 1 A |->| 2 C |---------->| 3 G |------------------->| 2 0 |---------->| NULL | + | 1 A |->| 1 B |->| 1 C |->| 1 D |->| 1 F |->| 1 G |->| 1 H |->| 1 0 |->| NULL | + | HED | | A | | B | | C | | D | | F | | G | | H | + * + * @endcode + * + * Inserting 'E' means: + * - Finding where 'E' should be inserted (after 'D'). + * - Creating node 'E' with a random height (heads/heads/tails so 3 high). + * - Updating 'D' to refer to 'E' at level 0. + * - Updating 'C' to refer to 'E' at level 1 and decreasing C's width to 2, increasing 'E' width at level 1 to 2. + * - Expanding HED to level 2 with a reference to 'E' and a width of 5. + * - Updating 'E' with a reference to NULL and a width of 4. + * + * Recursive Search for the Node Position + * -------------------------------------- + * The first two operations are done by a recursive search. + * This creates the chain HED[1], A[1], C[1], C[0], D[0] thus E will be created at level 0 and inserted after D. + * + * Node Creation + * ------------- + * Node E is created with a stack containing a single pointer to the next node F. + * Then a virtual coin is tossed, for each 'head' and extra NULL reference is added to the stack. + * If a 'tail' is thrown the stack is complete. + * In the example above when creating Node E we have encountered tosses of 'head', 'head', 'tail'. + * + * Recursive Unwinding + * ------------------- + * The remaining operations are done as recursion unwinds: + * + * - D[0] and C[0] update E[1] with their cumulative width (2). + * - C[1] adds 1 to width (a new node is inserted) then subtracts E[1]. + * - Then C[1]/E[1] are swapped so that the pointers and widths are correct. + * - And so on until HED is reached, in this case a new level is added and HED[2] swapped with E[2]. + * + * A similar procedure will be followed, in reverse, when removing E to restore the state of the skip list to the + * picture above. + * + * Algorithms + * ========== + * There doesn't seem to be much literature that I could find about the algorithms used for a skip list so these have + * all been invented here. + * + * In these descriptions: + * + * - 'right' is used to mean move to a higher ordinal node. + * - 'left' means to move to a lower ordinal node. + * - 'up' means to move to a coarser grained list, 'top' is the highest. + * - 'down' means to move to a finer grained list, 'bottom' is the level 0. + * + * has(T &val) const; + * ------------------ + * This returns true/false is the skip list has the value val. + * Starting at the highest possible level search rightwards until a larger value is encountered, then drop down. + * At level 0 return true if the Node value is the supplied value. + * This is @c O(log(N)) for well formed skip lists. + * + * at(size_t index) const; + * ----------------------- + * This returns the value of type T at the given index. + * The algorithm is similar to has(T &val) but the search moves rightwards if the width is less than the index and + * decrementing the index by the width. + * + * If progress can not be made to the right, drop down a level. + * If the index is 0 return the node value. + * This is @c O(log(N)) for well formed skip lists. + * + * insert(T &val) + * -------------- + * Finding the place to insert a node follows the has(T &val) algorithm to find the place in the skip list to create a + * new node. + * A duplicate value is inserted after any existing duplicate values. + * + * - All nodes are inserted at level 0 even if the insertion point can be seen at a higher level. + * - The search for an insertion location creates a recursion stack that, when unwound, updates the traversed nodes + * {width, Node*} data. + * - Once an insert position is found a Node is created whose height is determined by repeatedly tossing a virtual coin + * until a 'tails' is thrown. + * - This node initially has all node references to be to itself (this), and the widths set to 1 for level 0 and 0 for + * the remaining levels, they will be used to sum the widths at one level down. + * - On recursion ('left') each node adds its width to the new node at the level above the current level. + * - On moving up a level the current node swaps its width and node pointer with the new node at that new level. + * + * remove(T &val) + * -------------- + * + * If there are duplicate values the last one is removed first, this is for symmetry with insert(). + * Essentially this is the same as insert() but once the node is found the insert() updating algorithm is reversed and + * the node deleted. + * + * Code Layout + * =========== + * There are three classes defined in their own .h files and these are all included into the SkipList.h file. + * + * The classes are: + * + * SwappableNodeRefStack + * + * This is simple bookkeeping class that has a vector of [{skip_width, Node*}, ...]. + * This vector can be expanded or contracted at will. + * Both HeadNode and Node classes have one of these to manage their references. + * + * Node + * + * This represents a single value in the skip list. + * The height of a Node is determined at construction by tossing a virtual coin, this determines how many coarser + * lists this node participates in. + * A Node has a SwappableNodeRefStack object and a value of type T. + * + * HeadNode + * + * There is one of these per skip list and this provides the API to the entire skip list. + * The height of the HeadNode expands and contracts as required when Nodes are inserted or removed (it is the height + * of the highest Node). + * A HeadNode has a SwappableNodeRefStack object and an independently maintained count of the number of Node objects + * in the skip list. + * + * A Node and HeadNode have specialised methods such as has(), at(), insert(), remove() that traverse the skip lis + * recursively. + * + * Other Files of Significance + * --------------------------- + * SkipList.cpp exposes the random number generator (rand()) and seeder (srand()) so that they can be accessed + * CPython for deterministic testing. + * + * cSkipList.h and cSkipList.cpp contains a CPython module with a SkipList implementation for a number of builtin + * Python types. + * + * IntegrityEnums.h has definitions of error codes that can be created by the skip list integrity checking functions. + * + * Code Idioms + * =========== + * + * Prevent Copying + * --------------- + * Copying operations are (mostly) prohibited for performance reasons. + * The only class that allows copying is struct NodeRef that contains fundamental types. + * All other classes declare their copying operation private and unimplemented (rather than using C++11 delete) for + * compatibility with older compilers. + * + * Reverse Loop of Unsigned int + * ---------------------------- + * In a lot of the code we have to count down from some value to 0 + * with a size_t (an unsigned integer type) The idiom used is this: + * + * @code + * + * for (size_t l = height(); l-- > 0;) { + * // ... + * } + * + * @endcode + * + * The "l-- > 0" means test l against 0 then decrement it. + * l will thus start at the value height() - 1 down to 0 then exit the loop. + * + * @note If l is declared before the loop it will have the maximum value of a size_t unless a break statement is + * encountered. + * + * Roads not Travelled + * =================== + * Certain designs were not explored, here they are and why. + * + * Key/Value Implementation + * ------------------------ + * Skip lists are commonly used for key/value dictionaries. Given things + * like map or unorderedmap I see no reason why a SkipList should be used + * as an alternative. + * + * Adversarial Users + * ----------------- + * If the user knows the behaviour of the random number generator it is possible that they can change the order of + * insertion to create a poor distribution of nodes which will make operations tend to O(N) rather than O(log(N)). + * + * Probability != 0.5 + * ------------------ + * This implementation uses a fair coin to decide the height of the node. + * + * Some literature suggests other values such as p = 0.25 might be more efficient. + * Some experiments seem to show that this is the case with this implementation. + * Here are some results when using a vector of 1 million doubles and a sliding window of 101 where each value is + * inserted and removed and the cental value recovered: + * + * @code + * + Probability calculation p Time compared to p = 0.5 + rand() < RAND_MAX / 16; 0.0625 90% + rand() < RAND_MAX / 8; 0.125 83% + rand() < RAND_MAX / 4; 0.25 80% + rand() < RAND_MAX / 2; 0.5 100% + rand() > RAND_MAX / 4; 0.75 143% + rand() > RAND_MAX / 8; 0.875 201% + * + * @endcode + * + * Optimisation: Re-index Nodes on Complete Traversal + * -------------------------------------------------- + * + * @todo Re-index Nodes on Complete Traversal ??? + * + * Optimisation: Reuse removed nodes for insert() + * ---------------------------------------------- + * @todo Reuse removed nodes for insert() ??? + * + * Reference Counting + * ------------------ + * Some time (and particularly space) improvement could be obtained by reference counting nodes so that duplicate + * values could be eliminated. + * Since the primary use case for this skip list is for computing the rolling median of doubles the chances of + * duplicates are slim. + * For int, long and string there is a higher probability so reference counting might be implemented in the future if + * these types become commonly used. + * + * Use and Array of {skip_width, Node*} rather than a vector + * ---------------------------------------------------------------------- + * + * Less space would be used for each Node if the SwappableNodeRefStack used a dynamically allocated array of + * [{skip_width, Node*}, ...] rather than a vector. + * + * Performance + * =========== + * + * Reference platform: Macbook Pro, 13" running OS X 10.9.5. LLVM version 6.0 targeting x86_64-apple-darwin13.4.0 + * Compiled with -Os (small fast). + * + * Performance of at() and has() + * ----------------------------- + * + * Performance is O(log(N)) where N is the position in the skip list. + * + * On the reference platform this tests as t = 200 log2(N) in nanoseconds for skip lists of doubles. + * This factor of 200 can be between 70 and 500 for the same data but different indices because of the probabilistic + * nature of a skip list. + * For example finding the mid value of 1M doubles takes 3 to 4 microseconds. + * + * @note + * On Linux RHEL5 with -O3 this is much faster with t = 12 log2(N) + * [main.cpp perf_at_in_one_million(), main.cpp perf_has_in_one_million()] + * + * Performance of insert() and remove() + * ------------------------------------ + * A test that inserts then removes a single value in an empty list takes 440 nanoseconds (around 2.3 million per + * second). + * This should be fast as the search space is small. + * + * @note + * Linux RHEL5 with -O3 this is 4.2 million per second. [main.cpp perf_single_insert_remove()] + * + * A test that inserts 1M doubles into a skip list (no removal) takes 0.9 seconds (around 1.1 million per second). + * + * @note + * Linux RHEL5 with -O3 this is similar. [main.cpp perf_large_skiplist_ins_only()] + * + * A test that inserts 1M doubles into a skip list then removes all of them takes 1.0 seconds (around 1 million per second). + * + * @note + * Linux RHEL5 with -O3 this is similar. [main.cpp perf_large_skiplist_ins_rem()] + * + * A test that creates a skip list of 1M doubles then times how long it takes to insert and remove a value at the + * mid-point takes 1.03 microseconds per item (around 1 million per second). + * + * @note + * Linux RHEL5 with -O3 this is around 0.8 million per second. [main.cpp perf_single_ins_rem_middle()] + * + * A test that creates a skip list of 1M doubles then times how long it takes to insert a value, find the value at the + * mid point then remove that value (using insert()/at()/remove()) takes 1.2 microseconds per item (around 0.84 million + * per second). + * + * @note + * Linux RHEL5 with -O3 this is around 0.7 million per second. [main.cpp perf_single_ins_at_rem_middle()] + * + * Performance of a rolling median + * ------------------------------- + * On the reference platform a rolling median (using insert()/at()/remove()) on 1M random values takes about 0.93 + * seconds. + * + * @note + * Linux RHEL5 with -O3 this is about 0.7 seconds. + * [main.cpp perf_1m_median_values(), main.cpp perf_1m_medians_1000_vectors(), main.cpp perf_simulate_real_use()] + * + * The window size makes little difference, a rolling median on 1m items with a window size of 1 takes 0.491 seconds, + * with a window size of 524288 it takes 1.03 seconds. + * + * @note + * Linux RHEL5 with -O3 this is about 0.5 seconds. [main.cpp perf_roll_med_odd_index_wins()] + * + * Space Complexity + * ---------------- + * Given: + * + * - t = sizeof(T) ~ typ. 8 bytes for a double + * - v = sizeof(std::vector>) ~ typ. 32 bytes + * - p = sizeof(Node*) ~ typ. 8 bytes + * - e = sizeof(struct NodeRef) ~ typ. 8 + p = 16 bytes + * + * Then each node: is t + v bytes. + * + * Linked list at level 0 is e bytes per node. + * + * Linked list at level 1 is, typically, e / 2 bytes per node and so on. + * + * So the totality of linked lists is about 2e bytes per node. + * + * The total is N * (t + v + 2 * e) which for T as a double is typically 72 bytes per item. + * + * In practice this has been measured on the reference platform as a bit larger at 86.0 Mb for 1024*1024 doubles. + * + ***************** END: SkipList Documentation *****************/ + +/// Defined if you want the SkipList to have methods that can output +/// to stream (for debugging for example). +/// Defining this will mean that classes grow methods that use streams. +/// Undef this if you want a smaller binary in production as using streams +/// adds typically around 30kb to the binary. +/// However you may loose useful information such as formatted +/// exception messages with extra data. +//#define INCLUDE_METHODS_THAT_USE_STREAMS +#undef INCLUDE_METHODS_THAT_USE_STREAMS + +#include +#include +#include // Used for HeadNode::_lacksIntegrityNodeReferencesNotInList() +#include // Used for class Exception +#include + +#ifdef DEBUG +#include +#else +#ifndef assert +#define assert(x) +#endif +#endif // DEBUG + +#ifdef INCLUDE_METHODS_THAT_USE_STREAMS + +#include +#include + +#endif // INCLUDE_METHODS_THAT_USE_STREAMS + +//#define SKIPLIST_THREAD_SUPPORT +//#define SKIPLIST_THREAD_SUPPORT_TRACE + +#ifdef SKIPLIST_THREAD_SUPPORT +#ifdef SKIPLIST_THREAD_SUPPORT_TRACE +#include +#endif +#include +#endif + +/** + * @brief Namespace for all the C++ ordered structures. + */ +namespace duckdb_skiplistlib { + /** + * @brief Namespace for the C++ Slip List. + */ + namespace skip_list { + +/************************ Exceptions ****************************/ + +/** + * @brief Base exception class for all exceptions in the OrderedStructs::SkipList namespace. + */ + class Exception : public std::exception { + public: + explicit Exception(const std::string &in_msg) : msg(in_msg) {} + + const std::string &message() const { return msg; } + + virtual ~Exception() throw() {} + + protected: + std::string msg; + }; + +/** + * @brief Specialised exception case for an index out of range error. + */ + class IndexError : public Exception { + public: + explicit IndexError(const std::string &in_msg) : Exception(in_msg) {} + }; + +/** + * @brief Specialised exception for an value error where the given value does not exist in the Skip List. + */ + class ValueError : public Exception { + public: + explicit ValueError(const std::string &in_msg) : Exception(in_msg) {} + }; + +/** @brief Specialised exception used for NaN detection where value != value (example NaNs). */ + class FailedComparison : public Exception { + public: + explicit FailedComparison(const std::string &in_msg) : Exception(in_msg) {} + }; + +/** + * This throws an IndexError when the index value >= the size of Skip List. + * If @ref INCLUDE_METHODS_THAT_USE_STREAMS is defined then the error will have an informative message. + * + * @param index The out of range index. + */ + void _throw_exceeds_size(size_t index); + +/************************ END: Exceptions ****************************/ + + bool tossCoin(); + + /** Seed the random number generator for coin tosses. */ + void seedRand(unsigned seed); + +#ifdef SKIPLIST_THREAD_SUPPORT + /** + * Mutex used in a multi-threaded environment. + */ + extern std::mutex gSkipListMutex; +#endif + +#include "NodeRefs.h" +#include "Node.h" +#include "HeadNode.h" + + } // namespace skip_list +} // namespace duckdb_skiplistlib + +#endif /* defined(__SkipList__SkipList__) */