Skip to content

Commit

Permalink
Support removing partiton columns when writing bucket table
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf committed Jan 3, 2024
1 parent a160a7b commit 2a4ef01
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 73 deletions.
97 changes: 61 additions & 36 deletions velox/connectors/hive/HiveDataSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,13 @@ namespace facebook::velox::connector::hive {
namespace {

RowVectorPtr makeDataInput(
const std::vector<column_index_t>& partitonCols,
const std::vector<column_index_t>& dataCols,
const RowVectorPtr& input,
const RowTypePtr& dataType) {
std::vector<VectorPtr> childVectors;
childVectors.reserve(dataType->size());
childVectors.reserve(dataCols.size());
for (uint32_t i = 0; i < input->childrenSize(); i++) {
if (std::find(partitonCols.cbegin(), partitonCols.cend(), i) ==
partitonCols.cend()) {
if (std::find(dataCols.cbegin(), dataCols.cend(), i) != dataCols.cend()) {
childVectors.push_back(input->childAt(i));
}
}
Expand Down Expand Up @@ -75,6 +74,42 @@ std::vector<column_index_t> getPartitionChannels(
return channels;
}

// Returns a subset of column indices corresponding to non-partition keys.
std::vector<column_index_t> getDataChannels(
const std::vector<column_index_t>& partitionChannels,
const column_index_t childrenSize) {
std::vector<column_index_t> dataChannels;
dataChannels.reserve(childrenSize - partitionChannels.size());

for (column_index_t i = 0; i < childrenSize; i++) {
if (std::find(partitionChannels.cbegin(), partitionChannels.cend(), i) ==
partitionChannels.cend()) {
dataChannels.push_back(i);
}
}

return dataChannels;
}

// Returns the type corresponding to non-partition keys.
RowTypePtr getDataType(
const std::vector<column_index_t>& dataCols,
const RowTypePtr& inputType) {
std::vector<std::string> childNames;
std::vector<TypePtr> childTypes;
const auto& dataSize = dataCols.size();
childNames.reserve(dataSize);
childTypes.reserve(dataSize);
for (uint32_t i = 0; i < inputType->size(); i++) {
if (std::find(dataCols.cbegin(), dataCols.cend(), i) != dataCols.cend()) {
childNames.push_back(inputType->nameOf(i));
childTypes.push_back(inputType->childAt(i));
}
}

return ROW(std::move(childNames), std::move(childTypes));
}

std::string makePartitionDirectory(
const std::string& tableDirectory,
const std::optional<std::string>& partitionSubdirectory) {
Expand Down Expand Up @@ -329,6 +364,7 @@ HiveDataSink::HiveDataSink(
maxOpenWriters_,
connectorQueryCtx_->memoryPool())
: nullptr),
dataChannels_(getDataChannels(partitionChannels_, inputType_->size())),
bucketCount_(
insertTableHandle_->bucketProperty() == nullptr
? 0
Expand All @@ -353,18 +389,6 @@ HiveDataSink::HiveDataSink(
"Unsupported commit strategy: {}",
commitStrategyToString(commitStrategy_));

// Get the data input type based on the inputType and the parition index.
std::vector<TypePtr> childTypes;
std::vector<std::string> childNames;
for (auto i = 0; i < inputType_->size(); i++) {
if (std::find(partitionChannels_.cbegin(), partitionChannels_.cend(), i) ==
partitionChannels_.end()) {
childNames.push_back(inputType_->nameOf(i));
childTypes.push_back(inputType_->childAt(i));
}
}
dataType_ = ROW(std::move(childNames), std::move(childTypes));

if (!isBucketed()) {
return;
}
Expand All @@ -373,13 +397,21 @@ HiveDataSink::HiveDataSink(
sortColumnIndices_.reserve(sortedProperty.size());
sortCompareFlags_.reserve(sortedProperty.size());
for (int i = 0; i < sortedProperty.size(); ++i) {
sortColumnIndices_.push_back(
inputType_->getChildIdx(sortedProperty.at(i)->sortColumn()));
sortCompareFlags_.push_back(
{sortedProperty.at(i)->sortOrder().isNullsFirst(),
sortedProperty.at(i)->sortOrder().isAscending(),
false,
CompareFlags::NullHandlingMode::kNullAsValue});
auto columnIndex =
inputType_->getChildIdx(sortedProperty.at(i)->sortColumn());
if (std::find(
partitionChannels_.cbegin(),
partitionChannels_.cend(),
columnIndex) == partitionChannels_.cend()) {
sortColumnIndices_.push_back(
getDataType(dataChannels_, inputType_)
->getChildIdx(sortedProperty.at(i)->sortColumn()));
sortCompareFlags_.push_back(
{sortedProperty.at(i)->sortOrder().isNullsFirst(),
sortedProperty.at(i)->sortOrder().isAscending(),
false,
CompareFlags::NullHandlingMode::kNullAsValue});
}
}
}
}
Expand Down Expand Up @@ -435,13 +467,10 @@ void HiveDataSink::appendData(RowVectorPtr input) {
void HiveDataSink::write(size_t index, const VectorPtr& input) {
WRITER_NON_RECLAIMABLE_SECTION_GUARD(index);
// Skip the partition columns before writing.
auto dataInput = input;
if (!isBucketed()) {
dataInput = makeDataInput(
partitionChannels_,
std::dynamic_pointer_cast<RowVector>(input),
dataType_);
}
auto dataInput = makeDataInput(
dataChannels_,
std::dynamic_pointer_cast<RowVector>(input),
getDataType(dataChannels_, inputType_));

writers_[index]->write(dataInput);
writerInfo_[index]->numWrittenRows += dataInput->size();
Expand Down Expand Up @@ -640,11 +669,7 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) {
dwio::common::WriterOptions options;
const auto* connectorSessionProperties =
connectorQueryCtx_->sessionProperties();
if (!isBucketed()) {
options.schema = dataType_;
} else {
options.schema = inputType_;
}
options.schema = getDataType(dataChannels_, inputType_);

options.memoryPool = writerInfo_.back()->writerPool.get();
options.compressionKind = insertTableHandle_->compressionKind();
Expand Down Expand Up @@ -692,7 +717,7 @@ HiveDataSink::maybeCreateBucketSortWriter(
auto* sortPool = writerInfo_.back()->sortPool.get();
VELOX_CHECK_NOT_NULL(sortPool);
auto sortBuffer = std::make_unique<exec::SortBuffer>(
inputType_,
getDataType(dataChannels_, inputType_),
sortColumnIndices_,
sortCompareFlags_,
sortPool,
Expand Down
3 changes: 1 addition & 2 deletions velox/connectors/hive/HiveDataSink.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,15 +541,14 @@ class HiveDataSink : public DataSink {
void closeInternal();

const RowTypePtr inputType_;
// Written data columns into file.
RowTypePtr dataType_;
const std::shared_ptr<const HiveInsertTableHandle> insertTableHandle_;
const ConnectorQueryCtx* const connectorQueryCtx_;
const CommitStrategy commitStrategy_;
const std::shared_ptr<const HiveConfig> hiveConfig_;
const uint32_t maxOpenWriters_;
const std::vector<column_index_t> partitionChannels_;
const std::unique_ptr<PartitionIdGenerator> partitionIdGenerator_;
const std::vector<column_index_t> dataChannels_;
const int32_t bucketCount_{0};
const std::unique_ptr<core::PartitionFunction> bucketFunction_;
const std::shared_ptr<dwio::common::WriterFactory> writerFactory_;
Expand Down
77 changes: 42 additions & 35 deletions velox/exec/tests/TableWriteTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,12 @@ class TableWriteTest : public HiveConnectorTestBase {
std::vector<TypePtr> bucketedTypes = {REAL(), VARCHAR()};
std::vector<std::shared_ptr<const HiveSortingColumn>> sortedBy;
if (testParam_.bucketSort()) {
sortedBy = {
std::make_shared<const HiveSortingColumn>(
"c4", core::SortOrder{true, true}),
std::make_shared<const HiveSortingColumn>(
"c1", core::SortOrder{false, false})};
sortColumnIndices_ = {4, 1};
sortedFlags_ = {{true, true}, {false, false}};
// The sortedBy key shouldn't contain partitionBy key.
sortedBy = {std::make_shared<const HiveSortingColumn>(
"c4", core::SortOrder{true, true})};
// The sortColumnIndices_ should represent the indices after removing the partition keys.
sortColumnIndices_ = {2};
sortedFlags_ = {{true, true}};
}
bucketProperty_ = std::make_shared<HiveBucketProperty>(
testParam_.bucketKind(), 4, bucketedBy, bucketedTypes, sortedBy);
Expand Down Expand Up @@ -840,22 +839,12 @@ class TableWriteTest : public HiveConnectorTestBase {
void verifyPartitionedFilesData(
const std::vector<std::filesystem::path>& filePaths,
const std::filesystem::path& dirPath) {
if (bucketProperty_ != nullptr) {
HiveConnectorTestBase::assertQuery(
PlanBuilder().tableScan(rowType_).planNode(),
{makeHiveConnectorSplits(filePaths)},
fmt::format(
"SELECT * FROM tmp WHERE {}",
partitionNameToPredicate(getPartitionDirNames(dirPath))));

} else {
HiveConnectorTestBase::assertQuery(
PlanBuilder().tableScan(rowType_).planNode(),
{makeHiveConnectorSplits(filePaths)},
fmt::format(
"SELECT c2, c3, c4, c5 FROM tmp WHERE {}",
partitionNameToPredicate(getPartitionDirNames(dirPath))));
}
HiveConnectorTestBase::assertQuery(
PlanBuilder().tableScan(rowType_).planNode(),
{makeHiveConnectorSplits(filePaths)},
fmt::format(
"SELECT c2, c3, c4, c5 FROM tmp WHERE {}",
partitionNameToPredicate(getPartitionDirNames(dirPath))));
}

// Gets the hash function used by the production code to build bucket id.
Expand Down Expand Up @@ -1506,7 +1495,7 @@ TEST_P(AllTableWriterTest, scanFilterProjectWrite) {
// To test the correctness of the generated output,
// We create a new plan that only read that file and then
// compare that against a duckDB query that runs the whole query.
if (partitionedBy_.size() > 0 && bucketProperty_ == nullptr) {
if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, outputType);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
Expand Down Expand Up @@ -1568,7 +1557,7 @@ TEST_P(AllTableWriterTest, renameAndReorderColumns) {

assertQueryWithWriterConfigs(plan, filePaths, "SELECT count(*) FROM tmp");

if (partitionedBy_.size() > 0 && bucketProperty_ == nullptr) {
if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
HiveConnectorTestBase::assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
Expand Down Expand Up @@ -1614,7 +1603,7 @@ TEST_P(AllTableWriterTest, directReadWrite) {
// We create a new plan that only read that file and then
// compare that against a duckDB query that runs the whole query.

if (partitionedBy_.size() > 0 && bucketProperty_ == nullptr) {
if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
Expand Down Expand Up @@ -1655,7 +1644,7 @@ TEST_P(AllTableWriterTest, constantVectors) {

assertQuery(op, fmt::format("SELECT {}", size));

if (partitionedBy_.size() > 0 && bucketProperty_ == nullptr) {
if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
Expand Down Expand Up @@ -1714,7 +1703,7 @@ TEST_P(AllTableWriterTest, commitStrategies) {

assertQuery(plan, "SELECT count(*) FROM tmp");

if (partitionedBy_.size() > 0 && bucketProperty_ == nullptr) {
if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
Expand Down Expand Up @@ -1750,7 +1739,7 @@ TEST_P(AllTableWriterTest, commitStrategies) {

assertQuery(plan, "SELECT count(*) FROM tmp");

if (partitionedBy_.size() > 0 && bucketProperty_ == nullptr) {
if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
Expand Down Expand Up @@ -2411,11 +2400,23 @@ TEST_P(BucketedTableOnlyWriteTest, bucketCountLimit) {
} else {
assertQueryWithWriterConfigs(plan, "SELECT count(*) FROM tmp");

assertQuery(
PlanBuilder().tableScan(rowType_).planNode(),
makeHiveConnectorSplits(outputDirectory),
"SELECT * FROM tmp");
verifyTableWriterOutput(outputDirectory->path, rowType_);
if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
"SELECT c2, c3, c4, c5 FROM tmp");
auto originalRowType = rowType_;
rowType_ = newOutputType;
verifyTableWriterOutput(outputDirectory->path, rowType_);
rowType_ = originalRowType;
} else {
assertQuery(
PlanBuilder().tableScan(rowType_).planNode(),
makeHiveConnectorSplits(outputDirectory),
"SELECT * FROM tmp");
verifyTableWriterOutput(outputDirectory->path, rowType_);
}
}
}
}
Expand Down Expand Up @@ -3180,7 +3181,13 @@ TEST_P(BucketSortOnlyTableWriterTest, sortWriterSpill) {
const auto spillStats = globalSpillStats();
auto task =
assertQueryWithWriterConfigs(op, fmt::format("SELECT {}", 5 * 500), true);
verifyTableWriterOutput(outputDirectory->path, rowType_);
if (partitionedBy_.size() > 0) {
rowType_ = getScanOutput(partitionedBy_, rowType_);
verifyTableWriterOutput(outputDirectory->path, rowType_);
} else {
verifyTableWriterOutput(outputDirectory->path, rowType_);
}

const auto updatedSpillStats = globalSpillStats();
ASSERT_GT(updatedSpillStats.spilledBytes, spillStats.spilledBytes);
ASSERT_GT(updatedSpillStats.spilledPartitions, spillStats.spilledPartitions);
Expand Down

0 comments on commit 2a4ef01

Please sign in to comment.