Skip to content

Commit

Permalink
Rename getDataType and getDataChannels funcs in HiveDataSink
Browse files Browse the repository at this point in the history
This is from cherry-picking of
facebookincubator#8089, and is delta
of the PR on top of main branch. When we merged the PR, we failed
to merge the latest version. Everywhere is consistent though.

This change renames the func getDataType and getDataChannels
in HiveDataSink, and makes the static.
  • Loading branch information
gggrace14 committed Jan 17, 2024
1 parent 4e36849 commit 919973b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 23 deletions.
24 changes: 13 additions & 11 deletions velox/connectors/hive/HiveDataSink.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ namespace facebook::velox::connector::hive {

namespace {

// Returns the type corresponding to non-partition keys.
RowTypePtr getDataType(
// Returns the type of non-partition keys.
static RowTypePtr getNonPartitionTypes(
const std::vector<column_index_t>& dataCols,
const RowTypePtr& inputType) {
std::vector<std::string> childNames;
Expand All @@ -57,7 +57,8 @@ RowTypePtr getDataType(
return ROW(std::move(childNames), std::move(childTypes));
}

RowVectorPtr makeDataInput(
// Filters out partition columns if there is any.
static RowVectorPtr makeDataInput(
const std::vector<column_index_t>& dataCols,
const RowVectorPtr& input) {
std::vector<VectorPtr> childVectors;
Expand All @@ -70,15 +71,15 @@ RowVectorPtr makeDataInput(

return std::make_shared<RowVector>(
input->pool(),
getDataType(dataCols, asRowType(input->type())),
getNonPartitionTypes(dataCols, asRowType(input->type())),
input->nulls(),
input->size(),
std::move(childVectors),
input->getNullCount());
}

// Returns a subset of column indices corresponding to partition keys.
std::vector<column_index_t> getPartitionChannels(
static std::vector<column_index_t> getPartitionChannels(
const std::shared_ptr<const HiveInsertTableHandle>& insertTableHandle) {
std::vector<column_index_t> channels;

Expand All @@ -92,8 +93,8 @@ std::vector<column_index_t> getPartitionChannels(
return channels;
}

// Returns a subset of column indices corresponding to non-partition keys.
std::vector<column_index_t> getDataChannels(
// Returns the column indices of non-partition keys.
static std::vector<column_index_t> getNonPartitionChannels(
const std::vector<column_index_t>& partitionChannels,
const column_index_t childrenSize) {
std::vector<column_index_t> dataChannels;
Expand Down Expand Up @@ -366,7 +367,8 @@ HiveDataSink::HiveDataSink(
hiveConfig_->isFileColumnNamesReadAsLowerCase(
connectorQueryCtx->sessionProperties()))
: nullptr),
dataChannels_(getDataChannels(partitionChannels_, inputType_->size())),
dataChannels_(
getNonPartitionChannels(partitionChannels_, inputType_->size())),
bucketCount_(
insertTableHandle_->bucketProperty() == nullptr
? 0
Expand Down Expand Up @@ -400,7 +402,7 @@ HiveDataSink::HiveDataSink(
sortCompareFlags_.reserve(sortedProperty.size());
for (int i = 0; i < sortedProperty.size(); ++i) {
auto columnIndex =
getDataType(dataChannels_, inputType_)
getNonPartitionTypes(dataChannels_, inputType_)
->getChildIdxIfExists(sortedProperty.at(i)->sortColumn());
if (columnIndex.has_value()) {
sortColumnIndices_.push_back(columnIndex.value());
Expand Down Expand Up @@ -664,7 +666,7 @@ uint32_t HiveDataSink::appendWriter(const HiveWriterId& id) {
dwio::common::WriterOptions options;
const auto* connectorSessionProperties =
connectorQueryCtx_->sessionProperties();
options.schema = getDataType(dataChannels_, inputType_);
options.schema = getNonPartitionTypes(dataChannels_, inputType_);

options.memoryPool = writerInfo_.back()->writerPool.get();
options.compressionKind = insertTableHandle_->compressionKind();
Expand Down Expand Up @@ -711,7 +713,7 @@ HiveDataSink::maybeCreateBucketSortWriter(
auto* sortPool = writerInfo_.back()->sortPool.get();
VELOX_CHECK_NOT_NULL(sortPool);
auto sortBuffer = std::make_unique<exec::SortBuffer>(
getDataType(dataChannels_, inputType_),
getNonPartitionTypes(dataChannels_, inputType_),
sortColumnIndices_,
sortCompareFlags_,
sortPool,
Expand Down
27 changes: 15 additions & 12 deletions velox/exec/tests/TableWriteTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ std::function<PlanNodePtr(std::string, PlanNodePtr)> addTableWriter(
};
}

RowTypePtr getScanOutput(
RowTypePtr getNonPartitionsColumns(
const std::vector<std::string>& partitionedKeys,
const RowTypePtr& rowType) {
std::vector<std::string> dataColumnNames;
Expand Down Expand Up @@ -1511,7 +1511,7 @@ TEST_P(AllTableWriterTest, scanFilterProjectWrite) {
// We create a new plan that only read that file and then
// compare that against a duckDB query that runs the whole query.
if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, outputType);
auto newOutputType = getNonPartitionsColumns(partitionedBy_, outputType);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
Expand Down Expand Up @@ -1573,7 +1573,7 @@ TEST_P(AllTableWriterTest, renameAndReorderColumns) {
assertQueryWithWriterConfigs(plan, filePaths, "SELECT count(*) FROM tmp");

if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
auto newOutputType = getNonPartitionsColumns(partitionedBy_, tableSchema_);
HiveConnectorTestBase::assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
Expand Down Expand Up @@ -1619,7 +1619,7 @@ TEST_P(AllTableWriterTest, directReadWrite) {
// compare that against a duckDB query that runs the whole query.

if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
auto newOutputType = getNonPartitionsColumns(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
Expand Down Expand Up @@ -1660,7 +1660,7 @@ TEST_P(AllTableWriterTest, constantVectors) {
assertQuery(op, fmt::format("SELECT {}", size));

if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
auto newOutputType = getNonPartitionsColumns(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
Expand Down Expand Up @@ -1719,7 +1719,8 @@ TEST_P(AllTableWriterTest, commitStrategies) {
assertQuery(plan, "SELECT count(*) FROM tmp");

if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
auto newOutputType =
getNonPartitionsColumns(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
Expand Down Expand Up @@ -1755,7 +1756,8 @@ TEST_P(AllTableWriterTest, commitStrategies) {
assertQuery(plan, "SELECT count(*) FROM tmp");

if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
auto newOutputType =
getNonPartitionsColumns(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
Expand Down Expand Up @@ -1950,7 +1952,7 @@ TEST_P(PartitionedTableWriterTest, multiplePartitions) {
// Verify distribution of records in partition directories.
auto iterPartitionDirectory = actualPartitionDirectories.begin();
auto iterPartitionName = partitionNames.begin();
auto newOutputType = getScanOutput(partitionKeys, rowType);
auto newOutputType = getNonPartitionsColumns(partitionKeys, rowType);
while (iterPartitionDirectory != actualPartitionDirectories.end()) {
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
Expand Down Expand Up @@ -2025,7 +2027,7 @@ TEST_P(PartitionedTableWriterTest, singlePartition) {
fs::path(outputDirectory->path) / "p0=365");

// Verify all data is written to the single partition directory.
auto newOutputType = getScanOutput(partitionKeys, rowType);
auto newOutputType = getNonPartitionsColumns(partitionKeys, rowType);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
Expand Down Expand Up @@ -2072,7 +2074,7 @@ TEST_P(PartitionedWithoutBucketTableWriterTest, fromSinglePartitionToMultiple) {

assertQueryWithWriterConfigs(plan, "SELECT count(*) FROM tmp");

auto newOutputType = getScanOutput(partitionKeys, rowType);
auto newOutputType = getNonPartitionsColumns(partitionKeys, rowType);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
Expand Down Expand Up @@ -2416,7 +2418,8 @@ TEST_P(BucketedTableOnlyWriteTest, bucketCountLimit) {
assertQueryWithWriterConfigs(plan, "SELECT count(*) FROM tmp");

if (partitionedBy_.size() > 0) {
auto newOutputType = getScanOutput(partitionedBy_, tableSchema_);
auto newOutputType =
getNonPartitionsColumns(partitionedBy_, tableSchema_);
assertQuery(
PlanBuilder().tableScan(newOutputType).planNode(),
makeHiveConnectorSplits(outputDirectory),
Expand Down Expand Up @@ -3206,7 +3209,7 @@ TEST_P(BucketSortOnlyTableWriterTest, sortWriterSpill) {
auto task =
assertQueryWithWriterConfigs(op, fmt::format("SELECT {}", 5 * 500), true);
if (partitionedBy_.size() > 0) {
rowType_ = getScanOutput(partitionedBy_, rowType_);
rowType_ = getNonPartitionsColumns(partitionedBy_, rowType_);
verifyTableWriterOutput(outputDirectory->path, rowType_);
} else {
verifyTableWriterOutput(outputDirectory->path, rowType_);
Expand Down

0 comments on commit 919973b

Please sign in to comment.