diff --git a/src/core_functions/function_list.cpp b/src/core_functions/function_list.cpp index e62330b50c3..237b2bddd0c 100644 --- a/src/core_functions/function_list.cpp +++ b/src/core_functions/function_list.cpp @@ -198,8 +198,8 @@ static const StaticFunctionDefinition internal_functions[] = { DUCKDB_SCALAR_FUNCTION_SET(ISODayOfWeekFun), DUCKDB_SCALAR_FUNCTION_SET(ISOYearFun), DUCKDB_SCALAR_FUNCTION(JaccardFun), - DUCKDB_SCALAR_FUNCTION(JaroSimilarityFun), - DUCKDB_SCALAR_FUNCTION(JaroWinklerSimilarityFun), + DUCKDB_SCALAR_FUNCTION_SET(JaroSimilarityFun), + DUCKDB_SCALAR_FUNCTION_SET(JaroWinklerSimilarityFun), DUCKDB_SCALAR_FUNCTION_SET(JulianDayFun), DUCKDB_AGGREGATE_FUNCTION(KahanSumFun), DUCKDB_AGGREGATE_FUNCTION(KurtosisFun), diff --git a/src/core_functions/scalar/string/jaro_winkler.cpp b/src/core_functions/scalar/string/jaro_winkler.cpp index 3c54b411f7b..4cebbb3c91f 100644 --- a/src/core_functions/scalar/string/jaro_winkler.cpp +++ b/src/core_functions/scalar/string/jaro_winkler.cpp @@ -4,22 +4,25 @@ namespace duckdb { -static inline double JaroScalarFunction(const string_t &s1, const string_t &s2) { +static inline double JaroScalarFunction(const string_t &s1, const string_t &s2, const double_t &score_cutoff = 0.0) { auto s1_begin = s1.GetData(); auto s2_begin = s2.GetData(); - return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize()); + return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize(), + score_cutoff); } -static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2) { +static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2, + const double_t &score_cutoff = 0.0) { auto s1_begin = s1.GetData(); auto s2_begin = s2.GetData(); return duckdb_jaro_winkler::jaro_winkler_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, - s2_begin + s2.GetSize()); + s2_begin + s2.GetSize(), score_cutoff); } template -static void CachedFunction(Vector &constant, Vector &other, Vector &result, idx_t count) { +static void CachedFunction(Vector &constant, Vector &other, Vector &result, DataChunk &args) { auto val = constant.GetValue(0); + idx_t count = args.size(); if (val.IsNull()) { auto &result_validity = FlatVector::Validity(result); result_validity.SetAllInvalid(count); @@ -28,26 +31,46 @@ static void CachedFunction(Vector &constant, Vector &other, Vector &result, idx_ auto str_val = StringValue::Get(val); auto cached = CACHED_SIMILARITY(str_val); - UnaryExecutor::Execute(other, result, count, [&](const string_t &other_str) { - auto other_str_begin = other_str.GetData(); - return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize()); - }); + + D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3); + if (args.ColumnCount() == 2) { + UnaryExecutor::Execute(other, result, count, [&](const string_t &other_str) { + auto other_str_begin = other_str.GetData(); + return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize()); + }); + } else { + auto score_cutoff = args.data[2]; + BinaryExecutor::Execute( + other, score_cutoff, result, count, [&](const string_t &other_str, const double_t score_cutoff) { + auto other_str_begin = other_str.GetData(); + return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize(), score_cutoff); + }); + } } -template > +template static void TemplatedJaroWinklerFunction(DataChunk &args, Vector &result, SIMILARITY_FUNCTION fun) { bool arg0_constant = args.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR; bool arg1_constant = args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR; if (!(arg0_constant ^ arg1_constant)) { // We can't optimize by caching one of the two strings - BinaryExecutor::Execute(args.data[0], args.data[1], result, args.size(), fun); - return; + D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3); + if (args.ColumnCount() == 2) { + BinaryExecutor::Execute( + args.data[0], args.data[1], result, args.size(), + [&](const string_t &s1, const string_t &s2) { return fun(s1, s2, 0.0); }); + return; + } else { + TernaryExecutor::Execute(args.data[0], args.data[1], args.data[2], + result, args.size(), fun); + return; + } } if (arg0_constant) { - CachedFunction(args.data[0], args.data[1], result, args.size()); + CachedFunction(args.data[0], args.data[1], result, args); } else { - CachedFunction(args.data[1], args.data[0], result, args.size()); + CachedFunction(args.data[1], args.data[0], result, args); } } @@ -60,12 +83,30 @@ static void JaroWinklerFunction(DataChunk &args, ExpressionState &state, Vector JaroWinklerScalarFunction); } -ScalarFunction JaroSimilarityFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction); +ScalarFunctionSet JaroSimilarityFun::GetFunctions() { + ScalarFunctionSet jaro; + + const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); + auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction); + jaro.AddFunction(fun); + + fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE, + JaroFunction); + jaro.AddFunction(fun); + return jaro; } -ScalarFunction JaroWinklerSimilarityFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction); +ScalarFunctionSet JaroWinklerSimilarityFun::GetFunctions() { + ScalarFunctionSet jaroWinkler; + + const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); + auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction); + jaroWinkler.AddFunction(fun); + + fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE, + JaroWinklerFunction); + jaroWinkler.AddFunction(fun); + return jaroWinkler; } } // namespace duckdb diff --git a/src/include/duckdb/core_functions/scalar/string_functions.hpp b/src/include/duckdb/core_functions/scalar/string_functions.hpp index f9a60fcad7a..1bdff3da4a0 100644 --- a/src/include/duckdb/core_functions/scalar/string_functions.hpp +++ b/src/include/duckdb/core_functions/scalar/string_functions.hpp @@ -176,20 +176,20 @@ struct JaccardFun { struct JaroSimilarityFun { static constexpr const char *Name = "jaro_similarity"; - static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Parameters = "str1,str2,score_cutoff"; static constexpr const char *Description = "The Jaro similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; - static constexpr const char *Example = "jaro_similarity('duck','duckdb')"; + static constexpr const char *Example = "jaro_similarity('duck','duckdb', 0.5)"; - static ScalarFunction GetFunction(); + static ScalarFunctionSet GetFunctions(); }; struct JaroWinklerSimilarityFun { static constexpr const char *Name = "jaro_winkler_similarity"; - static constexpr const char *Parameters = "str1,str2"; + static constexpr const char *Parameters = "str1,str2,score_cutoff"; static constexpr const char *Description = "The Jaro-Winkler similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; - static constexpr const char *Example = "jaro_winkler_similarity('duck','duckdb')"; + static constexpr const char *Example = "jaro_winkler_similarity('duck','duckdb', 0.5)"; - static ScalarFunction GetFunction(); + static ScalarFunctionSet GetFunctions(); }; struct LeftFun { diff --git a/test/sql/function/string/test_jaro_winkler.test b/test/sql/function/string/test_jaro_winkler.test index 3b170d92b43..f6e387953f2 100644 --- a/test/sql/function/string/test_jaro_winkler.test +++ b/test/sql/function/string/test_jaro_winkler.test @@ -171,6 +171,27 @@ select jaro_winkler_similarity('PENNSYLVANIA', 'PENNCISYLVNIA') ---- 0.8980186480186481 +# test score cutoff +query T +select jaro_winkler_similarity('CRATE', 'TRACE', 0.7) +---- +0.733333 + +query T +select jaro_winkler_similarity('CRATE', 'TRACE', 0.75) +---- +0.0 + +query T +select jaro_winkler_similarity('000000000000000000000000000000000000000000000000000000000000000', '00000000000000000000000000000000000000000000000000000000000000000', 0.9) +---- +0.9938 + +query T +select jaro_winkler_similarity('000000000000000000000000000000000000000000000000000000000000000', '00000000000000000000000000000000000000000000000000000000000000000', 0.995) +---- +0.0 + # test with table just in case statement ok create table test as select '0000' || range::varchar s from range(10000);