Skip to content
Permalink
Browse files

Merge pull request #275 from hawkfish/hawkfish-combinefunc

Implement aggregate state merging
  • Loading branch information
hannesmuehleisen committed Sep 1, 2019
2 parents 958295e + a7b3c81 commit 297e39cdcb5e5c462bf2bec9318f49881d15df91
@@ -52,16 +52,8 @@ string ExpressionTypeToString(ExpressionType type) {
return "SCALAR";
case ExpressionType::AGGREGATE:
return "AGGREGATE";
case ExpressionType::WINDOW_SUM:
return "SUM";
case ExpressionType::WINDOW_COUNT_STAR:
return "COUNT_STAR";
case ExpressionType::WINDOW_MIN:
return "MIN";
case ExpressionType::WINDOW_MAX:
return "MAX";
case ExpressionType::WINDOW_AVG:
return "AVG";
case ExpressionType::WINDOW_AGGREGATE:
return "WINDOW_AGGREGATE";
case ExpressionType::WINDOW_RANK:
return "RANK";
case ExpressionType::WINDOW_RANK_DENSE:
@@ -52,15 +52,25 @@ static index_t BinarySearchRightmost(ChunkCollection &input, vector<Value> row,
return l - 1;
}

static void MaterializeExpression(ClientContext &context, Expression *expr, ChunkCollection &input,
static void MaterializeExpressions(ClientContext &context, Expression** exprs, index_t expr_count, ChunkCollection &input,
ChunkCollection &output, bool scalar = false) {
ChunkCollection boundary_start_collection;
vector<TypeId> types = {expr->return_type};
if (expr_count == 0 ) {
return;
}

vector<TypeId> types;
for ( index_t expr_idx = 0; expr_idx < expr_count; ++expr_idx ) {
types.push_back(exprs[expr_idx]->return_type);
}

for (index_t i = 0; i < input.chunks.size(); i++) {
DataChunk chunk;
chunk.Initialize(types);
ExpressionExecutor executor(*input.chunks[i]);
executor.ExecuteExpression(*expr, chunk.data[0]);
for ( index_t expr_idx = 0; expr_idx < expr_count; ++expr_idx ) {
auto expr = exprs[expr_idx];
executor.ExecuteExpression(*expr, chunk.data[expr_idx]);
}

chunk.Verify();
output.Append(chunk);
@@ -71,6 +81,11 @@ static void MaterializeExpression(ClientContext &context, Expression *expr, Chun
}
}

static void MaterializeExpression(ClientContext &context, Expression* expr, ChunkCollection &input,
ChunkCollection &output, bool scalar = false) {
MaterializeExpressions(context, &expr, 1, input, output, scalar);
}

static void SortCollectionForWindow(ClientContext &context, BoundWindowExpression *wexpr, ChunkCollection &input,
ChunkCollection &output, ChunkCollection &sort_collection) {
vector<TypeId> sort_types;
@@ -260,10 +275,12 @@ static void ComputeWindowExpression(ClientContext &context, BoundWindowExpressio

// evaluate inner expressions of window functions, could be more complex
ChunkCollection payload_collection;
if (wexpr->child) {
// TODO: child[0] may be a scalar, don't need to materialize the whole collection then
MaterializeExpression(context, wexpr->child.get(), input, payload_collection);
vector<Expression*> exprs;
for (auto& child : wexpr->children) {
exprs.push_back(child.get());
}
// TODO: child may be a scalar, don't need to materialize the whole collection then
MaterializeExpressions(context, exprs.data(), exprs.size(), input, payload_collection);

ChunkCollection leadlag_offset_collection;
ChunkCollection leadlag_default_collection;
@@ -296,16 +313,8 @@ static void ComputeWindowExpression(ClientContext &context, BoundWindowExpressio
// see http://www.vldb.org/pvldb/vol8/p1058-leis.pdf
unique_ptr<WindowSegmentTree> segment_tree = nullptr;

switch (wexpr->type) {
case ExpressionType::WINDOW_SUM:
case ExpressionType::WINDOW_MIN:
case ExpressionType::WINDOW_MAX:
case ExpressionType::WINDOW_AVG:
segment_tree = make_unique<WindowSegmentTree>(wexpr->type, wexpr->return_type, &payload_collection);
break;
default:
break;
// nothing
if (wexpr->aggregate) {
segment_tree = make_unique<WindowSegmentTree>(*(wexpr->aggregate), wexpr->return_type, &payload_collection);
}

WindowBoundariesState bounds;
@@ -342,18 +351,10 @@ static void ComputeWindowExpression(ClientContext &context, BoundWindowExpressio
}

switch (wexpr->type) {
case ExpressionType::WINDOW_SUM:
case ExpressionType::WINDOW_MIN:
case ExpressionType::WINDOW_MAX:
case ExpressionType::WINDOW_AVG: {
assert(segment_tree);
case ExpressionType::WINDOW_AGGREGATE: {
res = segment_tree->Compute(bounds.window_start, bounds.window_end);
break;
}
case ExpressionType::WINDOW_COUNT_STAR: {
res = Value::Numeric(wexpr->return_type, bounds.window_end - bounds.window_start);
break;
}
case ExpressionType::WINDOW_ROW_NUMBER: {
res = Value::Numeric(wexpr->return_type, row_idx - bounds.partition_start + 1);
break;
@@ -8,89 +8,74 @@
using namespace duckdb;
using namespace std;

void WindowSegmentTree::AggregateInit() {
switch (window_type) {
case ExpressionType::WINDOW_SUM:
case ExpressionType::WINDOW_AVG:
aggregate = Value::Numeric(payload_type, 0);
break;
case ExpressionType::WINDOW_MIN:
aggregate = Value::MaximumValue(payload_type);
break;
case ExpressionType::WINDOW_MAX:
aggregate = Value::MinimumValue(payload_type);
break;
default:
throw NotImplementedException("Window Type");
WindowSegmentTree::WindowSegmentTree(AggregateFunction& aggregate, TypeId result_type, ChunkCollection *input)
: aggregate(aggregate), state(aggregate.state_size(result_type)), statep(TypeId::POINTER, true, false), result_type(result_type),
input_ref(input) {
statep.count = STANDARD_VECTOR_SIZE;
VectorOperations::Set(statep, Value::POINTER((index_t) state.data()));

if (input_ref && input_ref->column_count() > 0) {
inputs = unique_ptr<Vector[]>(new Vector[input_ref->column_count()]);
}

if (aggregate.combine && inputs) {
ConstructTree();
}
assert(aggregate.type == payload_type);
n_aggregated = 0;
}

void WindowSegmentTree::AggregateInit() {
aggregate.initialize(state.data(), result_type);
}

Value WindowSegmentTree::AggegateFinal() {
if (n_aggregated == 0) {
return Value(payload_type);
}
switch (window_type) {
case ExpressionType::WINDOW_AVG:
return aggregate / Value::Numeric(payload_type, n_aggregated);
default:
return aggregate;
}
ConstantVector statev(Value::POINTER((index_t) state.data()));

Value r(result_type);
ConstantVector result(r);
result.SetNull(0, false);
aggregate.finalize(statev, result);

return aggregate;
return result.GetValue(0);
}

void WindowSegmentTree::WindowSegmentValue(index_t l_idx, index_t begin, index_t end) {
assert(begin <= end);
if (begin == end) {
return;
}
Vector s;
s.Reference(statep);
s.count = end - begin;
Vector v;
if (l_idx == 0) {
auto &vec = input_ref->GetChunk(begin).data[0];
v.Reference(vec);
index_t start_in_vector = begin % STANDARD_VECTOR_SIZE;
v.data = v.data + GetTypeIdSize(v.type) * start_in_vector;
v.count = end - begin;
v.nullmask <<= start_in_vector;
assert(v.count + start_in_vector <=
vec.count); // if STANDARD_VECTOR_SIZE is not divisible by fanout this will trip
const auto input_count = input_ref->column_count();
auto &chunk = input_ref->GetChunk(begin);
for (index_t i = 0; i < input_count; ++i) {
auto &v = inputs[i];
auto &vec = chunk.data[i];
v.Reference(vec);
index_t start_in_vector = begin % STANDARD_VECTOR_SIZE;
v.data = v.data + GetTypeIdSize(v.type) * start_in_vector;
v.count = end - begin;
v.nullmask <<= start_in_vector;
assert(!v.sel_vector);
v.Verify();
}
aggregate.update(&inputs[0], input_count, s);
} else {
assert(end - begin < STANDARD_VECTOR_SIZE);
v.data = levels_flat_native.get() + GetTypeIdSize(payload_type) * (begin + levels_flat_start[l_idx - 1]);
v.data = levels_flat_native.get() + state.size() * (begin + levels_flat_start[l_idx - 1]);
v.count = end - begin;
v.type = payload_type;
v.type = result_type;
assert(!v.sel_vector);
v.Verify();
aggregate.combine(v, s);
}

assert(!v.sel_vector);
v.Verify();

switch (window_type) {
case ExpressionType::WINDOW_SUM:
case ExpressionType::WINDOW_AVG:
aggregate = aggregate + VectorOperations::Sum(v);
break;
case ExpressionType::WINDOW_MIN: {
auto val = VectorOperations::Min(v);
aggregate = aggregate > val ? val : aggregate;
break;
}
case ExpressionType::WINDOW_MAX: {
auto val = VectorOperations::Max(v);
aggregate = aggregate < val ? val : aggregate;
break;
}
default:
throw NotImplementedException("Window Type");
}

n_aggregated += end - begin;
}

void WindowSegmentTree::ConstructTree() {
assert(input_ref);
assert(input_ref->column_count() == 1);
assert(inputs);

// compute space required to store internal nodes of segment tree
index_t internal_nodes = 0;
@@ -99,7 +84,7 @@ void WindowSegmentTree::ConstructTree() {
level_nodes = (index_t)ceil((double)level_nodes / TREE_FANOUT);
internal_nodes += level_nodes;
} while (level_nodes > 1);
levels_flat_native = unique_ptr<data_t[]>(new data_t[internal_nodes * GetTypeIdSize(payload_type)]);
levels_flat_native = unique_ptr<data_t[]>(new data_t[internal_nodes * state.size()]);
levels_flat_start.push_back(0);

index_t levels_flat_offset = 0;
@@ -112,11 +97,7 @@ void WindowSegmentTree::ConstructTree() {
AggregateInit();
WindowSegmentValue(level_current, pos, min(level_size, pos + TREE_FANOUT));

ConstantVector res_vec(AggegateFinal());
assert(res_vec.type == payload_type);
ConstantVector ptr_vec(Value::POINTER(
(index_t)(levels_flat_native.get() + (levels_flat_offset * GetTypeIdSize(payload_type)))));
VectorOperations::Scatter::Set(res_vec, ptr_vec);
memcpy(levels_flat_native.get() + (levels_flat_offset * state.size()), state.data(), state.size());

levels_flat_offset++;
}
@@ -128,7 +109,20 @@ void WindowSegmentTree::ConstructTree() {

Value WindowSegmentTree::Compute(index_t begin, index_t end) {
assert(input_ref);

// No arguments, so just count
if (!inputs) {
return Value::Numeric(result_type, end - begin);
}

AggregateInit();

// Aggregate everything at once if we can't combine states
if (!aggregate.combine) {
WindowSegmentValue(0, begin, end);
return AggegateFinal();
}

for (index_t l_idx = 0; l_idx < levels_flat_start.size() + 1; l_idx++) {
index_t parent_begin = begin / TREE_FANOUT;
index_t parent_end = end / TREE_FANOUT;
@@ -37,6 +37,24 @@ static void avg_update(Vector inputs[], index_t input_count, Vector &state) {
});
}

static void avg_combine(Vector &state, Vector &combined) {
// combine streaming avg states
auto combined_data = (avg_state_t**) combined.data;
auto state_data = (avg_state_t*) state.data;

VectorOperations::Exec(state, [&](uint64_t i, uint64_t k) {
auto combined_ptr = combined_data[i];
auto state_ptr = state_data + i;

if (0 == combined_ptr->count) {
*combined_ptr = *state_ptr;
} else if (state_ptr->count) {
combined_ptr->count += state_ptr->count;
combined_ptr->sum += state_ptr->sum;
}
});
}

static void avg_finalize(Vector &state, Vector &result) {
// compute finalization of streaming avg
VectorOperations::Exec(state, [&](uint64_t i, uint64_t k) {
@@ -51,5 +69,5 @@ static void avg_finalize(Vector &state, Vector &result) {
}

void Avg::RegisterFunction(BuiltinFunctions &set) {
set.AddFunction(AggregateFunction("avg", {SQLType::DOUBLE}, SQLType::DOUBLE, avg_payload_size, avg_initialize, avg_update, avg_finalize));
set.AddFunction(AggregateFunction("avg", {SQLType::DOUBLE}, SQLType::DOUBLE, avg_payload_size, avg_initialize, avg_update, avg_combine, avg_finalize));
}
@@ -30,11 +30,6 @@ static void covar_update(Vector inputs[], index_t input_count, Vector &state) {
if (inputs[0].nullmask[i] || inputs[1].nullmask[i]) {
return;
}
// Layout of state for online covariance:
// uint64_t count
// double meanx
// double meany
// double co-moment

auto state_ptr = (covar_state_t*) ((data_ptr_t *)state.data)[i];

@@ -58,6 +53,33 @@ static void covar_update(Vector inputs[], index_t input_count, Vector &state) {
});
}

static void covar_combine(Vector &state, Vector &combined) {
// combine streaming covar states
auto combined_data = (covar_state_t**) combined.data;
auto state_data = (covar_state_t*) state.data;

VectorOperations::Exec(state, [&](uint64_t i, uint64_t k) {
auto combined_ptr = combined_data[i];
auto state_ptr = state_data + i;

if (0 == combined_ptr->count) {
*combined_ptr = *state_ptr;
} else if (state_ptr->count) {
const auto count = combined_ptr->count + state_ptr->count;
const auto meanx = ( state_ptr->count * state_ptr->meanx + combined_ptr->count * combined_ptr->meanx ) / count;
const auto meany = ( state_ptr->count * state_ptr->meany + combined_ptr->count * combined_ptr->meany ) / count;

// Schubert and Gertz SSDBM 2018, equation 21
const auto deltax = combined_ptr->meanx - state_ptr->meanx;
const auto deltay = combined_ptr->meany - state_ptr->meany;
combined_ptr->co_moment = state_ptr->co_moment + combined_ptr->co_moment + deltax * deltay * state_ptr->count * combined_ptr->count / count;
combined_ptr->meanx = meanx;
combined_ptr->meany = meany;
combined_ptr->count = count;
}
});
}

static void covarpop_finalize(Vector &state, Vector &result) {
// compute finalization of streaming population covariance
VectorOperations::Exec(result, [&](uint64_t i, uint64_t k) {
@@ -89,9 +111,9 @@ static void covarsamp_finalize(Vector &state, Vector &result) {
}

void CovarSamp::RegisterFunction(BuiltinFunctions &set) {
set.AddFunction(AggregateFunction("covar_samp", {SQLType::DOUBLE, SQLType::DOUBLE}, SQLType::DOUBLE, covar_state_size, covar_initialize, covar_update, covarsamp_finalize));
set.AddFunction(AggregateFunction("covar_samp", {SQLType::DOUBLE, SQLType::DOUBLE}, SQLType::DOUBLE, covar_state_size, covar_initialize, covar_update, covar_combine, covarsamp_finalize));
}

void CovarPop::RegisterFunction(BuiltinFunctions &set) {
set.AddFunction(AggregateFunction("covar_pop", {SQLType::DOUBLE, SQLType::DOUBLE}, SQLType::DOUBLE, covar_state_size, covar_initialize, covar_update, covarpop_finalize));
set.AddFunction(AggregateFunction("covar_pop", {SQLType::DOUBLE, SQLType::DOUBLE}, SQLType::DOUBLE, covar_state_size, covar_initialize, covar_update, covar_combine, covarpop_finalize));
}

0 comments on commit 297e39c

Please sign in to comment.
You can’t perform that action at this time.