Skip to content

Commit

Permalink
Refactor AggregateCompanionAdapter to remove duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
wypb committed Jun 4, 2024
1 parent 61c3379 commit 608beaf
Showing 1 changed file with 85 additions and 129 deletions.
214 changes: 85 additions & 129 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,81 +325,12 @@ bool CompanionFunctionsRegistrar::registerMergeFunction(
.mainFunction;
}

bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
bool registerAggregateFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const std::string& mergeExtractFunctionName,
const std::vector<std::shared_ptr<AggregateFunctionSignature>>&
mergeExtractSignatures,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
bool registered = false;
for (const auto& [type, signatureGroup] : groupedSignatures) {
auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatureGroup);
if (mergeExtractSignatures.empty()) {
continue;
}

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type);

registered |=
exec::registerAggregateFunction(
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
[name, mergeExtractFunctionName](
core::AggregationNode::Step /*step*/,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
const auto& [originalResultType, _] =
resolveAggregateFunction(mergeExtractFunctionName, argTypes);
if (!originalResultType) {
// TODO: limitation -- result type must be resolveable given
// intermediate type of the original UDAF.
VELOX_UNREACHABLE(
"Signatures whose result types are not resolvable given intermediate types should have been excluded.");
}

if (auto func = getAggregateFunctionEntry(name)) {
auto fn = func->factory(
core::AggregationNode::Step::kFinal,
argTypes,
originalResultType,
config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_unique<
AggregateCompanionAdapter::MergeExtractFunction>(
std::move(fn), resultType);
}
VELOX_FAIL(
"Original aggregation function {} not found: {}",
name,
mergeExtractFunctionName);
},
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
}
return registered;
}

bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
}

auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatures);
if (mergeExtractSignatures.empty()) {
return false;
}

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return exec::registerAggregateFunction(
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
Expand Down Expand Up @@ -439,84 +370,63 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
.mainFunction;
}

bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
const std::string& originalName,
bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
bool registered = false;
for (const auto& [type, signatureGroup] : groupedSignatures) {
auto extractSignatures =
CompanionSignatures::extractFunctionSignatures(signatureGroup);
if (extractSignatures.empty()) {
auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatureGroup);
if (mergeExtractSignatures.empty()) {
continue;
}

auto factory = [originalName](
const std::string& name,
const std::vector<VectorFunctionArg>& inputArgs,
const core::QueryConfig& config)
-> std::shared_ptr<VectorFunction> {
std::vector<TypePtr> argTypes{inputArgs.size()};
std::transform(
inputArgs.begin(),
inputArgs.end(),
argTypes.begin(),
[](auto inputArg) { return inputArg.type; });

auto resultType = resolveVectorFunction(name, argTypes);
if (!resultType) {
// TODO: limitation -- result type must be resolveable given
// intermediate type of the original UDAF.
VELOX_UNREACHABLE(
"Signatures whose result types are not resolvable given intermediate types should have been excluded.");
}

if (auto func = getAggregateFunctionEntry(originalName)) {
auto fn = func->factory(
core::AggregationNode::Step::kFinal, argTypes, resultType, config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_shared<AggregateCompanionAdapter::ExtractFunction>(
std::move(fn));
}
VELOX_FAIL(
"Original aggregation function {} not found: {}", originalName, name);
};
auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type);

registered |= exec::registerStatefulVectorFunction(
CompanionSignatures::extractFunctionNameWithSuffix(originalName, type),
extractSignatures,
factory,
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.build(),
registered |= registerAggregateFunction(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
overwrite);
}
return registered;
}

bool CompanionFunctionsRegistrar::registerExtractFunction(
const std::string& originalName,
bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerExtractFunctionWithSuffix(
originalName, signatures, overwrite);
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
}

auto extractSignatures =
CompanionSignatures::extractFunctionSignatures(signatures);
if (extractSignatures.empty()) {
auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatures);
if (mergeExtractSignatures.empty()) {
return false;
}

auto factory =
[originalName](
const std::string& name,
const std::vector<VectorFunctionArg>& inputArgs,
const core::QueryConfig& config) -> std::shared_ptr<VectorFunction> {
auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return registerAggregateFunction(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
overwrite);
}

VectorFunctionFactory getVectorFunctionFactory(
const std::string& originalName) {
return [originalName](
const std::string& name,
const std::vector<VectorFunctionArg>& inputArgs,
const core::QueryConfig& config)
-> std::shared_ptr<VectorFunction> {
std::vector<TypePtr> argTypes{inputArgs.size()};
std::transform(
inputArgs.begin(),
Expand All @@ -542,10 +452,56 @@ bool CompanionFunctionsRegistrar::registerExtractFunction(
VELOX_FAIL(
"Original aggregation function {} not found: {}", originalName, name);
};
}

bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
const std::string& originalName,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
bool registered = false;
for (const auto& [type, signatureGroup] : groupedSignatures) {
auto extractSignatures =
CompanionSignatures::extractFunctionSignatures(signatureGroup);
if (extractSignatures.empty()) {
continue;
}

auto factory = getVectorFunctionFactory(originalName);
registered |= exec::registerStatefulVectorFunction(
CompanionSignatures::extractFunctionNameWithSuffix(originalName, type),
std::move(extractSignatures),
std::move(factory),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.build(),
overwrite);
}
return registered;
}

bool CompanionFunctionsRegistrar::registerExtractFunction(
const std::string& originalName,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerExtractFunctionWithSuffix(
originalName, signatures, overwrite);
}

auto extractSignatures =
CompanionSignatures::extractFunctionSignatures(signatures);
if (extractSignatures.empty()) {
return false;
}

auto factory = getVectorFunctionFactory(originalName);
return exec::registerStatefulVectorFunction(
CompanionSignatures::extractFunctionName(originalName),
extractSignatures,
factory,
std::move(extractSignatures),
std::move(factory),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
overwrite);
}
Expand Down

0 comments on commit 608beaf

Please sign in to comment.