Skip to content

Commit

Permalink
Implement initcap/concat functions (#3161)
Browse files Browse the repository at this point in the history
  • Loading branch information
acquamarin committed Mar 28, 2024
1 parent 20bde3a commit 956b3e3
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 77 deletions.
33 changes: 19 additions & 14 deletions src/function/base_lower_upper_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,27 @@ uint32_t BaseLowerUpperFunction::getResultLen(char* inputStr, uint32_t inputLen,
return outputLength;
}

uint64_t BaseLowerUpperFunction::convertCharCase(
char* result, const char* input, int32_t charPos, bool toUpper) {
if (input[charPos] & 0x80) {
int size = 0u, newSize = 0u;
auto codepoint = utf8proc_codepoint(input + charPos, size);
KU_ASSERT(codepoint >= 0); // Validity ensured by getResultLen.
int convertedCodepoint =
toUpper ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint);
utf8proc_codepoint_to_utf8(convertedCodepoint, newSize, result);
return size;
} else {
*result = toUpper ? toupper(input[charPos]) : tolower(input[charPos]);
return 1;
}
}

void BaseLowerUpperFunction::convertCase(char* result, uint32_t len, char* input, bool toUpper) {
for (auto i = 0u; i < len;) {
if (input[i] & 0x80) {
int size = 0, newSize = 0;
int codepoint = utf8proc_codepoint(input + i, size);
KU_ASSERT(codepoint >= 0); // Validity ensured by getResultLen.
int convertedCodepoint =
toUpper ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint);
utf8proc_codepoint_to_utf8(convertedCodepoint, newSize, result);
result += newSize;
i += size;
} else {
*result = toUpper ? toupper(input[i]) : tolower(input[i]);
i++;
result++;
}
auto charWidth = convertCharCase(result, input, i, toUpper);
i += charWidth;
result += charWidth;
}
}

Expand Down
1 change: 1 addition & 0 deletions src/function/function_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ FunctionCollection* FunctionCollection::getFunctions() {
SCALAR_FUNCTION(RegexpFullMatchFunction), SCALAR_FUNCTION(RegexpMatchesFunction),
SCALAR_FUNCTION(RegexpReplaceFunction), SCALAR_FUNCTION(RegexpExtractFunction),
SCALAR_FUNCTION(RegexpExtractAllFunction), SCALAR_FUNCTION(LevenshteinFunction),
SCALAR_FUNCTION(InitcapFunction),

// Array Functions
SCALAR_FUNCTION(ArrayValueFunction), SCALAR_FUNCTION(ArrayCrossProductFunction),
Expand Down
5 changes: 2 additions & 3 deletions src/function/vector_arithmetic_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "function/list/functions/list_concat_function.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"
#include "function/string/functions/concat_function.h"
#include "function/string/vector_string_functions.h"

using namespace kuzu::common;

Expand Down Expand Up @@ -155,8 +155,7 @@ function_set AddFunction::getFunctionSet() {
// string + string -> string
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::STRING},
LogicalTypeID::STRING,
ScalarFunction::BinaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, Concat>));
LogicalTypeID::STRING, ConcatFunction::execFunc));
// interval + interval → interval
result.push_back(getBinaryFunction<Add, interval_t, interval_t>(
name, LogicalTypeID::INTERVAL, LogicalTypeID::INTERVAL));
Expand Down
69 changes: 47 additions & 22 deletions src/function/vector_string_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "function/string/vector_string_functions.h"

#include "function/string/functions/array_extract_function.h"
#include "function/string/functions/concat_function.h"
#include "function/string/functions/contains_function.h"
#include "function/string/functions/ends_with_function.h"
#include "function/string/functions/initcap_function.h"
#include "function/string/functions/left_operation.h"
#include "function/string/functions/levenshtein_function.h"
#include "function/string/functions/lpad_function.h"
Expand Down Expand Up @@ -52,22 +52,6 @@ void BaseStrOperation::operation(ku_string_t& input, ku_string_t& result,
}
}

void Concat::concat(const char* left, uint32_t leftLen, const char* right, uint32_t rightLen,
ku_string_t& result, ValueVector& resultValueVector) {
auto len = leftLen + rightLen;
if (len <= ku_string_t::SHORT_STR_LENGTH /* concat result is short */) {
memcpy(result.prefix, left, leftLen);
memcpy(result.prefix + leftLen, right, rightLen);
} else {
StringVector::reserveString(&resultValueVector, result, len);
auto buffer = reinterpret_cast<char*>(result.overflowPtr);
memcpy(buffer, left, leftLen);
memcpy(buffer + leftLen, right, rightLen);
memcpy(result.prefix, buffer, ku_string_t::PREFIX_LENGTH);
}
result.len = len;
}

void Repeat::operation(
ku_string_t& left, int64_t& right, ku_string_t& result, ValueVector& resultValueVector) {
result.len = left.len * right;
Expand Down Expand Up @@ -121,13 +105,45 @@ function_set ArrayExtractFunction::getFunctionSet() {
return functionSet;
}

void ConcatFunction::execFunc(const std::vector<std::shared_ptr<ValueVector>>& parameters,
ValueVector& result, void* /*dataPtr*/) {
result.resetAuxiliaryBuffer();
for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize;
++selectedPos) {
auto pos = result.state->selVector->selectedPositions[selectedPos];
auto strLen = 0u;
for (auto i = 0u; i < parameters.size(); i++) {
const auto& parameter = parameters[i];
auto paramPos = parameter->state->isFlat() ?
parameter->state->selVector->selectedPositions[0] :
pos;
strLen += parameter->getValue<ku_string_t>(paramPos).len;
}
auto& resultStr = result.getValue<ku_string_t>(pos);
StringVector::reserveString(&result, resultStr, strLen);
auto dstData = strLen <= ku_string_t::SHORT_STR_LENGTH ?
resultStr.prefix :
reinterpret_cast<uint8_t*>(resultStr.overflowPtr);
for (auto i = 0u; i < parameters.size(); i++) {
const auto& parameter = parameters[i];
auto paramPos = parameter->state->isFlat() ?
parameter->state->selVector->selectedPositions[0] :
pos;
auto srcStr = parameter->getValue<ku_string_t>(paramPos);
memcpy(dstData, srcStr.getData(), srcStr.len);
dstData += srcStr.len;
}
if (strLen > ku_string_t::SHORT_STR_LENGTH) {
memcpy(resultStr.prefix, resultStr.getData(), ku_string_t::PREFIX_LENGTH);
}
}
}

function_set ConcatFunction::getFunctionSet() {
function_set functionSet;
functionSet.emplace_back(make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::STRING},
LogicalTypeID::STRING,
ScalarFunction::BinaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, Concat>,
false /* isVarLength */));
functionSet.emplace_back(
make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::STRING},
LogicalTypeID::STRING, execFunc, true /* isVarLength */));
return functionSet;
}

Expand Down Expand Up @@ -317,5 +333,14 @@ function_set LevenshteinFunction::getFunctionSet() {
return functionSet;
}

function_set InitcapFunction::getFunctionSet() {
function_set functionSet;
functionSet.emplace_back(make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}, LogicalTypeID::STRING,
ScalarFunction::UnaryStringExecFunction<ku_string_t, ku_string_t, Initcap>, nullptr,
nullptr, false /* isVarLength */));
return functionSet;
}

} // namespace function
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/include/common/types/ku_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ struct KUZU_API ku_string_t {
return isShortString(len) ? prefix : reinterpret_cast<uint8_t*>(overflowPtr);
}

uint8_t* getDataUnsafe() {
return isShortString(len) ? prefix : reinterpret_cast<uint8_t*>(overflowPtr);
}

// These functions do *NOT* allocate/resize the overflow buffer, it only copies the content and
// set the length.
void set(const std::string& value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct BaseLowerUpperFunction {
KUZU_API static void operation(common::ku_string_t& input, common::ku_string_t& result,
common::ValueVector& resultValueVector, bool isUpper);

static uint64_t convertCharCase(char* result, const char* input, int32_t charPos, bool toUpper);

private:
static uint32_t getResultLen(char* inputStr, uint32_t inputLen, bool isUpper);
static void convertCase(char* result, uint32_t len, char* input, bool toUpper);
Expand Down
35 changes: 0 additions & 35 deletions src/include/function/string/functions/concat_function.h

This file was deleted.

19 changes: 19 additions & 0 deletions src/include/function/string/functions/initcap_function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include "base_lower_upper_function.h"
#include "lower_function.h"

namespace kuzu {
namespace function {

struct Initcap {
static void operation(common::ku_string_t& operand, common::ku_string_t& result,
common::ValueVector& resultVector) {
Lower::operation(operand, result, resultVector);
BaseLowerUpperFunction::convertCharCase(reinterpret_cast<char*>(result.getDataUnsafe()),
reinterpret_cast<const char*>(result.getData()), 0 /* charPos */, true /* toUpper */);
}
};

} // namespace function
} // namespace kuzu
9 changes: 9 additions & 0 deletions src/include/function/string/vector_string_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ struct ArrayExtractFunction {
struct ConcatFunction : public VectorStringFunction {
static constexpr const char* name = "CONCAT";

static void execFunc(const std::vector<std::shared_ptr<common::ValueVector>>& parameters,
common::ValueVector& result, void* /*dataPtr*/);

static function_set getFunctionSet();
};

Expand Down Expand Up @@ -175,5 +178,11 @@ struct LevenshteinFunction : public VectorStringFunction {
static function_set getFunctionSet();
};

struct InitcapFunction : public VectorStringFunction {
static constexpr const char* name = "INITCAP";

static function_set getFunctionSet();
};

} // namespace function
} // namespace kuzu
20 changes: 17 additions & 3 deletions test/main/udf_test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "function/string/functions/concat_function.h"
#include "main_test_helper/main_test_helper.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -294,14 +293,29 @@ TEST_F(ApiTest, vectorizedBinaryAddDate) {
sortAndCheckTestResults(actualResult, expectedResult);
}

static void concat(const ku_string_t& left, const ku_string_t& right, ku_string_t& result,
ValueVector& resultValueVector) {
result.len = left.len + right.len;
if (result.len <= ku_string_t::SHORT_STR_LENGTH /* concat result is short */) {
memcpy(result.prefix, left.getData(), left.len);
memcpy(result.prefix + left.len, right.getData(), right.len);
} else {
StringVector::reserveString(&resultValueVector, result, result.len);
auto buffer = reinterpret_cast<char*>(result.overflowPtr);
memcpy(buffer, left.getData(), left.len);
memcpy(buffer + left.len, right.getData(), right.len);
memcpy(result.prefix, buffer, ku_string_t::PREFIX_LENGTH);
}
}

struct ConditionalConcat {
static inline void operation(
ku_string_t& a, bool& b, ku_string_t& c, ku_string_t& result, ValueVector& resultVector) {
// Concat a,c if b is true, otherwise concat c,a.
if (b) {
function::Concat::operation(a, c, result, resultVector);
concat(a, c, result, resultVector);
} else {
function::Concat::operation(c, a, result, resultVector);
concat(c, a, result, resultVector);
}
}
};
Expand Down
42 changes: 42 additions & 0 deletions test/test_files/tinysnb/function/string.test
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,45 @@ hellolongstringtest
-STATEMENT return levenshtein('horse', 'ros');
---- 1
3

-LOG ConcatMultipleStrings
-STATEMENT match (p:person) return concat("name: ", p.fName, ' ,age: ', "5");
---- 8
name: Alice ,age: 5
name: Bob ,age: 5
name: Carol ,age: 5
name: Dan ,age: 5
name: Elizabeth ,age: 5
name: Farooq ,age: 5
name: Greg ,age: 5
name: Hubert Blaine Wolfeschlegelsteinhausenbergerdorff ,age: 5
-STATEMENT match (p:person)-[:knows]->(p1:person) return concat("From: ", p.fName, ', To: ', p1.fName);
---- 14
From: Alice, To: Bob
From: Alice, To: Carol
From: Alice, To: Dan
From: Bob, To: Alice
From: Bob, To: Carol
From: Bob, To: Dan
From: Carol, To: Alice
From: Carol, To: Bob
From: Carol, To: Dan
From: Dan, To: Alice
From: Dan, To: Bob
From: Dan, To: Carol
From: Elizabeth, To: Farooq
From: Elizabeth, To: Greg

-LOG InitCapStrings
-STATEMENT match (o:organisation) return initcap(o.name);
---- 3
Abfsuni
Cswork
Deswork

-LOG InitCapUTF8String
-STATEMENT match (m:movies) return initcap(m.name);
---- 3
Sóló cón tu párejâ
The 😂😃🧘🏻‍♂️🌍🌦️🍞🚗 movie
Roma

0 comments on commit 956b3e3

Please sign in to comment.