Skip to content

Commit

Permalink
Merge 6d2a6cd into e03bb4a
Browse files Browse the repository at this point in the history
  • Loading branch information
mamonu committed Nov 3, 2021
2 parents e03bb4a + 6d2a6cd commit 1495e6b
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 6 deletions.
41 changes: 37 additions & 4 deletions splink/case_statements.py
Expand Up @@ -17,14 +17,47 @@ def _check_jaro_registered(spark):
return True

warnings.warn(
"Custom string comparison functions such as jaro_winkler_sim are available in"
" Spark Or you did not pass 'spark' (the SparkSession) into 'Model' You can"
" import these functions using the scala-udf-similarity-0.0.7.jar provided with"
" Splink"
"\n\nCustom string comparison functions such as jaro_winkler_sim are not available in Spark\n"
"Or you did not pass 'spark' (the SparkSession) into 'Model' \n"
"You can import these functions using the scala-udf-similarity-0.0.9.jar provided with"
" Splink.\n" + _get_spark_jars_string()
)
return False


def _get_spark_jars_string():
"""
Outputs the exact string needed in the sparkSession config variable `spark.jars`
In order to use the custom functions in the spark-udf-similarity-0.0.9.jar
"""

import splink

path = splink.__file__[0:-11] + "jars/scala-udf-similarity-0.0.9.jar"

message = (
"You will need to add it by correctly configuring your spark config\n"
"For example in Spark 2.4.5\n"
"\n"
"from pyspark.sql import SparkSession, types\n"
"from pyspark.context import SparkConf, SparkContext\n"
f"conf.set('spark.driver.extraClassPath', '{path}'') # Not needed in spark 3\n"
f"conf.set('spark.jars', '{path}'')\n"
"spark.udf.registerJavaFunction('jaro_winkler_sim','uk.gov.moj.dash.linkage.JaroWinklerSimilarity',types.DoubleType())\n"
"sc = SparkContext.getOrCreate(conf=conf)\n"
"spark = SparkSession(sc)\n"
"\n"
"Alternatively, for Jaro Winkler, you can register a less efficient"
" Python implementation using\n"
"\n"
"from splink.jar_fallback import jc_sim_py\n"
"spark.udf.register('jaro_winkler_sim', jc_sim_py)\n"
)

return message


def _find_last_end_position(case_statement):
# Since we're only interested in the position, case shouldn't matter. stmt = case_statement.lower()
case_statement = case_statement.lower()
Expand Down
110 changes: 110 additions & 0 deletions splink/jar_fallback.py
@@ -0,0 +1,110 @@
import math


def jc_sim_py(str1, str2):
"""
Jaccard`similarity calculated exactly as in stringutils.similarity jaccard in Apache Commons
"""

if not str1 or not str2:
return 0.0

k = 2 # default k in stringutil is 2 so leaving it like that for compatibility

# break strings into sets of rolling k-char syllables
a = set([str1[i : i + 1] for i in range(len(str1) - k + 1)])
b = set([str2[i : i + 1] for i in range(len(str2) - k + 1)])

# calculate instersection of two sets
c = a.intersection(b)

# return Jaccard similarity
return float(len(c)) / (len(a) + len(b) - len(c))


def jw_sim_py(
first,
second, # modification from original to not use other imput parameters
):
"""
Jaro-Winkler similarity calculated exactly as in stringutils.similarity_jaro_winkler in Apache Commons
using a modified version of the algorithm implemented by 'Jean-Bernard Ratte - jean.bernard.ratte@unary.ca'
found at https://github.com/nap/jaro-winkler-distance
used under the Apache License, Version 2.0 , with modifictions explicitly marked
"""

if not first or not second:
return 0.0 # modification from original to give 0.0 in case of nulls instead of a an exception

scaling = 0.1 # modification from original to have constant scaling in order to have a comparable result to the stringutils implementation

def _get_diff_index(first, second):
if first == second:
return -1

if not first or not second:
return 0

max_len = min(len(first), len(second))
for i in range(0, max_len):
if not first[i] == second[i]:
return i

return max_len

def _score(first, second):
shorter, longer = first.lower(), second.lower()

if len(first) > len(second):
longer, shorter = shorter, longer

m1 = _get_matching_characters(shorter, longer)
m2 = _get_matching_characters(longer, shorter)

if len(m1) == 0 or len(m2) == 0:
return 0.0

return (
float(len(m1)) / len(shorter)
+ float(len(m2)) / len(longer)
+ float(len(m1) - _transpositions(m1, m2)) / len(m1)
) / 3.0

def _get_prefix(first, second):
if not first or not second:
return ""

index = _get_diff_index(first, second)
if index == -1:
return first

elif index == 0:
return ""

else:
return first[0:index]

def _transpositions(first, second):
return math.floor(
len([(f, s) for f, s in zip(first, second) if not f == s]) / 2.0
)

def _get_matching_characters(first, second):
common = []
limit = math.floor(min(len(first), len(second)) / 2)

for i, l in enumerate(first):
left, right = int(max(0, i - limit)), int(min(i + limit + 1, len(second)))
if l in second[left:right]:
common.append(l)
second = (
second[0 : second.index(l)] + "*" + second[second.index(l) + 1 :]
)

return "".join(common)

jaro = _score(first, second)
cl = min(len(_get_prefix(first, second)), 4)

return round((jaro + (scaling * cl * (1.0 - jaro))) * 100.0) / 100.0
File renamed without changes.
6 changes: 4 additions & 2 deletions tests/conftest.py
Expand Up @@ -17,8 +17,10 @@ def spark():
conf = SparkConf()

conf.set("spark.jars.ivy", "/home/jovyan/.ivy2/")
conf.set("spark.driver.extraClassPath", "jars/scala-udf-similarity-0.0.9.jar")
conf.set("spark.jars", "jars/scala-udf-similarity-0.0.9.jar")
conf.set(
"spark.driver.extraClassPath", "splink/jars/scala-udf-similarity-0.0.9.jar"
)
conf.set("spark.jars", "splink/jars/scala-udf-similarity-0.0.9.jar")
conf.set("spark.driver.memory", "4g")
conf.set("spark.sql.shuffle.partitions", "12")

Expand Down
26 changes: 26 additions & 0 deletions tests/test_jar_fallback.py
@@ -0,0 +1,26 @@
import pytest
from splink.jar_fallback import jw_sim_py, jc_sim_py


def test_fallback_jw_nodata():
assert jw_sim_py(None, None) == 0.0
assert jw_sim_py("something", None) == 0.0
assert jw_sim_py(None, "Something") == 0.0


def test_fallback_jc_nodata():
assert jc_sim_py(None, None) == 0.0
assert jc_sim_py("something", None) == 0.0
assert jc_sim_py(None, "Something") == 0.0


def test_fallback_jw_wikipedia_examples():
assert jw_sim_py("fly", "ant") == 0.0
assert jw_sim_py("elephant", "hippo") == 0.44
assert jw_sim_py("ABC Corporation", "ABC Corp") == 0.91
assert jw_sim_py("PENNSYLVANIA", "PENNCISYLVNIA") == 0.9
assert jw_sim_py("D N H Enterprises Inc", "D & H Enterprises, Inc.") == 0.93
assert (
jw_sim_py("My Gym Children's Fitness Center", "My Gym. Childrens Fitness")
== 0.94
)

0 comments on commit 1495e6b

Please sign in to comment.