Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RIGHT and FULL JOIN with LowCardinality #9610

Merged
merged 3 commits into from
Mar 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 5 additions & 6 deletions dbms/src/Interpreters/InterpreterSelectQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,12 +824,13 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS

if (expressions.hasJoin())
{
Block header_before_join;
Block join_result_sample;
JoinPtr join = expressions.before_join->getTableJoinAlgo();

if constexpr (pipeline_with_processors)
{
header_before_join = pipeline.getHeader();
join_result_sample = ExpressionBlockInputStream(
std::make_shared<OneBlockInputStream>(pipeline.getHeader()), expressions.before_join).getHeader();

/// In case joined subquery has totals, and we don't, add default chunk to totals.
bool default_totals = false;
Expand All @@ -855,17 +856,15 @@ void InterpreterSelectQuery::executeImpl(TPipeline & pipeline, const BlockInputS
}
else
{
header_before_join = pipeline.firstStream()->getHeader();
/// Applies to all sources except stream_with_non_joined_data.
for (auto & stream : pipeline.streams)
stream = std::make_shared<InflatingExpressionBlockInputStream>(stream, expressions.before_join);

join_result_sample = pipeline.firstStream()->getHeader();
}

if (join)
{
Block join_result_sample = ExpressionBlockInputStream(
std::make_shared<OneBlockInputStream>(header_before_join), expressions.before_join).getHeader();

if (auto stream = join->createStreamWithNonJoinedRows(join_result_sample, settings.max_block_size))
{
if constexpr (pipeline_with_processors)
Expand Down
113 changes: 94 additions & 19 deletions dbms/src/Interpreters/Join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,68 @@ static void changeNullability(MutableColumnPtr & mutable_column)
mutable_column = (*std::move(column)).mutate();
}

static ColumnPtr emptyNotNullableClone(const ColumnPtr & column)
{
if (column->isNullable())
return checkAndGetColumn<ColumnNullable>(*column)->getNestedColumnPtr()->cloneEmpty();
return column->cloneEmpty();
}

static ColumnPtr changeLowCardinality(const ColumnPtr & column, const ColumnPtr & dst_sample)
{
if (dst_sample->lowCardinality())
{
MutableColumnPtr lc = dst_sample->cloneEmpty();
typeid_cast<ColumnLowCardinality &>(*lc).insertRangeFromFullColumn(*column, 0, column->size());
return lc;
}

return column->convertToFullColumnIfLowCardinality();
}

/// Change both column nullability and low cardinality
static void changeColumnRepresentation(const ColumnPtr & src_column, ColumnPtr & dst_column)
{
bool nullable_src = src_column->isNullable();
bool nullable_dst = dst_column->isNullable();

ColumnPtr dst_not_null = emptyNotNullableClone(dst_column);
bool lowcard_src = emptyNotNullableClone(src_column)->lowCardinality();
bool lowcard_dst = dst_not_null->lowCardinality();
bool change_lowcard = (!lowcard_src && lowcard_dst) || (lowcard_src && !lowcard_dst);

if (nullable_src && !nullable_dst)
{
auto * nullable = checkAndGetColumn<ColumnNullable>(*src_column);
if (change_lowcard)
dst_column = changeLowCardinality(nullable->getNestedColumnPtr(), dst_column);
else
dst_column = nullable->getNestedColumnPtr();
}
else if (!nullable_src && nullable_dst)
{
if (change_lowcard)
dst_column = makeNullable(changeLowCardinality(src_column, dst_not_null));
else
dst_column = makeNullable(src_column);
}
else /// same nullability
{
if (change_lowcard)
{
if (auto * nullable = checkAndGetColumn<ColumnNullable>(*src_column))
{
dst_column = makeNullable(changeLowCardinality(nullable->getNestedColumnPtr(), dst_not_null));
assert_cast<ColumnNullable &>(*dst_column->assumeMutable()).applyNullMap(nullable->getNullMapColumn());
}
else
dst_column = changeLowCardinality(src_column, dst_not_null);
}
else
dst_column = src_column;
}
}


Join::Join(std::shared_ptr<AnalyzedJoin> table_join_, const Block & right_sample_block, bool any_take_last_row_)
: table_join(table_join_)
Expand Down Expand Up @@ -315,11 +377,15 @@ void Join::setSampleBlock(const Block & block)
if (!empty())
return;

ColumnRawPtrs key_columns = JoinCommon::extractKeysForJoin(key_names_right, block, right_table_keys, sample_block_with_columns_to_add);
JoinCommon::splitAdditionalColumns(block, key_names_right, right_table_keys, sample_block_with_columns_to_add);

initRightBlockStructure(data->sample_block);
initRequiredRightKeys();

JoinCommon::removeLowCardinalityInplace(right_table_keys);
initRightBlockStructure(data->sample_block);

ColumnRawPtrs key_columns = JoinCommon::extractKeysForJoin(right_table_keys, key_names_right);

JoinCommon::createMissedColumns(sample_block_with_columns_to_add);
if (nullable_right_side)
JoinCommon::convertColumnsToNullable(sample_block_with_columns_to_add);
Expand Down Expand Up @@ -1249,7 +1315,10 @@ class NonJoinedBlockInputStream : public IBlockInputStream
///
std::unordered_map<size_t, size_t> same_result_keys;
/// Which right columns (saved in parent) need nullability change before placing them in result block
std::vector<size_t> right_nullability_changes;
std::vector<size_t> right_nullability_adds;
std::vector<size_t> right_nullability_removes;
/// Which right columns (saved in parent) need LowCardinality change before placing them in result block
std::vector<std::pair<size_t, ColumnPtr>> right_lowcard_changes;

std::any position;
std::optional<Join::BlockNullmapList::const_iterator> nulls_position;
Expand All @@ -1259,19 +1328,28 @@ class NonJoinedBlockInputStream : public IBlockInputStream
if (!column_indices_right.count(right_pos))
{
column_indices_right[right_pos] = result_position;

if (hasNullabilityChange(right_pos, result_position))
right_nullability_changes.push_back(right_pos);
extractColumnChanges(right_pos, result_position);
}
else
same_result_keys[result_position] = column_indices_right[right_pos];
}

bool hasNullabilityChange(size_t right_pos, size_t result_pos) const
void extractColumnChanges(size_t right_pos, size_t result_pos)
{
const auto & src = parent.savedBlockSample().getByPosition(right_pos).column;
const auto & dst = result_sample_block.getByPosition(result_pos).column;
return src->isNullable() != dst->isNullable();

if (!src->isNullable() && dst->isNullable())
right_nullability_adds.push_back(right_pos);

if (src->isNullable() && !dst->isNullable())
right_nullability_removes.push_back(right_pos);

ColumnPtr src_not_null = emptyNotNullableClone(src);
ColumnPtr dst_not_null = emptyNotNullableClone(dst);

if (src_not_null->lowCardinality() != dst_not_null->lowCardinality())
right_lowcard_changes.push_back({right_pos, dst_not_null});
}

Block createBlock()
Expand All @@ -1293,7 +1371,13 @@ class NonJoinedBlockInputStream : public IBlockInputStream
if (!rows_added)
return {};

for (size_t pos : right_nullability_changes)
for (size_t pos : right_nullability_removes)
changeNullability(columns_right[pos]);

for (auto & [pos, dst_sample] : right_lowcard_changes)
columns_right[pos] = changeLowCardinality(std::move(columns_right[pos]), dst_sample)->assumeMutable();

for (size_t pos : right_nullability_adds)
changeNullability(columns_right[pos]);

Block res = result_sample_block.cloneEmpty();
Expand All @@ -1318,16 +1402,7 @@ class NonJoinedBlockInputStream : public IBlockInputStream
{
auto & src_column = res.getByPosition(pr.second).column;
auto & dst_column = res.getByPosition(pr.first).column;

if (src_column->isNullable() && !dst_column->isNullable())
{
auto * nullable = checkAndGetColumn<ColumnNullable>(*src_column);
dst_column = nullable->getNestedColumnPtr();
}
else if (!src_column->isNullable() && dst_column->isNullable())
dst_column = makeNullable(src_column);
else
dst_column = src_column;
changeColumnRepresentation(src_column, dst_column);
}

return res;
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/Interpreters/MergeJoin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ MergeJoin::MergeJoin(std::shared_ptr<AnalyzedJoin> table_join_, const Block & ri
ErrorCodes::PARAMETER_OUT_OF_BOUND);
}

JoinCommon::extractKeysForJoin(table_join->keyNamesRight(), right_sample_block, right_table_keys, right_columns_to_add);
JoinCommon::splitAdditionalColumns(right_sample_block, table_join->keyNamesRight(), right_table_keys, right_columns_to_add);
JoinCommon::removeLowCardinalityInplace(right_table_keys);

const NameSet required_right_keys = table_join->requiredRightKeys();
for (const auto & column : right_table_keys)
Expand Down
39 changes: 18 additions & 21 deletions dbms/src/Interpreters/join_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,34 +98,31 @@ void removeLowCardinalityInplace(Block & block)
}
}

ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & right_sample_block,
Block & sample_block_with_keys, Block & sample_block_with_columns_to_add)
void splitAdditionalColumns(const Block & sample_block, const Names & key_names, Block & block_keys, Block & block_others)
{
size_t keys_size = key_names_right.size();
ColumnRawPtrs key_columns(keys_size);

sample_block_with_columns_to_add = materializeBlock(right_sample_block);
block_others = materializeBlock(sample_block);

for (size_t i = 0; i < keys_size; ++i)
for (const String & column_name : key_names)
{
const String & column_name = key_names_right[i];

/// there could be the same key names
if (sample_block_with_keys.has(column_name))
/// Extract right keys with correct keys order. There could be the same key names.
if (!block_keys.has(column_name))
{
key_columns[i] = sample_block_with_keys.getByName(column_name).column.get();
continue;
auto & col = block_others.getByName(column_name);
block_keys.insert(col);
block_others.erase(column_name);
}
}
}

auto & col = sample_block_with_columns_to_add.getByName(column_name);
col.column = recursiveRemoveLowCardinality(col.column);
col.type = recursiveRemoveLowCardinality(col.type);

/// Extract right keys with correct keys order.
sample_block_with_keys.insert(col);
sample_block_with_columns_to_add.erase(column_name);
ColumnRawPtrs extractKeysForJoin(const Block & block_keys, const Names & key_names)
{
size_t keys_size = key_names.size();
ColumnRawPtrs key_columns(keys_size);

key_columns[i] = sample_block_with_keys.getColumns().back().get();
for (size_t i = 0; i < keys_size; ++i)
{
const String & column_name = key_names[i];
key_columns[i] = block_keys.getByName(column_name).column.get();

/// We will join only keys, where all components are not NULL.
if (auto * nullable = checkAndGetColumn<ColumnNullable>(*key_columns[i]))
Expand Down
4 changes: 2 additions & 2 deletions dbms/src/Interpreters/join_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ ColumnRawPtrs getRawPointers(const Columns & columns);
void removeLowCardinalityInplace(Block & block);

/// Split key and other columns by keys name list
ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & right_sample_block,
Block & sample_block_with_keys, Block & sample_block_with_columns_to_add);
void splitAdditionalColumns(const Block & sample_block, const Names & key_names, Block & block_keys, Block & block_others);
ColumnRawPtrs extractKeysForJoin(const Block & block_keys, const Names & key_names_right);

/// Throw an exception if blocks have different types of key columns. Compare up to Nullability.
void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right);
Expand Down
80 changes: 80 additions & 0 deletions dbms/tests/queries/0_stateless/01049_join_low_card_bug.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
LowCardinality(String)
str LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String) LowCardinality(String)
str LowCardinality(String) LowCardinality(String)
str LowCardinality(String) LowCardinality(String)
str LowCardinality(String) LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String) String
str LowCardinality(String) String
str LowCardinality(String) String
str LowCardinality(String) String
String
str String
String
str String
str String LowCardinality(String)
str String LowCardinality(String)
str String LowCardinality(String)
str String LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String) Nullable(String)
str LowCardinality(String) Nullable(String)
str LowCardinality(String) Nullable(String)
str LowCardinality(String) Nullable(String)
\N Nullable(String)
str Nullable(String)
\N Nullable(String)
str Nullable(String)
\N str Nullable(String) LowCardinality(String)
\N str Nullable(String) LowCardinality(String)
\N str Nullable(String) LowCardinality(String)
\N str Nullable(String) LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
\N Nullable(String)
str Nullable(String)
\N Nullable(String)
str Nullable(String)
\N str Nullable(String)
\N str Nullable(String)
\N str Nullable(String)
\N str Nullable(String)
LowCardinality(String)
str LowCardinality(String)
LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
str LowCardinality(String)
\N Nullable(String)
str Nullable(String)
\N Nullable(String)
str Nullable(String)
\N str Nullable(String)
\N str Nullable(String)
\N str Nullable(String)
\N str Nullable(String)