Skip to content

Commit

Permalink
Re-pick concat_ws, PR 6292
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE committed Nov 3, 2023
1 parent 4a5b77e commit c5eec03
Show file tree
Hide file tree
Showing 5 changed files with 475 additions and 1 deletion.
6 changes: 5 additions & 1 deletion velox/expression/tests/SparkExpressionFuzzerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ int main(int argc, char** argv) {
"chr",
"replace",
"might_contain",
"unix_timestamp"};
"unix_timestamp",
// Skip concat_ws as it triggers a test failure due to an incorrect
// expression generation from fuzzer:
// https://github.com/facebookincubator/velox/issues/6590
"concat_ws"};
return FuzzerRunner::run(
FLAGS_only, FLAGS_seed, skipFunctions, FLAGS_special_forms);
}
2 changes: 2 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ void registerFunctions(const std::string& prefix) {
exec::registerStatefulVectorFunction(
prefix + "length", lengthSignatures(), makeLength);
VELOX_REGISTER_VECTOR_FUNCTION(udf_str_to_map, prefix + "str_to_map");
exec::registerStatefulVectorFunction(
prefix + "concat_ws", concatWsSignatures(), makeConcatWs);

registerFunction<Md5Function, Varchar, Varbinary>({prefix + "md5"});
registerFunction<Sha1HexStringFunction, Varchar, Varbinary>(
Expand Down
262 changes: 262 additions & 0 deletions velox/functions/sparksql/String.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,230 @@ class Length : public exec::VectorFunction {
}
};

void concatWsVariableParameters(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
exec::EvalCtx& context,
const std::string& connector,
FlatVector<StringView>& flatResult) {
std::vector<column_index_t> argMapping;
std::vector<std::string> constantStrings;
std::vector<StringView> constantStringViews;
auto numArgs = args.size();

// Save constant values to constantStrings_.
// Identify and combine consecutive constant inputs.
argMapping.reserve(numArgs - 1);
constantStrings.reserve(numArgs - 1);
for (auto i = 1; i < numArgs; ++i) {
argMapping.push_back(i);
if (args[i] && args[i]->as<ConstantVector<StringView>>() &&
!args[i]->as<ConstantVector<StringView>>()->isNullAt(0)) {
std::string value =
args[i]->as<ConstantVector<StringView>>()->valueAt(0).str();
column_index_t j = i + 1;
for (; j < args.size(); ++j) {
if (!args[j] || !args[j]->as<ConstantVector<StringView>>() ||
args[j]->as<ConstantVector<StringView>>()->isNullAt(0)) {
break;
}

value += connector +
args[j]->as<ConstantVector<StringView>>()->valueAt(0).str();
}
constantStrings.push_back(std::string(value.data(), value.size()));
i = j - 1;
} else {
constantStrings.push_back(std::string());
}
}

// Create StringViews for constant strings.
constantStringViews.reserve(numArgs - 1);
for (const auto& constantString : constantStrings) {
constantStringViews.push_back(
StringView(constantString.data(), constantString.size()));
}

auto numCols = argMapping.size();
std::vector<exec::LocalDecodedVector> decodedArgs;
decodedArgs.reserve(numCols);

for (auto i = 0; i < numCols; ++i) {
auto index = argMapping[i];
if (constantStringViews[i].empty()) {
decodedArgs.emplace_back(context, *args[index], rows);
} else {
// Do not decode constant inputs.
decodedArgs.emplace_back(context);
}
}

size_t totalResultBytes = 0;
rows.applyToSelected([&](auto row) {
auto isFirst = true;
for (int i = 0; i < numCols; i++) {
auto value = constantStringViews[i].empty()
? decodedArgs[i]->valueAt<StringView>(row)
: constantStringViews[i];
if (!value.empty()) {
if (isFirst) {
isFirst = false;
} else {
totalResultBytes += connector.size();
}
totalResultBytes += value.size();
}
}
});

// Allocate a string buffer.
auto rawBuffer = flatResult.getRawStringBufferWithSpace(totalResultBytes);
size_t offset = 0;
rows.applyToSelected([&](int row) {
const char* start = rawBuffer + offset;
size_t combinedSize = 0;
auto isFirst = true;
for (int i = 0; i < numCols; i++) {
StringView value;
if (constantStringViews[i].empty()) {
value = decodedArgs[i]->valueAt<StringView>(row);
} else {
value = constantStringViews[i];
}
auto size = value.size();
if (size > 0) {
if (isFirst) {
isFirst = false;
} else {
memcpy(rawBuffer + offset, connector.data(), connector.size());
offset += connector.size();
combinedSize += connector.size();
}
memcpy(rawBuffer + offset, value.data(), size);
combinedSize += size;
offset += size;
}
}
flatResult.setNoCopy(row, StringView(start, combinedSize));
});
}

void concatWsArray(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
exec::EvalCtx& context,
const std::string& connector,
FlatVector<StringView>& flatResult) {
exec::LocalDecodedVector arrayHolder(context, *args[1], rows);
auto& arrayDecoded = *arrayHolder.get();
auto baseArray = arrayDecoded.base()->as<ArrayVector>();
auto rawSizes = baseArray->rawSizes();
auto rawOffsets = baseArray->rawOffsets();
auto indices = arrayDecoded.indices();

auto elements = arrayHolder.get()->base()->as<ArrayVector>()->elements();
exec::LocalSelectivityVector nestedRows(context, elements->size());
nestedRows.get()->setAll();
exec::LocalDecodedVector elementsHolder(
context, *elements, *nestedRows.get());
auto& elementsDecoded = *elementsHolder.get();
auto elementsBase = elementsDecoded.base();

size_t totalResultBytes = 0;
rows.applyToSelected([&](auto row) {
auto size = rawSizes[indices[row]];
auto offset = rawOffsets[indices[row]];

auto isFirst = true;
for (auto i = 0; i < size; ++i) {
if (!elementsBase->isNullAt(offset + i)) {
auto element = elementsDecoded.valueAt<StringView>(offset + i);
if (!element.empty()) {
if (isFirst) {
isFirst = false;
} else {
totalResultBytes += connector.size();
}
totalResultBytes += element.size();
}
}
}
});

// Allocate a string buffer.
auto rawBuffer = flatResult.getRawStringBufferWithSpace(totalResultBytes);
size_t bufferOffset = 0;
rows.applyToSelected([&](int row) {
auto size = rawSizes[indices[row]];
auto offset = rawOffsets[indices[row]];

const char* start = rawBuffer + bufferOffset;
size_t combinedSize = 0;
auto isFirst = true;
for (auto i = 0; i < size; ++i) {
if (!elementsBase->isNullAt(offset + i)) {
auto element = elementsDecoded.valueAt<StringView>(offset + i);
if (!element.empty()) {
if (isFirst) {
isFirst = false;
} else {
memcpy(
rawBuffer + bufferOffset, connector.data(), connector.size());
bufferOffset += connector.size();
combinedSize += connector.size();
}
memcpy(rawBuffer + bufferOffset, element.data(), element.size());
bufferOffset += element.size();
combinedSize += element.size();
}
}
flatResult.setNoCopy(row, StringView(start, combinedSize));
}
});
}

class ConcatWs : public exec::VectorFunction {
public:
explicit ConcatWs(const std::string& connector) : connector_(connector) {}

bool isDefaultNullBehavior() const override {
return false;
}

void apply(
const SelectivityVector& selected,
std::vector<VectorPtr>& args,
const TypePtr& /* outputType */,
exec::EvalCtx& context,
VectorPtr& result) const override {
context.ensureWritable(selected, VARCHAR(), result);
auto flatResult = result->asFlatVector<StringView>();
auto numArgs = args.size();
if (numArgs == 1) {
selected.applyToSelected(
[&](int row) { flatResult->setNoCopy(row, StringView("")); });
return;
}

if (args[0]->isNullAt(0)) {
selected.applyToSelected([&](int row) { result->setNull(row, true); });
return;
}

auto arrayArgs = args[1]->typeKind() == TypeKind::ARRAY;
if (arrayArgs) {
concatWsArray(selected, args, context, connector_, *flatResult);
} else {
concatWsVariableParameters(
selected, args, context, connector_, *flatResult);
}
}

private:
const std::string connector_;
};

} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>> instrSignatures() {
Expand Down Expand Up @@ -142,6 +366,44 @@ std::shared_ptr<exec::VectorFunction> makeLength(
return kLengthFunction;
}

std::vector<std::shared_ptr<exec::FunctionSignature>> concatWsSignatures() {
return {
// varchar, varchar,... -> varchar.
exec::FunctionSignatureBuilder()
.returnType("varchar")
.constantArgumentType("varchar")
.argumentType("varchar")
.variableArity()
.build(),
// varchar, array(varchar) -> varchar.
exec::FunctionSignatureBuilder()
.returnType("varchar")
.constantArgumentType("varchar")
.argumentType("array(varchar)")
.build(),
};
}

std::shared_ptr<exec::VectorFunction> makeConcatWs(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/) {
auto numArgs = inputArgs.size();
VELOX_USER_CHECK(
numArgs >= 1,
"concat_ws requires one arguments at least, but got {}.",
numArgs);

BaseVector* constantPattern = inputArgs[0].constantValue.get();
VELOX_USER_CHECK(
nullptr != constantPattern,
"concat_ws requires constant connector arguments.");

auto connector =
constantPattern->as<ConstantVector<StringView>>()->valueAt(0).str();
return std::make_shared<ConcatWs>(connector);
}

void encodeDigestToBase16(uint8_t* output, int digestSize) {
static unsigned char const kHexCodes[] = "0123456789abcdef";
for (int i = digestSize - 1; i >= 0; --i) {
Expand Down
7 changes: 7 additions & 0 deletions velox/functions/sparksql/String.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ std::shared_ptr<exec::VectorFunction> makeLength(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

std::vector<std::shared_ptr<exec::FunctionSignature>> concatWsSignatures();

std::shared_ptr<exec::VectorFunction> makeConcatWs(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& config);

/// Expands each char of the digest data to two chars,
/// representing the hex value of each digest char, in order.
/// Note: digestSize must be one-half of outputSize.
Expand Down
Loading

0 comments on commit c5eec03

Please sign in to comment.