Skip to content

Commit

Permalink
[oap-native-sql] Support extra castVARCHAR and castBYTE in gandiva (#84)
Browse files Browse the repository at this point in the history
* [oap-native-sql] Add castToString support to various types

int[8-64], float32, float64, date32

Signed-off-by: Chendi Xue <chendi.xue@intel.com>

* [oap-native-sql] Using %g to format, added DIGS support if needed

Signed-off-by: Chendi Xue <chendi.xue@intel.com>

* [oap-native-sql] Add castByte support

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
  • Loading branch information
xuechendi authored and zhouyuan committed Feb 3, 2021
1 parent ca2d426 commit 73005fc
Show file tree
Hide file tree
Showing 10 changed files with 302 additions and 0 deletions.
3 changes: 3 additions & 0 deletions cpp/src/gandiva/function_registry_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
UNARY_SAFE_NULL_IF_NULL(castINT, {}, date32, int32),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, float32, int32),
UNARY_SAFE_NULL_IF_NULL(castINT, {}, float64, int32),
UNARY_SAFE_NULL_IF_NULL(castBYTE, {}, int16, int8),
UNARY_SAFE_NULL_IF_NULL(castBYTE, {}, int32, int8),
UNARY_SAFE_NULL_IF_NULL(castBYTE, {}, int64, int8),
UNARY_SAFE_NULL_IF_NULL(castBIGINT, {}, decimal128, int64),

// cast to float32
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/gandiva/function_registry_datetime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ std::vector<NativeFunction> GetDateTimeFunctionRegistry() {
kResultNullIfNull, "castTIMESTAMP_utf8",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),

NativeFunction("castVARCHAR", {}, DataTypeVector{date32(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_date32_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{timestamp(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_timestamp_int64",
NativeFunction::kNeedsContext),
Expand Down
23 changes: 23 additions & 0 deletions cpp/src/gandiva/function_registry_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,29 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
NativeFunction("castFLOAT8", {}, DataTypeVector{utf8()}, float64(),
kResultNullIfNull, "gdv_fn_castFLOAT8_utf8",
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
NativeFunction("castVARCHAR", {}, DataTypeVector{int8(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_int8_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{int16(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_int16_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{int32(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_int32_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{int64(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_int64_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{float32(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_float32_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{float64(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_float64_int64",
NativeFunction::kNeedsContext),

NativeFunction("castVARCHAR", {}, DataTypeVector{utf8(), int64()}, utf8(),
kResultNullIfNull, "castVARCHAR_utf8_int64",
Expand Down
46 changes: 46 additions & 0 deletions cpp/src/gandiva/precompiled/arithmetic_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
extern "C" {

#include <math.h>
#include <stdio.h>
#include <string.h>
#include <cfloat>
#include "./types.h"

// Expand inner macro for all numeric types.
Expand Down Expand Up @@ -119,6 +122,9 @@ CAST_UNARY(castINT, int64, int32)
CAST_UNARY(castINT, date32, int32)
CAST_UNARY(castINT, float32, int32)
CAST_UNARY(castINT, float64, int32)
CAST_UNARY(castBYTE, int16, int8)
CAST_UNARY(castBYTE, int32, int8)
CAST_UNARY(castBYTE, int64, int8)
CAST_UNARY(castFLOAT4, int32, float32)
CAST_UNARY(castFLOAT4, int64, float32)
CAST_UNARY(castFLOAT8, int32, float64)
Expand All @@ -127,6 +133,46 @@ CAST_UNARY(castFLOAT8, float32, float64)
CAST_UNARY(castFLOAT4, float64, float32)

#undef CAST_UNARY
#define nothing
#define PRINT(DIGSF, DIGS, FMT) PRINT_##DIGSF(DIGS, FMT)
#define PRINT_NOFMT(DIGS, FMT) int res = snprintf(char_buffer, length, FMT, in);
#define PRINT_FMT(DIGS, FMT) int res = snprintf(char_buffer, length, FMT, DIGS, in);

#define CAST_UNARY_UTF8(NAME, IN_TYPE, OUT_TYPE, FMT, DIGSF, DIGS) \
FORCE_INLINE \
const char* NAME##_##IN_TYPE##_int64(gdv_int64 context, gdv_##IN_TYPE in, \
gdv_int64 length, gdv_int32 * out_len) { \
const int32_t char_buffer_length = length; \
char char_buffer[char_buffer_length]; \
PRINT(DIGSF, DIGS, FMT) \
if (res < 0) { \
gdv_fn_context_set_error_msg(context, "Could not format the ##IN_TYPE"); \
return ""; \
} \
\
*out_len = strlen(char_buffer); \
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len)); \
if (ret == nullptr) { \
gdv_fn_context_set_error_msg(context, \
"Could not allocate memory for output string"); \
*out_len = 0; \
return ""; \
} \
\
memcpy(ret, char_buffer, *out_len); \
return ret; \
}

CAST_UNARY_UTF8(castVARCHAR, int8, utf8, "%d", NOFMT, nothing)
CAST_UNARY_UTF8(castVARCHAR, int16, utf8, "%d", NOFMT, nothing)
CAST_UNARY_UTF8(castVARCHAR, int32, utf8, "%d", NOFMT, nothing)
CAST_UNARY_UTF8(castVARCHAR, int64, utf8, "%ld", NOFMT, nothing)
// CAST_UNARY_UTF8(castVARCHAR, float32, utf8, "%.*f", FMT, FLT_DIG)
// CAST_UNARY_UTF8(castVARCHAR, float64, utf8, "%.*f", FMT, DBL_DIG)
CAST_UNARY_UTF8(castVARCHAR, float32, utf8, "%g", NOFMT, nothing)
CAST_UNARY_UTF8(castVARCHAR, float64, utf8, "%g", NOFMT, nothing)

#undef CAST_UNARY_UTF8

// simple nullable functions, result value = fn(input validity)
#define VALIDITY_OP(NAME, TYPE, OP) \
Expand Down
14 changes: 14 additions & 0 deletions cpp/src/gandiva/precompiled/arithmetic_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,18 @@ TEST(TestArithmeticOps, TestBitwiseOps) {
EXPECT_EQ(bitwise_not_int64(0x0000000000000000), 0xFFFFFFFFFFFFFFFF);
}

TEST(TestArithmeticOps, TestCastVarhcar) {
gandiva::ExecutionContext ctx;
uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
gdv_int32 out_len = 0;

const char* out_str = castVARCHAR_int32_int64(ctx_ptr, 88, 11L, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "88");
EXPECT_FALSE(ctx.has_error());

out_str = castVARCHAR_float64_int64(ctx_ptr, 8.712128f, 21L, &out_len);
EXPECT_EQ(std::string(out_str, out_len), "8.712128");
EXPECT_FALSE(ctx.has_error());
}

} // namespace gandiva
43 changes: 43 additions & 0 deletions cpp/src/gandiva/precompiled/time.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,49 @@ gdv_date64 castDATE_utf8(int64_t context, const char* input, gdv_int32 length) {
.count();
}

const char* castVARCHAR_date32_int64(gdv_int64 context, gdv_date32 in_day,
gdv_int64 length, gdv_int32* out_len) {
gdv_timestamp in = castDATE_date32(in_day);
gdv_int64 year = extractYear_timestamp(in);
gdv_int64 month = extractMonth_timestamp(in);
gdv_int64 day = extractDay_timestamp(in);

static const int kDateStringLen = 11;
const int char_buffer_length = kDateStringLen + 1; // snprintf adds \0
char char_buffer[char_buffer_length];

// yyyy-MM-dd hh:mm:ss.sss
int res = snprintf(char_buffer, char_buffer_length,
"%04" PRId64 "-%02" PRId64 "-%02" PRId64, year, month, day);
if (res < 0) {
gdv_fn_context_set_error_msg(context, "Could not format the date");
return "";
}

*out_len = static_cast<gdv_int32>(length);
if (*out_len > kDateStringLen) {
*out_len = kDateStringLen;
}

if (*out_len <= 0) {
if (*out_len < 0) {
gdv_fn_context_set_error_msg(context, "Length of output string cannot be negative");
}
*out_len = 0;
return "";
}

char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
if (ret == nullptr) {
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
*out_len = 0;
return "";
}

memcpy(ret, char_buffer, *out_len);
return ret;
}

/*
* Input consists of mandatory and optional fields.
* Mandatory fields are year, month and day.
Expand Down
9 changes: 9 additions & 0 deletions cpp/src/gandiva/precompiled/time_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -663,4 +663,13 @@ TEST(TestTime, TestCastTimestampToDate) {
EXPECT_EQ(StringToTimestamp("2000-05-01 00:00:00"), out);
}

TEST(TestTime, castVarcharDate) {
ExecutionContext context;
int64_t context_ptr = reinterpret_cast<int64_t>(&context);
gdv_int32 out_len;
gdv_date32 date = castDATE_utf8(context_ptr, "1967-12-1", 9);
const char* out = castVARCHAR_date32_int64(context_ptr, date, 10L, &out_len);
EXPECT_EQ(std::string(out, out_len), "1967-12-01");
}

} // namespace gandiva
14 changes: 14 additions & 0 deletions cpp/src/gandiva/precompiled/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ gdv_timestamp castTIMESTAMP_int64(gdv_int64);
gdv_date64 castDATE_timestamp(gdv_timestamp);
const char* castVARCHAR_timestamp_int64(int64_t, gdv_timestamp, gdv_int64, gdv_int32*);

const char* castVARCHAR_date32_int64(int64_t, gdv_date32, gdv_int64, gdv_int32*);

const char* castVARCHAR_int8_int64(int64_t, gdv_int8, gdv_int64, gdv_int32*);

const char* castVARCHAR_int16_int64(int64_t, gdv_int16, gdv_int64, gdv_int32*);

const char* castVARCHAR_int32_int64(int64_t, gdv_int32, gdv_int64, gdv_int32*);

const char* castVARCHAR_int64_int64(int64_t, gdv_int64, gdv_int64, gdv_int32*);

const char* castVARCHAR_float32_int64(int64_t, gdv_float32, gdv_int64, gdv_int32*);

const char* castVARCHAR_float64_int64(int64_t, gdv_float64, gdv_int64, gdv_int32*);

gdv_int64 truncate_int64_int32(gdv_int64 in, gdv_int32 out_scale);

const char* substr_utf8_int64_int64(gdv_int64 context, const char* input,
Expand Down
44 changes: 44 additions & 0 deletions cpp/src/gandiva/tests/date_time_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

#include <arrow/type.h>
#include <gtest/gtest.h>
#include <math.h>
#include <time.h>
Expand Down Expand Up @@ -537,4 +538,47 @@ TEST_F(TestProjector, TestMonthsBetween) {
EXPECT_ARROW_ARRAY_EQUALS(exp_output, outputs.at(0));
}

TEST_F(TestProjector, TestCastToUTF8) {
// schema for input fields
auto field_date32 = field("f_date32", arrow::date32());
auto node_date32 = TreeExprBuilder::MakeField(field_date32);
auto schema = arrow::schema({field_date32});

// output fields
auto field_0 = field("date32_str", arrow::utf8());

// Build expression
auto int64_literal = TreeExprBuilder::MakeLiteral(10L);
auto func0 = TreeExprBuilder::MakeFunction("castVARCHAR", {node_date32, int64_literal},
arrow::utf8());
auto expr0 = TreeExprBuilder::MakeExpression(func0, field_0);
std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr0}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok()) << status.message();

// Create a row-batch with some sample data
int num_records = 4;
time_t epoch = Epoch();
auto validity = {true, true, true, true};
std::vector<int32_t> field0_data = {DaysSince(epoch, 2000, 1, 1, 5, 0, 0, 0),
DaysSince(epoch, 1999, 12, 31, 5, 0, 0, 0),
DaysSince(epoch, 2015, 6, 30, 20, 0, 0, 0),
DaysSince(epoch, 2015, 7, 1, 20, 0, 0, 0)};
auto array0 =
MakeArrowTypeArray<arrow::Date32Type, int32_t>(date32(), field0_data, validity);
// expected output
auto exp = MakeArrowArray<arrow::StringType, std::string>(
{"2000-01-01", "1999-12-31", "2015-06-30", "2015-07-01"}, {true, true, true, true});

// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});

// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok()) << status.message();

// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}
} // namespace gandiva
102 changes: 102 additions & 0 deletions cpp/src/gandiva/tests/projector_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "gandiva/projector.h"

#include <arrow/type.h>
#include <gtest/gtest.h>

#include <cmath>
Expand Down Expand Up @@ -862,4 +863,105 @@ TEST_F(TestProjector, TestToDate) {
EXPECT_ARROW_ARRAY_EQUALS(exp, outputs.at(0));
}

TEST_F(TestProjector, TestCastToUTF8) {
// schema for input fields
auto field_float64 = field("f_float64", arrow::float64());
auto field_int32 = field("f_int32", arrow::int32());
auto schema = arrow::schema({field_float64, field_int32});

// output fields
auto field_0 = field("float64_str", arrow::utf8());
auto field_1 = field("int32_str", arrow::utf8());

// Build expression
auto node_a = TreeExprBuilder::MakeField(field_float64);
auto node_b = TreeExprBuilder::MakeField(field_int32);
auto int64_literal_0 = TreeExprBuilder::MakeLiteral(21L);
auto int64_literal_1 = TreeExprBuilder::MakeLiteral(11L);
auto func0 = TreeExprBuilder::MakeFunction("castVARCHAR", {node_a, int64_literal_0},
arrow::utf8());
auto expr0 = TreeExprBuilder::MakeExpression(func0, field_0);
auto func1 = TreeExprBuilder::MakeFunction("castVARCHAR", {node_b, int64_literal_1},
arrow::utf8());
auto expr1 = TreeExprBuilder::MakeExpression(func1, field_1);

std::shared_ptr<Projector> projector;
auto status = Projector::Make(schema, {expr0, expr1}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok()) << status.message();

// Create a row-batch with some sample data
int num_records = 4;
auto array0 = MakeArrowArrayFloat64({1989278888.23f, 5.892732f, -23487.3f, 9.712717f},
{true, true, true, true});
auto array1 = MakeArrowArrayInt32({5, 6, 7, 8}, {true, true, true, true});
// expected output
auto exp_0 = MakeArrowArray<arrow::StringType, std::string>(
{"1.98928e+09", "5.89273", "-23487.3", "9.71272"}, {true, true, true, true});
auto exp_1 = MakeArrowArray<arrow::StringType, std::string>({"5", "6", "7", "8"},
{true, true, true, true});

// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1});

// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok()) << status.message();

// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp_0, outputs.at(0));
EXPECT_ARROW_ARRAY_EQUALS(exp_1, outputs.at(1));
}

TEST_F(TestProjector, TestCastToByte) {
// schema for input fields
auto field_int16 = field("f_int16", arrow::int16());
auto field_int64 = field("f_int64", arrow::int64());
auto field_int32 = field("f_int32", arrow::int32());
auto schema = arrow::schema({field_int16, field_int64, field_int32});

// output fields
auto field_0 = field("out_0", arrow::int8());
auto field_1 = field("out_1", arrow::int8());
auto field_2 = field("out_2", arrow::int8());

// Build expression
auto node_a = TreeExprBuilder::MakeField(field_int16);
auto node_b = TreeExprBuilder::MakeField(field_int64);
auto node_c = TreeExprBuilder::MakeField(field_int32);
auto func0 = TreeExprBuilder::MakeFunction("castBYTE", {node_a}, arrow::int8());
auto expr0 = TreeExprBuilder::MakeExpression(func0, field_0);
auto func1 = TreeExprBuilder::MakeFunction("castBYTE", {node_b}, arrow::int8());
auto expr1 = TreeExprBuilder::MakeExpression(func1, field_1);
auto func2 = TreeExprBuilder::MakeFunction("castBYTE", {node_c}, arrow::int8());
auto expr2 = TreeExprBuilder::MakeExpression(func2, field_2);

std::shared_ptr<Projector> projector;
auto status =
Projector::Make(schema, {expr0, expr1, expr2}, TestConfiguration(), &projector);
EXPECT_TRUE(status.ok()) << status.message();

// Create a row-batch with some sample data
int num_records = 4;
auto array0 = MakeArrowArrayInt16({5, 6, 7, 8}, {true, true, true, true});
auto array1 = MakeArrowArrayInt64({5L, 6L, 7L, 257L}, {true, true, true, true});
auto array2 = MakeArrowArrayInt32({5, 6, 7, 8}, {true, true, true, true});
// expected output
auto exp_0 = MakeArrowArrayInt8({5, 6, 7, 8}, {true, true, true, true});
auto exp_1 = MakeArrowArrayInt8({5, 6, 7, 1}, {true, true, true, true});
auto exp_2 = MakeArrowArrayInt8({5, 6, 7, 8}, {true, true, true, true});

// prepare input record batch
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0, array1, array2});

// Evaluate expression
arrow::ArrayVector outputs;
status = projector->Evaluate(*in_batch, pool_, &outputs);
EXPECT_TRUE(status.ok()) << status.message();

// Validate results
EXPECT_ARROW_ARRAY_EQUALS(exp_0, outputs.at(0));
EXPECT_ARROW_ARRAY_EQUALS(exp_1, outputs.at(1));
EXPECT_ARROW_ARRAY_EQUALS(exp_2, outputs.at(2));
}
} // namespace gandiva

0 comments on commit 73005fc

Please sign in to comment.