Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AggregateCompanionAdapter to remove duplicate code #9920

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 85 additions & 129 deletions velox/exec/AggregateCompanionAdapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,81 +325,12 @@ bool CompanionFunctionsRegistrar::registerMergeFunction(
.mainFunction;
}

bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
bool registerAggregateFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
const std::string& mergeExtractFunctionName,
const std::vector<std::shared_ptr<AggregateFunctionSignature>>&
mergeExtractSignatures,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
bool registered = false;
for (const auto& [type, signatureGroup] : groupedSignatures) {
auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatureGroup);
if (mergeExtractSignatures.empty()) {
continue;
}

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type);

registered |=
exec::registerAggregateFunction(
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
[name, mergeExtractFunctionName](
core::AggregationNode::Step /*step*/,
const std::vector<TypePtr>& argTypes,
const TypePtr& resultType,
const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
const auto& [originalResultType, _] =
resolveAggregateFunction(mergeExtractFunctionName, argTypes);
if (!originalResultType) {
// TODO: limitation -- result type must be resolveable given
// intermediate type of the original UDAF.
VELOX_UNREACHABLE(
"Signatures whose result types are not resolvable given intermediate types should have been excluded.");
}

if (auto func = getAggregateFunctionEntry(name)) {
auto fn = func->factory(
core::AggregationNode::Step::kFinal,
argTypes,
originalResultType,
config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_unique<
AggregateCompanionAdapter::MergeExtractFunction>(
std::move(fn), resultType);
}
VELOX_FAIL(
"Original aggregation function {} not found: {}",
name,
mergeExtractFunctionName);
},
/*registerCompanionFunctions*/ false,
overwrite)
.mainFunction;
}
return registered;
}

bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
}

auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatures);
if (mergeExtractSignatures.empty()) {
return false;
}

auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return exec::registerAggregateFunction(
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
Expand Down Expand Up @@ -439,84 +370,63 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
.mainFunction;
}

bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
const std::string& originalName,
bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
bool registered = false;
for (const auto& [type, signatureGroup] : groupedSignatures) {
auto extractSignatures =
CompanionSignatures::extractFunctionSignatures(signatureGroup);
if (extractSignatures.empty()) {
auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatureGroup);
if (mergeExtractSignatures.empty()) {
continue;
}

auto factory = [originalName](
const std::string& name,
const std::vector<VectorFunctionArg>& inputArgs,
const core::QueryConfig& config)
-> std::shared_ptr<VectorFunction> {
std::vector<TypePtr> argTypes{inputArgs.size()};
std::transform(
inputArgs.begin(),
inputArgs.end(),
argTypes.begin(),
[](auto inputArg) { return inputArg.type; });

auto resultType = resolveVectorFunction(name, argTypes);
if (!resultType) {
// TODO: limitation -- result type must be resolveable given
// intermediate type of the original UDAF.
VELOX_UNREACHABLE(
"Signatures whose result types are not resolvable given intermediate types should have been excluded.");
}

if (auto func = getAggregateFunctionEntry(originalName)) {
auto fn = func->factory(
core::AggregationNode::Step::kFinal, argTypes, resultType, config);
VELOX_CHECK_NOT_NULL(fn);
return std::make_shared<AggregateCompanionAdapter::ExtractFunction>(
std::move(fn));
}
VELOX_FAIL(
"Original aggregation function {} not found: {}", originalName, name);
};
auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type);

registered |= exec::registerStatefulVectorFunction(
CompanionSignatures::extractFunctionNameWithSuffix(originalName, type),
extractSignatures,
factory,
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.build(),
registered |= registerAggregateFunction(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
overwrite);
}
return registered;
}

bool CompanionFunctionsRegistrar::registerExtractFunction(
const std::string& originalName,
bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
const std::string& name,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerExtractFunctionWithSuffix(
originalName, signatures, overwrite);
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
}

auto extractSignatures =
CompanionSignatures::extractFunctionSignatures(signatures);
if (extractSignatures.empty()) {
auto mergeExtractSignatures =
CompanionSignatures::mergeExtractFunctionSignatures(signatures);
if (mergeExtractSignatures.empty()) {
return false;
}

auto factory =
[originalName](
const std::string& name,
const std::vector<VectorFunctionArg>& inputArgs,
const core::QueryConfig& config) -> std::shared_ptr<VectorFunction> {
auto mergeExtractFunctionName =
CompanionSignatures::mergeExtractFunctionName(name);
return registerAggregateFunction(
name,
mergeExtractFunctionName,
std::move(mergeExtractSignatures),
overwrite);
}

VectorFunctionFactory getVectorFunctionFactory(
const std::string& originalName) {
return [originalName](
const std::string& name,
const std::vector<VectorFunctionArg>& inputArgs,
const core::QueryConfig& config)
-> std::shared_ptr<VectorFunction> {
std::vector<TypePtr> argTypes{inputArgs.size()};
std::transform(
inputArgs.begin(),
Expand All @@ -542,10 +452,56 @@ bool CompanionFunctionsRegistrar::registerExtractFunction(
VELOX_FAIL(
"Original aggregation function {} not found: {}", originalName, name);
};
}

bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(
const std::string& originalName,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
auto groupedSignatures =
CompanionSignatures::groupSignaturesByReturnType(signatures);
bool registered = false;
for (const auto& [type, signatureGroup] : groupedSignatures) {
auto extractSignatures =
CompanionSignatures::extractFunctionSignatures(signatureGroup);
if (extractSignatures.empty()) {
continue;
}

auto factory = getVectorFunctionFactory(originalName);
registered |= exec::registerStatefulVectorFunction(
CompanionSignatures::extractFunctionNameWithSuffix(originalName, type),
std::move(extractSignatures),
std::move(factory),
exec::VectorFunctionMetadataBuilder()
.defaultNullBehavior(false)
.build(),
overwrite);
}
return registered;
}

bool CompanionFunctionsRegistrar::registerExtractFunction(
const std::string& originalName,
const std::vector<AggregateFunctionSignaturePtr>& signatures,
bool overwrite) {
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
signatures)) {
return registerExtractFunctionWithSuffix(
originalName, signatures, overwrite);
}

auto extractSignatures =
CompanionSignatures::extractFunctionSignatures(signatures);
if (extractSignatures.empty()) {
return false;
}

auto factory = getVectorFunctionFactory(originalName);
return exec::registerStatefulVectorFunction(
CompanionSignatures::extractFunctionName(originalName),
extractSignatures,
factory,
std::move(extractSignatures),
std::move(factory),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build(),
overwrite);
}
Expand Down
Loading