diff --git a/splink/case_statements.py b/splink/case_statements.py index a85080c12e..fe7342c56e 100644 --- a/splink/case_statements.py +++ b/splink/case_statements.py @@ -17,14 +17,47 @@ def _check_jaro_registered(spark): return True warnings.warn( - "Custom string comparison functions such as jaro_winkler_sim are available in" - " Spark Or you did not pass 'spark' (the SparkSession) into 'Model' You can" - " import these functions using the scala-udf-similarity-0.0.7.jar provided with" - " Splink" + "\n\nCustom string comparison functions such as jaro_winkler_sim are not available in Spark\n" + "Or you did not pass 'spark' (the SparkSession) into 'Model' \n" + "You can import these functions using the scala-udf-similarity-0.0.9.jar provided with" + " Splink.\n" + _get_spark_jars_string() ) return False +def _get_spark_jars_string(): + """ + Outputs the exact string needed in the sparkSession config variable `spark.jars` + In order to use the custom functions in the spark-udf-similarity-0.0.9.jar + + """ + + import splink + + path = splink.__file__[0:-11] + "jars/scala-udf-similarity-0.0.9.jar" + + message = ( + "You will need to add it by correctly configuring your spark config\n" + "For example in Spark 2.4.5\n" + "\n" + "from pyspark.sql import SparkSession, types\n" + "from pyspark.context import SparkConf, SparkContext\n" + f"conf.set('spark.driver.extraClassPath', '{path}'') # Not needed in spark 3\n" + f"conf.set('spark.jars', '{path}'')\n" + "spark.udf.registerJavaFunction('jaro_winkler_sim','uk.gov.moj.dash.linkage.JaroWinklerSimilarity',types.DoubleType())\n" + "sc = SparkContext.getOrCreate(conf=conf)\n" + "spark = SparkSession(sc)\n" + "\n" + "Alternatively, for Jaro Winkler, you can register a less efficient" + " Python implementation using\n" + "\n" + "from splink.jar_fallback import jc_sim_py\n" + "spark.udf.register('jaro_winkler_sim', jc_sim_py)\n" + ) + + return message + + def _find_last_end_position(case_statement): # Since we're only interested in the position, case shouldn't matter. stmt = case_statement.lower() case_statement = case_statement.lower() diff --git a/splink/jar_fallback.py b/splink/jar_fallback.py new file mode 100644 index 0000000000..0957da193a --- /dev/null +++ b/splink/jar_fallback.py @@ -0,0 +1,110 @@ +import math + + +def jc_sim_py(str1, str2): + """ + Jaccard`similarity calculated exactly as in stringutils.similarity jaccard in Apache Commons + """ + + if not str1 or not str2: + return 0.0 + + k = 2 # default k in stringutil is 2 so leaving it like that for compatibility + + # break strings into sets of rolling k-char syllables + a = set([str1[i : i + 1] for i in range(len(str1) - k + 1)]) + b = set([str2[i : i + 1] for i in range(len(str2) - k + 1)]) + + # calculate instersection of two sets + c = a.intersection(b) + + # return Jaccard similarity + return float(len(c)) / (len(a) + len(b) - len(c)) + + +def jw_sim_py( + first, + second, # modification from original to not use other imput parameters +): + """ + Jaro-Winkler similarity calculated exactly as in stringutils.similarity_jaro_winkler in Apache Commons + using a modified version of the algorithm implemented by 'Jean-Bernard Ratte - jean.bernard.ratte@unary.ca' + found at https://github.com/nap/jaro-winkler-distance + + used under the Apache License, Version 2.0 , with modifictions explicitly marked + """ + + if not first or not second: + return 0.0 # modification from original to give 0.0 in case of nulls instead of a an exception + + scaling = 0.1 # modification from original to have constant scaling in order to have a comparable result to the stringutils implementation + + def _get_diff_index(first, second): + if first == second: + return -1 + + if not first or not second: + return 0 + + max_len = min(len(first), len(second)) + for i in range(0, max_len): + if not first[i] == second[i]: + return i + + return max_len + + def _score(first, second): + shorter, longer = first.lower(), second.lower() + + if len(first) > len(second): + longer, shorter = shorter, longer + + m1 = _get_matching_characters(shorter, longer) + m2 = _get_matching_characters(longer, shorter) + + if len(m1) == 0 or len(m2) == 0: + return 0.0 + + return ( + float(len(m1)) / len(shorter) + + float(len(m2)) / len(longer) + + float(len(m1) - _transpositions(m1, m2)) / len(m1) + ) / 3.0 + + def _get_prefix(first, second): + if not first or not second: + return "" + + index = _get_diff_index(first, second) + if index == -1: + return first + + elif index == 0: + return "" + + else: + return first[0:index] + + def _transpositions(first, second): + return math.floor( + len([(f, s) for f, s in zip(first, second) if not f == s]) / 2.0 + ) + + def _get_matching_characters(first, second): + common = [] + limit = math.floor(min(len(first), len(second)) / 2) + + for i, l in enumerate(first): + left, right = int(max(0, i - limit)), int(min(i + limit + 1, len(second))) + if l in second[left:right]: + common.append(l) + second = ( + second[0 : second.index(l)] + "*" + second[second.index(l) + 1 :] + ) + + return "".join(common) + + jaro = _score(first, second) + cl = min(len(_get_prefix(first, second)), 4) + + return round((jaro + (scaling * cl * (1.0 - jaro))) * 100.0) / 100.0 diff --git a/jars/scala-udf-similarity-0.0.9.jar b/splink/jars/scala-udf-similarity-0.0.9.jar similarity index 100% rename from jars/scala-udf-similarity-0.0.9.jar rename to splink/jars/scala-udf-similarity-0.0.9.jar diff --git a/tests/conftest.py b/tests/conftest.py index 078f6b111d..7d59a75d08 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,8 +17,10 @@ def spark(): conf = SparkConf() conf.set("spark.jars.ivy", "/home/jovyan/.ivy2/") - conf.set("spark.driver.extraClassPath", "jars/scala-udf-similarity-0.0.9.jar") - conf.set("spark.jars", "jars/scala-udf-similarity-0.0.9.jar") + conf.set( + "spark.driver.extraClassPath", "splink/jars/scala-udf-similarity-0.0.9.jar" + ) + conf.set("spark.jars", "splink/jars/scala-udf-similarity-0.0.9.jar") conf.set("spark.driver.memory", "4g") conf.set("spark.sql.shuffle.partitions", "12") diff --git a/tests/test_jar_fallback.py b/tests/test_jar_fallback.py new file mode 100644 index 0000000000..048ef94ed6 --- /dev/null +++ b/tests/test_jar_fallback.py @@ -0,0 +1,26 @@ +import pytest +from splink.jar_fallback import jw_sim_py, jc_sim_py + + +def test_fallback_jw_nodata(): + assert jw_sim_py(None, None) == 0.0 + assert jw_sim_py("something", None) == 0.0 + assert jw_sim_py(None, "Something") == 0.0 + + +def test_fallback_jc_nodata(): + assert jc_sim_py(None, None) == 0.0 + assert jc_sim_py("something", None) == 0.0 + assert jc_sim_py(None, "Something") == 0.0 + + +def test_fallback_jw_wikipedia_examples(): + assert jw_sim_py("fly", "ant") == 0.0 + assert jw_sim_py("elephant", "hippo") == 0.44 + assert jw_sim_py("ABC Corporation", "ABC Corp") == 0.91 + assert jw_sim_py("PENNSYLVANIA", "PENNCISYLVNIA") == 0.9 + assert jw_sim_py("D N H Enterprises Inc", "D & H Enterprises, Inc.") == 0.93 + assert ( + jw_sim_py("My Gym Children's Fitness Center", "My Gym. Childrens Fitness") + == 0.94 + )