Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement initcap/concat functions #3161

Merged
merged 2 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions src/function/base_lower_upper_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,27 @@ uint32_t BaseLowerUpperFunction::getResultLen(char* inputStr, uint32_t inputLen,
return outputLength;
}

uint64_t BaseLowerUpperFunction::convertCharCase(
char* result, const char* input, int32_t charPos, bool toUpper) {
if (input[charPos] & 0x80) {
int size = 0u, newSize = 0u;
auto codepoint = utf8proc_codepoint(input + charPos, size);
KU_ASSERT(codepoint >= 0); // Validity ensured by getResultLen.
int convertedCodepoint =
toUpper ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint);
utf8proc_codepoint_to_utf8(convertedCodepoint, newSize, result);
return size;
} else {
*result = toUpper ? toupper(input[charPos]) : tolower(input[charPos]);
return 1;
}
}

void BaseLowerUpperFunction::convertCase(char* result, uint32_t len, char* input, bool toUpper) {
for (auto i = 0u; i < len;) {
if (input[i] & 0x80) {
int size = 0, newSize = 0;
int codepoint = utf8proc_codepoint(input + i, size);
KU_ASSERT(codepoint >= 0); // Validity ensured by getResultLen.
int convertedCodepoint =
toUpper ? utf8proc_toupper(codepoint) : utf8proc_tolower(codepoint);
utf8proc_codepoint_to_utf8(convertedCodepoint, newSize, result);
result += newSize;
i += size;
} else {
*result = toUpper ? toupper(input[i]) : tolower(input[i]);
i++;
result++;
}
auto charWidth = convertCharCase(result, input, i, toUpper);
i += charWidth;
result += charWidth;
}
}

Expand Down
1 change: 1 addition & 0 deletions src/function/function_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ FunctionCollection* FunctionCollection::getFunctions() {
SCALAR_FUNCTION(RegexpFullMatchFunction), SCALAR_FUNCTION(RegexpMatchesFunction),
SCALAR_FUNCTION(RegexpReplaceFunction), SCALAR_FUNCTION(RegexpExtractFunction),
SCALAR_FUNCTION(RegexpExtractAllFunction), SCALAR_FUNCTION(LevenshteinFunction),
SCALAR_FUNCTION(InitcapFunction),

// Array Functions
SCALAR_FUNCTION(ArrayValueFunction), SCALAR_FUNCTION(ArrayCrossProductFunction),
Expand Down
5 changes: 2 additions & 3 deletions src/function/vector_arithmetic_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "function/list/functions/list_concat_function.h"
#include "function/list/vector_list_functions.h"
#include "function/scalar_function.h"
#include "function/string/functions/concat_function.h"
#include "function/string/vector_string_functions.h"

using namespace kuzu::common;

Expand Down Expand Up @@ -155,8 +155,7 @@ function_set AddFunction::getFunctionSet() {
// string + string -> string
result.push_back(std::make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::STRING},
LogicalTypeID::STRING,
ScalarFunction::BinaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, Concat>));
LogicalTypeID::STRING, ConcatFunction::execFunc));
// interval + interval → interval
result.push_back(getBinaryFunction<Add, interval_t, interval_t>(
name, LogicalTypeID::INTERVAL, LogicalTypeID::INTERVAL));
Expand Down
69 changes: 47 additions & 22 deletions src/function/vector_string_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "function/string/vector_string_functions.h"

#include "function/string/functions/array_extract_function.h"
#include "function/string/functions/concat_function.h"
#include "function/string/functions/contains_function.h"
#include "function/string/functions/ends_with_function.h"
#include "function/string/functions/initcap_function.h"
#include "function/string/functions/left_operation.h"
#include "function/string/functions/levenshtein_function.h"
#include "function/string/functions/lpad_function.h"
Expand Down Expand Up @@ -52,22 +52,6 @@
}
}

void Concat::concat(const char* left, uint32_t leftLen, const char* right, uint32_t rightLen,
ku_string_t& result, ValueVector& resultValueVector) {
auto len = leftLen + rightLen;
if (len <= ku_string_t::SHORT_STR_LENGTH /* concat result is short */) {
memcpy(result.prefix, left, leftLen);
memcpy(result.prefix + leftLen, right, rightLen);
} else {
StringVector::reserveString(&resultValueVector, result, len);
auto buffer = reinterpret_cast<char*>(result.overflowPtr);
memcpy(buffer, left, leftLen);
memcpy(buffer + leftLen, right, rightLen);
memcpy(result.prefix, buffer, ku_string_t::PREFIX_LENGTH);
}
result.len = len;
}

void Repeat::operation(
ku_string_t& left, int64_t& right, ku_string_t& result, ValueVector& resultValueVector) {
result.len = left.len * right;
Expand Down Expand Up @@ -121,13 +105,45 @@
return functionSet;
}

void ConcatFunction::execFunc(const std::vector<std::shared_ptr<ValueVector>>& parameters,
ValueVector& result, void* /*dataPtr*/) {
result.resetAuxiliaryBuffer();
for (auto selectedPos = 0u; selectedPos < result.state->selVector->selectedSize;
++selectedPos) {
auto pos = result.state->selVector->selectedPositions[selectedPos];
auto strLen = 0u;
for (auto i = 0u; i < parameters.size(); i++) {
const auto& parameter = parameters[i];
auto paramPos = parameter->state->isFlat() ?
parameter->state->selVector->selectedPositions[0] :
pos;
strLen += parameter->getValue<ku_string_t>(paramPos).len;
}
auto& resultStr = result.getValue<ku_string_t>(pos);
StringVector::reserveString(&result, resultStr, strLen);
auto dstData = strLen <= ku_string_t::SHORT_STR_LENGTH ?
resultStr.prefix :
reinterpret_cast<uint8_t*>(resultStr.overflowPtr);
for (auto i = 0u; i < parameters.size(); i++) {
const auto& parameter = parameters[i];
auto paramPos = parameter->state->isFlat() ?
parameter->state->selVector->selectedPositions[0] :
pos;
auto srcStr = parameter->getValue<ku_string_t>(paramPos);
memcpy(dstData, srcStr.getData(), srcStr.len);
dstData += srcStr.len;
}
if (strLen > ku_string_t::SHORT_STR_LENGTH) {
memcpy(resultStr.prefix, resultStr.getData(), ku_string_t::PREFIX_LENGTH);
}
}
}

function_set ConcatFunction::getFunctionSet() {
function_set functionSet;
functionSet.emplace_back(make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING, LogicalTypeID::STRING},
LogicalTypeID::STRING,
ScalarFunction::BinaryStringExecFunction<ku_string_t, ku_string_t, ku_string_t, Concat>,
false /* isVarLength */));
functionSet.emplace_back(
make_unique<ScalarFunction>(name, std::vector<LogicalTypeID>{LogicalTypeID::STRING},
LogicalTypeID::STRING, execFunc, true /* isVarLength */));
return functionSet;
}

Expand Down Expand Up @@ -317,5 +333,14 @@
return functionSet;
}

function_set InitcapFunction::getFunctionSet() {
function_set functionSet;
functionSet.emplace_back(make_unique<ScalarFunction>(name,
std::vector<LogicalTypeID>{LogicalTypeID::STRING}, LogicalTypeID::STRING,
ScalarFunction::UnaryStringExecFunction<ku_string_t, ku_string_t, Initcap>, nullptr,
nullptr, false /* isVarLength */));
return functionSet;
}

Check warning on line 343 in src/function/vector_string_functions.cpp

View check run for this annotation

Codecov / codecov/patch

src/function/vector_string_functions.cpp#L343

Added line #L343 was not covered by tests

} // namespace function
} // namespace kuzu
4 changes: 4 additions & 0 deletions src/include/common/types/ku_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ struct KUZU_API ku_string_t {
return isShortString(len) ? prefix : reinterpret_cast<uint8_t*>(overflowPtr);
}

uint8_t* getDataUnsafe() {
return isShortString(len) ? prefix : reinterpret_cast<uint8_t*>(overflowPtr);
}

// These functions do *NOT* allocate/resize the overflow buffer, it only copies the content and
// set the length.
void set(const std::string& value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ struct BaseLowerUpperFunction {
KUZU_API static void operation(common::ku_string_t& input, common::ku_string_t& result,
common::ValueVector& resultValueVector, bool isUpper);

static uint64_t convertCharCase(char* result, const char* input, int32_t charPos, bool toUpper);

private:
static uint32_t getResultLen(char* inputStr, uint32_t inputLen, bool isUpper);
static void convertCase(char* result, uint32_t len, char* input, bool toUpper);
Expand Down
35 changes: 0 additions & 35 deletions src/include/function/string/functions/concat_function.h

This file was deleted.

20 changes: 20 additions & 0 deletions src/include/function/string/functions/initcap_function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#pragma once

#include "base_lower_upper_function.h"
#include "lower_function.h"

namespace kuzu {
namespace function {

struct Initcap {
public:
acquamarin marked this conversation as resolved.
Show resolved Hide resolved
static void operation(common::ku_string_t& operand, common::ku_string_t& result,
common::ValueVector& resultVector) {
Lower::operation(operand, result, resultVector);
BaseLowerUpperFunction::convertCharCase(reinterpret_cast<char*>(result.getDataUnsafe()),
reinterpret_cast<const char*>(result.getData()), 0 /* charPos */, true /* toUpper */);
}
};

} // namespace function
} // namespace kuzu
9 changes: 9 additions & 0 deletions src/include/function/string/vector_string_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ struct ArrayExtractFunction {
struct ConcatFunction : public VectorStringFunction {
static constexpr const char* name = "CONCAT";

static void execFunc(const std::vector<std::shared_ptr<common::ValueVector>>& parameters,
common::ValueVector& result, void* /*dataPtr*/);

static function_set getFunctionSet();
};

Expand Down Expand Up @@ -175,5 +178,11 @@ struct LevenshteinFunction : public VectorStringFunction {
static function_set getFunctionSet();
};

struct InitcapFunction : public VectorStringFunction {
static constexpr const char* name = "INITCAP";

static function_set getFunctionSet();
};

} // namespace function
} // namespace kuzu
20 changes: 17 additions & 3 deletions test/main/udf_test.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "function/string/functions/concat_function.h"
#include "main_test_helper/main_test_helper.h"

using namespace kuzu::common;
Expand Down Expand Up @@ -294,14 +293,29 @@ TEST_F(ApiTest, vectorizedBinaryAddDate) {
sortAndCheckTestResults(actualResult, expectedResult);
}

static void concat(const ku_string_t& left, const ku_string_t& right, ku_string_t& result,
ValueVector& resultValueVector) {
result.len = left.len + right.len;
if (result.len <= ku_string_t::SHORT_STR_LENGTH /* concat result is short */) {
memcpy(result.prefix, left.getData(), left.len);
memcpy(result.prefix + left.len, right.getData(), right.len);
} else {
StringVector::reserveString(&resultValueVector, result, result.len);
auto buffer = reinterpret_cast<char*>(result.overflowPtr);
memcpy(buffer, left.getData(), left.len);
memcpy(buffer + left.len, right.getData(), right.len);
memcpy(result.prefix, buffer, ku_string_t::PREFIX_LENGTH);
}
}

struct ConditionalConcat {
static inline void operation(
ku_string_t& a, bool& b, ku_string_t& c, ku_string_t& result, ValueVector& resultVector) {
// Concat a,c if b is true, otherwise concat c,a.
if (b) {
function::Concat::operation(a, c, result, resultVector);
concat(a, c, result, resultVector);
} else {
function::Concat::operation(c, a, result, resultVector);
concat(c, a, result, resultVector);
}
}
};
Expand Down
42 changes: 42 additions & 0 deletions test/test_files/tinysnb/function/string.test
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,45 @@ hellolongstringtest
-STATEMENT return levenshtein('horse', 'ros');
---- 1
3

-LOG ConcatMultipleStrings
-STATEMENT match (p:person) return concat("name: ", p.fName, ' ,age: ', "5");
---- 8
name: Alice ,age: 5
name: Bob ,age: 5
name: Carol ,age: 5
name: Dan ,age: 5
name: Elizabeth ,age: 5
name: Farooq ,age: 5
name: Greg ,age: 5
name: Hubert Blaine Wolfeschlegelsteinhausenbergerdorff ,age: 5
-STATEMENT match (p:person)-[:knows]->(p1:person) return concat("From: ", p.fName, ', To: ', p1.fName);
---- 14
From: Alice, To: Bob
From: Alice, To: Carol
From: Alice, To: Dan
From: Bob, To: Alice
From: Bob, To: Carol
From: Bob, To: Dan
From: Carol, To: Alice
From: Carol, To: Bob
From: Carol, To: Dan
From: Dan, To: Alice
From: Dan, To: Bob
From: Dan, To: Carol
From: Elizabeth, To: Farooq
From: Elizabeth, To: Greg

-LOG InitCapStrings
-STATEMENT match (o:organisation) return initcap(o.name);
---- 3
Abfsuni
Cswork
Deswork

-LOG InitCapUTF8String
-STATEMENT match (m:movies) return initcap(m.name);
---- 3
Sóló cón tu párejâ
The 😂😃🧘🏻‍♂️🌍🌦️🍞🚗 movie
Roma
Loading