Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose threshold argument of Jaro-Winkler similarity #12079

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/core_functions/function_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ static const StaticFunctionDefinition internal_functions[] = {
DUCKDB_SCALAR_FUNCTION_SET(ISODayOfWeekFun),
DUCKDB_SCALAR_FUNCTION_SET(ISOYearFun),
DUCKDB_SCALAR_FUNCTION(JaccardFun),
DUCKDB_SCALAR_FUNCTION(JaroSimilarityFun),
DUCKDB_SCALAR_FUNCTION(JaroWinklerSimilarityFun),
DUCKDB_SCALAR_FUNCTION_SET(JaroSimilarityFun),
DUCKDB_SCALAR_FUNCTION_SET(JaroWinklerSimilarityFun),
DUCKDB_SCALAR_FUNCTION_SET(JulianDayFun),
DUCKDB_AGGREGATE_FUNCTION(KahanSumFun),
DUCKDB_AGGREGATE_FUNCTION(KurtosisFun),
Expand Down
77 changes: 59 additions & 18 deletions src/core_functions/scalar/string/jaro_winkler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@

namespace duckdb {

static inline double JaroScalarFunction(const string_t &s1, const string_t &s2) {
static inline double JaroScalarFunction(const string_t &s1, const string_t &s2, const double_t &score_cutoff = 0.0) {
auto s1_begin = s1.GetData();
auto s2_begin = s2.GetData();
return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize());
return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize(),
score_cutoff);
}

static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2) {
static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2,
const double_t &score_cutoff = 0.0) {
auto s1_begin = s1.GetData();
auto s2_begin = s2.GetData();
return duckdb_jaro_winkler::jaro_winkler_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin,
s2_begin + s2.GetSize());
s2_begin + s2.GetSize(), score_cutoff);
}

template <class CACHED_SIMILARITY>
static void CachedFunction(Vector &constant, Vector &other, Vector &result, idx_t count) {
static void CachedFunction(Vector &constant, Vector &other, Vector &result, DataChunk &args) {
auto val = constant.GetValue(0);
idx_t count = args.size();
if (val.IsNull()) {
auto &result_validity = FlatVector::Validity(result);
result_validity.SetAllInvalid(count);
Expand All @@ -28,26 +31,46 @@ static void CachedFunction(Vector &constant, Vector &other, Vector &result, idx_

auto str_val = StringValue::Get(val);
auto cached = CACHED_SIMILARITY(str_val);
UnaryExecutor::Execute<string_t, double>(other, result, count, [&](const string_t &other_str) {
auto other_str_begin = other_str.GetData();
return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize());
});

D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3);
if (args.ColumnCount() == 2) {
UnaryExecutor::Execute<string_t, double>(other, result, count, [&](const string_t &other_str) {
auto other_str_begin = other_str.GetData();
return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize());
});
} else {
auto score_cutoff = args.data[2];
BinaryExecutor::Execute<string_t, double_t, double>(
other, score_cutoff, result, count, [&](const string_t &other_str, const double_t score_cutoff) {
auto other_str_begin = other_str.GetData();
return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize(), score_cutoff);
});
}
}

template <class CACHED_SIMILARITY, class SIMILARITY_FUNCTION = std::function<double(string_t, string_t)>>
template <class CACHED_SIMILARITY, class SIMILARITY_FUNCTION>
static void TemplatedJaroWinklerFunction(DataChunk &args, Vector &result, SIMILARITY_FUNCTION fun) {
bool arg0_constant = args.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR;
bool arg1_constant = args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR;
if (!(arg0_constant ^ arg1_constant)) {
// We can't optimize by caching one of the two strings
BinaryExecutor::Execute<string_t, string_t, double>(args.data[0], args.data[1], result, args.size(), fun);
return;
D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3);
if (args.ColumnCount() == 2) {
BinaryExecutor::Execute<string_t, string_t, double>(
args.data[0], args.data[1], result, args.size(),
[&](const string_t &s1, const string_t &s2) { return fun(s1, s2, 0.0); });
return;
} else {
TernaryExecutor::Execute<string_t, string_t, double_t, double>(args.data[0], args.data[1], args.data[2],
result, args.size(), fun);
return;
}
}

if (arg0_constant) {
CachedFunction<CACHED_SIMILARITY>(args.data[0], args.data[1], result, args.size());
CachedFunction<CACHED_SIMILARITY>(args.data[0], args.data[1], result, args);
} else {
CachedFunction<CACHED_SIMILARITY>(args.data[1], args.data[0], result, args.size());
CachedFunction<CACHED_SIMILARITY>(args.data[1], args.data[0], result, args);
}
}

Expand All @@ -60,12 +83,30 @@ static void JaroWinklerFunction(DataChunk &args, ExpressionState &state, Vector
JaroWinklerScalarFunction);
}

ScalarFunction JaroSimilarityFun::GetFunction() {
return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction);
ScalarFunctionSet JaroSimilarityFun::GetFunctions() {
ScalarFunctionSet jaro;

const auto list_type = LogicalType::LIST(LogicalType::VARCHAR);
auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction);
jaro.AddFunction(fun);

fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE,
JaroFunction);
jaro.AddFunction(fun);
return jaro;
}

ScalarFunction JaroWinklerSimilarityFun::GetFunction() {
return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction);
ScalarFunctionSet JaroWinklerSimilarityFun::GetFunctions() {
ScalarFunctionSet jaroWinkler;

const auto list_type = LogicalType::LIST(LogicalType::VARCHAR);
auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction);
jaroWinkler.AddFunction(fun);

fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE,
JaroWinklerFunction);
jaroWinkler.AddFunction(fun);
return jaroWinkler;
}

} // namespace duckdb
12 changes: 6 additions & 6 deletions src/include/duckdb/core_functions/scalar/string_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,20 @@ struct JaccardFun {

struct JaroSimilarityFun {
static constexpr const char *Name = "jaro_similarity";
static constexpr const char *Parameters = "str1,str2";
static constexpr const char *Parameters = "str1,str2,score_cutoff";
static constexpr const char *Description = "The Jaro similarity between two strings. Different case is considered different. Returns a number between 0 and 1";
static constexpr const char *Example = "jaro_similarity('duck','duckdb')";
static constexpr const char *Example = "jaro_similarity('duck','duckdb', 0.5)";

static ScalarFunction GetFunction();
static ScalarFunctionSet GetFunctions();
};

struct JaroWinklerSimilarityFun {
static constexpr const char *Name = "jaro_winkler_similarity";
static constexpr const char *Parameters = "str1,str2";
static constexpr const char *Parameters = "str1,str2,score_cutoff";
static constexpr const char *Description = "The Jaro-Winkler similarity between two strings. Different case is considered different. Returns a number between 0 and 1";
static constexpr const char *Example = "jaro_winkler_similarity('duck','duckdb')";
static constexpr const char *Example = "jaro_winkler_similarity('duck','duckdb', 0.5)";

static ScalarFunction GetFunction();
static ScalarFunctionSet GetFunctions();
};

struct LeftFun {
Expand Down
21 changes: 21 additions & 0 deletions test/sql/function/string/test_jaro_winkler.test
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,27 @@ select jaro_winkler_similarity('PENNSYLVANIA', 'PENNCISYLVNIA')
----
0.8980186480186481

# test score cutoff
query T
select jaro_winkler_similarity('CRATE', 'TRACE', 0.7)
----
0.733333

query T
select jaro_winkler_similarity('CRATE', 'TRACE', 0.75)
----
0.0

query T
select jaro_winkler_similarity('000000000000000000000000000000000000000000000000000000000000000', '00000000000000000000000000000000000000000000000000000000000000000', 0.9)
----
0.9938

query T
select jaro_winkler_similarity('000000000000000000000000000000000000000000000000000000000000000', '00000000000000000000000000000000000000000000000000000000000000000', 0.995)
----
0.0

# test with table just in case
statement ok
create table test as select '0000' || range::varchar s from range(10000);
Expand Down