Skip to content

Commit

Permalink
apacheGH-40183: [C++] Fix cast function bind failed after add an alia…
Browse files Browse the repository at this point in the history
…s name through AddAlias (apache#40200)

### Rationale for this change

Cast function bind failed after add a alias name through AddAlias.

### What changes are included in this PR?

Add a const `cast_function` which registered in `AddFunction` for check cast alias in `arrow::compute::GetFunction`.

### Are these changes tested?
Yes

### Are there any user-facing changes?
Yes, cast's alias name can also execute with expression system.

* GitHub Issue: apache#40183

Authored-by: hugo.zhang <hugo.zhang@openpie.com>
Signed-off-by: Benjamin Kietzman <bengilgit@gmail.com>
  • Loading branch information
ZhangHuiGui authored and mapleFU committed Mar 7, 2024
1 parent a12207d commit 3558087
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
7 changes: 5 additions & 2 deletions cpp/src/arrow/compute/expression_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,12 @@ struct FlattenedAssociativeChain {

inline Result<std::shared_ptr<compute::Function>> GetFunction(
const Expression::Call& call, compute::ExecContext* exec_context) {
if (call.function_name != "cast") {
return exec_context->func_registry()->GetFunction(call.function_name);
ARROW_ASSIGN_OR_RAISE(auto function,
exec_context->func_registry()->GetFunction(call.function_name));
if (function.get() != exec_context->func_registry()->cast_function()) {
return function;
}

// XXX this special case is strange; why not make "cast" a ScalarFunction?
const TypeHolder& to_type =
::arrow::internal::checked_cast<const compute::CastOptions&>(*call.options).to_type;
Expand Down
11 changes: 11 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,17 @@ TEST(Expression, BindCall) {
add(cast(field_ref("i32"), float32()), literal(3.5F)));
}

TEST(Expression, BindWithAliasCasts) {
auto fm = GetFunctionRegistry();
EXPECT_OK(fm->AddAlias("alias_cast", "cast"));

auto expr = call("alias_cast", {field_ref("f1")}, CastOptions::Unsafe(arrow::int32()));
EXPECT_FALSE(expr.IsBound());

auto schema = arrow::schema({field("f1", decimal128(30, 3))});
ExpectBindsTo(expr, no_change, &expr, *schema);
}

TEST(Expression, BindWithDecimalArithmeticOps) {
for (std::string arith_op : {"add", "subtract", "multiply", "divide"}) {
auto expr = call(arith_op, {field_ref("d1"), field_ref("d2")});
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/arrow/compute/registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class FunctionRegistry::FunctionRegistryImpl {
static_cast<int>(name_to_function_.size());
}

const Function* cast_function() { return cast_function_; }

private:
// must not acquire mutex
Status CanAddFunctionName(const std::string& name, bool allow_overwrite) {
Expand Down Expand Up @@ -169,6 +171,9 @@ class FunctionRegistry::FunctionRegistryImpl {
RETURN_NOT_OK(CanAddFunctionName(name, allow_overwrite));
if (add) {
name_to_function_[name] = std::move(function);
if (name == "cast") {
cast_function_ = name_to_function_[name].get();
}
}
return Status::OK();
}
Expand Down Expand Up @@ -205,6 +210,8 @@ class FunctionRegistry::FunctionRegistryImpl {
std::mutex lock_;
std::unordered_map<std::string, std::shared_ptr<Function>> name_to_function_;
std::unordered_map<std::string, const FunctionOptionsType*> name_to_options_type_;

const Function* cast_function_;
};

std::unique_ptr<FunctionRegistry> FunctionRegistry::Make() {
Expand Down Expand Up @@ -268,6 +275,8 @@ Result<const FunctionOptionsType*> FunctionRegistry::GetFunctionOptionsType(

int FunctionRegistry::num_functions() const { return impl_->num_functions(); }

const Function* FunctionRegistry::cast_function() const { return impl_->cast_function(); }

namespace internal {

static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/compute/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class ARROW_EXPORT FunctionRegistry {
/// \brief The number of currently registered functions.
int num_functions() const;

/// \brief The cast function object registered in AddFunction.
///
/// Helpful for get cast function as needed.
const Function* cast_function() const;

private:
FunctionRegistry();

Expand Down

0 comments on commit 3558087

Please sign in to comment.