diff --git a/velox/exec/AggregateCompanionAdapter.cpp b/velox/exec/AggregateCompanionAdapter.cpp index c1a087f24d7c..8b79a5e05d06 100644 --- a/velox/exec/AggregateCompanionAdapter.cpp +++ b/velox/exec/AggregateCompanionAdapter.cpp @@ -325,81 +325,12 @@ bool CompanionFunctionsRegistrar::registerMergeFunction( .mainFunction; } -bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix( - const std::string& name, - const std::vector& signatures, - bool overwrite) { - auto groupedSignatures = - CompanionSignatures::groupSignaturesByReturnType(signatures); - bool registered = false; - for (const auto& [type, signatureGroup] : groupedSignatures) { - auto mergeExtractSignatures = - CompanionSignatures::mergeExtractFunctionSignatures(signatureGroup); - if (mergeExtractSignatures.empty()) { - continue; - } - - auto mergeExtractFunctionName = - CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type); - - registered |= - exec::registerAggregateFunction( - mergeExtractFunctionName, - std::move(mergeExtractSignatures), - [name, mergeExtractFunctionName]( - core::AggregationNode::Step /*step*/, - const std::vector& argTypes, - const TypePtr& resultType, - const core::QueryConfig& config) -> std::unique_ptr { - const auto& [originalResultType, _] = - resolveAggregateFunction(mergeExtractFunctionName, argTypes); - if (!originalResultType) { - // TODO: limitation -- result type must be resolveable given - // intermediate type of the original UDAF. - VELOX_UNREACHABLE( - "Signatures whose result types are not resolvable given intermediate types should have been excluded."); - } - - if (auto func = getAggregateFunctionEntry(name)) { - auto fn = func->factory( - core::AggregationNode::Step::kFinal, - argTypes, - originalResultType, - config); - VELOX_CHECK_NOT_NULL(fn); - return std::make_unique< - AggregateCompanionAdapter::MergeExtractFunction>( - std::move(fn), resultType); - } - VELOX_FAIL( - "Original aggregation function {} not found: {}", - name, - mergeExtractFunctionName); - }, - /*registerCompanionFunctions*/ false, - overwrite) - .mainFunction; - } - return registered; -} - -bool CompanionFunctionsRegistrar::registerMergeExtractFunction( +bool registerAggregateFunction( const std::string& name, - const std::vector& signatures, + const std::string& mergeExtractFunctionName, + const std::vector>& + mergeExtractSignatures, bool overwrite) { - if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures( - signatures)) { - return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite); - } - - auto mergeExtractSignatures = - CompanionSignatures::mergeExtractFunctionSignatures(signatures); - if (mergeExtractSignatures.empty()) { - return false; - } - - auto mergeExtractFunctionName = - CompanionSignatures::mergeExtractFunctionName(name); return exec::registerAggregateFunction( mergeExtractFunctionName, std::move(mergeExtractSignatures), @@ -439,84 +370,63 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction( .mainFunction; } -bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix( - const std::string& originalName, +bool CompanionFunctionsRegistrar::registerMergeExtractFunctionWithSuffix( + const std::string& name, const std::vector& signatures, bool overwrite) { auto groupedSignatures = CompanionSignatures::groupSignaturesByReturnType(signatures); bool registered = false; for (const auto& [type, signatureGroup] : groupedSignatures) { - auto extractSignatures = - CompanionSignatures::extractFunctionSignatures(signatureGroup); - if (extractSignatures.empty()) { + auto mergeExtractSignatures = + CompanionSignatures::mergeExtractFunctionSignatures(signatureGroup); + if (mergeExtractSignatures.empty()) { continue; } - auto factory = [originalName]( - const std::string& name, - const std::vector& inputArgs, - const core::QueryConfig& config) - -> std::shared_ptr { - std::vector argTypes{inputArgs.size()}; - std::transform( - inputArgs.begin(), - inputArgs.end(), - argTypes.begin(), - [](auto inputArg) { return inputArg.type; }); - - auto resultType = resolveVectorFunction(name, argTypes); - if (!resultType) { - // TODO: limitation -- result type must be resolveable given - // intermediate type of the original UDAF. - VELOX_UNREACHABLE( - "Signatures whose result types are not resolvable given intermediate types should have been excluded."); - } - - if (auto func = getAggregateFunctionEntry(originalName)) { - auto fn = func->factory( - core::AggregationNode::Step::kFinal, argTypes, resultType, config); - VELOX_CHECK_NOT_NULL(fn); - return std::make_shared( - std::move(fn)); - } - VELOX_FAIL( - "Original aggregation function {} not found: {}", originalName, name); - }; + auto mergeExtractFunctionName = + CompanionSignatures::mergeExtractFunctionNameWithSuffix(name, type); - registered |= exec::registerStatefulVectorFunction( - CompanionSignatures::extractFunctionNameWithSuffix(originalName, type), - extractSignatures, - factory, - exec::VectorFunctionMetadataBuilder() - .defaultNullBehavior(false) - .build(), + registered |= registerAggregateFunction( + name, + mergeExtractFunctionName, + std::move(mergeExtractSignatures), overwrite); } return registered; } -bool CompanionFunctionsRegistrar::registerExtractFunction( - const std::string& originalName, +bool CompanionFunctionsRegistrar::registerMergeExtractFunction( + const std::string& name, const std::vector& signatures, bool overwrite) { if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures( signatures)) { - return registerExtractFunctionWithSuffix( - originalName, signatures, overwrite); + return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite); } - auto extractSignatures = - CompanionSignatures::extractFunctionSignatures(signatures); - if (extractSignatures.empty()) { + auto mergeExtractSignatures = + CompanionSignatures::mergeExtractFunctionSignatures(signatures); + if (mergeExtractSignatures.empty()) { return false; } - auto factory = - [originalName]( - const std::string& name, - const std::vector& inputArgs, - const core::QueryConfig& config) -> std::shared_ptr { + auto mergeExtractFunctionName = + CompanionSignatures::mergeExtractFunctionName(name); + return registerAggregateFunction( + name, + mergeExtractFunctionName, + std::move(mergeExtractSignatures), + overwrite); +} + +VectorFunctionFactory getVectorFunctionFactory( + const std::string& originalName) { + return [originalName]( + const std::string& name, + const std::vector& inputArgs, + const core::QueryConfig& config) + -> std::shared_ptr { std::vector argTypes{inputArgs.size()}; std::transform( inputArgs.begin(), @@ -542,6 +452,52 @@ bool CompanionFunctionsRegistrar::registerExtractFunction( VELOX_FAIL( "Original aggregation function {} not found: {}", originalName, name); }; +} + +bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix( + const std::string& originalName, + const std::vector& signatures, + bool overwrite) { + auto groupedSignatures = + CompanionSignatures::groupSignaturesByReturnType(signatures); + bool registered = false; + for (const auto& [type, signatureGroup] : groupedSignatures) { + auto extractSignatures = + CompanionSignatures::extractFunctionSignatures(signatureGroup); + if (extractSignatures.empty()) { + continue; + } + + auto factory = getVectorFunctionFactory(originalName); + registered |= exec::registerStatefulVectorFunction( + CompanionSignatures::extractFunctionNameWithSuffix(originalName, type), + extractSignatures, + factory, + exec::VectorFunctionMetadataBuilder() + .defaultNullBehavior(false) + .build(), + overwrite); + } + return registered; +} + +bool CompanionFunctionsRegistrar::registerExtractFunction( + const std::string& originalName, + const std::vector& signatures, + bool overwrite) { + if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures( + signatures)) { + return registerExtractFunctionWithSuffix( + originalName, signatures, overwrite); + } + + auto extractSignatures = + CompanionSignatures::extractFunctionSignatures(signatures); + if (extractSignatures.empty()) { + return false; + } + + auto factory = getVectorFunctionFactory(originalName); return exec::registerStatefulVectorFunction( CompanionSignatures::extractFunctionName(originalName), extractSignatures,