Skip to content

Commit

Permalink
Add ngrams Presto function. (#8209)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #8209

Reviewed By: laithsakka, mbasmanova

Differential Revision: D52410110

fbshipit-source-id: bd91abc84dbf8a7ae965b9ec19278653e53801e3
  • Loading branch information
Amit Dutta authored and facebook-github-bot committed Jan 15, 2024
1 parent 3cb4ec9 commit 5dc23c6
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 0 deletions.
13 changes: 13 additions & 0 deletions velox/docs/functions/presto/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,19 @@ Array Functions

Flattens an ``array(array(T))`` to an ``array(T)`` by concatenating the contained arrays.

.. function:: ngrams(array(T), n) -> array(array(T))

Returns `n-grams <https://en.wikipedia.org/wiki/N-gram>`_ for the array.
Throws if n is zero or negative. If n is greater or equal to input array,
result array contains input array as the only item. ::

SELECT ngrams(ARRAY['foo', 'bar', 'baz', 'foo'], 2); -- [['foo', 'bar'], ['bar', 'baz'], ['baz', 'foo']]
SELECT ngrams(ARRAY['foo', 'bar', 'baz', 'foo'], 3); -- [['foo', 'bar', 'baz'], ['bar', 'baz', 'foo']]
SELECT ngrams(ARRAY['foo', 'bar', 'baz', 'foo'], 4); -- [['foo', 'bar', 'baz', 'foo']]
SELECT ngrams(ARRAY['foo', 'bar', 'baz', 'foo'], 5); -- [['foo', 'bar', 'baz', 'foo']]
SELECT ngrams(ARRAY[1, 2, 3, 4], 2); -- [[1, 2], [2, 3], [3, 4]]
SELECT ngrams(ARRAY["foo", NULL, "bar"], 2); -- [["foo", NULL], [NULL, "bar"]]

.. function:: reduce(array(T), initialState S, inputFunction(S,T,S), outputFunction(S,R)) -> R

Returns a single value reduced from ``array``. ``inputFunction`` will
Expand Down
88 changes: 88 additions & 0 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,94 @@ struct ArrayRemoveNullFunctionString {
}
};

template <typename T>
struct ArrayNGramsFunction {
VELOX_DEFINE_FUNCTION_TYPES(T)

// Fast path for primitives.
template <typename Out, typename In>
void call(Out& out, const In& input, int64_t n) {
VELOX_USER_CHECK_GT(n, 0, "N must be greater than zero.");

if (n > input.size()) {
auto& newItem = out.add_item();
newItem.copy_from(input);
return;
}

for (auto i = 0; i <= input.size() - n; ++i) {
auto& newItem = out.add_item();
for (auto j = 0; j < n; ++j) {
if (input[i + j].has_value()) {
auto& newGranularItem = newItem.add_item();
newGranularItem = input[i + j].value();
} else {
newItem.add_null();
}
}
}
}

// Generic implementation.
void call(
out_type<Array<Array<Generic<T1>>>>& out,
const arg_type<Array<Generic<T1>>>& input,
int64_t n) {
VELOX_USER_CHECK_GT(n, 0, "N must be greater than zero.");

if (n > input.size()) {
auto& newItem = out.add_item();
newItem.copy_from(input);
return;
}

for (auto i = 0; i <= input.size() - n; ++i) {
auto& newItem = out.add_item();
for (auto j = 0; j < n; ++j) {
if (input[i + j].has_value()) {
auto& newGranularItem = newItem.add_item();
newGranularItem.copy_from(input[i + j].value());
} else {
newItem.add_null();
}
}
}
}
};

template <typename T>
struct ArrayNGramsFunctionFunctionString {
VELOX_DEFINE_FUNCTION_TYPES(T);

static constexpr int32_t reuse_strings_from_arg = 0;

// String version that avoids copy of strings.
void call(
out_type<Array<Array<Varchar>>>& out,
const arg_type<Array<Varchar>>& input,
int64_t n) {
VELOX_USER_CHECK_GT(n, 0, "N must be greater than zero.");

if (n > input.size()) {
auto& newItem = out.add_item();
newItem.copy_from(input);
return;
}

for (auto i = 0; i <= input.size() - n; ++i) {
auto& newItem = out.add_item();
for (auto j = 0; j < n; ++j) {
if (input[i + j].has_value()) {
auto& newGranularItem = newItem.add_item();
newGranularItem.setNoCopy(input[i + j].value());
} else {
newItem.add_null();
}
}
}
}
};

/// This class implements the array flatten function.
///
/// DEFINITION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ inline void registerArrayRemoveFunctions(const std::string& prefix) {
{prefix + "array_remove"});
}

template <typename T>
inline void registerArrayNGramsFunctions(const std::string& prefix) {
registerFunction<ArrayNGramsFunction, Array<Array<T>>, Array<T>, int64_t>(
{prefix + "ngrams"});
}

void registerInternalArrayFunctions() {
VELOX_REGISTER_VECTOR_FUNCTION(
udf_$internal$canonicalize, "$internal$canonicalize");
Expand Down Expand Up @@ -238,6 +244,24 @@ void registerArrayFunctions(const std::string& prefix) {
Array<Varchar>,
Array<Varchar>>({prefix + "remove_nulls"});

registerArrayNGramsFunctions<int8_t>(prefix);
registerArrayNGramsFunctions<int16_t>(prefix);
registerArrayNGramsFunctions<int32_t>(prefix);
registerArrayNGramsFunctions<int64_t>(prefix);
registerArrayNGramsFunctions<int128_t>(prefix);
registerArrayNGramsFunctions<float>(prefix);
registerArrayNGramsFunctions<double>(prefix);
registerArrayNGramsFunctions<bool>(prefix);
registerArrayNGramsFunctions<Timestamp>(prefix);
registerArrayNGramsFunctions<Date>(prefix);
registerArrayNGramsFunctions<Varbinary>(prefix);
registerArrayNGramsFunctions<Generic<T1>>(prefix);
registerFunction<
ArrayNGramsFunctionFunctionString,
Array<Array<Varchar>>,
Array<Varchar>,
int64_t>({prefix + "ngrams"});

registerArrayUnionFunctions<int8_t>(prefix);
registerArrayUnionFunctions<int16_t>(prefix);
registerArrayUnionFunctions<int32_t>(prefix);
Expand Down
111 changes: 111 additions & 0 deletions velox/functions/prestosql/tests/ArrayNGramsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <optional>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions {
namespace {
class ArrayNGramsTest : public test::FunctionBaseTest {
protected:
template <typename T>
void testNgram(
const std::vector<std::optional<T>>& inputArray,
int64_t n,
const std::vector<std::optional<std::vector<std::optional<T>>>>&
expectedOutput) {
std::vector<std::optional<std::vector<std::optional<T>>>> inputVec(
{inputArray});
auto input = makeNullableArrayVector<T>(inputVec);
auto result =
evaluate(fmt::format("ngrams(c0, {})", n), makeRowVector({input}));

auto expected = makeNullableNestedArrayVector<T>({expectedOutput});
assertEqualVectors(expected, result);
}
};

TEST_F(ArrayNGramsTest, integers) {
testNgram<int64_t>({1, 2, 3, 4}, 1, {{{1}}, {{2}}, {{3}}, {{4}}});
testNgram<int64_t>({1, 2, 3, 4}, 2, {{{1, 2}}, {{2, 3}}, {{3, 4}}});
testNgram<int64_t>({1, 2, 3, 4}, 3, {{{1, 2, 3}}, {{2, 3, 4}}});
testNgram<int64_t>({1, 2, 3, 4}, 4, {{{1, 2, 3, 4}}});
testNgram<int64_t>({1, 2, 3, 4}, 5, {{{1, 2, 3, 4}}});
testNgram<int64_t>(
{1, 2, 3, 4}, std::numeric_limits<int32_t>::max(), {{{1, 2, 3, 4}}});
testNgram<int64_t>(
{1, 2, 3, 4},
std::numeric_limits<int32_t>::max() + (int64_t)(1000),
{{{1, 2, 3, 4}}});
testNgram<int64_t>({}, 1, {{{}}});
testNgram<int64_t>({}, 10, {{{}}});
}

TEST_F(ArrayNGramsTest, invalidN) {
auto input = makeArrayVector<int64_t>({{1, 2, 3, 4}});
VELOX_ASSERT_THROW(
evaluate("ngrams(c0, 0)", makeRowVector({input})),
"(0 vs. 0) N must be greater than zero");
VELOX_ASSERT_THROW(
evaluate("ngrams(c0, -5)", makeRowVector({input})),
"(-5 vs. 0) N must be greater than zero");
input = makeArrayVector<int64_t>({{}});
VELOX_ASSERT_THROW(
evaluate("ngrams(c0, 0)", makeRowVector({input})),
"(0 vs. 0) N must be greater than zero");
}

TEST_F(ArrayNGramsTest, strings) {
testNgram<std::string>(
{"foo", "bar", "baz", "this is a very long sentence"},
1,
{{{"foo"}}, {{"bar"}}, {{"baz"}}, {{"this is a very long sentence"}}});
testNgram<std::string>(
{"foo", "bar", "baz", "this is a very long sentence"},
2,
{{{"foo", "bar"}},
{{"bar", "baz"}},
{{"baz", "this is a very long sentence"}}});
testNgram<std::string>(
{"foo", "bar", "baz", "this is a very long sentence"},
3,
{{{"foo", "bar", "baz"}},
{{"bar", "baz", "this is a very long sentence"}}});
testNgram<std::string>(
{"foo", "bar", "baz", "this is a very long sentence"},
4,
{{{"foo", "bar", "baz", "this is a very long sentence"}}});
testNgram<std::string>(
{"foo", "bar", "baz", "this is a very long sentence"},
5,
{{{"foo", "bar", "baz", "this is a very long sentence"}}});
}

TEST_F(ArrayNGramsTest, nulls) {
testNgram<std::string>(
{"foo", std::nullopt, "bar"},
2,
{{{"foo", std::nullopt}}, {{std::nullopt, "bar"}}});
testNgram<std::string>(
{std::nullopt, std::nullopt, std::nullopt},
2,
{{{std::nullopt, std::nullopt}}, {{std::nullopt, std::nullopt}}});
}
} // namespace

} // namespace facebook::velox::functions
1 change: 1 addition & 0 deletions velox/functions/prestosql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_executable(
ArrayIntersectTest.cpp
ArrayMaxTest.cpp
ArrayMinTest.cpp
ArrayNGramsTest.cpp
ArrayNoneMatchTest.cpp
ArrayNormalizeTest.cpp
ArrayPositionTest.cpp
Expand Down

0 comments on commit 5dc23c6

Please sign in to comment.