Skip to content

Commit

Permalink
Rework merge join to support all join types, not just INNER and MARK …
Browse files Browse the repository at this point in the history
…joins
  • Loading branch information
Mytherin committed Jun 29, 2020
1 parent d9dc1e0 commit 4c56c88
Show file tree
Hide file tree
Showing 12 changed files with 283 additions and 275 deletions.
4 changes: 2 additions & 2 deletions src/execution/merge_join/CMakeLists.txt
@@ -1,8 +1,8 @@
add_library_unity(duckdb_merge_join
OBJECT
merge_join.cpp
merge_join_inner.cpp
merge_join_mark.cpp)
merge_join_complex.cpp
merge_join_simple.cpp)
set(ALL_OBJECT_FILES
${ALL_OBJECT_FILES} $<TARGET_OBJECTS:duckdb_merge_join>
PARENT_SCOPE)
21 changes: 11 additions & 10 deletions src/execution/merge_join/merge_join.cpp
Expand Up @@ -2,9 +2,10 @@

#include "duckdb/parser/expression/comparison_expression.hpp"

using namespace duckdb;
using namespace std;

namespace duckdb {

template <class MJ, class L_ARG, class R_ARG> static idx_t merge_join(L_ARG &l, R_ARG &r) {
switch (l.type) {
case TypeId::BOOL:
Expand All @@ -30,39 +31,39 @@ template <class MJ, class L_ARG, class R_ARG> static idx_t merge_join(L_ARG &l,
template <class T, class L_ARG, class R_ARG>
static idx_t perform_merge_join(L_ARG &l, R_ARG &r, ExpressionType comparison_type) {
switch (comparison_type) {
case ExpressionType::COMPARE_EQUAL:
return merge_join<typename T::Equality, L_ARG, R_ARG>(l, r);
case ExpressionType::COMPARE_LESSTHAN:
return merge_join<typename T::LessThan, L_ARG, R_ARG>(l, r);
case ExpressionType::COMPARE_LESSTHANOREQUALTO:
return merge_join<typename T::LessThanEquals, L_ARG, R_ARG>(l, r);
case ExpressionType::COMPARE_GREATERTHAN:
return merge_join<typename T::GreaterThan, L_ARG, R_ARG>(l, r);
default:
// "Unimplemented comparison type for merge join!"
assert(comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO);
case ExpressionType::COMPARE_GREATERTHANOREQUALTO:
return merge_join<typename T::GreaterThanEquals, L_ARG, R_ARG>(l, r);
default:
throw NotImplementedException("Unimplemented comparison type for merge join!");
}
}

idx_t MergeJoinInner::Perform(MergeInfo &l, MergeInfo &r, ExpressionType comparison_type) {
idx_t MergeJoinComplex::Perform(MergeInfo &l, MergeInfo &r, ExpressionType comparison_type) {
assert(l.info_type == MergeInfoType::SCALAR_MERGE_INFO && r.info_type == MergeInfoType::SCALAR_MERGE_INFO);
auto &left = (ScalarMergeInfo &)l;
auto &right = (ScalarMergeInfo &)r;
assert(left.type == right.type);
if (left.order.count == 0 || right.order.count == 0) {
return 0;
}
return perform_merge_join<MergeJoinInner, ScalarMergeInfo, ScalarMergeInfo>(left, right, comparison_type);
return perform_merge_join<MergeJoinComplex, ScalarMergeInfo, ScalarMergeInfo>(left, right, comparison_type);
}

idx_t MergeJoinMark::Perform(MergeInfo &l, MergeInfo &r, ExpressionType comparison_type) {
idx_t MergeJoinSimple::Perform(MergeInfo &l, MergeInfo &r, ExpressionType comparison_type) {
assert(l.info_type == MergeInfoType::SCALAR_MERGE_INFO && r.info_type == MergeInfoType::CHUNK_MERGE_INFO);
auto &left = (ScalarMergeInfo &)l;
auto &right = (ChunkMergeInfo &)r;
assert(left.type == right.type);
if (left.order.count == 0 || right.data_chunks.count == 0) {
return 0;
}
return perform_merge_join<MergeJoinMark, ScalarMergeInfo, ChunkMergeInfo>(left, right, comparison_type);
return perform_merge_join<MergeJoinSimple, ScalarMergeInfo, ChunkMergeInfo>(left, right, comparison_type);
}

}
63 changes: 63 additions & 0 deletions src/execution/merge_join/merge_join_complex.cpp
@@ -0,0 +1,63 @@
#include "duckdb/common/operator/comparison_operators.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/execution/merge_join.hpp"
#include "duckdb/parser/expression/comparison_expression.hpp"

using namespace std;

namespace duckdb {

template<class T, class OP>
idx_t merge_join_complex_lt(ScalarMergeInfo &l, ScalarMergeInfo &r) {
if (r.pos >= r.order.count) {
return 0;
}
auto ldata = (T *)l.order.vdata.data;
auto rdata = (T *)r.order.vdata.data;
auto &lorder = l.order.order;
auto &rorder = r.order.order;
idx_t result_count = 0;
while (true) {
if (l.pos < l.order.count) {
auto lidx = lorder.get_index(l.pos);
auto ridx = rorder.get_index(r.pos);
auto dlidx = l.order.vdata.sel->get_index(lidx);
auto dridx = r.order.vdata.sel->get_index(ridx);
if (OP::Operation(ldata[dlidx], rdata[dridx])) {
// left side smaller: found match
l.result.set_index(result_count, lidx);
r.result.set_index(result_count, ridx);
result_count++;
// move left side forward
l.pos++;
if (result_count == STANDARD_VECTOR_SIZE) {
// out of space!
break;
}
continue;
}
}
// right side smaller or equal, or left side exhausted: move
// right pointer forward reset left side to start
l.pos = 0;
r.pos++;
if (r.pos == r.order.count) {
break;
}
}
return result_count;

}

template <class T> idx_t MergeJoinComplex::LessThan::Operation(ScalarMergeInfo &l, ScalarMergeInfo &r) {
return merge_join_complex_lt<T, duckdb::LessThan>(l, r);
}

template <class T> idx_t MergeJoinComplex::LessThanEquals::Operation(ScalarMergeInfo &l, ScalarMergeInfo &r) {
return merge_join_complex_lt<T, duckdb::LessThanEquals>(l, r);
}

INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinComplex, LessThan, ScalarMergeInfo, ScalarMergeInfo);
INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinComplex, LessThanEquals, ScalarMergeInfo, ScalarMergeInfo);

}
134 changes: 0 additions & 134 deletions src/execution/merge_join/merge_join_inner.cpp

This file was deleted.

Expand Up @@ -3,14 +3,11 @@
#include "duckdb/execution/merge_join.hpp"
#include "duckdb/parser/expression/comparison_expression.hpp"

using namespace duckdb;
using namespace std;

template <class T> idx_t MergeJoinMark::Equality::Operation(ScalarMergeInfo &l, ChunkMergeInfo &r) {
throw NotImplementedException("Merge Join with Equality not implemented");
}
namespace duckdb {

template <class T, class OP> static idx_t merge_join_mark_gt(ScalarMergeInfo &l, ChunkMergeInfo &r) {
template <class T, class OP> static idx_t merge_join_simple_gt(ScalarMergeInfo &l, ChunkMergeInfo &r) {
auto ldata = (T *)l.order.vdata.data;
auto &lorder = l.order.order;
l.pos = l.order.count;
Expand Down Expand Up @@ -43,15 +40,15 @@ template <class T, class OP> static idx_t merge_join_mark_gt(ScalarMergeInfo &l,
}
return 0;
}
template <class T> idx_t MergeJoinMark::GreaterThan::Operation(ScalarMergeInfo &l, ChunkMergeInfo &r) {
return merge_join_mark_gt<T, duckdb::GreaterThan>(l, r);
template <class T> idx_t MergeJoinSimple::GreaterThan::Operation(ScalarMergeInfo &l, ChunkMergeInfo &r) {
return merge_join_simple_gt<T, duckdb::GreaterThan>(l, r);
}

template <class T> idx_t MergeJoinMark::GreaterThanEquals::Operation(ScalarMergeInfo &l, ChunkMergeInfo &r) {
return merge_join_mark_gt<T, duckdb::GreaterThanEquals>(l, r);
template <class T> idx_t MergeJoinSimple::GreaterThanEquals::Operation(ScalarMergeInfo &l, ChunkMergeInfo &r) {
return merge_join_simple_gt<T, duckdb::GreaterThanEquals>(l, r);
}

template <class T, class OP> static idx_t merge_join_mark_lt(ScalarMergeInfo &l, ChunkMergeInfo &r) {
template <class T, class OP> static idx_t merge_join_simple_lt(ScalarMergeInfo &l, ChunkMergeInfo &r) {
auto ldata = (T *)l.order.vdata.data;
auto &lorder = l.order.order;
l.pos = 0;
Expand Down Expand Up @@ -85,16 +82,17 @@ template <class T, class OP> static idx_t merge_join_mark_lt(ScalarMergeInfo &l,
return 0;
}

template <class T> idx_t MergeJoinMark::LessThan::Operation(ScalarMergeInfo &l, ChunkMergeInfo &r) {
return merge_join_mark_lt<T, duckdb::LessThan>(l, r);
template <class T> idx_t MergeJoinSimple::LessThan::Operation(ScalarMergeInfo &l, ChunkMergeInfo &r) {
return merge_join_simple_lt<T, duckdb::LessThan>(l, r);
}

template <class T> idx_t MergeJoinMark::LessThanEquals::Operation(ScalarMergeInfo &l, ChunkMergeInfo &r) {
return merge_join_mark_lt<T, duckdb::LessThanEquals>(l, r);
template <class T> idx_t MergeJoinSimple::LessThanEquals::Operation(ScalarMergeInfo &l, ChunkMergeInfo &r) {
return merge_join_simple_lt<T, duckdb::LessThanEquals>(l, r);
}

INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinMark, Equality, ScalarMergeInfo, ChunkMergeInfo);
INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinMark, LessThan, ScalarMergeInfo, ChunkMergeInfo);
INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinMark, LessThanEquals, ScalarMergeInfo, ChunkMergeInfo);
INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinMark, GreaterThan, ScalarMergeInfo, ChunkMergeInfo);
INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinMark, GreaterThanEquals, ScalarMergeInfo, ChunkMergeInfo);
INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinSimple, LessThan, ScalarMergeInfo, ChunkMergeInfo);
INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinSimple, LessThanEquals, ScalarMergeInfo, ChunkMergeInfo);
INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinSimple, GreaterThan, ScalarMergeInfo, ChunkMergeInfo);
INSTANTIATE_MERGEJOIN_TEMPLATES(MergeJoinSimple, GreaterThanEquals, ScalarMergeInfo, ChunkMergeInfo);

}
34 changes: 20 additions & 14 deletions src/execution/operator/join/physical_nested_loop_join.cpp
Expand Up @@ -237,6 +237,23 @@ void PhysicalNestedLoopJoin::ResolveSimpleJoin(ClientContext &context, DataChunk
} while (chunk.size() == 0);
}

void PhysicalJoin::ConstructLeftJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) {
SelectionVector remaining_sel(STANDARD_VECTOR_SIZE);
idx_t remaining_count = 0;
for (idx_t i = 0; i < left.size(); i++) {
if (!found_match[i]) {
remaining_sel.set_index(remaining_count++, i);
}
}
if (remaining_count > 0) {
result.Slice(left, remaining_sel, remaining_count);
for (idx_t idx = left.column_count(); idx < result.column_count(); idx++) {
result.data[idx].vector_type = VectorType::CONSTANT_VECTOR;
ConstantVector::SetNull(result.data[idx], true);
}
}
}

void PhysicalNestedLoopJoin::ResolveComplexJoin(ClientContext &context, DataChunk &chunk, PhysicalOperatorState *state_) {
auto state = reinterpret_cast<PhysicalNestedLoopJoinState *>(state_);
auto &gstate = (NestedLoopJoinGlobalState &)*sink_state;
Expand All @@ -257,20 +274,9 @@ void PhysicalNestedLoopJoin::ResolveComplexJoin(ClientContext &context, DataChun
// left join: before we move to the next chunk, see if we need to output any vectors that didn't
// have a match found
if (state->left_found_match) {
SelectionVector remaining_sel(STANDARD_VECTOR_SIZE);
idx_t remaining_count = 0;
for (idx_t i = 0; i < state->child_chunk.size(); i++) {
if (!state->left_found_match[i]) {
remaining_sel.set_index(remaining_count++, i);
}
}
state->left_found_match.reset();
if (remaining_count > 0) {
chunk.Slice(state->child_chunk, remaining_sel, remaining_count);
for (idx_t idx = state->child_chunk.column_count(); idx < chunk.column_count(); idx++) {
chunk.data[idx].vector_type = VectorType::CONSTANT_VECTOR;
ConstantVector::SetNull(chunk.data[idx], true);
}
PhysicalJoin::ConstructLeftJoinResult(state->child_chunk, chunk, state->left_found_match.get());
state->left_found_match.reset();
if (chunk.size() > 0) {
return;
}
}
Expand Down

0 comments on commit 4c56c88

Please sign in to comment.